feat:新增 DetectionAdapter 统一 4 种检测服务输出
This commit is contained in:
212
apps/server/services/adapters/detection_adapter.py
Normal file
212
apps/server/services/adapters/detection_adapter.py
Normal file
@@ -0,0 +1,212 @@
|
||||
"""检测结果适配器 (MVP-1 / P0)
|
||||
|
||||
负责将各检测服务的原始 dict 输出,转换为统一的
|
||||
``DetectionResult`` 数据契约。
|
||||
|
||||
输入示例 (旧格式)::
|
||||
|
||||
{
|
||||
"success": True,
|
||||
"message": "检测完成",
|
||||
"detections": [
|
||||
{"class": "fire", "label": "火焰",
|
||||
"confidence": 0.87, "bbox": [10, 20, 100, 200],
|
||||
"track_id": 5 # 可选
|
||||
},
|
||||
...
|
||||
],
|
||||
"stats": {...}
|
||||
}
|
||||
|
||||
输出: ``DetectionResult``
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from models.event_schemas import (
|
||||
BBox,
|
||||
DetectionResult,
|
||||
DetectionSource,
|
||||
DetectionStats,
|
||||
UnifiedDetection,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DetectionAdapter:
|
||||
"""检测结果适配器。
|
||||
|
||||
覆盖以下 4 种来源:
|
||||
|
||||
1. YOLO 系列 (``DetectionService.detect_image`` / ``detect_frame``)
|
||||
2. PaddleDetection 抽烟 (``PaddleDetectionService``)
|
||||
3. 车辆检测 (``VehicleDetectionService``,含 track_id)
|
||||
4. ppTSM 行为识别 Docker (``ActionDetectionService``)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def from_yolo(
|
||||
raw: Dict[str, Any],
|
||||
model_id: Optional[str] = None,
|
||||
source: DetectionSource = DetectionSource.YOLO,
|
||||
) -> DetectionResult:
|
||||
"""适配 YOLO / 复合检测产出的 dict。"""
|
||||
|
||||
return DetectionAdapter._from_generic(raw, source=source, model_id=model_id)
|
||||
|
||||
@staticmethod
|
||||
def from_paddle(
|
||||
raw: Dict[str, Any], model_id: Optional[str] = "smoking_detection"
|
||||
) -> DetectionResult:
|
||||
"""适配 PaddleDetection (抽烟) 输出。"""
|
||||
|
||||
return DetectionAdapter._from_generic(
|
||||
raw, source=DetectionSource.PADDLE, model_id=model_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_vehicle(
|
||||
raw: Dict[str, Any], model_id: Optional[str] = "vehicle_detection"
|
||||
) -> DetectionResult:
|
||||
"""适配车辆检测 (含 track_id) 输出。"""
|
||||
|
||||
return DetectionAdapter._from_generic(
|
||||
raw, source=DetectionSource.PADDLE, model_id=model_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_action_docker(
|
||||
raw: Dict[str, Any], model_id: Optional[str] = "ppTSM_fight"
|
||||
) -> DetectionResult:
|
||||
"""适配 Docker ppTSM 行为识别输出。
|
||||
|
||||
其原始检测使用 ``class_id`` 字段而非 ``class``,此处统一归一化。
|
||||
"""
|
||||
|
||||
normalized = dict(raw)
|
||||
normalized_dets: List[Dict[str, Any]] = []
|
||||
for det in raw.get("detections", []) or []:
|
||||
d = dict(det)
|
||||
if "class" not in d:
|
||||
cls_id = d.get("class_id", 0)
|
||||
d["class"] = "fight" if int(cls_id) == 0 else "normal"
|
||||
if "label" not in d:
|
||||
d["label"] = "打架" if d["class"] == "fight" else "正常"
|
||||
normalized_dets.append(d)
|
||||
normalized["detections"] = normalized_dets
|
||||
|
||||
return DetectionAdapter._from_generic(
|
||||
normalized, source=DetectionSource.ACTION_DOCKER, model_id=model_id
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 内部
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _from_generic(
|
||||
raw: Dict[str, Any],
|
||||
source: DetectionSource,
|
||||
model_id: Optional[str],
|
||||
) -> DetectionResult:
|
||||
"""通用归一化逻辑。"""
|
||||
|
||||
if not isinstance(raw, dict):
|
||||
raise TypeError(f"raw 必须为 dict,收到 {type(raw)}")
|
||||
|
||||
success = bool(raw.get("success", True))
|
||||
message = str(raw.get("message", ""))
|
||||
raw_dets = raw.get("detections") or []
|
||||
|
||||
unified_dets: List[UnifiedDetection] = []
|
||||
for det in raw_dets:
|
||||
try:
|
||||
unified_dets.append(
|
||||
DetectionAdapter._build_detection(det, source, model_id)
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001 - 单条失败不阻塞整体
|
||||
logger.warning("适配检测项失败,已跳过: %s, raw=%s", exc, det)
|
||||
|
||||
stats = DetectionAdapter._build_stats(raw.get("stats"), unified_dets, model_id)
|
||||
|
||||
return DetectionResult(
|
||||
success=success,
|
||||
message=message,
|
||||
source=source,
|
||||
model_id=model_id,
|
||||
detections=unified_dets,
|
||||
stats=stats,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_detection(
|
||||
det: Dict[str, Any],
|
||||
source: DetectionSource,
|
||||
model_id: Optional[str],
|
||||
) -> UnifiedDetection:
|
||||
"""构建单个 UnifiedDetection。"""
|
||||
|
||||
bbox_raw = det.get("bbox") or det.get("box")
|
||||
if not bbox_raw or len(bbox_raw) < 4:
|
||||
raise ValueError(f"缺少有效 bbox: {det}")
|
||||
|
||||
x1, y1, x2, y2 = (int(v) for v in bbox_raw[:4])
|
||||
# 确保 x2 >= x1, y2 >= y1
|
||||
if x2 < x1:
|
||||
x1, x2 = x2, x1
|
||||
if y2 < y1:
|
||||
y1, y2 = y2, y1
|
||||
|
||||
class_name = det.get("class") or det.get("class_name") or "unknown"
|
||||
label = det.get("label", class_name)
|
||||
confidence = float(det.get("confidence", det.get("score", 0.0)))
|
||||
confidence = max(0.0, min(1.0, confidence))
|
||||
|
||||
track_id = det.get("track_id")
|
||||
if track_id is not None:
|
||||
track_id = int(track_id)
|
||||
|
||||
extra: Dict[str, Any] = {}
|
||||
for key in ("center", "plate_number", "is_illegal_parking", "trajectory"):
|
||||
if key in det:
|
||||
extra[key] = det[key]
|
||||
|
||||
return UnifiedDetection(
|
||||
track_id=track_id,
|
||||
class_name=str(class_name),
|
||||
label=str(label),
|
||||
confidence=confidence,
|
||||
bbox=BBox(x1=x1, y1=y1, x2=x2, y2=y2),
|
||||
source=source,
|
||||
model_id=model_id,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_stats(
|
||||
raw_stats: Optional[Dict[str, Any]],
|
||||
detections: List[UnifiedDetection],
|
||||
model_id: Optional[str],
|
||||
) -> DetectionStats:
|
||||
"""构建 DetectionStats,若原始 stats 缺字段则按 detections 重算。"""
|
||||
|
||||
if not detections:
|
||||
avg_conf = 0.0
|
||||
else:
|
||||
avg_conf = sum(d.confidence for d in detections) / len(detections)
|
||||
|
||||
raw_stats = raw_stats or {}
|
||||
return DetectionStats(
|
||||
total_detections=int(raw_stats.get("total_detections", len(detections))),
|
||||
avg_confidence=float(raw_stats.get("avg_confidence", round(avg_conf, 3))),
|
||||
processing_time=float(raw_stats.get("processing_time", 0.0)),
|
||||
model_used=raw_stats.get("model_used", model_id),
|
||||
fps=raw_stats.get("fps"),
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["DetectionAdapter"]
|
||||
Reference in New Issue
Block a user