Files
jc-video-recognize/apps/server/services/model_service.py

252 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import logging
# PyTorch 2.6 兼容性修复
os.environ.setdefault('TORCH_DISABLE_TORCH_GRAPH_OPTIMIZER', '1')
logger = logging.getLogger(__name__)
try:
import torch
if hasattr(torch, 'serialization'):
# 导入 ultralytics 相关类以供 add_safe_globals 使用
from ultralytics.nn.tasks import DetectionModel
# 注册所有需要的类
torch.serialization.add_safe_globals([
DetectionModel, # ultralytics 模型类
torch.nn.modules.container.Sequential, # torch 序列类
torch.nn.modules.Conv2d, # 卷积层
torch.nn.modules.batchnorm.BatchNorm2d, # 批归一化
torch.nn.modules.activation.ReLU, # 激活函数
torch.nn.modules.Linear, # 线性层
torch.nn.modules.Dropout, # Dropout 层
torch.nn.modules.Upsample, # 上采样
torch.nn.modules.PixelShuffle, # 像素重排
])
logger.info("✅ 检测到 PyTorch 2.6+,应用 ultralytics 兼容性修复")
except (ImportError, AttributeError, NameError) as e:
logger.warning(f"⚠️ PyTorch 兼容性修复失败: {e}")
from ultralytics import YOLO
from typing import Dict, List, Optional, Union
class ModelService:
def __init__(self):
self.models: Dict[str, Union[YOLO, object]] = {}
# 基础路径:从 apps/server/services/model_service.py 到 jc-video-web 根目录
base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
self.model_configs = {
'fire_detection': {
'path': os.path.join(base_dir, 'models', 'fire_detection', 'best.pt'),
'type': 'yolov10',
'classes': ['Fire', 'Smoke'],
'labels': {'Fire': '火焰', 'Smoke': '烟雾'},
'size': '61MB',
'description': '基于YOLOv10的火灾烟雾检测模型',
'name': '火灾检测'
},
'helmet_detection': {
'path': os.path.join(base_dir, 'models', 'helmet_detection', 'yolov8n.pt'),
'type': 'yolov8',
'classes': ['person', 'helmet'],
'labels': {'person': '人员', 'helmet': '安全帽'},
'size': '6MB',
'description': '基于YOLOv8的安全帽检测模型',
'name': '安全帽检测'
},
'crowd_detection': {
'path': os.path.join(base_dir, 'models', 'crowd_detection', 'yolov8l.pt'),
'type': 'yolov8',
'classes': ['person'],
'labels': {'person': '人员'},
'size': '100MB',
'description': '基于YOLOv8的人群聚集检测模型',
'name': '人群检测'
},
'smoking_detection': {
'path': os.path.join(base_dir, 'models', 'smoking_detection', 'smoking_yolov8n.pt'),
'type': 'yolov8',
'classes': ['cigarette', 'smoke'],
'labels': {'cigarette': '香烟', 'smoke': '烟雾'},
'size': '6MB',
'description': '基于YOLOv8的抽烟检测模型',
'name': '抽烟检测 (YOLOv8)'
},
'smoking_detection_paddle': {
'path': os.path.join(base_dir, 'models', 'smoking_detection_paddle', 'model.pdmodel'),
'type': 'paddle',
'classes': ['cigarette'],
'labels': {'cigarette': '香烟'},
'size': '27MB',
'description': '基于PaddlePaddle PP-YOLOE-s的抽烟检测模型更高准确率',
'name': '抽烟检测 (Paddle)'
},
'loitering_detection': {
'path': os.path.join(base_dir, 'models', 'loitering_detection', 'yolov8n.pt'),
'type': 'yolov8',
'classes': ['person'],
'labels': {'person': '人员'},
'size': '6MB',
'description': '基于YOLOv8的徘徊检测模型',
'name': '徘徊检测'
},
'vehicle_detection': {
'path': os.path.join(base_dir, 'models', 'vehicle_detection_paddle', 'mot_ppyoloe_l_36e_ppvehicle', 'model.pdmodel'),
'type': 'paddle',
'classes': ['vehicle'],
'labels': {'vehicle': '车辆'},
'size': '181MB',
'description': '基于PaddlePaddle PP-YOLOE-l的车辆检测和跟踪模型',
'name': '车辆检测 (Paddle)'
},
'illegal_parking_detection': {
'path': os.path.join(base_dir, 'models', 'vehicle_detection_paddle', 'mot_ppyoloe_l_36e_ppvehicle', 'model.pdmodel'),
'type': 'paddle',
'classes': ['vehicle'],
'labels': {'vehicle': '车辆'},
'size': '200MB',
'description': '基于PaddlePaddle PP-YOLOE-l的违停检测模型支持车牌识别',
'name': '违停检测 (Paddle)'
},
'fight_detection': {
'path': os.path.join(base_dir, 'models', 'fight_detection', 'yolov8n.pt'),
'type': 'yolov8',
'classes': ['violence', 'non_violence'],
'labels': {'violence': '暴力行为', 'non_violence': '正常'},
'size': '22MB',
'description': '基于YOLOv8的打架斗殴检测模型',
'name': '打架斗殴检测(YOLO)',
'supports_video': True
},
'action_detection': {
'path': 'docker_api',
'type': 'docker_api',
'classes': ['fight', 'normal'],
'labels': {'fight': '打架', 'normal': '正常'},
'size': 'Docker',
'description': '基于PaddleVideo ppTSM的打架检测模型通过Docker API调用',
'name': '打架检测 (Docker)'
}
}
def get_available_models(self) -> List[Dict]:
available_models = []
for model_id, config in self.model_configs.items():
model_path = config['path']
# 检查模型是否存在
model_exists = False
if config['type'] == 'docker_api':
# Docker API 类型的模型不需要检查文件,总是可用
model_exists = True
elif config['type'] == 'paddle':
model_dir = os.path.dirname(model_path)
required_files = ['model.pdmodel', 'model.pdiparams', 'infer_cfg.yml']
model_exists = all(
os.path.exists(os.path.join(model_dir, f))
for f in required_files
)
else:
model_exists = os.path.exists(model_path)
if model_exists:
model_info = {
'id': model_id,
'name': config['name'],
'description': config['description'],
'classes': config['classes'],
'labels': config['labels'],
'size': config['size'],
'type': config['type']
}
# 支持 video 的模型增加标记
if config.get('supports_video'):
model_info['supports_video'] = True
available_models.append(model_info)
else:
logger.warning(f"模型文件不存在: {model_path}")
return available_models
async def load_model(self, model_id: str) -> Optional[Union[YOLO, object]]:
if model_id not in self.model_configs:
logger.error(f"未知模型ID: {model_id}")
return None
if model_id in self.models:
return self.models[model_id]
config = self.model_configs[model_id]
# 处理 Docker API 模型
if config['type'] == 'docker_api':
try:
if model_id == 'action_detection':
from .action_detection_service import ActionDetectionModel
logger.info(f"正在加载 Docker API 行为识别服务: {model_id}")
model = ActionDetectionModel()
else:
logger.error(f"未知的 Docker API 模型类型: {model_id}")
return None
self.models[model_id] = model
logger.info(f"Docker API 服务加载成功: {model_id}")
return model
except Exception as e:
logger.error(f"Docker API 服务加载失败: {model_id}, 错误: {e}")
return None
# 处理 PaddleDetection 模型
if config['type'] == 'paddle':
try:
if model_id == 'smoking_detection_paddle':
from .paddle_detection_service import SmokingDetectionModel
logger.info(f"正在加载 PaddlePaddle 抽烟检测服务: {model_id}")
model = SmokingDetectionModel()
elif model_id in ['vehicle_detection', 'illegal_parking_detection']:
from .vehicle_detection_service import VehicleDetectionModel
logger.info(f"正在加载 PaddlePaddle 车辆检测服务: {model_id}")
model = VehicleDetectionModel()
else:
logger.error(f"未知的 Paddle 模型类型: {model_id}")
return None
self.models[model_id] = model
logger.info(f"PaddlePaddle 服务加载成功: {model_id}")
return model
except Exception as e:
logger.error(f"PaddlePaddle 服务加载失败: {model_id}, 错误: {e}")
return None
# 处理 YOLO 模型
model_path = config['path']
if not os.path.exists(model_path):
logger.warning(f"模型文件不存在: {model_path},跳过加载 {model_id}")
return None
try:
logger.info(f"正在加载 YOLO 模型: {model_id} from {model_path}")
model = YOLO(model_path)
self.models[model_id] = model
logger.info(f"YOLO 模型加载成功: {model_id}")
return model
except Exception as e:
logger.error(f"YOLO 模型加载失败: {model_id}, 错误: {e}")
logger.error(f"模型路径: {model_path}")
import traceback
logger.error(f"详细错误: {traceback.format_exc()}")
return None
def get_model(self, model_id: str) -> Optional[Union[YOLO, object]]:
return self.models.get(model_id)
async def unload_model(self, model_id: str) -> bool:
if model_id in self.models:
del self.models[model_id]
logger.info(f"模型已卸载: {model_id}")
return True
return False