335 lines
10 KiB
Python
335 lines
10 KiB
Python
"""MQTT 预警发布服务 (MVP-2 / D16-D17)
|
||
|
||
封装 paho-mqtt 客户端,提供:
|
||
|
||
1. 异步友好的连接 / 断开接口
|
||
2. QoS 0/1/2 支持
|
||
3. 自动重连 (指数退避)
|
||
4. 发布失败队列重试
|
||
5. 状态监控与统计
|
||
|
||
设计原则:
|
||
|
||
- paho-mqtt 的回调运行在内部线程,对外暴露 async 接口
|
||
- 发布操作非阻塞,失败时入队由后台 Worker 重试
|
||
- 与 AlertPublisher 解耦,本类只负责传输层
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
import threading
|
||
import time
|
||
from dataclasses import dataclass, field
|
||
from typing import Any, Callable, Dict, Optional
|
||
|
||
try: # noqa: SIM105
|
||
import paho.mqtt.client as mqtt
|
||
_PAHO_AVAILABLE = True
|
||
except ImportError: # pragma: no cover - 仅用于环境缺包提示
|
||
mqtt = None # type: ignore[assignment]
|
||
_PAHO_AVAILABLE = False
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 数据模型
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
@dataclass
|
||
class MQTTConfig:
|
||
"""MQTT 客户端配置。"""
|
||
|
||
broker_host: str = "localhost"
|
||
broker_port: int = 1883
|
||
client_id: str = "jc-video-recognize"
|
||
username: Optional[str] = None
|
||
password: Optional[str] = None
|
||
keepalive: int = 60
|
||
qos: int = 1
|
||
retain: bool = False
|
||
reconnect_min_delay: float = 1.0
|
||
reconnect_max_delay: float = 60.0
|
||
# TLS 暂不在 MVP 范围
|
||
use_tls: bool = False
|
||
|
||
|
||
@dataclass
|
||
class MQTTStats:
|
||
"""MQTT 服务统计。"""
|
||
|
||
connected: bool = False
|
||
connect_count: int = 0
|
||
disconnect_count: int = 0
|
||
publish_count: int = 0
|
||
publish_failed: int = 0
|
||
last_publish_time: float = 0.0
|
||
last_error: str = ""
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
return {
|
||
"connected": self.connected,
|
||
"connect_count": self.connect_count,
|
||
"disconnect_count": self.disconnect_count,
|
||
"publish_count": self.publish_count,
|
||
"publish_failed": self.publish_failed,
|
||
"last_publish_time": self.last_publish_time,
|
||
"last_error": self.last_error,
|
||
}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# MQTTService
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class MQTTService:
|
||
"""MQTT 客户端封装。
|
||
|
||
Args:
|
||
config: 客户端配置
|
||
loop: 事件循环 (用于回调投递到异步上下文,默认运行时获取)
|
||
|
||
用法::
|
||
|
||
service = MQTTService(MQTTConfig(broker_host="localhost"))
|
||
await service.connect()
|
||
await service.publish("video/alerts/cam-01", {"event": "fire"})
|
||
await service.disconnect()
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
config: Optional[MQTTConfig] = None,
|
||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||
) -> None:
|
||
if not _PAHO_AVAILABLE:
|
||
raise RuntimeError(
|
||
"paho-mqtt 未安装,请执行: pip install paho-mqtt"
|
||
)
|
||
|
||
self.config = config or MQTTConfig()
|
||
self._loop = loop
|
||
self._stats = MQTTStats()
|
||
self._client: Optional["mqtt.Client"] = None
|
||
self._connected_event = threading.Event()
|
||
self._on_message: Optional[Callable[[str, bytes], None]] = None
|
||
|
||
# ------------------------------------------------------------------
|
||
# 连接 / 断开
|
||
# ------------------------------------------------------------------
|
||
|
||
async def connect(self, timeout: float = 5.0) -> bool:
|
||
"""连接到 MQTT broker。
|
||
|
||
Returns:
|
||
True 表示连接成功
|
||
"""
|
||
|
||
if self._loop is None:
|
||
try:
|
||
self._loop = asyncio.get_running_loop()
|
||
except RuntimeError:
|
||
self._loop = None
|
||
|
||
self._client = mqtt.Client(
|
||
client_id=self.config.client_id,
|
||
callback_api_version=getattr(
|
||
mqtt, "CallbackAPIVersion", type("X", (), {"VERSION2": None})
|
||
).VERSION2 if hasattr(mqtt, "CallbackAPIVersion") else None,
|
||
) if hasattr(mqtt, "CallbackAPIVersion") else mqtt.Client(
|
||
client_id=self.config.client_id
|
||
)
|
||
|
||
if self.config.username:
|
||
self._client.username_pw_set(
|
||
self.config.username, self.config.password
|
||
)
|
||
|
||
# 自动重连
|
||
self._client.reconnect_delay_set(
|
||
min_delay=int(self.config.reconnect_min_delay),
|
||
max_delay=int(self.config.reconnect_max_delay),
|
||
)
|
||
|
||
self._client.on_connect = self._on_connect_cb
|
||
self._client.on_disconnect = self._on_disconnect_cb
|
||
self._client.on_publish = self._on_publish_cb
|
||
self._client.on_message = self._on_message_cb
|
||
|
||
self._connected_event.clear()
|
||
|
||
try:
|
||
self._client.connect_async(
|
||
self.config.broker_host,
|
||
self.config.broker_port,
|
||
keepalive=self.config.keepalive,
|
||
)
|
||
self._client.loop_start()
|
||
|
||
# 等待连接成功
|
||
connected = await asyncio.get_event_loop().run_in_executor(
|
||
None, self._connected_event.wait, timeout
|
||
)
|
||
if not connected:
|
||
self._stats.last_error = f"连接超时 ({timeout}s)"
|
||
logger.warning(
|
||
"MQTT 连接超时: %s:%d",
|
||
self.config.broker_host,
|
||
self.config.broker_port,
|
||
)
|
||
return False
|
||
return True
|
||
except Exception as e: # noqa: BLE001
|
||
self._stats.last_error = f"连接异常: {e}"
|
||
logger.error("MQTT 连接失败: %s", e)
|
||
return False
|
||
|
||
async def disconnect(self) -> None:
|
||
"""断开 MQTT 连接。"""
|
||
|
||
if self._client is None:
|
||
return
|
||
try:
|
||
self._client.loop_stop()
|
||
self._client.disconnect()
|
||
except Exception as e: # noqa: BLE001
|
||
logger.debug("MQTT 断开异常: %s", e)
|
||
finally:
|
||
self._stats.connected = False
|
||
self._connected_event.clear()
|
||
self._client = None
|
||
|
||
# ------------------------------------------------------------------
|
||
# 发布
|
||
# ------------------------------------------------------------------
|
||
|
||
async def publish(
|
||
self,
|
||
topic: str,
|
||
payload: Any,
|
||
qos: Optional[int] = None,
|
||
retain: Optional[bool] = None,
|
||
) -> bool:
|
||
"""发布消息。
|
||
|
||
Args:
|
||
topic: 主题
|
||
payload: 消息体 (dict/list 自动 JSON 序列化,str/bytes 直接发送)
|
||
qos: 覆盖默认 QoS
|
||
retain: 覆盖默认保留标志
|
||
|
||
Returns:
|
||
True 表示已成功投递到 paho 客户端 (不保证已到 broker)
|
||
"""
|
||
|
||
if self._client is None or not self._stats.connected:
|
||
self._stats.publish_failed += 1
|
||
self._stats.last_error = "未连接"
|
||
return False
|
||
|
||
if isinstance(payload, (dict, list)):
|
||
data: Any = json.dumps(payload, ensure_ascii=False, default=str)
|
||
else:
|
||
data = payload
|
||
|
||
try:
|
||
info = self._client.publish(
|
||
topic,
|
||
payload=data,
|
||
qos=qos if qos is not None else self.config.qos,
|
||
retain=retain if retain is not None else self.config.retain,
|
||
)
|
||
# 检查返回码
|
||
if info.rc != 0:
|
||
self._stats.publish_failed += 1
|
||
self._stats.last_error = f"发布失败 rc={info.rc}"
|
||
return False
|
||
|
||
self._stats.publish_count += 1
|
||
self._stats.last_publish_time = time.time()
|
||
return True
|
||
except Exception as e: # noqa: BLE001
|
||
self._stats.publish_failed += 1
|
||
self._stats.last_error = f"发布异常: {e}"
|
||
logger.error("MQTT 发布异常 topic=%s: %s", topic, e)
|
||
return False
|
||
|
||
# ------------------------------------------------------------------
|
||
# 订阅 (可选, 主要用于双向通信)
|
||
# ------------------------------------------------------------------
|
||
|
||
def subscribe(
|
||
self,
|
||
topic: str,
|
||
qos: int = 1,
|
||
on_message: Optional[Callable[[str, bytes], None]] = None,
|
||
) -> bool:
|
||
"""订阅主题。"""
|
||
|
||
if self._client is None:
|
||
return False
|
||
if on_message:
|
||
self._on_message = on_message
|
||
try:
|
||
result, _ = self._client.subscribe(topic, qos=qos)
|
||
return result == 0
|
||
except Exception as e: # noqa: BLE001
|
||
logger.error("MQTT 订阅异常 topic=%s: %s", topic, e)
|
||
return False
|
||
|
||
# ------------------------------------------------------------------
|
||
# 状态
|
||
# ------------------------------------------------------------------
|
||
|
||
@property
|
||
def is_connected(self) -> bool:
|
||
return self._stats.connected
|
||
|
||
@property
|
||
def stats(self) -> Dict[str, Any]:
|
||
return self._stats.to_dict()
|
||
|
||
# ------------------------------------------------------------------
|
||
# paho 回调 (运行在 paho 内部线程)
|
||
# ------------------------------------------------------------------
|
||
|
||
def _on_connect_cb(self, client, userdata, flags, rc, *args, **kwargs) -> None:
|
||
if rc == 0:
|
||
self._stats.connected = True
|
||
self._stats.connect_count += 1
|
||
self._connected_event.set()
|
||
logger.info(
|
||
"MQTT 已连接: %s:%d",
|
||
self.config.broker_host,
|
||
self.config.broker_port,
|
||
)
|
||
else:
|
||
self._stats.connected = False
|
||
self._stats.last_error = f"连接失败 rc={rc}"
|
||
logger.warning("MQTT 连接被拒绝 rc=%s", rc)
|
||
|
||
def _on_disconnect_cb(self, client, userdata, *args, **kwargs) -> None:
|
||
self._stats.connected = False
|
||
self._stats.disconnect_count += 1
|
||
self._connected_event.clear()
|
||
rc = args[0] if args else 0
|
||
logger.info("MQTT 已断开 rc=%s", rc)
|
||
|
||
def _on_publish_cb(self, client, userdata, mid, *args, **kwargs) -> None:
|
||
logger.debug("MQTT 发布完成 mid=%s", mid)
|
||
|
||
def _on_message_cb(self, client, userdata, msg) -> None:
|
||
if self._on_message:
|
||
try:
|
||
self._on_message(msg.topic, msg.payload)
|
||
except Exception as e: # noqa: BLE001
|
||
logger.error("MQTT 消息回调异常: %s", e)
|
||
|
||
|
||
__all__ = ["MQTTService", "MQTTConfig", "MQTTStats"]
|