Files
jc-video-recognize/apps/server/services/model_service.py
wwh a16e684e46 feat: 新增车辆检测Paddle模型及相关服务,优化依赖与代码兼容性
1. 新增3套PaddlePaddle车辆检测相关模型文件
2. 新增车辆检测服务类与违停检测功能
3. 更新服务依赖并添加环境初始化脚本与文档
4. 修复YOLO检测tensor转换兼容问题
5. 新增PyTorch版本兼容性修复逻辑
6. 扩展模型服务支持Paddle模型加载
2026-05-21 16:26:26 +08:00

208 lines
9.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import logging
# PyTorch 2.6 兼容性修复
os.environ.setdefault('TORCH_DISABLE_TORCH_GRAPH_OPTIMIZER', '1')
logger = logging.getLogger(__name__)
try:
import torch
if hasattr(torch, 'serialization'):
# 导入 ultralytics 相关类以供 add_safe_globals 使用
from ultralytics.nn.tasks import DetectionModel
# 注册所有需要的类
torch.serialization.add_safe_globals([
DetectionModel, # ultralytics 模型类
torch.nn.modules.container.Sequential, # torch 序列类
torch.nn.modules.Conv2d, # 卷积层
torch.nn.modules.batchnorm.BatchNorm2d, # 批归一化
torch.nn.modules.activation.ReLU, # 激活函数
torch.nn.modules.Linear, # 线性层
torch.nn.modules.Dropout, # Dropout 层
torch.nn.modules.Upsample, # 上采样
torch.nn.modules.PixelShuffle, # 像素重排
])
logger.info("✅ 检测到 PyTorch 2.6+,应用 ultralytics 兼容性修复")
except (ImportError, AttributeError, NameError) as e:
logger.warning(f"⚠️ PyTorch 兼容性修复失败: {e}")
from ultralytics import YOLO
from typing import Dict, List, Optional, Union
class ModelService:
def __init__(self):
self.models: Dict[str, Union[YOLO, object]] = {}
# 基础路径:从 apps/server/services/model_service.py 到 jc-video-web 根目录
base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
self.model_configs = {
'fire_detection': {
'path': os.path.join(base_dir, 'models', 'fire_detection', 'best.pt'),
'type': 'yolov10',
'classes': ['Fire', 'Smoke'],
'labels': {'Fire': '火焰', 'Smoke': '烟雾'},
'size': '61MB',
'description': '基于YOLOv10的火灾烟雾检测模型',
'name': '火灾检测'
},
'helmet_detection': {
'path': os.path.join(base_dir, 'models', 'helmet_detection', 'yolov8n.pt'),
'type': 'yolov8',
'classes': ['person', 'helmet'],
'labels': {'person': '人员', 'helmet': '安全帽'},
'size': '6MB',
'description': '基于YOLOv8的安全帽检测模型',
'name': '安全帽检测'
},
'crowd_detection': {
'path': os.path.join(base_dir, 'models', 'crowd_detection', 'yolov8l.pt'),
'type': 'yolov8',
'classes': ['person'],
'labels': {'person': '人员'},
'size': '100MB',
'description': '基于YOLOv8的人群聚集检测模型',
'name': '人群检测'
},
'smoking_detection': {
'path': os.path.join(base_dir, 'models', 'smoking_detection', 'smoking_yolov8n.pt'),
'type': 'yolov8',
'classes': ['cigarette', 'smoke'],
'labels': {'cigarette': '香烟', 'smoke': '烟雾'},
'size': '6MB',
'description': '基于YOLOv8的抽烟检测模型',
'name': '抽烟检测 (YOLOv8)'
},
'smoking_detection_paddle': {
'path': os.path.join(base_dir, 'models', 'smoking_detection_paddle', 'model.pdmodel'),
'type': 'paddle',
'classes': ['cigarette'],
'labels': {'cigarette': '香烟'},
'size': '27MB',
'description': '基于PaddlePaddle PP-YOLOE-s的抽烟检测模型更高准确率',
'name': '抽烟检测 (Paddle)'
},
'loitering_detection': {
'path': os.path.join(base_dir, 'models', 'loitering_detection', 'yolov8n.pt'),
'type': 'yolov8',
'classes': ['person'],
'labels': {'person': '人员'},
'size': '6MB',
'description': '基于YOLOv8的徘徊检测模型',
'name': '徘徊检测'
},
'vehicle_detection': {
'path': os.path.join(base_dir, 'models', 'vehicle_detection_paddle', 'mot_ppyoloe_l_36e_ppvehicle', 'model.pdmodel'),
'type': 'paddle',
'classes': ['vehicle'],
'labels': {'vehicle': '车辆'},
'size': '181MB',
'description': '基于PaddlePaddle PP-YOLOE-l的车辆检测和跟踪模型',
'name': '车辆检测 (Paddle)'
},
'illegal_parking_detection': {
'path': os.path.join(base_dir, 'models', 'vehicle_detection_paddle', 'mot_ppyoloe_l_36e_ppvehicle', 'model.pdmodel'),
'type': 'paddle',
'classes': ['vehicle'],
'labels': {'vehicle': '车辆'},
'size': '200MB',
'description': '基于PaddlePaddle PP-YOLOE-l的违停检测模型支持车牌识别',
'name': '违停检测 (Paddle)'
}
}
def get_available_models(self) -> List[Dict]:
available_models = []
for model_id, config in self.model_configs.items():
model_path = config['path']
# 检查模型是否存在Paddle模型检查目录YOLO模型检查文件
model_exists = False
if config['type'] == 'paddle':
model_dir = os.path.dirname(model_path)
required_files = ['model.pdmodel', 'model.pdiparams', 'infer_cfg.yml']
model_exists = all(
os.path.exists(os.path.join(model_dir, f))
for f in required_files
)
else:
model_exists = os.path.exists(model_path)
if model_exists:
available_models.append({
'id': model_id,
'name': config['name'],
'description': config['description'],
'classes': config['classes'],
'labels': config['labels'],
'size': config['size'],
'type': config['type']
})
else:
logger.warning(f"模型文件不存在: {model_path}")
return available_models
async def load_model(self, model_id: str) -> Optional[Union[YOLO, object]]:
if model_id not in self.model_configs:
logger.error(f"未知模型ID: {model_id}")
return None
if model_id in self.models:
return self.models[model_id]
config = self.model_configs[model_id]
# 处理 PaddleDetection 模型
if config['type'] == 'paddle':
try:
if model_id == 'smoking_detection_paddle':
from .paddle_detection_service import SmokingDetectionModel
logger.info(f"正在加载 PaddlePaddle 抽烟检测服务: {model_id}")
model = SmokingDetectionModel()
elif model_id in ['vehicle_detection', 'illegal_parking_detection']:
from .vehicle_detection_service import VehicleDetectionModel
logger.info(f"正在加载 PaddlePaddle 车辆检测服务: {model_id}")
model = VehicleDetectionModel()
else:
logger.error(f"未知的 Paddle 模型类型: {model_id}")
return None
self.models[model_id] = model
logger.info(f"PaddlePaddle 服务加载成功: {model_id}")
return model
except Exception as e:
logger.error(f"PaddlePaddle 服务加载失败: {model_id}, 错误: {e}")
return None
# 处理 YOLO 模型
model_path = config['path']
if not os.path.exists(model_path):
logger.warning(f"模型文件不存在: {model_path},跳过加载 {model_id}")
return None
try:
logger.info(f"正在加载 YOLO 模型: {model_id} from {model_path}")
model = YOLO(model_path)
self.models[model_id] = model
logger.info(f"YOLO 模型加载成功: {model_id}")
return model
except Exception as e:
logger.error(f"YOLO 模型加载失败: {model_id}, 错误: {e}")
logger.error(f"模型路径: {model_path}")
import traceback
logger.error(f"详细错误: {traceback.format_exc()}")
return None
def get_model(self, model_id: str) -> Optional[Union[YOLO, object]]:
return self.models.get(model_id)
async def unload_model(self, model_id: str) -> bool:
if model_id in self.models:
del self.models[model_id]
logger.info(f"模型已卸载: {model_id}")
return True
return False