213 lines
6.5 KiB
Python
213 lines
6.5 KiB
Python
"""检测结果适配器 (MVP-1 / P0)
|
||
|
||
负责将各检测服务的原始 dict 输出,转换为统一的
|
||
``DetectionResult`` 数据契约。
|
||
|
||
输入示例 (旧格式)::
|
||
|
||
{
|
||
"success": True,
|
||
"message": "检测完成",
|
||
"detections": [
|
||
{"class": "fire", "label": "火焰",
|
||
"confidence": 0.87, "bbox": [10, 20, 100, 200],
|
||
"track_id": 5 # 可选
|
||
},
|
||
...
|
||
],
|
||
"stats": {...}
|
||
}
|
||
|
||
输出: ``DetectionResult``
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
from models.event_schemas import (
|
||
BBox,
|
||
DetectionResult,
|
||
DetectionSource,
|
||
DetectionStats,
|
||
UnifiedDetection,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class DetectionAdapter:
|
||
"""检测结果适配器。
|
||
|
||
覆盖以下 4 种来源:
|
||
|
||
1. YOLO 系列 (``DetectionService.detect_image`` / ``detect_frame``)
|
||
2. PaddleDetection 抽烟 (``PaddleDetectionService``)
|
||
3. 车辆检测 (``VehicleDetectionService``,含 track_id)
|
||
4. ppTSM 行为识别 Docker (``ActionDetectionService``)
|
||
"""
|
||
|
||
@staticmethod
|
||
def from_yolo(
|
||
raw: Dict[str, Any],
|
||
model_id: Optional[str] = None,
|
||
source: DetectionSource = DetectionSource.YOLO,
|
||
) -> DetectionResult:
|
||
"""适配 YOLO / 复合检测产出的 dict。"""
|
||
|
||
return DetectionAdapter._from_generic(raw, source=source, model_id=model_id)
|
||
|
||
@staticmethod
|
||
def from_paddle(
|
||
raw: Dict[str, Any], model_id: Optional[str] = "smoking_detection"
|
||
) -> DetectionResult:
|
||
"""适配 PaddleDetection (抽烟) 输出。"""
|
||
|
||
return DetectionAdapter._from_generic(
|
||
raw, source=DetectionSource.PADDLE, model_id=model_id
|
||
)
|
||
|
||
@staticmethod
|
||
def from_vehicle(
|
||
raw: Dict[str, Any], model_id: Optional[str] = "vehicle_detection"
|
||
) -> DetectionResult:
|
||
"""适配车辆检测 (含 track_id) 输出。"""
|
||
|
||
return DetectionAdapter._from_generic(
|
||
raw, source=DetectionSource.PADDLE, model_id=model_id
|
||
)
|
||
|
||
@staticmethod
|
||
def from_action_docker(
|
||
raw: Dict[str, Any], model_id: Optional[str] = "ppTSM_fight"
|
||
) -> DetectionResult:
|
||
"""适配 Docker ppTSM 行为识别输出。
|
||
|
||
其原始检测使用 ``class_id`` 字段而非 ``class``,此处统一归一化。
|
||
"""
|
||
|
||
normalized = dict(raw)
|
||
normalized_dets: List[Dict[str, Any]] = []
|
||
for det in raw.get("detections", []) or []:
|
||
d = dict(det)
|
||
if "class" not in d:
|
||
cls_id = d.get("class_id", 0)
|
||
d["class"] = "fight" if int(cls_id) == 0 else "normal"
|
||
if "label" not in d:
|
||
d["label"] = "打架" if d["class"] == "fight" else "正常"
|
||
normalized_dets.append(d)
|
||
normalized["detections"] = normalized_dets
|
||
|
||
return DetectionAdapter._from_generic(
|
||
normalized, source=DetectionSource.ACTION_DOCKER, model_id=model_id
|
||
)
|
||
|
||
# ------------------------------------------------------------------
|
||
# 内部
|
||
# ------------------------------------------------------------------
|
||
|
||
@staticmethod
|
||
def _from_generic(
|
||
raw: Dict[str, Any],
|
||
source: DetectionSource,
|
||
model_id: Optional[str],
|
||
) -> DetectionResult:
|
||
"""通用归一化逻辑。"""
|
||
|
||
if not isinstance(raw, dict):
|
||
raise TypeError(f"raw 必须为 dict,收到 {type(raw)}")
|
||
|
||
success = bool(raw.get("success", True))
|
||
message = str(raw.get("message", ""))
|
||
raw_dets = raw.get("detections") or []
|
||
|
||
unified_dets: List[UnifiedDetection] = []
|
||
for det in raw_dets:
|
||
try:
|
||
unified_dets.append(
|
||
DetectionAdapter._build_detection(det, source, model_id)
|
||
)
|
||
except Exception as exc: # noqa: BLE001 - 单条失败不阻塞整体
|
||
logger.warning("适配检测项失败,已跳过: %s, raw=%s", exc, det)
|
||
|
||
stats = DetectionAdapter._build_stats(raw.get("stats"), unified_dets, model_id)
|
||
|
||
return DetectionResult(
|
||
success=success,
|
||
message=message,
|
||
source=source,
|
||
model_id=model_id,
|
||
detections=unified_dets,
|
||
stats=stats,
|
||
)
|
||
|
||
@staticmethod
|
||
def _build_detection(
|
||
det: Dict[str, Any],
|
||
source: DetectionSource,
|
||
model_id: Optional[str],
|
||
) -> UnifiedDetection:
|
||
"""构建单个 UnifiedDetection。"""
|
||
|
||
bbox_raw = det.get("bbox") or det.get("box")
|
||
if not bbox_raw or len(bbox_raw) < 4:
|
||
raise ValueError(f"缺少有效 bbox: {det}")
|
||
|
||
x1, y1, x2, y2 = (int(v) for v in bbox_raw[:4])
|
||
# 确保 x2 >= x1, y2 >= y1
|
||
if x2 < x1:
|
||
x1, x2 = x2, x1
|
||
if y2 < y1:
|
||
y1, y2 = y2, y1
|
||
|
||
class_name = det.get("class") or det.get("class_name") or "unknown"
|
||
label = det.get("label", class_name)
|
||
confidence = float(det.get("confidence", det.get("score", 0.0)))
|
||
confidence = max(0.0, min(1.0, confidence))
|
||
|
||
track_id = det.get("track_id")
|
||
if track_id is not None:
|
||
track_id = int(track_id)
|
||
|
||
extra: Dict[str, Any] = {}
|
||
for key in ("center", "plate_number", "is_illegal_parking", "trajectory"):
|
||
if key in det:
|
||
extra[key] = det[key]
|
||
|
||
return UnifiedDetection(
|
||
track_id=track_id,
|
||
class_name=str(class_name),
|
||
label=str(label),
|
||
confidence=confidence,
|
||
bbox=BBox(x1=x1, y1=y1, x2=x2, y2=y2),
|
||
source=source,
|
||
model_id=model_id,
|
||
extra=extra,
|
||
)
|
||
|
||
@staticmethod
|
||
def _build_stats(
|
||
raw_stats: Optional[Dict[str, Any]],
|
||
detections: List[UnifiedDetection],
|
||
model_id: Optional[str],
|
||
) -> DetectionStats:
|
||
"""构建 DetectionStats,若原始 stats 缺字段则按 detections 重算。"""
|
||
|
||
if not detections:
|
||
avg_conf = 0.0
|
||
else:
|
||
avg_conf = sum(d.confidence for d in detections) / len(detections)
|
||
|
||
raw_stats = raw_stats or {}
|
||
return DetectionStats(
|
||
total_detections=int(raw_stats.get("total_detections", len(detections))),
|
||
avg_confidence=float(raw_stats.get("avg_confidence", round(avg_conf, 3))),
|
||
processing_time=float(raw_stats.get("processing_time", 0.0)),
|
||
model_used=raw_stats.get("model_used", model_id),
|
||
fps=raw_stats.get("fps"),
|
||
)
|
||
|
||
|
||
__all__ = ["DetectionAdapter"]
|