diff --git a/apps/server/services/adapters/__init__.py b/apps/server/services/adapters/__init__.py new file mode 100644 index 0000000..1a7aba6 --- /dev/null +++ b/apps/server/services/adapters/__init__.py @@ -0,0 +1,9 @@ +"""检测结果适配器子包。 + +将各检测服务 (YOLO / Paddle / Action Docker 等) 的原始输出 +适配为 ``models.event_schemas.DetectionResult``。 +""" + +from .detection_adapter import DetectionAdapter + +__all__ = ["DetectionAdapter"] diff --git a/apps/server/services/adapters/detection_adapter.py b/apps/server/services/adapters/detection_adapter.py new file mode 100644 index 0000000..4bc1229 --- /dev/null +++ b/apps/server/services/adapters/detection_adapter.py @@ -0,0 +1,212 @@ +"""检测结果适配器 (MVP-1 / P0) + +负责将各检测服务的原始 dict 输出,转换为统一的 +``DetectionResult`` 数据契约。 + +输入示例 (旧格式):: + + { + "success": True, + "message": "检测完成", + "detections": [ + {"class": "fire", "label": "火焰", + "confidence": 0.87, "bbox": [10, 20, 100, 200], + "track_id": 5 # 可选 + }, + ... + ], + "stats": {...} + } + +输出: ``DetectionResult`` +""" + +from __future__ import annotations + +import logging +from typing import Any, Dict, List, Optional + +from models.event_schemas import ( + BBox, + DetectionResult, + DetectionSource, + DetectionStats, + UnifiedDetection, +) + +logger = logging.getLogger(__name__) + + +class DetectionAdapter: + """检测结果适配器。 + + 覆盖以下 4 种来源: + + 1. YOLO 系列 (``DetectionService.detect_image`` / ``detect_frame``) + 2. PaddleDetection 抽烟 (``PaddleDetectionService``) + 3. 车辆检测 (``VehicleDetectionService``,含 track_id) + 4. ppTSM 行为识别 Docker (``ActionDetectionService``) + """ + + @staticmethod + def from_yolo( + raw: Dict[str, Any], + model_id: Optional[str] = None, + source: DetectionSource = DetectionSource.YOLO, + ) -> DetectionResult: + """适配 YOLO / 复合检测产出的 dict。""" + + return DetectionAdapter._from_generic(raw, source=source, model_id=model_id) + + @staticmethod + def from_paddle( + raw: Dict[str, Any], model_id: Optional[str] = "smoking_detection" + ) -> DetectionResult: + """适配 PaddleDetection (抽烟) 输出。""" + + return DetectionAdapter._from_generic( + raw, source=DetectionSource.PADDLE, model_id=model_id + ) + + @staticmethod + def from_vehicle( + raw: Dict[str, Any], model_id: Optional[str] = "vehicle_detection" + ) -> DetectionResult: + """适配车辆检测 (含 track_id) 输出。""" + + return DetectionAdapter._from_generic( + raw, source=DetectionSource.PADDLE, model_id=model_id + ) + + @staticmethod + def from_action_docker( + raw: Dict[str, Any], model_id: Optional[str] = "ppTSM_fight" + ) -> DetectionResult: + """适配 Docker ppTSM 行为识别输出。 + + 其原始检测使用 ``class_id`` 字段而非 ``class``,此处统一归一化。 + """ + + normalized = dict(raw) + normalized_dets: List[Dict[str, Any]] = [] + for det in raw.get("detections", []) or []: + d = dict(det) + if "class" not in d: + cls_id = d.get("class_id", 0) + d["class"] = "fight" if int(cls_id) == 0 else "normal" + if "label" not in d: + d["label"] = "打架" if d["class"] == "fight" else "正常" + normalized_dets.append(d) + normalized["detections"] = normalized_dets + + return DetectionAdapter._from_generic( + normalized, source=DetectionSource.ACTION_DOCKER, model_id=model_id + ) + + # ------------------------------------------------------------------ + # 内部 + # ------------------------------------------------------------------ + + @staticmethod + def _from_generic( + raw: Dict[str, Any], + source: DetectionSource, + model_id: Optional[str], + ) -> DetectionResult: + """通用归一化逻辑。""" + + if not isinstance(raw, dict): + raise TypeError(f"raw 必须为 dict,收到 {type(raw)}") + + success = bool(raw.get("success", True)) + message = str(raw.get("message", "")) + raw_dets = raw.get("detections") or [] + + unified_dets: List[UnifiedDetection] = [] + for det in raw_dets: + try: + unified_dets.append( + DetectionAdapter._build_detection(det, source, model_id) + ) + except Exception as exc: # noqa: BLE001 - 单条失败不阻塞整体 + logger.warning("适配检测项失败,已跳过: %s, raw=%s", exc, det) + + stats = DetectionAdapter._build_stats(raw.get("stats"), unified_dets, model_id) + + return DetectionResult( + success=success, + message=message, + source=source, + model_id=model_id, + detections=unified_dets, + stats=stats, + ) + + @staticmethod + def _build_detection( + det: Dict[str, Any], + source: DetectionSource, + model_id: Optional[str], + ) -> UnifiedDetection: + """构建单个 UnifiedDetection。""" + + bbox_raw = det.get("bbox") or det.get("box") + if not bbox_raw or len(bbox_raw) < 4: + raise ValueError(f"缺少有效 bbox: {det}") + + x1, y1, x2, y2 = (int(v) for v in bbox_raw[:4]) + # 确保 x2 >= x1, y2 >= y1 + if x2 < x1: + x1, x2 = x2, x1 + if y2 < y1: + y1, y2 = y2, y1 + + class_name = det.get("class") or det.get("class_name") or "unknown" + label = det.get("label", class_name) + confidence = float(det.get("confidence", det.get("score", 0.0))) + confidence = max(0.0, min(1.0, confidence)) + + track_id = det.get("track_id") + if track_id is not None: + track_id = int(track_id) + + extra: Dict[str, Any] = {} + for key in ("center", "plate_number", "is_illegal_parking", "trajectory"): + if key in det: + extra[key] = det[key] + + return UnifiedDetection( + track_id=track_id, + class_name=str(class_name), + label=str(label), + confidence=confidence, + bbox=BBox(x1=x1, y1=y1, x2=x2, y2=y2), + source=source, + model_id=model_id, + extra=extra, + ) + + @staticmethod + def _build_stats( + raw_stats: Optional[Dict[str, Any]], + detections: List[UnifiedDetection], + model_id: Optional[str], + ) -> DetectionStats: + """构建 DetectionStats,若原始 stats 缺字段则按 detections 重算。""" + + if not detections: + avg_conf = 0.0 + else: + avg_conf = sum(d.confidence for d in detections) / len(detections) + + raw_stats = raw_stats or {} + return DetectionStats( + total_detections=int(raw_stats.get("total_detections", len(detections))), + avg_confidence=float(raw_stats.get("avg_confidence", round(avg_conf, 3))), + processing_time=float(raw_stats.get("processing_time", 0.0)), + model_used=raw_stats.get("model_used", model_id), + fps=raw_stats.get("fps"), + ) + + +__all__ = ["DetectionAdapter"]