449 lines
17 KiB
Python
449 lines
17 KiB
Python
import cv2
|
||
import numpy as np
|
||
import base64
|
||
import logging
|
||
import json
|
||
import os
|
||
import time
|
||
import uuid
|
||
from typing import Optional
|
||
from fastapi import APIRouter, UploadFile, File, Form, Query
|
||
from models.schemas import ImageDetectionResult, VideoDetectionResult
|
||
|
||
router = APIRouter()
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@router.post("/detect/image", response_model=ImageDetectionResult)
|
||
async def detect_image(
|
||
file: UploadFile = File(...),
|
||
model_id: str = Query("fire_detection"),
|
||
confidence: float = Query(0.5),
|
||
iou: float = Query(0.45),
|
||
algorithm_config: Optional[str] = Query(None, description="算法配置JSON字符串")
|
||
):
|
||
"""
|
||
图片检测接口
|
||
|
||
Args:
|
||
algorithm_config: 算法配置JSON,例如:
|
||
{
|
||
"enable_stationary_detection": true,
|
||
"enable_loitering_detection": false,
|
||
"stationary_threshold": 10.0,
|
||
"position_tolerance": 50,
|
||
"loitering_threshold": 300.0,
|
||
"movement_threshold": 5.0
|
||
}
|
||
"""
|
||
from main import model_service
|
||
from services.detection_service import DetectionService
|
||
|
||
detection_service = DetectionService(model_service)
|
||
|
||
# 解析算法配置
|
||
algo_config = None
|
||
if algorithm_config:
|
||
try:
|
||
algo_config = json.loads(algorithm_config)
|
||
except json.JSONDecodeError as e:
|
||
logger.warning(f"算法配置解析失败: {e}")
|
||
|
||
try:
|
||
contents = await file.read()
|
||
nparr = np.frombuffer(contents, np.uint8)
|
||
frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||
|
||
if frame is None:
|
||
return ImageDetectionResult(
|
||
success=False,
|
||
message="无法读取图片",
|
||
data={}
|
||
)
|
||
|
||
result = await detection_service.detect_image(
|
||
frame, model_id, confidence, iou, algorithm_config=algo_config
|
||
)
|
||
|
||
if result['success']:
|
||
annotated_frame = detection_service.draw_detections(
|
||
frame, result['detections'], algorithm_config=algo_config
|
||
)
|
||
|
||
# 将标注后的图片转换为 base64
|
||
_, buffer = cv2.imencode('.jpg', annotated_frame)
|
||
img_base64 = base64.b64encode(buffer).decode('utf-8')
|
||
|
||
return ImageDetectionResult(
|
||
success=True,
|
||
message="检测完成",
|
||
data={
|
||
"detections": result['detections'],
|
||
"image_base64": img_base64,
|
||
"stats": result['stats'],
|
||
"alerts": result.get('alerts', []),
|
||
"behavior_stats": result.get('behavior_stats', {})
|
||
}
|
||
)
|
||
else:
|
||
return ImageDetectionResult(
|
||
success=False,
|
||
message=result['message'],
|
||
data={}
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"图片检测失败: {e}")
|
||
return ImageDetectionResult(
|
||
success=False,
|
||
message=f"检测失败: {str(e)}",
|
||
data={}
|
||
)
|
||
|
||
|
||
@router.get("/algorithms/config")
|
||
async def get_algorithm_config():
|
||
"""获取算法配置选项"""
|
||
return {
|
||
"algorithms": [
|
||
{
|
||
"id": "stationary_detection",
|
||
"name": "静止检测",
|
||
"description": "检测人员在同一位置静止停留",
|
||
"params": [
|
||
{
|
||
"name": "stationary_threshold",
|
||
"label": "静止阈值",
|
||
"type": "number",
|
||
"default": 10.0,
|
||
"min": 1.0,
|
||
"max": 300.0,
|
||
"unit": "秒",
|
||
"description": "超过此时间视为静止"
|
||
},
|
||
{
|
||
"name": "position_tolerance",
|
||
"label": "位置容差",
|
||
"type": "number",
|
||
"default": 50,
|
||
"min": 10,
|
||
"max": 200,
|
||
"unit": "像素",
|
||
"description": "位置匹配容差范围"
|
||
}
|
||
]
|
||
},
|
||
{
|
||
"id": "loitering_detection",
|
||
"name": "徘徊检测",
|
||
"description": "检测人员长时间停留(需要跟踪ID)",
|
||
"params": [
|
||
{
|
||
"name": "loitering_threshold",
|
||
"label": "徘徊阈值",
|
||
"type": "number",
|
||
"default": 300.0,
|
||
"min": 60.0,
|
||
"max": 1800.0,
|
||
"unit": "秒",
|
||
"description": "超过此时间视为徘徊"
|
||
},
|
||
{
|
||
"name": "movement_threshold",
|
||
"label": "移动阈值",
|
||
"type": "number",
|
||
"default": 5.0,
|
||
"min": 1.0,
|
||
"max": 50.0,
|
||
"unit": "像素",
|
||
"description": "小于此移动视为静止"
|
||
}
|
||
]
|
||
}
|
||
]
|
||
}
|
||
|
||
|
||
@router.post("/detect/video", response_model=VideoDetectionResult)
|
||
async def detect_video(
|
||
file: UploadFile = File(...),
|
||
model_id: str = Query("fight_detection"),
|
||
confidence: float = Query(0.5),
|
||
iou: float = Query(0.45)
|
||
):
|
||
"""
|
||
视频检测接口 - 逐帧推理,返回标注后的视频文件及帧级检测结果
|
||
|
||
支持模型:fight_detection 及其他 YOLO 类型模型
|
||
|
||
Returns:
|
||
VideoDetectionResult: 包含标注视频URL、帧级检测结果、统计信息
|
||
"""
|
||
from main import model_service
|
||
from services.detection_service import DetectionService
|
||
|
||
detection_service = DetectionService(model_service)
|
||
|
||
# 校验模型是否支持视频检测
|
||
model_config = model_service.model_configs.get(model_id)
|
||
if not model_config:
|
||
return VideoDetectionResult(
|
||
success=False,
|
||
message=f"未知模型: {model_id}",
|
||
data={}
|
||
)
|
||
|
||
# 仅允许 YOLO 类型模型进行视频检测
|
||
if model_config['type'] not in ('yolov8', 'yolov10'):
|
||
return VideoDetectionResult(
|
||
success=False,
|
||
message=f"模型 {model_id} 不支持视频检测(仅支持YOLO类型模型)",
|
||
data={}
|
||
)
|
||
|
||
# 校验文件格式
|
||
filename = file.filename or ""
|
||
ext = os.path.splitext(filename)[1].lower()
|
||
allowed_exts = ['.mp4', '.avi', '.mov', '.mkv', '.webm']
|
||
if ext not in allowed_exts:
|
||
return VideoDetectionResult(
|
||
success=False,
|
||
message=f"不支持的视频格式: {ext},支持: {', '.join(allowed_exts)}",
|
||
data={}
|
||
)
|
||
|
||
try:
|
||
# 将上传的视频保存到临时文件
|
||
temp_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "static", "temp")
|
||
os.makedirs(temp_dir, exist_ok=True)
|
||
|
||
temp_input = os.path.join(temp_dir, f"input_{uuid.uuid4().hex}{ext}")
|
||
temp_output = os.path.join(temp_dir, f"output_{uuid.uuid4().hex}.mp4")
|
||
|
||
with open(temp_input, "wb") as f:
|
||
content = await file.read()
|
||
f.write(content)
|
||
|
||
# 打开视频文件
|
||
cap = cv2.VideoCapture(temp_input)
|
||
if not cap.isOpened():
|
||
return VideoDetectionResult(
|
||
success=False,
|
||
message="无法打开视频文件",
|
||
data={}
|
||
)
|
||
|
||
# 获取视频属性
|
||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||
|
||
# 尝试使用浏览器兼容的编码器(avc1/H.264优先,mp4v回退)
|
||
def try_create_writer(path, fps, size):
|
||
for codec in ['avc1', 'mp4v']:
|
||
fourcc = cv2.VideoWriter_fourcc(*codec)
|
||
writer = cv2.VideoWriter(path, fourcc, fps, size)
|
||
if writer.isOpened():
|
||
logger.info(f"[detect_video] 使用视频编码器: {codec}, 输出: {path}")
|
||
return writer, codec
|
||
writer.release()
|
||
return None, None
|
||
|
||
out, used_codec = try_create_writer(temp_output, fps, (width, height))
|
||
if out is None:
|
||
cap.release()
|
||
return VideoDetectionResult(
|
||
success=False,
|
||
message="无法初始化视频写入器(无可用的编码器)",
|
||
data={}
|
||
)
|
||
|
||
frame_results = []
|
||
key_frames = [] # 关键帧截图(检测到目标的帧)
|
||
total_detections = 0
|
||
total_confidence = 0.0
|
||
frames_with_detections = 0
|
||
start_time = time.time()
|
||
|
||
frame_index = 0
|
||
while True:
|
||
ret, frame = cap.read()
|
||
if not ret:
|
||
break
|
||
|
||
# 逐帧推理
|
||
annotated_frame, result_data = await detection_service.detect_frame(
|
||
frame, model_id, confidence, iou, draw=True
|
||
)
|
||
|
||
# 写入标注后的帧
|
||
out.write(annotated_frame)
|
||
|
||
# 收集帧级检测结果 & 提取关键帧截图
|
||
if result_data['success'] and result_data['detections']:
|
||
frames_with_detections += 1
|
||
total_detections += len(result_data['detections'])
|
||
total_confidence += sum(d['confidence'] for d in result_data['detections'])
|
||
|
||
frame_results.append({
|
||
'frame_index': frame_index,
|
||
'timestamp': round(frame_index / fps, 2) if fps > 0 else 0,
|
||
'detections': result_data['detections'],
|
||
'detection_count': len(result_data['detections'])
|
||
})
|
||
|
||
# 提取关键帧截图(最多保留 20 张,防止响应过大)
|
||
if len(key_frames) < 20:
|
||
try:
|
||
# 缩放以提高传输效率(最大宽度 640)
|
||
h, w = annotated_frame.shape[:2]
|
||
max_w = 640
|
||
if w > max_w:
|
||
scale = max_w / w
|
||
new_w = int(w * scale)
|
||
new_h = int(h * scale)
|
||
thumb = cv2.resize(annotated_frame, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
||
else:
|
||
thumb = annotated_frame
|
||
|
||
_, buffer = cv2.imencode('.jpg', thumb, [cv2.IMWRITE_JPEG_QUALITY, 85])
|
||
img_base64 = base64.b64encode(buffer).decode('utf-8')
|
||
key_frames.append({
|
||
'frame_index': frame_index,
|
||
'timestamp': round(frame_index / fps, 2) if fps > 0 else 0,
|
||
'image': f'data:image/jpeg;base64,{img_base64}',
|
||
'detections': result_data['detections']
|
||
})
|
||
except Exception as e:
|
||
logger.warning(f"关键帧截图失败 frame={frame_index}: {e}")
|
||
|
||
frame_index += 1
|
||
|
||
cap.release()
|
||
out.release()
|
||
|
||
processing_time = time.time() - start_time
|
||
avg_confidence = total_confidence / total_detections if total_detections > 0 else 0
|
||
|
||
# 将输出视频移动到 results 目录
|
||
results_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "static", "results")
|
||
os.makedirs(results_dir, exist_ok=True)
|
||
|
||
# 先用 ffmpeg 转码为浏览器兼容的 H.264 MP4(如果 ffmpeg 可用)
|
||
ffmpeg_success = False
|
||
result_filename = None
|
||
result_path = None
|
||
video_codec_warning = False
|
||
|
||
try:
|
||
import subprocess
|
||
import shutil
|
||
|
||
# 生成 ffmpeg 输出路径(始终用 .mp4,因为 ffmpeg 输出 H.264)
|
||
mp4_filename = f"fight_detect_{uuid.uuid4().hex}.mp4"
|
||
mp4_path = os.path.join(results_dir, mp4_filename)
|
||
|
||
ffmpeg_cmd = [
|
||
'ffmpeg', '-y', '-i', temp_output,
|
||
'-c:v', 'libx264', '-preset', 'fast',
|
||
'-pix_fmt', 'yuv420p',
|
||
'-movflags', '+faststart',
|
||
'-an', # 去除音频(避免音频编码问题)
|
||
mp4_path
|
||
]
|
||
logger.info(f"[detect_video] 开始 ffmpeg 转码: {' '.join(ffmpeg_cmd[:5])} ... -> {mp4_path}")
|
||
result = subprocess.run(ffmpeg_cmd, capture_output=True, timeout=300)
|
||
|
||
if result.returncode == 0:
|
||
# 检查 ffmpeg 输出文件是否有效
|
||
if os.path.exists(mp4_path) and os.path.getsize(mp4_path) > 1024:
|
||
ffmpeg_success = True
|
||
result_filename = mp4_filename
|
||
result_path = mp4_path
|
||
logger.info(f"[detect_video] ffmpeg 转码成功: {mp4_path} ({os.path.getsize(mp4_path)} bytes)")
|
||
# 清理 OpenCV 临时文件
|
||
try:
|
||
if os.path.exists(temp_output):
|
||
os.remove(temp_output)
|
||
except OSError:
|
||
pass
|
||
else:
|
||
logger.warning(f"[detect_video] ffmpeg 输出文件无效或为空")
|
||
# 回退到直接复制
|
||
raise RuntimeError("ffmpeg 输出文件无效")
|
||
else:
|
||
stderr = result.stderr.decode('utf-8', errors='ignore')[:800]
|
||
logger.warning(f"[detect_video] ffmpeg 转码失败 (rc={result.returncode}): {stderr}")
|
||
raise RuntimeError(f"ffmpeg 失败: {stderr}")
|
||
|
||
except FileNotFoundError as e:
|
||
logger.warning(f"[detect_video] ffmpeg 命令未找到 (PATH可能未生效): {e}")
|
||
# ffmpeg 未安装,回退到直接复制原始 OpenCV 输出
|
||
import shutil
|
||
fallback_ext = '.mp4' if used_codec == 'avc1' else '.avi'
|
||
result_filename = f"fight_detect_{uuid.uuid4().hex}{fallback_ext}"
|
||
result_path = os.path.join(results_dir, result_filename)
|
||
shutil.move(temp_output, result_path)
|
||
video_codec_warning = True
|
||
except Exception as e:
|
||
logger.warning(f"[detect_video] ffmpeg 异常,回退到原始文件: {e}")
|
||
import shutil
|
||
fallback_ext = '.mp4' if used_codec == 'avc1' else '.avi'
|
||
result_filename = f"fight_detect_{uuid.uuid4().hex}{fallback_ext}"
|
||
result_path = os.path.join(results_dir, result_filename)
|
||
shutil.move(temp_output, result_path)
|
||
video_codec_warning = True
|
||
|
||
# 检查输出文件是否有效
|
||
video_valid = result_path and os.path.exists(result_path) and os.path.getsize(result_path) > 1024
|
||
|
||
# 清理输入临时文件
|
||
for f in [temp_input, temp_output]:
|
||
try:
|
||
if os.path.exists(f):
|
||
os.remove(f)
|
||
except OSError:
|
||
pass
|
||
|
||
video_url = f"/static/results/{result_filename}"
|
||
|
||
return VideoDetectionResult(
|
||
success=True,
|
||
message="视频检测完成",
|
||
data={
|
||
"video_url": video_url,
|
||
"video_valid": video_valid,
|
||
"video_codec_warning": video_codec_warning,
|
||
"key_frames": key_frames,
|
||
"frame_results": frame_results,
|
||
"stats": {
|
||
"total_frames": total_frames,
|
||
"frames_with_detections": frames_with_detections,
|
||
"total_detections": total_detections,
|
||
"avg_confidence": round(avg_confidence, 3),
|
||
"processing_time": round(processing_time, 2),
|
||
"fps": round(total_frames / processing_time, 2) if processing_time > 0 else 0,
|
||
"model_used": model_id,
|
||
"video_info": {
|
||
"width": width,
|
||
"height": height,
|
||
"original_fps": round(fps, 2),
|
||
"duration": round(total_frames / fps, 2) if fps > 0 else 0,
|
||
"used_codec": used_codec,
|
||
"ffmpeg_transcoded": ffmpeg_success
|
||
}
|
||
}
|
||
}
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"视频检测失败: {e}")
|
||
import traceback
|
||
logger.error(f"详细错误: {traceback.format_exc()}")
|
||
return VideoDetectionResult(
|
||
success=False,
|
||
message=f"视频检测失败: {str(e)}",
|
||
data={}
|
||
)
|