Files
jc-video-recognize/apps/server/services/event/rule_engine.py

207 lines
6.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""预警规则引擎 (MVP-1 / P2)
YAML 驱动的预警规则引擎,对 ``CandidateEvent`` 列表应用规则,
满足条件的事件会被升级为 ``AlertEvent``。
规则 YAML 示例 (config/rules/fire.yaml)::
name: fire_critical
event_type: fire
enabled: true
min_confidence: 0.6
severity: critical # 覆盖默认严重性 (可选)
description: 检测到火焰,立即触发预警
规则条件支持:
- ``min_confidence``: 置信度阈值
- ``allowed_sources``: 允许的来源 (source_id 白名单None 表示不限制)
- ``required_labels``: 检测项 label 必须包含其中之一
- ``min_bbox_area``: 边界框最小面积
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional
import yaml
from models.event_schemas import (
AlertEvent,
CandidateEvent,
EventType,
SeverityLevel,
)
logger = logging.getLogger(__name__)
@dataclass
class AlertRule:
"""单条预警规则。"""
name: str
event_type: EventType
enabled: bool = True
min_confidence: float = 0.0
severity: Optional[SeverityLevel] = None
allowed_sources: Optional[List[str]] = None
required_labels: Optional[List[str]] = None
min_bbox_area: int = 0
description: str = ""
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "AlertRule":
return cls(
name=str(data["name"]),
event_type=EventType(data["event_type"]),
enabled=bool(data.get("enabled", True)),
min_confidence=float(data.get("min_confidence", 0.0)),
severity=SeverityLevel(data["severity"]) if data.get("severity") else None,
allowed_sources=data.get("allowed_sources"),
required_labels=data.get("required_labels"),
min_bbox_area=int(data.get("min_bbox_area", 0)),
description=str(data.get("description", "")),
)
def matches(self, event: CandidateEvent) -> bool:
"""判断候选事件是否命中规则。"""
if not self.enabled:
return False
if event.event_type != self.event_type:
return False
if event.confidence < self.min_confidence:
return False
if self.allowed_sources and event.source_id not in self.allowed_sources:
return False
if self.required_labels:
labels = {event.detection.class_name, event.detection.label}
if not any(lbl in labels for lbl in self.required_labels):
return False
if self.min_bbox_area > 0 and event.detection.bbox.area < self.min_bbox_area:
return False
return True
@dataclass
class _RuleStats:
loaded: int = 0
enabled: int = 0
files: List[str] = field(default_factory=list)
class AlertRuleEngine:
"""预警规则引擎。
支持从单个 YAML 文件或目录批量加载规则。
"""
def __init__(self, rules: Optional[List[AlertRule]] = None) -> None:
self.rules: List[AlertRule] = list(rules or [])
self._stats = _RuleStats(
loaded=len(self.rules),
enabled=sum(1 for r in self.rules if r.enabled),
)
# ------------------------------------------------------------------
# 加载
# ------------------------------------------------------------------
@classmethod
def from_directory(cls, rules_dir: str | Path) -> "AlertRuleEngine":
"""从目录加载所有 ``*.yaml`` / ``*.yml`` 规则文件。"""
engine = cls()
engine.load_directory(rules_dir)
return engine
def load_directory(self, rules_dir: str | Path) -> int:
path = Path(rules_dir)
if not path.exists():
logger.warning("规则目录不存在: %s", path)
return 0
count = 0
for file in sorted(path.glob("*.y*ml")):
try:
count += self.load_file(file)
except Exception as exc: # noqa: BLE001
logger.error("加载规则文件失败 %s: %s", file, exc)
logger.info("规则引擎加载完成: 共 %d 条规则 (来自 %s)", count, path)
return count
def load_file(self, rule_file: str | Path) -> int:
path = Path(rule_file)
with path.open("r", encoding="utf-8") as fp:
data = yaml.safe_load(fp)
if data is None:
return 0
# 单条规则 dict 或 列表 或 {"rules": [...]} 三种格式都支持
if isinstance(data, dict) and "rules" in data:
rule_items = data["rules"]
elif isinstance(data, list):
rule_items = data
else:
rule_items = [data]
added = 0
for item in rule_items:
rule = AlertRule.from_dict(item)
self.rules.append(rule)
self._stats.loaded += 1
if rule.enabled:
self._stats.enabled += 1
added += 1
self._stats.files.append(str(path))
return added
# ------------------------------------------------------------------
# 评估
# ------------------------------------------------------------------
def evaluate(self, events: List[CandidateEvent]) -> List[AlertEvent]:
"""对一批候选事件执行规则评估,返回触发的预警事件。"""
alerts: List[AlertEvent] = []
for event in events:
for rule in self.rules:
if rule.matches(event):
alerts.append(self._build_alert(event, rule))
event.triggered_rules.append(rule.name)
break # 命中一条即可,避免重复
return alerts
@staticmethod
def _build_alert(event: CandidateEvent, rule: AlertRule) -> AlertEvent:
severity = rule.severity or event.severity
return AlertEvent(
event_type=event.event_type,
severity=severity,
confidence=event.confidence,
source_id=event.source_id,
detections=[event.detection],
rule_name=rule.name,
metadata={"description": rule.description},
)
# ------------------------------------------------------------------
# 自省
# ------------------------------------------------------------------
@property
def stats(self) -> Dict[str, Any]:
return {
"loaded": self._stats.loaded,
"enabled": self._stats.enabled,
"files": list(self._stats.files),
}
__all__ = ["AlertRule", "AlertRuleEngine"]