From 40fd3089a72b0404d59944981c01e7850ed2839d Mon Sep 17 00:00:00 2001 From: wuzhuorong <973204353@qq.com> Date: Fri, 12 Jun 2026 13:56:35 +0800 Subject: [PATCH] =?UTF-8?q?feat(server):=20=E6=96=B0=E5=A2=9ERTSP=E5=A4=9A?= =?UTF-8?q?=E8=B7=AF=E8=A7=86=E9=A2=91=E6=B5=81=E6=8E=A5=E5=85=A5=E6=9C=8D?= =?UTF-8?q?=E5=8A=A1=E3=80=82-=20=E5=AE=9E=E7=8E=B0=E5=9F=BA=E4=BA=8ERing?= =?UTF-8?q?=20Buffer=E7=9A=84=E5=B8=A7=E7=BC=93=E5=86=B2=E5=8C=BA(frame=5F?= =?UTF-8?q?buffer)=EF=BC=8C=E6=94=AF=E6=8C=81=E7=BA=BF=E7=A8=8B=E5=AE=89?= =?UTF-8?q?=E5=85=A8=E8=AF=BB=E5=86=99=20-=20=E5=AE=9E=E7=8E=B0RTSP?= =?UTF-8?q?=E6=B5=81=E6=8E=A5=E5=85=A5=E6=9C=8D=E5=8A=A1(rtsp=5Fservice)?= =?UTF-8?q?=EF=BC=8C=E6=94=AF=E6=8C=81=E5=8D=95=E8=B7=AF=E6=B5=81=E8=BF=9E?= =?UTF-8?q?=E6=8E=A5/=E8=A7=A3=E7=A0=81/=E5=B8=A7=E9=87=87=E9=9B=86=20-=20?= =?UTF-8?q?=E5=AE=9E=E7=8E=B0=E5=A4=9A=E8=B7=AF=E6=B5=81=E8=B0=83=E5=BA=A6?= =?UTF-8?q?=E7=AE=A1=E7=90=86=E5=99=A8(stream=5Fmanager)=EF=BC=8C=E7=BB=9F?= =?UTF-8?q?=E4=B8=80=E7=AE=A1=E7=90=86=E5=A4=9A=E8=B7=AFRTSP=E6=B5=81?= =?UTF-8?q?=E5=90=AF=E5=81=9C=E4=B8=8E=E7=8A=B6=E6=80=81=E7=9B=91=E6=8E=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/server/services/frame_buffer.py | 254 +++++++++++++++ apps/server/services/rtsp_service.py | 396 ++++++++++++++++++++++++ apps/server/services/stream_manager.py | 407 +++++++++++++++++++++++++ 3 files changed, 1057 insertions(+) create mode 100644 apps/server/services/frame_buffer.py create mode 100644 apps/server/services/rtsp_service.py create mode 100644 apps/server/services/stream_manager.py diff --git a/apps/server/services/frame_buffer.py b/apps/server/services/frame_buffer.py new file mode 100644 index 0000000..3b8e09b --- /dev/null +++ b/apps/server/services/frame_buffer.py @@ -0,0 +1,254 @@ +"""帧缓冲区 (MVP-2 / D15) + +基于 Ring Buffer 的帧缓冲,配合丢帧策略,避免多路 RTSP 流场景下 +内存无限增长。 + +核心设计: + +1. 固定容量的环形缓冲区,写满后自动覆盖最旧帧 +2. 支持按策略丢帧: 最新帧优先 (实时性) / 均匀采样 (覆盖率) +3. 线程安全: 使用 asyncio.Lock 保护并发读写 +4. 帧元数据: 每帧附带 stream_id / timestamp / frame_index +""" + +from __future__ import annotations + +import asyncio +import logging +import time +from collections import deque +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np + +logger = logging.getLogger(__name__) + + +class DropPolicy(str, Enum): + """丢帧策略。""" + + LATEST = "latest" # 保留最新帧,覆盖最旧帧 (默认,适合实时检测) + SAMPLE = "sample" # 均匀采样保留,丢弃中间帧 (适合回溯分析) + + +@dataclass +class FrameMeta: + """帧元数据。""" + + stream_id: str + frame_index: int + timestamp: float + width: int = 0 + height: int = 0 + + +@dataclass +class FrameItem: + """缓冲区中的帧条目。""" + + frame: np.ndarray + meta: FrameMeta + + +class FrameBuffer: + """环形帧缓冲区。 + + Args: + capacity: 缓冲区最大帧数 + drop_policy: 丢帧策略 + max_memory_mb: 内存上限 (MB),超过时强制丢帧;0 表示不限制 + """ + + def __init__( + self, + capacity: int = 300, + drop_policy: DropPolicy = DropPolicy.LATEST, + max_memory_mb: float = 0, + ) -> None: + self.capacity = max(1, capacity) + self.drop_policy = drop_policy + self.max_memory_mb = max(0.0, max_memory_mb) + self._buffer: deque[FrameItem] = deque(maxlen=self.capacity) + self._lock = asyncio.Lock() + self._total_written: int = 0 + self._total_dropped: int = 0 + + # ------------------------------------------------------------------ + # 写入 + # ------------------------------------------------------------------ + + async def write( + self, + frame: np.ndarray, + stream_id: str, + frame_index: int, + timestamp: Optional[float] = None, + ) -> None: + """写入一帧到缓冲区。 + + 当缓冲区已满时,根据 ``drop_policy`` 决定丢弃策略。 + """ + + meta = FrameMeta( + stream_id=stream_id, + frame_index=frame_index, + timestamp=timestamp or time.time(), + width=frame.shape[1] if frame.ndim >= 2 else 0, + height=frame.shape[0] if frame.ndim >= 2 else 0, + ) + item = FrameItem(frame=frame, meta=meta) + + async with self._lock: + self._total_written += 1 + + if len(self._buffer) >= self.capacity: + self._apply_drop_policy(item) + else: + self._buffer.append(item) + + # 内存上限检查 + if self.max_memory_mb > 0: + self._enforce_memory_limit() + + # ------------------------------------------------------------------ + # 读取 + # ------------------------------------------------------------------ + + async def read_latest(self) -> Optional[FrameItem]: + """读取最新一帧 (不消费)。""" + + async with self._lock: + if not self._buffer: + return None + return self._buffer[-1] + + async def read_oldest(self) -> Optional[FrameItem]: + """读取最旧一帧 (不消费)。""" + + async with self._lock: + if not self._buffer: + return None + return self._buffer[0] + + async def read_all(self) -> List[FrameItem]: + """读取缓冲区所有帧 (快照,不消费)。""" + + async with self._lock: + return list(self._buffer) + + async def read_range( + self, + start_index: int = 0, + count: Optional[int] = None, + ) -> List[FrameItem]: + """读取指定范围的帧 (快照)。 + + Args: + start_index: 从缓冲区开头的偏移量 + count: 读取帧数,None 表示到末尾 + """ + + async with self._lock: + items = list(self._buffer) + if start_index >= len(items): + return [] + end = len(items) if count is None else start_index + count + return items[start_index:end] + + async def pop_latest(self) -> Optional[FrameItem]: + """弹出最新一帧 (消费)。""" + + async with self._lock: + if not self._buffer: + return None + return self._buffer.pop() + + async def pop_oldest(self) -> Optional[FrameItem]: + """弹出最旧一帧 (消费)。""" + + async with self._lock: + if not self._buffer: + return None + return self._buffer.popleft() + + # ------------------------------------------------------------------ + # 状态 + # ------------------------------------------------------------------ + + async def clear(self) -> None: + """清空缓冲区。""" + + async with self._lock: + self._buffer.clear() + + @property + def size(self) -> int: + """当前缓冲区帧数。""" + + return len(self._buffer) + + @property + def stats(self) -> Dict[str, Any]: + """缓冲区统计信息。""" + + return { + "size": len(self._buffer), + "capacity": self.capacity, + "total_written": self._total_written, + "total_dropped": self._total_dropped, + "drop_policy": self.drop_policy.value, + "usage_percent": round(len(self._buffer) / self.capacity * 100, 1), + } + + def estimate_memory_mb(self) -> float: + """估算当前缓冲区占用内存 (MB)。""" + + if not self._buffer: + return 0.0 + # 取第一帧估算单帧大小 + sample = self._buffer[0].frame + frame_bytes = sample.nbytes if isinstance(sample, np.ndarray) else 0 + return len(self._buffer) * frame_bytes / (1024 * 1024) + + # ------------------------------------------------------------------ + # 内部 + # ------------------------------------------------------------------ + + def _apply_drop_policy(self, new_item: FrameItem) -> None: + """缓冲区满时应用丢帧策略。""" + + if self.drop_policy == DropPolicy.LATEST: + # 覆盖最旧帧 (deque maxlen 自动处理) + self._total_dropped += 1 + self._buffer.append(new_item) + elif self.drop_policy == DropPolicy.SAMPLE: + # 均匀采样: 丢弃偶数位置的帧,腾出空间 + sampled = deque(maxlen=self.capacity) + step = 2 + for i, item in enumerate(self._buffer): + if i % step != 0: + self._total_dropped += 1 + else: + sampled.append(item) + sampled.append(new_item) + self._buffer = sampled + + def _enforce_memory_limit(self) -> None: + """强制执行内存上限,超出时丢弃最旧帧。""" + + while self.max_memory_mb > 0 and self._buffer: + current_mb = self.estimate_memory_mb() + if current_mb <= self.max_memory_mb: + break + self._buffer.popleft() + self._total_dropped += 1 + logger.debug( + "FrameBuffer 内存超限 (%.1f > %.1f MB),丢弃最旧帧", + current_mb, + self.max_memory_mb, + ) + + +__all__ = ["FrameBuffer", "FrameItem", "FrameMeta", "DropPolicy"] diff --git a/apps/server/services/rtsp_service.py b/apps/server/services/rtsp_service.py new file mode 100644 index 0000000..ec1cae8 --- /dev/null +++ b/apps/server/services/rtsp_service.py @@ -0,0 +1,396 @@ +"""RTSP 流接入服务 (MVP-2 / D11-D12) + +负责单路 RTSP 流的连接、解码、自动重连和帧产出。 + +核心设计: + +1. 基于 OpenCV VideoCapture 的 RTSP 接入,兼容主流 IP 摄像头 +2. 后台线程解码帧,避免阻塞事件循环 +3. 自动重连: 断线后按指数退避策略重试 +4. 帧回调: 每解码一帧触发回调,由 StreamManager 分发到检测管道 +5. 优雅关闭: stop() 等待解码线程退出,释放资源 + +使用方式:: + + service = RTSPService( + stream_id="cam-01", + rtsp_url="rtsp://admin:pass@192.168.1.100:554/stream", + on_frame=handle_frame, + ) + await service.start() + ... + await service.stop() +""" + +from __future__ import annotations + +import asyncio +import enum +import logging +import threading +import time +from dataclasses import dataclass, field +from typing import Any, Callable, Coroutine, Dict, Optional + +import cv2 +import numpy as np + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# 数据模型 +# --------------------------------------------------------------------------- + + +class StreamStatus(str, enum.Enum): + """流状态。""" + + IDLE = "idle" # 未启动 + CONNECTING = "connecting" # 连接中 + CONNECTED = "connected" # 已连接,正在解码 + RECONNECTING = "reconnecting" # 断线重连中 + STOPPED = "stopped" # 已停止 + ERROR = "error" # 不可恢复错误 + + +@dataclass +class StreamConfig: + """单路 RTSP 流配置。""" + + stream_id: str + rtsp_url: str + # 解码参数 + reconnect_attempts: int = 10 # 最大重连次数,0 = 无限 + reconnect_interval_base: float = 2.0 # 首次重连间隔 (秒) + reconnect_interval_max: float = 60.0 # 最大重连间隔 (秒) + reconnect_backoff_factor: float = 2.0 # 退避因子 + # 帧采样 + frame_skip: int = 0 # 每隔 N 帧取 1 帧,0 = 每帧都取 + # OpenCV 参数 + buffer_size: int = 1 # FFmpeg 缓冲区大小 (越小延迟越低) + # 超时 + read_timeout: float = 5.0 # 单帧读取超时 (秒) + # 检测配置 + model_id: str = "fire_detection" + confidence: float = 0.5 + iou: float = 0.45 + + +@dataclass +class StreamInfo: + """流运行时信息。""" + + stream_id: str + status: StreamStatus = StreamStatus.IDLE + rtsp_url: str = "" + # 统计 + frames_decoded: int = 0 + frames_dropped: int = 0 + reconnect_count: int = 0 + last_frame_time: float = 0.0 + fps: float = 0.0 + # 时间 + connected_at: float = 0.0 + error_message: str = "" + + def to_dict(self) -> Dict[str, Any]: + return { + "stream_id": self.stream_id, + "status": self.status.value, + "rtsp_url": self._mask_url(self.rtsp_url), + "frames_decoded": self.frames_decoded, + "frames_dropped": self.frames_dropped, + "reconnect_count": self.reconnect_count, + "fps": round(self.fps, 2), + "connected_at": self.connected_at, + "error_message": self.error_message, + } + + @staticmethod + def _mask_url(url: str) -> str: + """遮蔽 RTSP URL 中的密码。""" + if "@" not in url: + return url + try: + prefix, rest = url.split("://", 1) + creds_host = rest.split("@", 1) + if len(creds_host) == 2: + creds, host_path = creds_host + if ":" in creds: + user, _ = creds.split(":", 1) + return f"{prefix}://{user}:****@{host_path}" + except Exception: + pass + return url + + +# --------------------------------------------------------------------------- +# 帧回调类型 +# --------------------------------------------------------------------------- + +# on_frame(stream_id, frame, frame_index, timestamp) -> None +FrameCallback = Callable[[str, np.ndarray, int, float], Coroutine[Any, Any, None]] + + +# --------------------------------------------------------------------------- +# RTSPService +# --------------------------------------------------------------------------- + + +class RTSPService: + """单路 RTSP 流接入服务。 + + 在后台线程中执行 OpenCV 解码循环,通过 asyncio 事件循环 + 将帧投递到异步回调,不阻塞主事件循环。 + """ + + def __init__( + self, + stream_id: str, + rtsp_url: str, + on_frame: Optional[FrameCallback] = None, + config: Optional[StreamConfig] = None, + loop: Optional[asyncio.AbstractEventLoop] = None, + ) -> None: + self.config = config or StreamConfig( + stream_id=stream_id, rtsp_url=rtsp_url + ) + self.config.stream_id = stream_id + self.config.rtsp_url = rtsp_url + + self._on_frame = on_frame + self._loop = loop + + self._info = StreamInfo( + stream_id=stream_id, + rtsp_url=rtsp_url, + ) + self._cap: Optional[cv2.VideoCapture] = None + self._thread: Optional[threading.Thread] = None + self._stop_event = threading.Event() + self._frame_index: int = 0 + + # ------------------------------------------------------------------ + # 生命周期 + # ------------------------------------------------------------------ + + async def start(self) -> None: + """启动 RTSP 流解码。""" + + if self._info.status in (StreamStatus.CONNECTED, StreamStatus.CONNECTING): + logger.warning("RTSP 流 %s 已在运行中", self.config.stream_id) + return + + if self._loop is None: + self._loop = asyncio.get_running_loop() + + self._stop_event.clear() + self._frame_index = 0 + self._info.status = StreamStatus.CONNECTING + + self._thread = threading.Thread( + target=self._decode_loop, + name=f"rtsp-{self.config.stream_id}", + daemon=True, + ) + self._thread.start() + logger.info("RTSP 流 %s 解码线程已启动", self.config.stream_id) + + async def stop(self) -> None: + """停止 RTSP 流解码,释放资源。""" + + if self._info.status == StreamStatus.STOPPED: + return + + self._stop_event.set() + self._info.status = StreamStatus.STOPPED + + if self._thread and self._thread.is_alive(): + self._thread.join(timeout=5.0) + + self._release_capture() + logger.info( + "RTSP 流 %s 已停止, 共解码 %d 帧", + self.config.stream_id, + self._info.frames_decoded, + ) + + # ------------------------------------------------------------------ + # 状态 + # ------------------------------------------------------------------ + + @property + def info(self) -> StreamInfo: + return self._info + + @property + def status(self) -> StreamStatus: + return self._info.status + + @property + def is_running(self) -> bool: + return self._info.status in ( + StreamStatus.CONNECTED, + StreamStatus.CONNECTING, + StreamStatus.RECONNECTING, + ) + + # ------------------------------------------------------------------ + # 解码循环 (后台线程) + # ------------------------------------------------------------------ + + def _decode_loop(self) -> None: + """后台线程: RTSP 解码 + 自动重连。""" + + attempt = 0 + + while not self._stop_event.is_set(): + # 尝试连接 + connected = self._connect() + if not connected: + if self._stop_event.is_set(): + break + attempt += 1 + if ( + self.config.reconnect_attempts > 0 + and attempt > self.config.reconnect_attempts + ): + self._info.status = StreamStatus.ERROR + self._info.error_message = ( + f"超过最大重连次数 ({self.config.reconnect_attempts})" + ) + logger.error( + "RTSP 流 %s %s", + self.config.stream_id, + self._info.error_message, + ) + break + + # 指数退避 + interval = min( + self.config.reconnect_interval_base + * (self.config.reconnect_backoff_factor ** (attempt - 1)), + self.config.reconnect_interval_max, + ) + self._info.status = StreamStatus.RECONNECTING + self._info.reconnect_count += 1 + logger.warning( + "RTSP 流 %s 连接失败,第 %d 次重连,等待 %.1fs", + self.config.stream_id, + attempt, + interval, + ) + self._stop_event.wait(timeout=interval) + continue + + # 连接成功,重置计数 + attempt = 0 + self._info.status = StreamStatus.CONNECTED + self._info.connected_at = time.time() + logger.info("RTSP 流 %s 已连接: %s", self.config.stream_id, self.config.rtsp_url) + + # 解码帧 + self._read_frames() + + # 如果 read_frames 退出且未被停止,说明断线了 + if not self._stop_event.is_set(): + self._release_capture() + self._info.status = StreamStatus.RECONNECTING + logger.warning("RTSP 流 %s 断线,准备重连", self.config.stream_id) + + def _connect(self) -> bool: + """尝试连接 RTSP 流。""" + + try: + cap = cv2.VideoCapture(self.config.rtsp_url, cv2.CAP_FFMPEG) + # 降低缓冲以减少延迟 + cap.set(cv2.CAP_PROP_BUFFERSIZE, self.config.buffer_size) + + if not cap.isOpened(): + return False + + # 验证: 尝试读取一帧 + ret, _ = cap.read() + if not ret: + cap.release() + return False + + self._cap = cap + return True + + except Exception as e: + logger.debug("RTSP 流 %s 连接异常: %s", self.config.stream_id, e) + return False + + def _read_frames(self) -> None: + """持续读取帧直到断线或停止信号。""" + + if self._cap is None: + return + + fps_counter_start = time.time() + fps_frame_count = 0 + + while not self._stop_event.is_set(): + try: + ret, frame = self._cap.read() + except Exception as e: + logger.warning("RTSP 流 %s 读取异常: %s", self.config.stream_id, e) + break + + if not ret or frame is None: + logger.warning("RTSP 流 %s 读取帧失败,可能断线", self.config.stream_id) + break + + self._frame_index += 1 + self._info.frames_decoded += 1 + self._info.last_frame_time = time.time() + + # 帧采样 + if self.config.frame_skip > 0 and self._frame_index % (self.config.frame_skip + 1) != 1: + self._info.frames_dropped += 1 + continue + + # FPS 统计 + fps_frame_count += 1 + elapsed = time.time() - fps_counter_start + if elapsed >= 1.0: + self._info.fps = fps_frame_count / elapsed + fps_frame_count = 0 + fps_counter_start = time.time() + + # 通过事件循环投递帧到异步回调 + if self._on_frame and self._loop and not self._loop.is_closed(): + try: + asyncio.run_coroutine_threadsafe( + self._on_frame( + self.config.stream_id, + frame, + self._frame_index, + self._info.last_frame_time, + ), + self._loop, + ) + except RuntimeError as e: + logger.debug("投递帧回调失败 (事件循环可能已关闭): %s", e) + break + + def _release_capture(self) -> None: + """释放 VideoCapture 资源。""" + + if self._cap is not None: + try: + self._cap.release() + except Exception: + pass + self._cap = None + + +__all__ = [ + "RTSPService", + "StreamConfig", + "StreamInfo", + "StreamStatus", + "FrameCallback", +] diff --git a/apps/server/services/stream_manager.py b/apps/server/services/stream_manager.py new file mode 100644 index 0000000..86f4222 --- /dev/null +++ b/apps/server/services/stream_manager.py @@ -0,0 +1,407 @@ +"""多路流调度管理器 (MVP-2 / D13-D14) + +负责管理多路 RTSP 流的生命周期、帧缓冲、状态监控和检测调度。 + +核心设计: + +1. 统一管理多路 RTSPService 实例 +2. 每路流对应一个 FrameBuffer,解耦解码与检测 +3. 检测调度: 轮询 / 事件驱动,按流优先级分配检测资源 +4. 状态监控: 汇总所有流状态,提供健康检查接口 +5. 优雅关闭: 按序停止所有流,等待资源释放 + +使用方式:: + + manager = StreamManager(model_service=model_service) + await manager.add_stream("cam-01", "rtsp://admin:pass@192.168.1.100:554/stream") + await manager.start_stream("cam-01") + ... + info = manager.get_stream_info("cam-01") + ... + await manager.stop_all() +""" + +from __future__ import annotations + +import asyncio +import logging +import time +from typing import Any, Callable, Coroutine, Dict, List, Optional + +import numpy as np + +from .frame_buffer import DropPolicy, FrameBuffer +from .rtsp_service import FrameCallback, RTSPService, StreamConfig, StreamStatus + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# 流条目 +# --------------------------------------------------------------------------- + + +class _StreamEntry: + """管理器内部: 单路流的完整上下文。""" + + __slots__ = ("service", "buffer", "config", "detect_task") + + def __init__( + self, + service: RTSPService, + buffer: FrameBuffer, + config: StreamConfig, + ) -> None: + self.service = service + self.buffer = buffer + self.config = config + self.detect_task: Optional[asyncio.Task] = None + + +# --------------------------------------------------------------------------- +# StreamManager +# --------------------------------------------------------------------------- + + +class StreamManager: + """多路 RTSP 流调度管理器。 + + Args: + model_service: 模型服务实例,用于创建 DetectionService + buffer_capacity: 每路流帧缓冲区容量 + buffer_drop_policy: 帧缓冲区丢帧策略 + max_streams: 最大流数量 + detect_interval: 检测轮询间隔 (秒),0 = 每帧检测 + """ + + def __init__( + self, + model_service: Any = None, + buffer_capacity: int = 300, + buffer_drop_policy: DropPolicy = DropPolicy.LATEST, + max_streams: int = 16, + detect_interval: float = 0.0, + ) -> None: + self._model_service = model_service + self._buffer_capacity = buffer_capacity + self._buffer_drop_policy = buffer_drop_policy + self._max_streams = max(1, max_streams) + self._detect_interval = detect_interval + + self._streams: Dict[str, _StreamEntry] = {} + self._lock = asyncio.Lock() + self._running = False + + # 帧回调: 写入缓冲区 + 触发检测 + self._on_detect: Optional[ + Callable[[str, np.ndarray, int, float], Coroutine[Any, Any, None]] + ] = None + + # ------------------------------------------------------------------ + # 流管理 + # ------------------------------------------------------------------ + + async def add_stream( + self, + stream_id: str, + rtsp_url: str, + config: Optional[StreamConfig] = None, + ) -> Dict[str, Any]: + """添加一路 RTSP 流 (不立即启动)。 + + Returns: + 操作结果 {"success": bool, "message": str} + """ + + async with self._lock: + if stream_id in self._streams: + return {"success": False, "message": f"流 {stream_id} 已存在"} + + if len(self._streams) >= self._max_streams: + return { + "success": False, + "message": f"已达最大流数量 ({self._max_streams})", + } + + stream_config = config or StreamConfig( + stream_id=stream_id, rtsp_url=rtsp_url + ) + stream_config.stream_id = stream_id + stream_config.rtsp_url = rtsp_url + + buffer = FrameBuffer( + capacity=self._buffer_capacity, + drop_policy=self._buffer_drop_policy, + ) + + service = RTSPService( + stream_id=stream_id, + rtsp_url=rtsp_url, + on_frame=self._handle_frame, + config=stream_config, + ) + + self._streams[stream_id] = _StreamEntry( + service=service, + buffer=buffer, + config=stream_config, + ) + + logger.info("已添加 RTSP 流: %s (%s)", stream_id, rtsp_url) + return {"success": True, "message": f"流 {stream_id} 已添加"} + + async def remove_stream(self, stream_id: str) -> Dict[str, Any]: + """移除一路 RTSP 流 (先停止再移除)。""" + + async with self._lock: + entry = self._streams.get(stream_id) + if entry is None: + return {"success": False, "message": f"流 {stream_id} 不存在"} + + # 停止流 + await entry.service.stop() + if entry.detect_task and not entry.detect_task.done(): + entry.detect_task.cancel() + try: + await entry.detect_task + except asyncio.CancelledError: + pass + + # 清空缓冲区 + await entry.buffer.clear() + + del self._streams[stream_id] + logger.info("已移除 RTSP 流: %s", stream_id) + return {"success": True, "message": f"流 {stream_id} 已移除"} + + async def start_stream(self, stream_id: str) -> Dict[str, Any]: + """启动一路 RTSP 流。""" + + entry = self._streams.get(stream_id) + if entry is None: + return {"success": False, "message": f"流 {stream_id} 不存在"} + + if entry.service.is_running: + return {"success": False, "message": f"流 {stream_id} 已在运行中"} + + await entry.service.start() + + # 启动检测轮询任务 + if self._on_detect: + entry.detect_task = asyncio.create_task( + self._detect_loop(stream_id), + name=f"detect-{stream_id}", + ) + + logger.info("已启动 RTSP 流: %s", stream_id) + return {"success": True, "message": f"流 {stream_id} 已启动"} + + async def stop_stream(self, stream_id: str) -> Dict[str, Any]: + """停止单路 RTSP 流 (不移除)。""" + + entry = self._streams.get(stream_id) + if entry is None: + return {"success": False, "message": f"流 {stream_id} 不存在"} + + await entry.service.stop() + + if entry.detect_task and not entry.detect_task.done(): + entry.detect_task.cancel() + try: + await entry.detect_task + except asyncio.CancelledError: + pass + entry.detect_task = None + + logger.info("已停止 RTSP 流: %s", stream_id) + return {"success": True, "message": f"流 {stream_id} 已停止"} + + async def start_all(self) -> None: + """启动所有已添加的流。""" + + self._running = True + for stream_id in list(self._streams.keys()): + await self.start_stream(stream_id) + + async def stop_all(self) -> None: + """停止所有流。""" + + self._running = False + for stream_id in list(self._streams.keys()): + await self.stop_stream(stream_id) + logger.info("所有 RTSP 流已停止") + + # ------------------------------------------------------------------ + # 检测调度 + # ------------------------------------------------------------------ + + def set_detect_callback( + self, + callback: Callable[[str, np.ndarray, int, float], Coroutine[Any, Any, None]], + ) -> None: + """设置检测回调函数。 + + 回调签名: ``callback(stream_id, frame, frame_index, timestamp)`` + """ + + self._on_detect = callback + + async def _detect_loop(self, stream_id: str) -> None: + """检测轮询循环: 从缓冲区取最新帧进行检测。""" + + entry = self._streams.get(stream_id) + if entry is None: + return + + while entry.service.is_running: + try: + item = await entry.buffer.read_latest() + if item is not None and self._on_detect: + await self._on_detect( + stream_id, + item.frame, + item.meta.frame_index, + item.meta.timestamp, + ) + + if self._detect_interval > 0: + await asyncio.sleep(self._detect_interval) + else: + await asyncio.sleep(0.03) # ~30fps + + except asyncio.CancelledError: + break + except Exception as e: + logger.error("检测循环异常 (stream=%s): %s", stream_id, e) + await asyncio.sleep(1.0) + + # ------------------------------------------------------------------ + # 帧回调 + # ------------------------------------------------------------------ + + async def _handle_frame( + self, + stream_id: str, + frame: np.ndarray, + frame_index: int, + timestamp: float, + ) -> None: + """RTSPService 帧回调: 写入对应流的缓冲区。""" + + entry = self._streams.get(stream_id) + if entry is None: + return + + await entry.buffer.write( + frame=frame, + stream_id=stream_id, + frame_index=frame_index, + timestamp=timestamp, + ) + + # ------------------------------------------------------------------ + # 状态查询 + # ------------------------------------------------------------------ + + def get_stream_info(self, stream_id: str) -> Optional[Dict[str, Any]]: + """获取单路流状态信息。""" + + entry = self._streams.get(stream_id) + if entry is None: + return None + + info = entry.service.info.to_dict() + info["buffer"] = entry.buffer.stats + info["config"] = { + "model_id": entry.config.model_id, + "confidence": entry.config.confidence, + "iou": entry.config.iou, + "frame_skip": entry.config.frame_skip, + } + return info + + def get_all_streams_info(self) -> List[Dict[str, Any]]: + """获取所有流状态信息。""" + + return [ + info + for sid in self._streams + if (info := self.get_stream_info(sid)) is not None + ] + + @property + def stream_count(self) -> int: + return len(self._streams) + + @property + def active_stream_count(self) -> int: + return sum( + 1 + for e in self._streams.values() + if e.service.status == StreamStatus.CONNECTED + ) + + def get_health(self) -> Dict[str, Any]: + """获取管理器健康状态。""" + + total = len(self._streams) + active = self.active_stream_count + reconnecting = sum( + 1 + for e in self._streams.values() + if e.service.status == StreamStatus.RECONNECTING + ) + errored = sum( + 1 + for e in self._streams.values() + if e.service.status == StreamStatus.ERROR + ) + + return { + "total_streams": total, + "active_streams": active, + "reconnecting_streams": reconnecting, + "error_streams": errored, + "max_streams": self._max_streams, + "healthy": errored == 0, + } + + # ------------------------------------------------------------------ + # 流配置更新 + # ------------------------------------------------------------------ + + async def update_stream_config( + self, + stream_id: str, + model_id: Optional[str] = None, + confidence: Optional[float] = None, + iou: Optional[float] = None, + frame_skip: Optional[int] = None, + ) -> Dict[str, Any]: + """更新流的检测配置 (运行时热更新)。""" + + entry = self._streams.get(stream_id) + if entry is None: + return {"success": False, "message": f"流 {stream_id} 不存在"} + + if model_id is not None: + entry.config.model_id = model_id + if confidence is not None: + entry.config.confidence = confidence + if iou is not None: + entry.config.iou = iou + if frame_skip is not None: + entry.config.frame_skip = frame_skip + + logger.info( + "流 %s 配置已更新: model=%s, conf=%.2f, iou=%.2f, skip=%d", + stream_id, + entry.config.model_id, + entry.config.confidence, + entry.config.iou, + entry.config.frame_skip, + ) + return {"success": True, "message": f"流 {stream_id} 配置已更新"} + + +__all__ = ["StreamManager"]