feat:新增事件决策/规则/聚合三段管道引擎
This commit is contained in:
206
apps/server/services/event/rule_engine.py
Normal file
206
apps/server/services/event/rule_engine.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""预警规则引擎 (MVP-1 / P2)
|
||||
|
||||
YAML 驱动的预警规则引擎,对 ``CandidateEvent`` 列表应用规则,
|
||||
满足条件的事件会被升级为 ``AlertEvent``。
|
||||
|
||||
规则 YAML 示例 (config/rules/fire.yaml)::
|
||||
|
||||
name: fire_critical
|
||||
event_type: fire
|
||||
enabled: true
|
||||
min_confidence: 0.6
|
||||
severity: critical # 覆盖默认严重性 (可选)
|
||||
description: 检测到火焰,立即触发预警
|
||||
|
||||
规则条件支持:
|
||||
|
||||
- ``min_confidence``: 置信度阈值
|
||||
- ``allowed_sources``: 允许的来源 (source_id 白名单,None 表示不限制)
|
||||
- ``required_labels``: 检测项 label 必须包含其中之一
|
||||
- ``min_bbox_area``: 边界框最小面积
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
from models.event_schemas import (
|
||||
AlertEvent,
|
||||
CandidateEvent,
|
||||
EventType,
|
||||
SeverityLevel,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AlertRule:
|
||||
"""单条预警规则。"""
|
||||
|
||||
name: str
|
||||
event_type: EventType
|
||||
enabled: bool = True
|
||||
min_confidence: float = 0.0
|
||||
severity: Optional[SeverityLevel] = None
|
||||
allowed_sources: Optional[List[str]] = None
|
||||
required_labels: Optional[List[str]] = None
|
||||
min_bbox_area: int = 0
|
||||
description: str = ""
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "AlertRule":
|
||||
return cls(
|
||||
name=str(data["name"]),
|
||||
event_type=EventType(data["event_type"]),
|
||||
enabled=bool(data.get("enabled", True)),
|
||||
min_confidence=float(data.get("min_confidence", 0.0)),
|
||||
severity=SeverityLevel(data["severity"]) if data.get("severity") else None,
|
||||
allowed_sources=data.get("allowed_sources"),
|
||||
required_labels=data.get("required_labels"),
|
||||
min_bbox_area=int(data.get("min_bbox_area", 0)),
|
||||
description=str(data.get("description", "")),
|
||||
)
|
||||
|
||||
def matches(self, event: CandidateEvent) -> bool:
|
||||
"""判断候选事件是否命中规则。"""
|
||||
|
||||
if not self.enabled:
|
||||
return False
|
||||
if event.event_type != self.event_type:
|
||||
return False
|
||||
if event.confidence < self.min_confidence:
|
||||
return False
|
||||
if self.allowed_sources and event.source_id not in self.allowed_sources:
|
||||
return False
|
||||
if self.required_labels:
|
||||
labels = {event.detection.class_name, event.detection.label}
|
||||
if not any(lbl in labels for lbl in self.required_labels):
|
||||
return False
|
||||
if self.min_bbox_area > 0 and event.detection.bbox.area < self.min_bbox_area:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@dataclass
|
||||
class _RuleStats:
|
||||
loaded: int = 0
|
||||
enabled: int = 0
|
||||
files: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
class AlertRuleEngine:
|
||||
"""预警规则引擎。
|
||||
|
||||
支持从单个 YAML 文件或目录批量加载规则。
|
||||
"""
|
||||
|
||||
def __init__(self, rules: Optional[List[AlertRule]] = None) -> None:
|
||||
self.rules: List[AlertRule] = list(rules or [])
|
||||
self._stats = _RuleStats(
|
||||
loaded=len(self.rules),
|
||||
enabled=sum(1 for r in self.rules if r.enabled),
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 加载
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
def from_directory(cls, rules_dir: str | Path) -> "AlertRuleEngine":
|
||||
"""从目录加载所有 ``*.yaml`` / ``*.yml`` 规则文件。"""
|
||||
|
||||
engine = cls()
|
||||
engine.load_directory(rules_dir)
|
||||
return engine
|
||||
|
||||
def load_directory(self, rules_dir: str | Path) -> int:
|
||||
path = Path(rules_dir)
|
||||
if not path.exists():
|
||||
logger.warning("规则目录不存在: %s", path)
|
||||
return 0
|
||||
|
||||
count = 0
|
||||
for file in sorted(path.glob("*.y*ml")):
|
||||
try:
|
||||
count += self.load_file(file)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.error("加载规则文件失败 %s: %s", file, exc)
|
||||
logger.info("规则引擎加载完成: 共 %d 条规则 (来自 %s)", count, path)
|
||||
return count
|
||||
|
||||
def load_file(self, rule_file: str | Path) -> int:
|
||||
path = Path(rule_file)
|
||||
with path.open("r", encoding="utf-8") as fp:
|
||||
data = yaml.safe_load(fp)
|
||||
|
||||
if data is None:
|
||||
return 0
|
||||
|
||||
# 单条规则 dict 或 列表 或 {"rules": [...]} 三种格式都支持
|
||||
if isinstance(data, dict) and "rules" in data:
|
||||
rule_items = data["rules"]
|
||||
elif isinstance(data, list):
|
||||
rule_items = data
|
||||
else:
|
||||
rule_items = [data]
|
||||
|
||||
added = 0
|
||||
for item in rule_items:
|
||||
rule = AlertRule.from_dict(item)
|
||||
self.rules.append(rule)
|
||||
self._stats.loaded += 1
|
||||
if rule.enabled:
|
||||
self._stats.enabled += 1
|
||||
added += 1
|
||||
self._stats.files.append(str(path))
|
||||
return added
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 评估
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def evaluate(self, events: List[CandidateEvent]) -> List[AlertEvent]:
|
||||
"""对一批候选事件执行规则评估,返回触发的预警事件。"""
|
||||
|
||||
alerts: List[AlertEvent] = []
|
||||
for event in events:
|
||||
for rule in self.rules:
|
||||
if rule.matches(event):
|
||||
alerts.append(self._build_alert(event, rule))
|
||||
event.triggered_rules.append(rule.name)
|
||||
break # 命中一条即可,避免重复
|
||||
return alerts
|
||||
|
||||
@staticmethod
|
||||
def _build_alert(event: CandidateEvent, rule: AlertRule) -> AlertEvent:
|
||||
severity = rule.severity or event.severity
|
||||
return AlertEvent(
|
||||
event_type=event.event_type,
|
||||
severity=severity,
|
||||
confidence=event.confidence,
|
||||
source_id=event.source_id,
|
||||
detections=[event.detection],
|
||||
rule_name=rule.name,
|
||||
metadata={"description": rule.description},
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 自省
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def stats(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"loaded": self._stats.loaded,
|
||||
"enabled": self._stats.enabled,
|
||||
"files": list(self._stats.files),
|
||||
}
|
||||
|
||||
|
||||
__all__ = ["AlertRule", "AlertRuleEngine"]
|
||||
Reference in New Issue
Block a user