252 lines
9.3 KiB
Python
252 lines
9.3 KiB
Python
"""统一应用配置 (MVP-1 / 新增)
|
||
|
||
基于 pydantic-settings 的多环境配置管理,集中收敛原先散落在
|
||
``main.py``、各 service 模块中的环境变量、路径、阈值等配置。
|
||
|
||
加载顺序 (优先级从高到低)::
|
||
|
||
1. 显式构造参数
|
||
2. 环境变量 (大小写不敏感)
|
||
3. ``.env`` 文件 (位于 apps/server/.env)
|
||
4. 默认值
|
||
|
||
使用方式::
|
||
|
||
from core.settings import get_settings
|
||
|
||
settings = get_settings()
|
||
print(settings.api.port)
|
||
print(settings.detection.default_confidence)
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import os
|
||
from functools import lru_cache
|
||
from pathlib import Path
|
||
from typing import List, Optional
|
||
|
||
from pydantic import Field
|
||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||
|
||
|
||
# 项目根 (apps/server)
|
||
SERVER_DIR: Path = Path(__file__).resolve().parent.parent
|
||
PROJECT_ROOT: Path = SERVER_DIR.parent.parent
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 子配置
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class APISettings(BaseSettings):
|
||
"""API 与服务器相关配置。"""
|
||
|
||
model_config = SettingsConfigDict(env_prefix="API_", extra="ignore")
|
||
|
||
host: str = Field(default="0.0.0.0", description="监听地址")
|
||
port: int = Field(default=8000, description="监听端口")
|
||
reload: bool = Field(default=True, description="开发模式自动重载")
|
||
cors_origins: List[str] = Field(
|
||
default_factory=lambda: ["*"], description="允许的跨域来源"
|
||
)
|
||
|
||
|
||
class DetectionSettings(BaseSettings):
|
||
"""检测相关全局默认值。"""
|
||
|
||
model_config = SettingsConfigDict(env_prefix="DETECTION_", extra="ignore")
|
||
|
||
default_confidence: float = Field(default=0.5, ge=0.0, le=1.0)
|
||
default_iou: float = Field(default=0.45, ge=0.0, le=1.0)
|
||
# 决策引擎过滤的最低置信度
|
||
min_confidence: float = Field(default=0.3, ge=0.0, le=1.0)
|
||
|
||
|
||
class ActionDetectionSettings(BaseSettings):
|
||
"""ppTSM 行为识别 (Docker) 服务配置。"""
|
||
|
||
model_config = SettingsConfigDict(env_prefix="ACTION_DETECTION_", extra="ignore")
|
||
|
||
api_url: str = Field(default="http://localhost:8081")
|
||
timeout: int = Field(default=30, ge=1)
|
||
|
||
|
||
class EventEngineSettings(BaseSettings):
|
||
"""事件决策 + 聚合 + 规则引擎配置。"""
|
||
|
||
model_config = SettingsConfigDict(env_prefix="EVENT_", extra="ignore")
|
||
|
||
# 时间窗口去重 (秒),同一 (source_id, event_type, track_id) 在窗口内只产生一条
|
||
dedup_window_seconds: float = Field(default=30.0, ge=0.0)
|
||
# 规则 YAML 目录 (相对 server 根)
|
||
rules_dir: str = Field(default="config/rules")
|
||
# 事件聚合最大活跃事件数 (超过将按 LRU 淘汰)
|
||
max_active_events: int = Field(default=1000, ge=1)
|
||
|
||
|
||
class RTSPSettings(BaseSettings):
|
||
"""RTSP 流接入相关配置 (MVP-2)。"""
|
||
|
||
model_config = SettingsConfigDict(env_prefix="RTSP_", extra="ignore")
|
||
|
||
max_streams: int = Field(default=16, ge=1, description="最大同时接入流数量")
|
||
buffer_capacity: int = Field(default=300, ge=1, description="每路流帧缓冲区容量")
|
||
reconnect_attempts: int = Field(default=10, ge=0, description="最大重连次数,0=无限")
|
||
reconnect_interval_base: float = Field(default=2.0, ge=0.5, description="首次重连间隔(秒)")
|
||
reconnect_interval_max: float = Field(default=60.0, ge=1.0, description="最大重连间隔(秒)")
|
||
reconnect_backoff_factor: float = Field(default=2.0, ge=1.0, description="退避因子")
|
||
frame_skip: int = Field(default=0, ge=0, description="帧采样间隔,0=每帧都取")
|
||
read_timeout: float = Field(default=5.0, ge=1.0, description="单帧读取超时(秒)")
|
||
detect_interval: float = Field(default=0.0, ge=0.0, description="检测轮询间隔(秒),0=每帧检测")
|
||
|
||
|
||
class MQTTSettings(BaseSettings):
|
||
"""MQTT 预警发布相关配置 (MVP-2 / D16-D18)。"""
|
||
|
||
model_config = SettingsConfigDict(env_prefix="MQTT_", extra="ignore")
|
||
|
||
enabled: bool = Field(default=False, description="是否启用 MQTT 发布")
|
||
broker_host: str = Field(default="localhost", description="MQTT broker 主机")
|
||
broker_port: int = Field(default=1883, ge=1, le=65535, description="MQTT broker 端口")
|
||
client_id: str = Field(default="jc-video-recognize", description="MQTT 客户端ID")
|
||
username: Optional[str] = Field(default=None, description="MQTT 用户名")
|
||
password: Optional[str] = Field(default=None, description="MQTT 密码")
|
||
keepalive: int = Field(default=60, ge=10, description="心跳间隔(秒)")
|
||
qos: int = Field(default=1, ge=0, le=2, description="QoS 等级")
|
||
retain: bool = Field(default=False, description="是否保留消息")
|
||
# 主题模板
|
||
alert_topic_prefix: str = Field(default="video/alerts", description="预警主题前缀")
|
||
# 重连
|
||
reconnect_min_delay: float = Field(default=1.0, ge=0.1, description="最小重连间隔(秒)")
|
||
reconnect_max_delay: float = Field(default=60.0, ge=1.0, description="最大重连间隔(秒)")
|
||
|
||
|
||
class TrackingSettings(BaseSettings):
|
||
"""目标跟踪 (ByteTrack) 配置 (MVP-2 / D19)。"""
|
||
|
||
model_config = SettingsConfigDict(env_prefix="TRACKING_", extra="ignore")
|
||
|
||
enabled: bool = Field(default=True, description="是否启用目标跟踪")
|
||
track_thresh: float = Field(default=0.5, ge=0.0, le=1.0, description="跟踪置信度阈值")
|
||
high_thresh: float = Field(default=0.6, ge=0.0, le=1.0, description="高置信度阈值")
|
||
match_thresh: float = Field(default=0.8, ge=0.0, le=1.0, description="IOU 匹配阈值")
|
||
max_lost_frames: int = Field(default=30, ge=1, description="目标丢失最大帧数")
|
||
min_box_area: float = Field(default=10.0, ge=0.0, description="最小框面积")
|
||
|
||
|
||
class AggregatorSettings(BaseSettings):
|
||
"""事件聚合器扩展配置 (MVP-2 / D20)。"""
|
||
|
||
model_config = SettingsConfigDict(env_prefix="AGGREGATOR_", extra="ignore")
|
||
|
||
enable_spatial_merge: bool = Field(default=True, description="是否启用空间邻近合并")
|
||
spatial_iou_threshold: float = Field(
|
||
default=0.3, ge=0.0, le=1.0, description="空间合并 IOU 阈值"
|
||
)
|
||
confidence_fusion_strategy: str = Field(
|
||
default="weighted",
|
||
description="置信度融合策略: weighted/max/avg",
|
||
)
|
||
fusion_decay_factor: float = Field(
|
||
default=0.9, ge=0.0, le=1.0, description="历史置信度衰减因子"
|
||
)
|
||
|
||
|
||
class LoggingSettings(BaseSettings):
|
||
"""日志配置。"""
|
||
|
||
model_config = SettingsConfigDict(env_prefix="LOG_", extra="ignore")
|
||
|
||
level: str = Field(default="INFO")
|
||
json_format: bool = Field(default=False)
|
||
|
||
|
||
class PathSettings(BaseSettings):
|
||
"""关键路径 (静态资源 / 外部模型) 配置。"""
|
||
|
||
model_config = SettingsConfigDict(env_prefix="PATH_", extra="ignore")
|
||
|
||
server_dir: Path = SERVER_DIR
|
||
project_root: Path = PROJECT_ROOT
|
||
static_dir: Path = SERVER_DIR / "static"
|
||
results_dir: Path = SERVER_DIR / "static" / "results"
|
||
temp_dir: Path = SERVER_DIR / "static" / "temp"
|
||
uploads_dir: Path = SERVER_DIR / "static" / "uploads"
|
||
external_paddle: Path = (
|
||
PROJECT_ROOT / "external" / "video-recognition-system" / "PaddlePaddle"
|
||
)
|
||
|
||
def ensure(self) -> None:
|
||
"""确保所有需要的目录存在 (启动时调用)。"""
|
||
|
||
for p in (self.static_dir, self.results_dir, self.temp_dir, self.uploads_dir):
|
||
p.mkdir(parents=True, exist_ok=True)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 总配置
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class Settings(BaseSettings):
|
||
"""应用总配置。
|
||
|
||
子配置统一在此聚合,便于以 ``settings.api.port`` 这种命名空间方式访问。
|
||
"""
|
||
|
||
model_config = SettingsConfigDict(
|
||
env_file=str(SERVER_DIR / ".env"),
|
||
env_file_encoding="utf-8",
|
||
env_nested_delimiter="__",
|
||
case_sensitive=False,
|
||
extra="ignore",
|
||
)
|
||
|
||
env: str = Field(default="development", description="运行环境")
|
||
debug: bool = Field(default=True)
|
||
|
||
api: APISettings = Field(default_factory=APISettings)
|
||
detection: DetectionSettings = Field(default_factory=DetectionSettings)
|
||
action_detection: ActionDetectionSettings = Field(
|
||
default_factory=ActionDetectionSettings
|
||
)
|
||
event_engine: EventEngineSettings = Field(default_factory=EventEngineSettings)
|
||
rtsp: RTSPSettings = Field(default_factory=RTSPSettings)
|
||
mqtt: MQTTSettings = Field(default_factory=MQTTSettings)
|
||
tracking: TrackingSettings = Field(default_factory=TrackingSettings)
|
||
aggregator: AggregatorSettings = Field(default_factory=AggregatorSettings)
|
||
logging: LoggingSettings = Field(default_factory=LoggingSettings)
|
||
paths: PathSettings = Field(default_factory=PathSettings)
|
||
|
||
|
||
@lru_cache(maxsize=1)
|
||
def get_settings() -> Settings:
|
||
"""获取全局 Settings 单例 (带缓存)。
|
||
|
||
测试场景下可调用 ``get_settings.cache_clear()`` 重新加载。
|
||
"""
|
||
|
||
settings = Settings()
|
||
settings.paths.ensure()
|
||
return settings
|
||
|
||
|
||
__all__ = [
|
||
"Settings",
|
||
"APISettings",
|
||
"DetectionSettings",
|
||
"ActionDetectionSettings",
|
||
"EventEngineSettings",
|
||
"RTSPSettings",
|
||
"MQTTSettings",
|
||
"TrackingSettings",
|
||
"AggregatorSettings",
|
||
"LoggingSettings",
|
||
"PathSettings",
|
||
"get_settings",
|
||
"SERVER_DIR",
|
||
"PROJECT_ROOT",
|
||
]
|