"""预警规则引擎 (MVP-1 / P2) YAML 驱动的预警规则引擎,对 ``CandidateEvent`` 列表应用规则, 满足条件的事件会被升级为 ``AlertEvent``。 规则 YAML 示例 (config/rules/fire.yaml):: name: fire_critical event_type: fire enabled: true min_confidence: 0.6 severity: critical # 覆盖默认严重性 (可选) description: 检测到火焰,立即触发预警 规则条件支持: - ``min_confidence``: 置信度阈值 - ``allowed_sources``: 允许的来源 (source_id 白名单,None 表示不限制) - ``required_labels``: 检测项 label 必须包含其中之一 - ``min_bbox_area``: 边界框最小面积 """ from __future__ import annotations import logging from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, List, Optional import yaml from models.event_schemas import ( AlertEvent, CandidateEvent, EventType, SeverityLevel, ) logger = logging.getLogger(__name__) @dataclass class AlertRule: """单条预警规则。""" name: str event_type: EventType enabled: bool = True min_confidence: float = 0.0 severity: Optional[SeverityLevel] = None allowed_sources: Optional[List[str]] = None required_labels: Optional[List[str]] = None min_bbox_area: int = 0 description: str = "" @classmethod def from_dict(cls, data: Dict[str, Any]) -> "AlertRule": return cls( name=str(data["name"]), event_type=EventType(data["event_type"]), enabled=bool(data.get("enabled", True)), min_confidence=float(data.get("min_confidence", 0.0)), severity=SeverityLevel(data["severity"]) if data.get("severity") else None, allowed_sources=data.get("allowed_sources"), required_labels=data.get("required_labels"), min_bbox_area=int(data.get("min_bbox_area", 0)), description=str(data.get("description", "")), ) def matches(self, event: CandidateEvent) -> bool: """判断候选事件是否命中规则。""" if not self.enabled: return False if event.event_type != self.event_type: return False if event.confidence < self.min_confidence: return False if self.allowed_sources and event.source_id not in self.allowed_sources: return False if self.required_labels: labels = {event.detection.class_name, event.detection.label} if not any(lbl in labels for lbl in self.required_labels): return False if self.min_bbox_area > 0 and event.detection.bbox.area < self.min_bbox_area: return False return True @dataclass class _RuleStats: loaded: int = 0 enabled: int = 0 files: List[str] = field(default_factory=list) class AlertRuleEngine: """预警规则引擎。 支持从单个 YAML 文件或目录批量加载规则。 """ def __init__(self, rules: Optional[List[AlertRule]] = None) -> None: self.rules: List[AlertRule] = list(rules or []) self._stats = _RuleStats( loaded=len(self.rules), enabled=sum(1 for r in self.rules if r.enabled), ) # ------------------------------------------------------------------ # 加载 # ------------------------------------------------------------------ @classmethod def from_directory(cls, rules_dir: str | Path) -> "AlertRuleEngine": """从目录加载所有 ``*.yaml`` / ``*.yml`` 规则文件。""" engine = cls() engine.load_directory(rules_dir) return engine def load_directory(self, rules_dir: str | Path) -> int: path = Path(rules_dir) if not path.exists(): logger.warning("规则目录不存在: %s", path) return 0 count = 0 for file in sorted(path.glob("*.y*ml")): try: count += self.load_file(file) except Exception as exc: # noqa: BLE001 logger.error("加载规则文件失败 %s: %s", file, exc) logger.info("规则引擎加载完成: 共 %d 条规则 (来自 %s)", count, path) return count def load_file(self, rule_file: str | Path) -> int: path = Path(rule_file) with path.open("r", encoding="utf-8") as fp: data = yaml.safe_load(fp) if data is None: return 0 # 单条规则 dict 或 列表 或 {"rules": [...]} 三种格式都支持 if isinstance(data, dict) and "rules" in data: rule_items = data["rules"] elif isinstance(data, list): rule_items = data else: rule_items = [data] added = 0 for item in rule_items: rule = AlertRule.from_dict(item) self.rules.append(rule) self._stats.loaded += 1 if rule.enabled: self._stats.enabled += 1 added += 1 self._stats.files.append(str(path)) return added # ------------------------------------------------------------------ # 评估 # ------------------------------------------------------------------ def evaluate(self, events: List[CandidateEvent]) -> List[AlertEvent]: """对一批候选事件执行规则评估,返回触发的预警事件。""" alerts: List[AlertEvent] = [] for event in events: for rule in self.rules: if rule.matches(event): alerts.append(self._build_alert(event, rule)) event.triggered_rules.append(rule.name) break # 命中一条即可,避免重复 return alerts @staticmethod def _build_alert(event: CandidateEvent, rule: AlertRule) -> AlertEvent: severity = rule.severity or event.severity return AlertEvent( event_type=event.event_type, severity=severity, confidence=event.confidence, source_id=event.source_id, detections=[event.detection], rule_name=rule.name, metadata={"description": rule.description}, ) # ------------------------------------------------------------------ # 自省 # ------------------------------------------------------------------ @property def stats(self) -> Dict[str, Any]: return { "loaded": self._stats.loaded, "enabled": self._stats.enabled, "files": list(self._stats.files), } __all__ = ["AlertRule", "AlertRuleEngine"]