"""预警 WebSocket 接口 (MVP-2 / D23-D24) 提供 ``/ws/alerts`` 接口,供前端订阅实时预警事件。 工作原理: 1. 前端通过 WebSocket 连接 ``/ws/alerts`` 2. 后端维护连接池,预警事件通过 ``AlertBroadcaster`` 广播到所有订阅者 3. 支持按事件类型、source_id 过滤订阅 4. AlertPublisher 在发布到 MQTT 的同时也调用本接口的广播器 消息格式 (与 AlertPublisher 保持一致):: { "type": "alert", "data": { "alert_id": "...", "event_type": "fire", ... } } 订阅消息:: {"action": "subscribe", "filter": {"event_types": ["fire"], "source_ids": ["cam-01"]}} {"action": "unsubscribe"} {"action": "ping"} """ from __future__ import annotations import asyncio import json import logging from typing import Any, Dict, List, Optional, Set from fastapi import APIRouter, WebSocket, WebSocketDisconnect logger = logging.getLogger(__name__) router = APIRouter() # --------------------------------------------------------------------------- # AlertBroadcaster # --------------------------------------------------------------------------- class AlertBroadcaster: """预警事件广播器 (单例)。 保存所有活跃 WebSocket 连接,提供广播 + 按过滤条件分发能力。 """ def __init__(self) -> None: self._connections: Dict[int, "_Subscriber"] = {} self._lock = asyncio.Lock() async def add_connection(self, ws: WebSocket) -> "_Subscriber": """添加 WebSocket 连接。""" async with self._lock: sub = _Subscriber(ws=ws) self._connections[id(ws)] = sub logger.info("AlertBroadcaster 新增订阅者: %d (当前 %d)", id(ws), len(self._connections)) return sub async def remove_connection(self, ws: WebSocket) -> None: """移除 WebSocket 连接。""" async with self._lock: self._connections.pop(id(ws), None) logger.info( "AlertBroadcaster 移除订阅者: %d (剩余 %d)", id(ws), len(self._connections), ) async def broadcast(self, alert: Dict[str, Any]) -> int: """广播预警事件到所有匹配的订阅者。 Returns: 实际投递的订阅者数量 """ message = {"type": "alert", "data": alert} delivered = 0 async with self._lock: subs = list(self._connections.values()) # 在锁外发送,避免阻塞其他连接管理 for sub in subs: if not sub.matches(alert): continue try: await sub.ws.send_json(message) delivered += 1 except Exception as e: # noqa: BLE001 logger.debug("广播失败 (订阅者 %d): %s", id(sub.ws), e) return delivered @property def connection_count(self) -> int: return len(self._connections) def get_stats(self) -> Dict[str, Any]: return { "total_connections": len(self._connections), "subscribers": [ { "id": sid, "event_types": list(s.event_types) if s.event_types else "all", "source_ids": list(s.source_ids) if s.source_ids else "all", } for sid, s in self._connections.items() ], } # --------------------------------------------------------------------------- # 订阅者 # --------------------------------------------------------------------------- class _Subscriber: """单个 WebSocket 订阅者上下文。""" __slots__ = ("ws", "event_types", "source_ids") def __init__( self, ws: WebSocket, event_types: Optional[Set[str]] = None, source_ids: Optional[Set[str]] = None, ) -> None: self.ws = ws self.event_types: Set[str] = event_types or set() self.source_ids: Set[str] = source_ids or set() def update_filter( self, event_types: Optional[List[str]] = None, source_ids: Optional[List[str]] = None, ) -> None: self.event_types = set(event_types or []) self.source_ids = set(source_ids or []) def matches(self, alert: Dict[str, Any]) -> bool: """判断预警事件是否匹配订阅过滤条件。""" if self.event_types: event_type = alert.get("event_type") if event_type not in self.event_types: return False if self.source_ids: source_id = alert.get("source_id") if source_id not in self.source_ids: return False return True # --------------------------------------------------------------------------- # 全局广播器实例 # --------------------------------------------------------------------------- _broadcaster: Optional[AlertBroadcaster] = None def get_broadcaster() -> AlertBroadcaster: """获取全局 AlertBroadcaster 单例。""" global _broadcaster if _broadcaster is None: _broadcaster = AlertBroadcaster() return _broadcaster # --------------------------------------------------------------------------- # WebSocket 路由 # --------------------------------------------------------------------------- @router.websocket("/ws/alerts") async def alerts_websocket(websocket: WebSocket): """前端订阅预警事件的 WebSocket 接口。""" await websocket.accept() broadcaster = get_broadcaster() subscriber = await broadcaster.add_connection(websocket) try: # 发送欢迎消息 await websocket.send_json( { "type": "welcome", "data": { "message": "已连接预警频道", "subscriber_id": id(websocket), }, } ) while True: data = await websocket.receive_text() try: message = json.loads(data) except json.JSONDecodeError: await websocket.send_json( {"type": "error", "data": {"message": "无效的 JSON"}} ) continue action = message.get("action") if action == "subscribe": filter_cfg = message.get("filter") or {} subscriber.update_filter( event_types=filter_cfg.get("event_types"), source_ids=filter_cfg.get("source_ids"), ) await websocket.send_json( { "type": "subscribed", "data": { "event_types": list(subscriber.event_types) or "all", "source_ids": list(subscriber.source_ids) or "all", }, } ) elif action == "unsubscribe": subscriber.update_filter() await websocket.send_json( {"type": "unsubscribed", "data": {}} ) elif action == "ping": await websocket.send_json({"type": "pong", "data": {}}) else: await websocket.send_json( { "type": "error", "data": {"message": f"未知 action: {action}"}, } ) except WebSocketDisconnect: logger.info("预警订阅者断开连接") except Exception as e: # noqa: BLE001 logger.error("预警 WebSocket 异常: %s", e) finally: await broadcaster.remove_connection(websocket) @router.get("/alerts/subscribers") async def get_subscribers(): """获取当前预警订阅者状态 (调试用)。""" return get_broadcaster().get_stats() __all__ = ["router", "AlertBroadcaster", "get_broadcaster"]