- 实现RTSP流接入服务(rtsp_service),支持单路流连接/解码/帧采集 - 实现多路流调度管理器(stream_manager),统一管理多路RTSP流启停与状态监控
408 lines
13 KiB
Python
408 lines
13 KiB
Python
"""多路流调度管理器 (MVP-2 / D13-D14)
|
||
|
||
负责管理多路 RTSP 流的生命周期、帧缓冲、状态监控和检测调度。
|
||
|
||
核心设计:
|
||
|
||
1. 统一管理多路 RTSPService 实例
|
||
2. 每路流对应一个 FrameBuffer,解耦解码与检测
|
||
3. 检测调度: 轮询 / 事件驱动,按流优先级分配检测资源
|
||
4. 状态监控: 汇总所有流状态,提供健康检查接口
|
||
5. 优雅关闭: 按序停止所有流,等待资源释放
|
||
|
||
使用方式::
|
||
|
||
manager = StreamManager(model_service=model_service)
|
||
await manager.add_stream("cam-01", "rtsp://admin:pass@192.168.1.100:554/stream")
|
||
await manager.start_stream("cam-01")
|
||
...
|
||
info = manager.get_stream_info("cam-01")
|
||
...
|
||
await manager.stop_all()
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import logging
|
||
import time
|
||
from typing import Any, Callable, Coroutine, Dict, List, Optional
|
||
|
||
import numpy as np
|
||
|
||
from .frame_buffer import DropPolicy, FrameBuffer
|
||
from .rtsp_service import FrameCallback, RTSPService, StreamConfig, StreamStatus
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 流条目
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class _StreamEntry:
|
||
"""管理器内部: 单路流的完整上下文。"""
|
||
|
||
__slots__ = ("service", "buffer", "config", "detect_task")
|
||
|
||
def __init__(
|
||
self,
|
||
service: RTSPService,
|
||
buffer: FrameBuffer,
|
||
config: StreamConfig,
|
||
) -> None:
|
||
self.service = service
|
||
self.buffer = buffer
|
||
self.config = config
|
||
self.detect_task: Optional[asyncio.Task] = None
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# StreamManager
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class StreamManager:
|
||
"""多路 RTSP 流调度管理器。
|
||
|
||
Args:
|
||
model_service: 模型服务实例,用于创建 DetectionService
|
||
buffer_capacity: 每路流帧缓冲区容量
|
||
buffer_drop_policy: 帧缓冲区丢帧策略
|
||
max_streams: 最大流数量
|
||
detect_interval: 检测轮询间隔 (秒),0 = 每帧检测
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
model_service: Any = None,
|
||
buffer_capacity: int = 300,
|
||
buffer_drop_policy: DropPolicy = DropPolicy.LATEST,
|
||
max_streams: int = 16,
|
||
detect_interval: float = 0.0,
|
||
) -> None:
|
||
self._model_service = model_service
|
||
self._buffer_capacity = buffer_capacity
|
||
self._buffer_drop_policy = buffer_drop_policy
|
||
self._max_streams = max(1, max_streams)
|
||
self._detect_interval = detect_interval
|
||
|
||
self._streams: Dict[str, _StreamEntry] = {}
|
||
self._lock = asyncio.Lock()
|
||
self._running = False
|
||
|
||
# 帧回调: 写入缓冲区 + 触发检测
|
||
self._on_detect: Optional[
|
||
Callable[[str, np.ndarray, int, float], Coroutine[Any, Any, None]]
|
||
] = None
|
||
|
||
# ------------------------------------------------------------------
|
||
# 流管理
|
||
# ------------------------------------------------------------------
|
||
|
||
async def add_stream(
|
||
self,
|
||
stream_id: str,
|
||
rtsp_url: str,
|
||
config: Optional[StreamConfig] = None,
|
||
) -> Dict[str, Any]:
|
||
"""添加一路 RTSP 流 (不立即启动)。
|
||
|
||
Returns:
|
||
操作结果 {"success": bool, "message": str}
|
||
"""
|
||
|
||
async with self._lock:
|
||
if stream_id in self._streams:
|
||
return {"success": False, "message": f"流 {stream_id} 已存在"}
|
||
|
||
if len(self._streams) >= self._max_streams:
|
||
return {
|
||
"success": False,
|
||
"message": f"已达最大流数量 ({self._max_streams})",
|
||
}
|
||
|
||
stream_config = config or StreamConfig(
|
||
stream_id=stream_id, rtsp_url=rtsp_url
|
||
)
|
||
stream_config.stream_id = stream_id
|
||
stream_config.rtsp_url = rtsp_url
|
||
|
||
buffer = FrameBuffer(
|
||
capacity=self._buffer_capacity,
|
||
drop_policy=self._buffer_drop_policy,
|
||
)
|
||
|
||
service = RTSPService(
|
||
stream_id=stream_id,
|
||
rtsp_url=rtsp_url,
|
||
on_frame=self._handle_frame,
|
||
config=stream_config,
|
||
)
|
||
|
||
self._streams[stream_id] = _StreamEntry(
|
||
service=service,
|
||
buffer=buffer,
|
||
config=stream_config,
|
||
)
|
||
|
||
logger.info("已添加 RTSP 流: %s (%s)", stream_id, rtsp_url)
|
||
return {"success": True, "message": f"流 {stream_id} 已添加"}
|
||
|
||
async def remove_stream(self, stream_id: str) -> Dict[str, Any]:
|
||
"""移除一路 RTSP 流 (先停止再移除)。"""
|
||
|
||
async with self._lock:
|
||
entry = self._streams.get(stream_id)
|
||
if entry is None:
|
||
return {"success": False, "message": f"流 {stream_id} 不存在"}
|
||
|
||
# 停止流
|
||
await entry.service.stop()
|
||
if entry.detect_task and not entry.detect_task.done():
|
||
entry.detect_task.cancel()
|
||
try:
|
||
await entry.detect_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
|
||
# 清空缓冲区
|
||
await entry.buffer.clear()
|
||
|
||
del self._streams[stream_id]
|
||
logger.info("已移除 RTSP 流: %s", stream_id)
|
||
return {"success": True, "message": f"流 {stream_id} 已移除"}
|
||
|
||
async def start_stream(self, stream_id: str) -> Dict[str, Any]:
|
||
"""启动一路 RTSP 流。"""
|
||
|
||
entry = self._streams.get(stream_id)
|
||
if entry is None:
|
||
return {"success": False, "message": f"流 {stream_id} 不存在"}
|
||
|
||
if entry.service.is_running:
|
||
return {"success": False, "message": f"流 {stream_id} 已在运行中"}
|
||
|
||
await entry.service.start()
|
||
|
||
# 启动检测轮询任务
|
||
if self._on_detect:
|
||
entry.detect_task = asyncio.create_task(
|
||
self._detect_loop(stream_id),
|
||
name=f"detect-{stream_id}",
|
||
)
|
||
|
||
logger.info("已启动 RTSP 流: %s", stream_id)
|
||
return {"success": True, "message": f"流 {stream_id} 已启动"}
|
||
|
||
async def stop_stream(self, stream_id: str) -> Dict[str, Any]:
|
||
"""停止单路 RTSP 流 (不移除)。"""
|
||
|
||
entry = self._streams.get(stream_id)
|
||
if entry is None:
|
||
return {"success": False, "message": f"流 {stream_id} 不存在"}
|
||
|
||
await entry.service.stop()
|
||
|
||
if entry.detect_task and not entry.detect_task.done():
|
||
entry.detect_task.cancel()
|
||
try:
|
||
await entry.detect_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
entry.detect_task = None
|
||
|
||
logger.info("已停止 RTSP 流: %s", stream_id)
|
||
return {"success": True, "message": f"流 {stream_id} 已停止"}
|
||
|
||
async def start_all(self) -> None:
|
||
"""启动所有已添加的流。"""
|
||
|
||
self._running = True
|
||
for stream_id in list(self._streams.keys()):
|
||
await self.start_stream(stream_id)
|
||
|
||
async def stop_all(self) -> None:
|
||
"""停止所有流。"""
|
||
|
||
self._running = False
|
||
for stream_id in list(self._streams.keys()):
|
||
await self.stop_stream(stream_id)
|
||
logger.info("所有 RTSP 流已停止")
|
||
|
||
# ------------------------------------------------------------------
|
||
# 检测调度
|
||
# ------------------------------------------------------------------
|
||
|
||
def set_detect_callback(
|
||
self,
|
||
callback: Callable[[str, np.ndarray, int, float], Coroutine[Any, Any, None]],
|
||
) -> None:
|
||
"""设置检测回调函数。
|
||
|
||
回调签名: ``callback(stream_id, frame, frame_index, timestamp)``
|
||
"""
|
||
|
||
self._on_detect = callback
|
||
|
||
async def _detect_loop(self, stream_id: str) -> None:
|
||
"""检测轮询循环: 从缓冲区取最新帧进行检测。"""
|
||
|
||
entry = self._streams.get(stream_id)
|
||
if entry is None:
|
||
return
|
||
|
||
while entry.service.is_running:
|
||
try:
|
||
item = await entry.buffer.read_latest()
|
||
if item is not None and self._on_detect:
|
||
await self._on_detect(
|
||
stream_id,
|
||
item.frame,
|
||
item.meta.frame_index,
|
||
item.meta.timestamp,
|
||
)
|
||
|
||
if self._detect_interval > 0:
|
||
await asyncio.sleep(self._detect_interval)
|
||
else:
|
||
await asyncio.sleep(0.03) # ~30fps
|
||
|
||
except asyncio.CancelledError:
|
||
break
|
||
except Exception as e:
|
||
logger.error("检测循环异常 (stream=%s): %s", stream_id, e)
|
||
await asyncio.sleep(1.0)
|
||
|
||
# ------------------------------------------------------------------
|
||
# 帧回调
|
||
# ------------------------------------------------------------------
|
||
|
||
async def _handle_frame(
|
||
self,
|
||
stream_id: str,
|
||
frame: np.ndarray,
|
||
frame_index: int,
|
||
timestamp: float,
|
||
) -> None:
|
||
"""RTSPService 帧回调: 写入对应流的缓冲区。"""
|
||
|
||
entry = self._streams.get(stream_id)
|
||
if entry is None:
|
||
return
|
||
|
||
await entry.buffer.write(
|
||
frame=frame,
|
||
stream_id=stream_id,
|
||
frame_index=frame_index,
|
||
timestamp=timestamp,
|
||
)
|
||
|
||
# ------------------------------------------------------------------
|
||
# 状态查询
|
||
# ------------------------------------------------------------------
|
||
|
||
def get_stream_info(self, stream_id: str) -> Optional[Dict[str, Any]]:
|
||
"""获取单路流状态信息。"""
|
||
|
||
entry = self._streams.get(stream_id)
|
||
if entry is None:
|
||
return None
|
||
|
||
info = entry.service.info.to_dict()
|
||
info["buffer"] = entry.buffer.stats
|
||
info["config"] = {
|
||
"model_id": entry.config.model_id,
|
||
"confidence": entry.config.confidence,
|
||
"iou": entry.config.iou,
|
||
"frame_skip": entry.config.frame_skip,
|
||
}
|
||
return info
|
||
|
||
def get_all_streams_info(self) -> List[Dict[str, Any]]:
|
||
"""获取所有流状态信息。"""
|
||
|
||
return [
|
||
info
|
||
for sid in self._streams
|
||
if (info := self.get_stream_info(sid)) is not None
|
||
]
|
||
|
||
@property
|
||
def stream_count(self) -> int:
|
||
return len(self._streams)
|
||
|
||
@property
|
||
def active_stream_count(self) -> int:
|
||
return sum(
|
||
1
|
||
for e in self._streams.values()
|
||
if e.service.status == StreamStatus.CONNECTED
|
||
)
|
||
|
||
def get_health(self) -> Dict[str, Any]:
|
||
"""获取管理器健康状态。"""
|
||
|
||
total = len(self._streams)
|
||
active = self.active_stream_count
|
||
reconnecting = sum(
|
||
1
|
||
for e in self._streams.values()
|
||
if e.service.status == StreamStatus.RECONNECTING
|
||
)
|
||
errored = sum(
|
||
1
|
||
for e in self._streams.values()
|
||
if e.service.status == StreamStatus.ERROR
|
||
)
|
||
|
||
return {
|
||
"total_streams": total,
|
||
"active_streams": active,
|
||
"reconnecting_streams": reconnecting,
|
||
"error_streams": errored,
|
||
"max_streams": self._max_streams,
|
||
"healthy": errored == 0,
|
||
}
|
||
|
||
# ------------------------------------------------------------------
|
||
# 流配置更新
|
||
# ------------------------------------------------------------------
|
||
|
||
async def update_stream_config(
|
||
self,
|
||
stream_id: str,
|
||
model_id: Optional[str] = None,
|
||
confidence: Optional[float] = None,
|
||
iou: Optional[float] = None,
|
||
frame_skip: Optional[int] = None,
|
||
) -> Dict[str, Any]:
|
||
"""更新流的检测配置 (运行时热更新)。"""
|
||
|
||
entry = self._streams.get(stream_id)
|
||
if entry is None:
|
||
return {"success": False, "message": f"流 {stream_id} 不存在"}
|
||
|
||
if model_id is not None:
|
||
entry.config.model_id = model_id
|
||
if confidence is not None:
|
||
entry.config.confidence = confidence
|
||
if iou is not None:
|
||
entry.config.iou = iou
|
||
if frame_skip is not None:
|
||
entry.config.frame_skip = frame_skip
|
||
|
||
logger.info(
|
||
"流 %s 配置已更新: model=%s, conf=%.2f, iou=%.2f, skip=%d",
|
||
stream_id,
|
||
entry.config.model_id,
|
||
entry.config.confidence,
|
||
entry.config.iou,
|
||
entry.config.frame_skip,
|
||
)
|
||
return {"success": True, "message": f"流 {stream_id} 配置已更新"}
|
||
|
||
|
||
__all__ = ["StreamManager"]
|