Files

248 lines
8.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""事件聚合器 (MVP-1 / P5 + MVP-2 / D20 增强版)
MVP-1 能力:
- 时间窗口去重: 对同一 (source_id, event_type, track_id_or_bbox_hash)
在配置窗口内只保留一条预警事件
MVP-2 / D20 新增能力:
- 空间邻近合并: 同一 (source_id, event_type) 且 bbox IOU 高的事件视为同一目标
- 置信度加权融合: 多次命中时按时间衰减加权融合置信度
- 融合策略可选: weighted / max / avg
"""
from __future__ import annotations
import logging
import time
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple, TypeAlias
from models.event_schemas import AlertEvent, BBox, UnifiedDetection
logger = logging.getLogger(__name__)
_AggKey: TypeAlias = Tuple[Optional[str], str, str]
# ---------------------------------------------------------------------------
# IOU 工具
# ---------------------------------------------------------------------------
def _bbox_iou(b1: BBox, b2: BBox) -> float:
inter_x1 = max(b1.x1, b2.x1)
inter_y1 = max(b1.y1, b2.y1)
inter_x2 = min(b1.x2, b2.x2)
inter_y2 = min(b1.y2, b2.y2)
inter_w = max(0, inter_x2 - inter_x1)
inter_h = max(0, inter_y2 - inter_y1)
inter = inter_w * inter_h
if inter == 0:
return 0.0
union = b1.area + b2.area - inter
return inter / union if union > 0 else 0.0
# ---------------------------------------------------------------------------
# EventAggregator
# ---------------------------------------------------------------------------
class EventAggregator:
"""基于时间窗口 + 空间邻近的预警去重 / 融合器。
Args:
dedup_window_seconds: 去重窗口 (秒),同 key 在窗口内不会重复产出
max_active_events: 内存中最大活跃事件数,超过时按 LRU 淘汰
enable_spatial_merge: 是否启用空间邻近合并 (MVP-2)
spatial_iou_threshold: 空间合并 IOU 阈值
fusion_strategy: 置信度融合策略: ``weighted`` / ``max`` / ``avg``
fusion_decay_factor: 历史置信度衰减因子 (越小越偏向新数据)
"""
SUPPORTED_FUSION_STRATEGIES = ("weighted", "max", "avg")
def __init__(
self,
dedup_window_seconds: float = 30.0,
max_active_events: int = 1000,
enable_spatial_merge: bool = False,
spatial_iou_threshold: float = 0.3,
fusion_strategy: str = "max",
fusion_decay_factor: float = 0.9,
) -> None:
self.dedup_window_seconds = max(0.0, dedup_window_seconds)
self.max_active_events = max(1, max_active_events)
self.enable_spatial_merge = enable_spatial_merge
self.spatial_iou_threshold = max(0.0, min(1.0, spatial_iou_threshold))
if fusion_strategy not in self.SUPPORTED_FUSION_STRATEGIES:
raise ValueError(
f"不支持的融合策略: {fusion_strategy}, "
f"支持: {self.SUPPORTED_FUSION_STRATEGIES}"
)
self.fusion_strategy = fusion_strategy
self.fusion_decay_factor = max(0.0, min(1.0, fusion_decay_factor))
# 按插入顺序保存以便 LRU 淘汰
self._active: "OrderedDict[_AggKey, AlertEvent]" = OrderedDict()
# ------------------------------------------------------------------
# 主入口
# ------------------------------------------------------------------
def aggregate(self, alerts: List[AlertEvent]) -> List[AlertEvent]:
"""聚合一批预警事件,返回去重 / 融合后真正应当对外发出的事件。"""
now = time.time()
self._evict_expired(now)
emitted: List[AlertEvent] = []
for alert in alerts:
# 1. 优先按 key (track_id / bbox 网格) 精确匹配
key = self._make_key(alert)
existing = self._active.get(key)
# 2. 空间邻近合并: 若 key 未命中,尝试 IOU 匹配
if existing is None and self.enable_spatial_merge:
spatial_key = self._find_spatial_match(alert)
if spatial_key is not None:
existing = self._active.get(spatial_key)
key = spatial_key # 复用旧 key
if existing is None:
self._active[key] = alert
self._active.move_to_end(key)
emitted.append(alert)
if len(self._active) > self.max_active_events:
dropped_key, _ = self._active.popitem(last=False)
logger.debug("聚合器 LRU 淘汰事件: %s", dropped_key)
else:
# 窗口内重复:融合统计
self._fuse(existing, alert, now)
self._active.move_to_end(key)
return emitted
# ------------------------------------------------------------------
# 融合
# ------------------------------------------------------------------
def _fuse(self, existing: AlertEvent, new: AlertEvent, now: float) -> None:
"""将新事件融合到已有事件。"""
existing.last_seen = now
existing.occurrence_count += 1
# 置信度融合
if self.fusion_strategy == "max":
existing.confidence = max(existing.confidence, new.confidence)
elif self.fusion_strategy == "avg":
n = existing.occurrence_count
existing.confidence = (
existing.confidence * (n - 1) + new.confidence
) / n
else: # weighted
decay = self.fusion_decay_factor
existing.confidence = (
existing.confidence * decay + new.confidence * (1 - decay)
)
# 取整到 4 位避免浮点漂移
existing.confidence = round(
max(0.0, min(1.0, existing.confidence)), 4
)
# 严重性向上提升 (新事件更严重时)
severity_order = ["info", "low", "medium", "high", "critical"]
try:
if severity_order.index(new.severity.value) > severity_order.index(
existing.severity.value
):
existing.severity = new.severity
except ValueError:
pass
# 更新最新的检测目标 (用于 LLM 触发器拿到最新 bbox)
if new.detections:
existing.detections = new.detections
# ------------------------------------------------------------------
# 空间匹配
# ------------------------------------------------------------------
def _find_spatial_match(self, alert: AlertEvent) -> Optional[_AggKey]:
"""在活跃事件中寻找空间邻近的同类型事件。"""
if not alert.detections:
return None
target_bbox = alert.detections[0].bbox
best_key: Optional[_AggKey] = None
best_iou = 0.0
for key, existing in self._active.items():
# 必须同 source + 同事件类型
if key[0] != alert.source_id:
continue
if key[1] != alert.event_type.value:
continue
if not existing.detections:
continue
iou = _bbox_iou(target_bbox, existing.detections[0].bbox)
if iou >= self.spatial_iou_threshold and iou > best_iou:
best_iou = iou
best_key = key
return best_key
# ------------------------------------------------------------------
# Key 构造
# ------------------------------------------------------------------
@staticmethod
def _make_key(alert: AlertEvent) -> _AggKey:
if alert.detections:
target_id = EventAggregator._target_identity(alert.detections[0])
else:
target_id = "no_target"
return (alert.source_id, alert.event_type.value, target_id)
@staticmethod
def _target_identity(det: UnifiedDetection) -> str:
"""构造目标稳定标识:优先 track_id否则用 bbox 网格哈希。"""
if det.track_id is not None:
return f"t{det.track_id}"
cx, cy = det.bbox.center
return f"g{int(cx) // 50}_{int(cy) // 50}_{det.class_name}"
# ------------------------------------------------------------------
# 过期淘汰
# ------------------------------------------------------------------
def _evict_expired(self, now: float) -> None:
if self.dedup_window_seconds <= 0:
self._active.clear()
return
expired_keys = [
key
for key, alert in self._active.items()
if now - alert.first_seen > self.dedup_window_seconds
]
for key in expired_keys:
self._active.pop(key, None)
# ------------------------------------------------------------------
# 自省
# ------------------------------------------------------------------
@property
def active_count(self) -> int:
return len(self._active)
def snapshot(self) -> Dict[str, AlertEvent]:
return {f"{k[0]}|{k[1]}|{k[2]}": v for k, v in self._active.items()}
__all__ = ["EventAggregator"]