207 lines
6.5 KiB
Python
207 lines
6.5 KiB
Python
"""预警规则引擎 (MVP-1 / P2)
|
||
|
||
YAML 驱动的预警规则引擎,对 ``CandidateEvent`` 列表应用规则,
|
||
满足条件的事件会被升级为 ``AlertEvent``。
|
||
|
||
规则 YAML 示例 (config/rules/fire.yaml)::
|
||
|
||
name: fire_critical
|
||
event_type: fire
|
||
enabled: true
|
||
min_confidence: 0.6
|
||
severity: critical # 覆盖默认严重性 (可选)
|
||
description: 检测到火焰,立即触发预警
|
||
|
||
规则条件支持:
|
||
|
||
- ``min_confidence``: 置信度阈值
|
||
- ``allowed_sources``: 允许的来源 (source_id 白名单,None 表示不限制)
|
||
- ``required_labels``: 检测项 label 必须包含其中之一
|
||
- ``min_bbox_area``: 边界框最小面积
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
from dataclasses import dataclass, field
|
||
from pathlib import Path
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
import yaml
|
||
|
||
from models.event_schemas import (
|
||
AlertEvent,
|
||
CandidateEvent,
|
||
EventType,
|
||
SeverityLevel,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class AlertRule:
|
||
"""单条预警规则。"""
|
||
|
||
name: str
|
||
event_type: EventType
|
||
enabled: bool = True
|
||
min_confidence: float = 0.0
|
||
severity: Optional[SeverityLevel] = None
|
||
allowed_sources: Optional[List[str]] = None
|
||
required_labels: Optional[List[str]] = None
|
||
min_bbox_area: int = 0
|
||
description: str = ""
|
||
|
||
@classmethod
|
||
def from_dict(cls, data: Dict[str, Any]) -> "AlertRule":
|
||
return cls(
|
||
name=str(data["name"]),
|
||
event_type=EventType(data["event_type"]),
|
||
enabled=bool(data.get("enabled", True)),
|
||
min_confidence=float(data.get("min_confidence", 0.0)),
|
||
severity=SeverityLevel(data["severity"]) if data.get("severity") else None,
|
||
allowed_sources=data.get("allowed_sources"),
|
||
required_labels=data.get("required_labels"),
|
||
min_bbox_area=int(data.get("min_bbox_area", 0)),
|
||
description=str(data.get("description", "")),
|
||
)
|
||
|
||
def matches(self, event: CandidateEvent) -> bool:
|
||
"""判断候选事件是否命中规则。"""
|
||
|
||
if not self.enabled:
|
||
return False
|
||
if event.event_type != self.event_type:
|
||
return False
|
||
if event.confidence < self.min_confidence:
|
||
return False
|
||
if self.allowed_sources and event.source_id not in self.allowed_sources:
|
||
return False
|
||
if self.required_labels:
|
||
labels = {event.detection.class_name, event.detection.label}
|
||
if not any(lbl in labels for lbl in self.required_labels):
|
||
return False
|
||
if self.min_bbox_area > 0 and event.detection.bbox.area < self.min_bbox_area:
|
||
return False
|
||
return True
|
||
|
||
|
||
@dataclass
|
||
class _RuleStats:
|
||
loaded: int = 0
|
||
enabled: int = 0
|
||
files: List[str] = field(default_factory=list)
|
||
|
||
|
||
class AlertRuleEngine:
|
||
"""预警规则引擎。
|
||
|
||
支持从单个 YAML 文件或目录批量加载规则。
|
||
"""
|
||
|
||
def __init__(self, rules: Optional[List[AlertRule]] = None) -> None:
|
||
self.rules: List[AlertRule] = list(rules or [])
|
||
self._stats = _RuleStats(
|
||
loaded=len(self.rules),
|
||
enabled=sum(1 for r in self.rules if r.enabled),
|
||
)
|
||
|
||
# ------------------------------------------------------------------
|
||
# 加载
|
||
# ------------------------------------------------------------------
|
||
|
||
@classmethod
|
||
def from_directory(cls, rules_dir: str | Path) -> "AlertRuleEngine":
|
||
"""从目录加载所有 ``*.yaml`` / ``*.yml`` 规则文件。"""
|
||
|
||
engine = cls()
|
||
engine.load_directory(rules_dir)
|
||
return engine
|
||
|
||
def load_directory(self, rules_dir: str | Path) -> int:
|
||
path = Path(rules_dir)
|
||
if not path.exists():
|
||
logger.warning("规则目录不存在: %s", path)
|
||
return 0
|
||
|
||
count = 0
|
||
for file in sorted(path.glob("*.y*ml")):
|
||
try:
|
||
count += self.load_file(file)
|
||
except Exception as exc: # noqa: BLE001
|
||
logger.error("加载规则文件失败 %s: %s", file, exc)
|
||
logger.info("规则引擎加载完成: 共 %d 条规则 (来自 %s)", count, path)
|
||
return count
|
||
|
||
def load_file(self, rule_file: str | Path) -> int:
|
||
path = Path(rule_file)
|
||
with path.open("r", encoding="utf-8") as fp:
|
||
data = yaml.safe_load(fp)
|
||
|
||
if data is None:
|
||
return 0
|
||
|
||
# 单条规则 dict 或 列表 或 {"rules": [...]} 三种格式都支持
|
||
if isinstance(data, dict) and "rules" in data:
|
||
rule_items = data["rules"]
|
||
elif isinstance(data, list):
|
||
rule_items = data
|
||
else:
|
||
rule_items = [data]
|
||
|
||
added = 0
|
||
for item in rule_items:
|
||
rule = AlertRule.from_dict(item)
|
||
self.rules.append(rule)
|
||
self._stats.loaded += 1
|
||
if rule.enabled:
|
||
self._stats.enabled += 1
|
||
added += 1
|
||
self._stats.files.append(str(path))
|
||
return added
|
||
|
||
# ------------------------------------------------------------------
|
||
# 评估
|
||
# ------------------------------------------------------------------
|
||
|
||
def evaluate(self, events: List[CandidateEvent]) -> List[AlertEvent]:
|
||
"""对一批候选事件执行规则评估,返回触发的预警事件。"""
|
||
|
||
alerts: List[AlertEvent] = []
|
||
for event in events:
|
||
for rule in self.rules:
|
||
if rule.matches(event):
|
||
alerts.append(self._build_alert(event, rule))
|
||
event.triggered_rules.append(rule.name)
|
||
break # 命中一条即可,避免重复
|
||
return alerts
|
||
|
||
@staticmethod
|
||
def _build_alert(event: CandidateEvent, rule: AlertRule) -> AlertEvent:
|
||
severity = rule.severity or event.severity
|
||
return AlertEvent(
|
||
event_type=event.event_type,
|
||
severity=severity,
|
||
confidence=event.confidence,
|
||
source_id=event.source_id,
|
||
detections=[event.detection],
|
||
rule_name=rule.name,
|
||
metadata={"description": rule.description},
|
||
)
|
||
|
||
# ------------------------------------------------------------------
|
||
# 自省
|
||
# ------------------------------------------------------------------
|
||
|
||
@property
|
||
def stats(self) -> Dict[str, Any]:
|
||
return {
|
||
"loaded": self._stats.loaded,
|
||
"enabled": self._stats.enabled,
|
||
"files": list(self._stats.files),
|
||
}
|
||
|
||
|
||
__all__ = ["AlertRule", "AlertRuleEngine"]
|