164 lines
5.0 KiB
Python
164 lines
5.0 KiB
Python
"""事件决策引擎 (MVP-1 / P1)
|
|
|
|
职责:
|
|
|
|
1. 对统一 ``DetectionResult`` 中的每个检测应用置信度过滤
|
|
2. 将底层 ``class_name`` 映射到统一 ``EventType``
|
|
3. 产出 ``CandidateEvent`` 列表,供规则引擎与聚合器继续处理
|
|
|
|
设计原则 (MVP 简化版):
|
|
|
|
- 不做温度缩放 / 校准 (后续 MVP-3 再迭代)
|
|
- 不做场景分类 (后续按需引入)
|
|
- 类型映射规则可外部覆盖,避免硬编码
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from typing import Dict, List, Optional
|
|
|
|
from models.event_schemas import (
|
|
CandidateEvent,
|
|
DetectionResult,
|
|
EventType,
|
|
SeverityLevel,
|
|
UnifiedDetection,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# 默认的 class_name -> EventType 映射 (覆盖当前已有模型)
|
|
DEFAULT_CLASS_TO_EVENT: Dict[str, EventType] = {
|
|
# 火灾
|
|
"fire": EventType.FIRE,
|
|
"flame": EventType.FIRE,
|
|
"smoke": EventType.SMOKE,
|
|
# 抽烟
|
|
"smoking": EventType.SMOKING,
|
|
"cigarette": EventType.SMOKING,
|
|
# 打架
|
|
"fight": EventType.FIGHT,
|
|
"fighting": EventType.FIGHT,
|
|
# 行为
|
|
"loitering": EventType.LOITERING,
|
|
"stationary": EventType.STATIONARY,
|
|
"intrusion": EventType.INTRUSION,
|
|
# 车辆
|
|
"vehicle": EventType.VEHICLE,
|
|
"car": EventType.VEHICLE,
|
|
"truck": EventType.VEHICLE,
|
|
"bus": EventType.VEHICLE,
|
|
"illegal_parking": EventType.ILLEGAL_PARKING,
|
|
# 人员
|
|
"person": EventType.PERSON,
|
|
}
|
|
|
|
# 事件类型 -> 默认严重性 (规则引擎可覆盖)
|
|
DEFAULT_SEVERITY: Dict[EventType, SeverityLevel] = {
|
|
EventType.FIRE: SeverityLevel.CRITICAL,
|
|
EventType.SMOKE: SeverityLevel.HIGH,
|
|
EventType.SMOKING: SeverityLevel.MEDIUM,
|
|
EventType.FIGHT: SeverityLevel.HIGH,
|
|
EventType.LOITERING: SeverityLevel.MEDIUM,
|
|
EventType.STATIONARY: SeverityLevel.LOW,
|
|
EventType.INTRUSION: SeverityLevel.HIGH,
|
|
EventType.ILLEGAL_PARKING: SeverityLevel.MEDIUM,
|
|
EventType.VEHICLE: SeverityLevel.INFO,
|
|
EventType.PERSON: SeverityLevel.INFO,
|
|
EventType.UNKNOWN: SeverityLevel.INFO,
|
|
}
|
|
|
|
|
|
class EventDecisionEngine:
|
|
"""事件决策引擎 (简化版)。
|
|
|
|
Args:
|
|
min_confidence: 全局最低置信度阈值,低于此值的检测会被丢弃
|
|
class_to_event: 自定义类别映射,会与默认映射合并 (覆盖)
|
|
ignore_event_types: 不希望产出的事件类型集合 (例如 PERSON 太频繁)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
min_confidence: float = 0.5,
|
|
class_to_event: Optional[Dict[str, EventType]] = None,
|
|
ignore_event_types: Optional[List[EventType]] = None,
|
|
) -> None:
|
|
self.min_confidence = max(0.0, min(1.0, min_confidence))
|
|
self.class_to_event: Dict[str, EventType] = dict(DEFAULT_CLASS_TO_EVENT)
|
|
if class_to_event:
|
|
self.class_to_event.update(class_to_event)
|
|
self.ignore_event_types = set(ignore_event_types or [])
|
|
|
|
# ------------------------------------------------------------------
|
|
# 主入口
|
|
# ------------------------------------------------------------------
|
|
|
|
def decide(
|
|
self,
|
|
result: DetectionResult,
|
|
source_id: Optional[str] = None,
|
|
) -> List[CandidateEvent]:
|
|
"""根据检测结果产出候选事件列表。
|
|
|
|
Args:
|
|
result: 统一检测结果
|
|
source_id: 摄像头/视频流标识,用于后续聚合
|
|
"""
|
|
|
|
if not result.success or not result.detections:
|
|
return []
|
|
|
|
events: List[CandidateEvent] = []
|
|
for det in result.detections:
|
|
event = self._build_candidate(det, source_id)
|
|
if event is not None:
|
|
events.append(event)
|
|
|
|
if events:
|
|
logger.debug(
|
|
"DecisionEngine 产出 %d 条候选事件 (source_id=%s, model=%s)",
|
|
len(events),
|
|
source_id,
|
|
result.model_id,
|
|
)
|
|
return events
|
|
|
|
# ------------------------------------------------------------------
|
|
# 内部
|
|
# ------------------------------------------------------------------
|
|
|
|
def _build_candidate(
|
|
self,
|
|
det: UnifiedDetection,
|
|
source_id: Optional[str],
|
|
) -> Optional[CandidateEvent]:
|
|
if det.confidence < self.min_confidence:
|
|
return None
|
|
|
|
event_type = self.map_event_type(det.class_name)
|
|
if event_type in self.ignore_event_types:
|
|
return None
|
|
|
|
severity = DEFAULT_SEVERITY.get(event_type, SeverityLevel.INFO)
|
|
|
|
return CandidateEvent(
|
|
event_type=event_type,
|
|
severity=severity,
|
|
confidence=det.confidence,
|
|
detection=det,
|
|
source_id=source_id,
|
|
)
|
|
|
|
def map_event_type(self, class_name: str) -> EventType:
|
|
"""将 class_name 映射为 EventType (大小写不敏感)。"""
|
|
|
|
if not class_name:
|
|
return EventType.UNKNOWN
|
|
return self.class_to_event.get(class_name.lower().strip(), EventType.UNKNOWN)
|
|
|
|
|
|
__all__ = ["EventDecisionEngine", "DEFAULT_CLASS_TO_EVENT", "DEFAULT_SEVERITY"]
|