feat:新增基于YOLOv8的打架斗殴模型,支持上传视频识别以及摄像头识别
This commit is contained in:
@@ -3,9 +3,12 @@ import numpy as np
|
|||||||
import base64
|
import base64
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from fastapi import APIRouter, UploadFile, File, Form, Query
|
from fastapi import APIRouter, UploadFile, File, Form, Query
|
||||||
from models.schemas import ImageDetectionResult
|
from models.schemas import ImageDetectionResult, VideoDetectionResult
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -159,3 +162,287 @@ async def get_algorithm_config():
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/detect/video", response_model=VideoDetectionResult)
|
||||||
|
async def detect_video(
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
model_id: str = Query("fight_detection"),
|
||||||
|
confidence: float = Query(0.5),
|
||||||
|
iou: float = Query(0.45)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
视频检测接口 - 逐帧推理,返回标注后的视频文件及帧级检测结果
|
||||||
|
|
||||||
|
支持模型:fight_detection 及其他 YOLO 类型模型
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
VideoDetectionResult: 包含标注视频URL、帧级检测结果、统计信息
|
||||||
|
"""
|
||||||
|
from main import model_service
|
||||||
|
from services.detection_service import DetectionService
|
||||||
|
|
||||||
|
detection_service = DetectionService(model_service)
|
||||||
|
|
||||||
|
# 校验模型是否支持视频检测
|
||||||
|
model_config = model_service.model_configs.get(model_id)
|
||||||
|
if not model_config:
|
||||||
|
return VideoDetectionResult(
|
||||||
|
success=False,
|
||||||
|
message=f"未知模型: {model_id}",
|
||||||
|
data={}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 仅允许 YOLO 类型模型进行视频检测
|
||||||
|
if model_config['type'] not in ('yolov8', 'yolov10'):
|
||||||
|
return VideoDetectionResult(
|
||||||
|
success=False,
|
||||||
|
message=f"模型 {model_id} 不支持视频检测(仅支持YOLO类型模型)",
|
||||||
|
data={}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 校验文件格式
|
||||||
|
filename = file.filename or ""
|
||||||
|
ext = os.path.splitext(filename)[1].lower()
|
||||||
|
allowed_exts = ['.mp4', '.avi', '.mov', '.mkv', '.webm']
|
||||||
|
if ext not in allowed_exts:
|
||||||
|
return VideoDetectionResult(
|
||||||
|
success=False,
|
||||||
|
message=f"不支持的视频格式: {ext},支持: {', '.join(allowed_exts)}",
|
||||||
|
data={}
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 将上传的视频保存到临时文件
|
||||||
|
temp_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "static", "temp")
|
||||||
|
os.makedirs(temp_dir, exist_ok=True)
|
||||||
|
|
||||||
|
temp_input = os.path.join(temp_dir, f"input_{uuid.uuid4().hex}{ext}")
|
||||||
|
temp_output = os.path.join(temp_dir, f"output_{uuid.uuid4().hex}.mp4")
|
||||||
|
|
||||||
|
with open(temp_input, "wb") as f:
|
||||||
|
content = await file.read()
|
||||||
|
f.write(content)
|
||||||
|
|
||||||
|
# 打开视频文件
|
||||||
|
cap = cv2.VideoCapture(temp_input)
|
||||||
|
if not cap.isOpened():
|
||||||
|
return VideoDetectionResult(
|
||||||
|
success=False,
|
||||||
|
message="无法打开视频文件",
|
||||||
|
data={}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取视频属性
|
||||||
|
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||||
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||||
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||||
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
|
|
||||||
|
# 尝试使用浏览器兼容的编码器(avc1/H.264优先,mp4v回退)
|
||||||
|
def try_create_writer(path, fps, size):
|
||||||
|
for codec in ['avc1', 'mp4v']:
|
||||||
|
fourcc = cv2.VideoWriter_fourcc(*codec)
|
||||||
|
writer = cv2.VideoWriter(path, fourcc, fps, size)
|
||||||
|
if writer.isOpened():
|
||||||
|
logger.info(f"[detect_video] 使用视频编码器: {codec}, 输出: {path}")
|
||||||
|
return writer, codec
|
||||||
|
writer.release()
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
out, used_codec = try_create_writer(temp_output, fps, (width, height))
|
||||||
|
if out is None:
|
||||||
|
cap.release()
|
||||||
|
return VideoDetectionResult(
|
||||||
|
success=False,
|
||||||
|
message="无法初始化视频写入器(无可用的编码器)",
|
||||||
|
data={}
|
||||||
|
)
|
||||||
|
|
||||||
|
frame_results = []
|
||||||
|
key_frames = [] # 关键帧截图(检测到目标的帧)
|
||||||
|
total_detections = 0
|
||||||
|
total_confidence = 0.0
|
||||||
|
frames_with_detections = 0
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
frame_index = 0
|
||||||
|
while True:
|
||||||
|
ret, frame = cap.read()
|
||||||
|
if not ret:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 逐帧推理
|
||||||
|
annotated_frame, result_data = await detection_service.detect_frame(
|
||||||
|
frame, model_id, confidence, iou, draw=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 写入标注后的帧
|
||||||
|
out.write(annotated_frame)
|
||||||
|
|
||||||
|
# 收集帧级检测结果 & 提取关键帧截图
|
||||||
|
if result_data['success'] and result_data['detections']:
|
||||||
|
frames_with_detections += 1
|
||||||
|
total_detections += len(result_data['detections'])
|
||||||
|
total_confidence += sum(d['confidence'] for d in result_data['detections'])
|
||||||
|
|
||||||
|
frame_results.append({
|
||||||
|
'frame_index': frame_index,
|
||||||
|
'timestamp': round(frame_index / fps, 2) if fps > 0 else 0,
|
||||||
|
'detections': result_data['detections'],
|
||||||
|
'detection_count': len(result_data['detections'])
|
||||||
|
})
|
||||||
|
|
||||||
|
# 提取关键帧截图(最多保留 20 张,防止响应过大)
|
||||||
|
if len(key_frames) < 20:
|
||||||
|
try:
|
||||||
|
# 缩放以提高传输效率(最大宽度 640)
|
||||||
|
h, w = annotated_frame.shape[:2]
|
||||||
|
max_w = 640
|
||||||
|
if w > max_w:
|
||||||
|
scale = max_w / w
|
||||||
|
new_w = int(w * scale)
|
||||||
|
new_h = int(h * scale)
|
||||||
|
thumb = cv2.resize(annotated_frame, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
||||||
|
else:
|
||||||
|
thumb = annotated_frame
|
||||||
|
|
||||||
|
_, buffer = cv2.imencode('.jpg', thumb, [cv2.IMWRITE_JPEG_QUALITY, 85])
|
||||||
|
img_base64 = base64.b64encode(buffer).decode('utf-8')
|
||||||
|
key_frames.append({
|
||||||
|
'frame_index': frame_index,
|
||||||
|
'timestamp': round(frame_index / fps, 2) if fps > 0 else 0,
|
||||||
|
'image': f'data:image/jpeg;base64,{img_base64}',
|
||||||
|
'detections': result_data['detections']
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"关键帧截图失败 frame={frame_index}: {e}")
|
||||||
|
|
||||||
|
frame_index += 1
|
||||||
|
|
||||||
|
cap.release()
|
||||||
|
out.release()
|
||||||
|
|
||||||
|
processing_time = time.time() - start_time
|
||||||
|
avg_confidence = total_confidence / total_detections if total_detections > 0 else 0
|
||||||
|
|
||||||
|
# 将输出视频移动到 results 目录
|
||||||
|
results_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "static", "results")
|
||||||
|
os.makedirs(results_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 先用 ffmpeg 转码为浏览器兼容的 H.264 MP4(如果 ffmpeg 可用)
|
||||||
|
ffmpeg_success = False
|
||||||
|
result_filename = None
|
||||||
|
result_path = None
|
||||||
|
video_codec_warning = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import subprocess
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
# 生成 ffmpeg 输出路径(始终用 .mp4,因为 ffmpeg 输出 H.264)
|
||||||
|
mp4_filename = f"fight_detect_{uuid.uuid4().hex}.mp4"
|
||||||
|
mp4_path = os.path.join(results_dir, mp4_filename)
|
||||||
|
|
||||||
|
ffmpeg_cmd = [
|
||||||
|
'ffmpeg', '-y', '-i', temp_output,
|
||||||
|
'-c:v', 'libx264', '-preset', 'fast',
|
||||||
|
'-pix_fmt', 'yuv420p',
|
||||||
|
'-movflags', '+faststart',
|
||||||
|
'-an', # 去除音频(避免音频编码问题)
|
||||||
|
mp4_path
|
||||||
|
]
|
||||||
|
logger.info(f"[detect_video] 开始 ffmpeg 转码: {' '.join(ffmpeg_cmd[:5])} ... -> {mp4_path}")
|
||||||
|
result = subprocess.run(ffmpeg_cmd, capture_output=True, timeout=300)
|
||||||
|
|
||||||
|
if result.returncode == 0:
|
||||||
|
# 检查 ffmpeg 输出文件是否有效
|
||||||
|
if os.path.exists(mp4_path) and os.path.getsize(mp4_path) > 1024:
|
||||||
|
ffmpeg_success = True
|
||||||
|
result_filename = mp4_filename
|
||||||
|
result_path = mp4_path
|
||||||
|
logger.info(f"[detect_video] ffmpeg 转码成功: {mp4_path} ({os.path.getsize(mp4_path)} bytes)")
|
||||||
|
# 清理 OpenCV 临时文件
|
||||||
|
try:
|
||||||
|
if os.path.exists(temp_output):
|
||||||
|
os.remove(temp_output)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
logger.warning(f"[detect_video] ffmpeg 输出文件无效或为空")
|
||||||
|
# 回退到直接复制
|
||||||
|
raise RuntimeError("ffmpeg 输出文件无效")
|
||||||
|
else:
|
||||||
|
stderr = result.stderr.decode('utf-8', errors='ignore')[:800]
|
||||||
|
logger.warning(f"[detect_video] ffmpeg 转码失败 (rc={result.returncode}): {stderr}")
|
||||||
|
raise RuntimeError(f"ffmpeg 失败: {stderr}")
|
||||||
|
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
logger.warning(f"[detect_video] ffmpeg 命令未找到 (PATH可能未生效): {e}")
|
||||||
|
# ffmpeg 未安装,回退到直接复制原始 OpenCV 输出
|
||||||
|
import shutil
|
||||||
|
fallback_ext = '.mp4' if used_codec == 'avc1' else '.avi'
|
||||||
|
result_filename = f"fight_detect_{uuid.uuid4().hex}{fallback_ext}"
|
||||||
|
result_path = os.path.join(results_dir, result_filename)
|
||||||
|
shutil.move(temp_output, result_path)
|
||||||
|
video_codec_warning = True
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[detect_video] ffmpeg 异常,回退到原始文件: {e}")
|
||||||
|
import shutil
|
||||||
|
fallback_ext = '.mp4' if used_codec == 'avc1' else '.avi'
|
||||||
|
result_filename = f"fight_detect_{uuid.uuid4().hex}{fallback_ext}"
|
||||||
|
result_path = os.path.join(results_dir, result_filename)
|
||||||
|
shutil.move(temp_output, result_path)
|
||||||
|
video_codec_warning = True
|
||||||
|
|
||||||
|
# 检查输出文件是否有效
|
||||||
|
video_valid = result_path and os.path.exists(result_path) and os.path.getsize(result_path) > 1024
|
||||||
|
|
||||||
|
# 清理输入临时文件
|
||||||
|
for f in [temp_input, temp_output]:
|
||||||
|
try:
|
||||||
|
if os.path.exists(f):
|
||||||
|
os.remove(f)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
video_url = f"/static/results/{result_filename}"
|
||||||
|
|
||||||
|
return VideoDetectionResult(
|
||||||
|
success=True,
|
||||||
|
message="视频检测完成",
|
||||||
|
data={
|
||||||
|
"video_url": video_url,
|
||||||
|
"video_valid": video_valid,
|
||||||
|
"video_codec_warning": video_codec_warning,
|
||||||
|
"key_frames": key_frames,
|
||||||
|
"frame_results": frame_results,
|
||||||
|
"stats": {
|
||||||
|
"total_frames": total_frames,
|
||||||
|
"frames_with_detections": frames_with_detections,
|
||||||
|
"total_detections": total_detections,
|
||||||
|
"avg_confidence": round(avg_confidence, 3),
|
||||||
|
"processing_time": round(processing_time, 2),
|
||||||
|
"fps": round(total_frames / processing_time, 2) if processing_time > 0 else 0,
|
||||||
|
"model_used": model_id,
|
||||||
|
"video_info": {
|
||||||
|
"width": width,
|
||||||
|
"height": height,
|
||||||
|
"original_fps": round(fps, 2),
|
||||||
|
"duration": round(total_frames / fps, 2) if fps > 0 else 0,
|
||||||
|
"used_codec": used_codec,
|
||||||
|
"ffmpeg_transcoded": ffmpeg_success
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"视频检测失败: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(f"详细错误: {traceback.format_exc()}")
|
||||||
|
return VideoDetectionResult(
|
||||||
|
success=False,
|
||||||
|
message=f"视频检测失败: {str(e)}",
|
||||||
|
data={}
|
||||||
|
)
|
||||||
|
|||||||
@@ -37,3 +37,8 @@ class DetectionConfig(BaseModel):
|
|||||||
model_id: str
|
model_id: str
|
||||||
confidence: float = Field(default=0.5, ge=0.1, le=1.0)
|
confidence: float = Field(default=0.5, ge=0.1, le=1.0)
|
||||||
iou: float = Field(default=0.45, ge=0.1, le=0.9)
|
iou: float = Field(default=0.45, ge=0.1, le=0.9)
|
||||||
|
|
||||||
|
class VideoDetectionResult(BaseModel):
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
data: Dict[str, Any]
|
||||||
|
|||||||
@@ -374,7 +374,12 @@ class DetectionService:
|
|||||||
'person': (0, 255, 0),
|
'person': (0, 255, 0),
|
||||||
'helmet': (255, 255, 0),
|
'helmet': (255, 255, 0),
|
||||||
'no_helmet': (255, 0, 255),
|
'no_helmet': (255, 0, 255),
|
||||||
'cigarette': (0, 165, 255)
|
'cigarette': (0, 165, 255),
|
||||||
|
# 兼容旧模型类别
|
||||||
|
'violence': (0, 0, 255),
|
||||||
|
'fight': (0, 0, 255),
|
||||||
|
'normal': (0, 200, 0),
|
||||||
|
'non_violence': (0, 200, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
for det in detections:
|
for det in detections:
|
||||||
|
|||||||
@@ -110,6 +110,16 @@ class ModelService:
|
|||||||
'description': '基于PaddlePaddle PP-YOLOE-l的违停检测模型,支持车牌识别',
|
'description': '基于PaddlePaddle PP-YOLOE-l的违停检测模型,支持车牌识别',
|
||||||
'name': '违停检测 (Paddle)'
|
'name': '违停检测 (Paddle)'
|
||||||
},
|
},
|
||||||
|
'fight_detection': {
|
||||||
|
'path': os.path.join(base_dir, 'models', 'fight_detection', 'yolov8n.pt'),
|
||||||
|
'type': 'yolov8',
|
||||||
|
'classes': ['violence', 'non_violence'],
|
||||||
|
'labels': {'violence': '暴力行为', 'non_violence': '正常'},
|
||||||
|
'size': '22MB',
|
||||||
|
'description': '基于YOLOv8的打架斗殴检测模型',
|
||||||
|
'name': '打架斗殴检测(YOLO)',
|
||||||
|
'supports_video': True
|
||||||
|
},
|
||||||
'action_detection': {
|
'action_detection': {
|
||||||
'path': 'docker_api',
|
'path': 'docker_api',
|
||||||
'type': 'docker_api',
|
'type': 'docker_api',
|
||||||
@@ -142,7 +152,7 @@ class ModelService:
|
|||||||
model_exists = os.path.exists(model_path)
|
model_exists = os.path.exists(model_path)
|
||||||
|
|
||||||
if model_exists:
|
if model_exists:
|
||||||
available_models.append({
|
model_info = {
|
||||||
'id': model_id,
|
'id': model_id,
|
||||||
'name': config['name'],
|
'name': config['name'],
|
||||||
'description': config['description'],
|
'description': config['description'],
|
||||||
@@ -150,7 +160,11 @@ class ModelService:
|
|||||||
'labels': config['labels'],
|
'labels': config['labels'],
|
||||||
'size': config['size'],
|
'size': config['size'],
|
||||||
'type': config['type']
|
'type': config['type']
|
||||||
})
|
}
|
||||||
|
# 支持 video 的模型增加标记
|
||||||
|
if config.get('supports_video'):
|
||||||
|
model_info['supports_video'] = True
|
||||||
|
available_models.append(model_info)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"模型文件不存在: {model_path}")
|
logger.warning(f"模型文件不存在: {model_path}")
|
||||||
return available_models
|
return available_models
|
||||||
|
|||||||
Reference in New Issue
Block a user