Files
jc-video-recognize/apps/server/api/detection.py

456 lines
17 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import cv2
import numpy as np
import base64
import logging
import json
import os
import time
import uuid
from typing import Optional
from fastapi import APIRouter, UploadFile, File, Form, Query
from models.schemas import ImageDetectionResult, VideoDetectionResult
router = APIRouter()
logger = logging.getLogger(__name__)
@router.post("/detect/image", response_model=ImageDetectionResult)
async def detect_image(
file: UploadFile = File(...),
model_id: str = Query("fire_detection"),
confidence: float = Query(0.5),
iou: float = Query(0.45),
algorithm_config: Optional[str] = Query(None, description="算法配置JSON字符串"),
composite: bool = Query(False, description="是否启用复合检测(火灾检测时同时检测火焰和烟雾)")
):
"""
图片检测接口
Args:
algorithm_config: 算法配置JSON例如
{
"enable_stationary_detection": true,
"enable_loitering_detection": false,
"stationary_threshold": 10.0,
"position_tolerance": 50,
"loitering_threshold": 300.0,
"movement_threshold": 5.0
}
"""
from main import model_service
from services.detection_service import DetectionService
detection_service = DetectionService(model_service)
# 解析算法配置
algo_config = None
if algorithm_config:
try:
algo_config = json.loads(algorithm_config)
except json.JSONDecodeError as e:
logger.warning(f"算法配置解析失败: {e}")
try:
contents = await file.read()
nparr = np.frombuffer(contents, np.uint8)
frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if frame is None:
return ImageDetectionResult(
success=False,
message="无法读取图片",
data={}
)
# 判断是否启用复合火灾检测
if composite and model_id == 'fire_detection':
result = await detection_service.detect_fire_composite(
frame, confidence, iou
)
else:
result = await detection_service.detect_image(
frame, model_id, confidence, iou, algorithm_config=algo_config
)
if result['success']:
annotated_frame = detection_service.draw_detections(
frame, result['detections'], algorithm_config=algo_config
)
# 将标注后的图片转换为 base64
_, buffer = cv2.imencode('.jpg', annotated_frame)
img_base64 = base64.b64encode(buffer).decode('utf-8')
return ImageDetectionResult(
success=True,
message="检测完成",
data={
"detections": result['detections'],
"image_base64": img_base64,
"stats": result['stats'],
"alerts": result.get('alerts', []),
"behavior_stats": result.get('behavior_stats', {})
}
)
else:
return ImageDetectionResult(
success=False,
message=result['message'],
data={}
)
except Exception as e:
logger.error(f"图片检测失败: {e}")
return ImageDetectionResult(
success=False,
message=f"检测失败: {str(e)}",
data={}
)
@router.get("/algorithms/config")
async def get_algorithm_config():
"""获取算法配置选项"""
return {
"algorithms": [
{
"id": "stationary_detection",
"name": "静止检测",
"description": "检测人员在同一位置静止停留",
"params": [
{
"name": "stationary_threshold",
"label": "静止阈值",
"type": "number",
"default": 10.0,
"min": 1.0,
"max": 300.0,
"unit": "",
"description": "超过此时间视为静止"
},
{
"name": "position_tolerance",
"label": "位置容差",
"type": "number",
"default": 50,
"min": 10,
"max": 200,
"unit": "像素",
"description": "位置匹配容差范围"
}
]
},
{
"id": "loitering_detection",
"name": "徘徊检测",
"description": "检测人员长时间停留需要跟踪ID",
"params": [
{
"name": "loitering_threshold",
"label": "徘徊阈值",
"type": "number",
"default": 300.0,
"min": 60.0,
"max": 1800.0,
"unit": "",
"description": "超过此时间视为徘徊"
},
{
"name": "movement_threshold",
"label": "移动阈值",
"type": "number",
"default": 5.0,
"min": 1.0,
"max": 50.0,
"unit": "像素",
"description": "小于此移动视为静止"
}
]
}
]
}
@router.post("/detect/video", response_model=VideoDetectionResult)
async def detect_video(
file: UploadFile = File(...),
model_id: str = Query("fight_detection"),
confidence: float = Query(0.5),
iou: float = Query(0.45)
):
"""
视频检测接口 - 逐帧推理,返回标注后的视频文件及帧级检测结果
支持模型fight_detection 及其他 YOLO 类型模型
Returns:
VideoDetectionResult: 包含标注视频URL、帧级检测结果、统计信息
"""
from main import model_service
from services.detection_service import DetectionService
detection_service = DetectionService(model_service)
# 校验模型是否支持视频检测
model_config = model_service.model_configs.get(model_id)
if not model_config:
return VideoDetectionResult(
success=False,
message=f"未知模型: {model_id}",
data={}
)
# 仅允许 YOLO 类型模型进行视频检测
if model_config['type'] not in ('yolov8', 'yolov10'):
return VideoDetectionResult(
success=False,
message=f"模型 {model_id} 不支持视频检测仅支持YOLO类型模型",
data={}
)
# 校验文件格式
filename = file.filename or ""
ext = os.path.splitext(filename)[1].lower()
allowed_exts = ['.mp4', '.avi', '.mov', '.mkv', '.webm']
if ext not in allowed_exts:
return VideoDetectionResult(
success=False,
message=f"不支持的视频格式: {ext},支持: {', '.join(allowed_exts)}",
data={}
)
try:
# 将上传的视频保存到临时文件
temp_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "static", "temp")
os.makedirs(temp_dir, exist_ok=True)
temp_input = os.path.join(temp_dir, f"input_{uuid.uuid4().hex}{ext}")
temp_output = os.path.join(temp_dir, f"output_{uuid.uuid4().hex}.mp4")
with open(temp_input, "wb") as f:
content = await file.read()
f.write(content)
# 打开视频文件
cap = cv2.VideoCapture(temp_input)
if not cap.isOpened():
return VideoDetectionResult(
success=False,
message="无法打开视频文件",
data={}
)
# 获取视频属性
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# 尝试使用浏览器兼容的编码器avc1/H.264优先mp4v回退
def try_create_writer(path, fps, size):
for codec in ['avc1', 'mp4v']:
fourcc = cv2.VideoWriter_fourcc(*codec)
writer = cv2.VideoWriter(path, fourcc, fps, size)
if writer.isOpened():
logger.info(f"[detect_video] 使用视频编码器: {codec}, 输出: {path}")
return writer, codec
writer.release()
return None, None
out, used_codec = try_create_writer(temp_output, fps, (width, height))
if out is None:
cap.release()
return VideoDetectionResult(
success=False,
message="无法初始化视频写入器(无可用的编码器)",
data={}
)
frame_results = []
key_frames = [] # 关键帧截图(检测到目标的帧)
total_detections = 0
total_confidence = 0.0
frames_with_detections = 0
start_time = time.time()
frame_index = 0
while True:
ret, frame = cap.read()
if not ret:
break
# 逐帧推理
annotated_frame, result_data = await detection_service.detect_frame(
frame, model_id, confidence, iou, draw=True
)
# 写入标注后的帧
out.write(annotated_frame)
# 收集帧级检测结果 & 提取关键帧截图
if result_data['success'] and result_data['detections']:
frames_with_detections += 1
total_detections += len(result_data['detections'])
total_confidence += sum(d['confidence'] for d in result_data['detections'])
frame_results.append({
'frame_index': frame_index,
'timestamp': round(frame_index / fps, 2) if fps > 0 else 0,
'detections': result_data['detections'],
'detection_count': len(result_data['detections'])
})
# 提取关键帧截图(最多保留 20 张,防止响应过大)
if len(key_frames) < 20:
try:
# 缩放以提高传输效率(最大宽度 640
h, w = annotated_frame.shape[:2]
max_w = 640
if w > max_w:
scale = max_w / w
new_w = int(w * scale)
new_h = int(h * scale)
thumb = cv2.resize(annotated_frame, (new_w, new_h), interpolation=cv2.INTER_AREA)
else:
thumb = annotated_frame
_, buffer = cv2.imencode('.jpg', thumb, [cv2.IMWRITE_JPEG_QUALITY, 85])
img_base64 = base64.b64encode(buffer).decode('utf-8')
key_frames.append({
'frame_index': frame_index,
'timestamp': round(frame_index / fps, 2) if fps > 0 else 0,
'image': f'data:image/jpeg;base64,{img_base64}',
'detections': result_data['detections']
})
except Exception as e:
logger.warning(f"关键帧截图失败 frame={frame_index}: {e}")
frame_index += 1
cap.release()
out.release()
processing_time = time.time() - start_time
avg_confidence = total_confidence / total_detections if total_detections > 0 else 0
# 将输出视频移动到 results 目录
results_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "static", "results")
os.makedirs(results_dir, exist_ok=True)
# 先用 ffmpeg 转码为浏览器兼容的 H.264 MP4如果 ffmpeg 可用)
ffmpeg_success = False
result_filename = None
result_path = None
video_codec_warning = False
try:
import subprocess
import shutil
# 生成 ffmpeg 输出路径(始终用 .mp4因为 ffmpeg 输出 H.264
mp4_filename = f"fight_detect_{uuid.uuid4().hex}.mp4"
mp4_path = os.path.join(results_dir, mp4_filename)
ffmpeg_cmd = [
'ffmpeg', '-y', '-i', temp_output,
'-c:v', 'libx264', '-preset', 'fast',
'-pix_fmt', 'yuv420p',
'-movflags', '+faststart',
'-an', # 去除音频(避免音频编码问题)
mp4_path
]
logger.info(f"[detect_video] 开始 ffmpeg 转码: {' '.join(ffmpeg_cmd[:5])} ... -> {mp4_path}")
result = subprocess.run(ffmpeg_cmd, capture_output=True, timeout=300)
if result.returncode == 0:
# 检查 ffmpeg 输出文件是否有效
if os.path.exists(mp4_path) and os.path.getsize(mp4_path) > 1024:
ffmpeg_success = True
result_filename = mp4_filename
result_path = mp4_path
logger.info(f"[detect_video] ffmpeg 转码成功: {mp4_path} ({os.path.getsize(mp4_path)} bytes)")
# 清理 OpenCV 临时文件
try:
if os.path.exists(temp_output):
os.remove(temp_output)
except OSError:
pass
else:
logger.warning(f"[detect_video] ffmpeg 输出文件无效或为空")
# 回退到直接复制
raise RuntimeError("ffmpeg 输出文件无效")
else:
stderr = result.stderr.decode('utf-8', errors='ignore')[:800]
logger.warning(f"[detect_video] ffmpeg 转码失败 (rc={result.returncode}): {stderr}")
raise RuntimeError(f"ffmpeg 失败: {stderr}")
except FileNotFoundError as e:
logger.warning(f"[detect_video] ffmpeg 命令未找到 (PATH可能未生效): {e}")
# ffmpeg 未安装,回退到直接复制原始 OpenCV 输出
import shutil
fallback_ext = '.mp4' if used_codec == 'avc1' else '.avi'
result_filename = f"fight_detect_{uuid.uuid4().hex}{fallback_ext}"
result_path = os.path.join(results_dir, result_filename)
shutil.move(temp_output, result_path)
video_codec_warning = True
except Exception as e:
logger.warning(f"[detect_video] ffmpeg 异常,回退到原始文件: {e}")
import shutil
fallback_ext = '.mp4' if used_codec == 'avc1' else '.avi'
result_filename = f"fight_detect_{uuid.uuid4().hex}{fallback_ext}"
result_path = os.path.join(results_dir, result_filename)
shutil.move(temp_output, result_path)
video_codec_warning = True
# 检查输出文件是否有效
video_valid = result_path and os.path.exists(result_path) and os.path.getsize(result_path) > 1024
# 清理输入临时文件
for f in [temp_input, temp_output]:
try:
if os.path.exists(f):
os.remove(f)
except OSError:
pass
video_url = f"/static/results/{result_filename}"
return VideoDetectionResult(
success=True,
message="视频检测完成",
data={
"video_url": video_url,
"video_valid": video_valid,
"video_codec_warning": video_codec_warning,
"key_frames": key_frames,
"frame_results": frame_results,
"stats": {
"total_frames": total_frames,
"frames_with_detections": frames_with_detections,
"total_detections": total_detections,
"avg_confidence": round(avg_confidence, 3),
"processing_time": round(processing_time, 2),
"fps": round(total_frames / processing_time, 2) if processing_time > 0 else 0,
"model_used": model_id,
"video_info": {
"width": width,
"height": height,
"original_fps": round(fps, 2),
"duration": round(total_frames / fps, 2) if fps > 0 else 0,
"used_codec": used_codec,
"ffmpeg_transcoded": ffmpeg_success
}
}
}
)
except Exception as e:
logger.error(f"视频检测失败: {e}")
import traceback
logger.error(f"详细错误: {traceback.format_exc()}")
return VideoDetectionResult(
success=False,
message=f"视频检测失败: {str(e)}",
data={}
)