238 lines
10 KiB
Python
238 lines
10 KiB
Python
import os
|
||
import logging
|
||
|
||
# PyTorch 2.6 兼容性修复
|
||
os.environ.setdefault('TORCH_DISABLE_TORCH_GRAPH_OPTIMIZER', '1')
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
try:
|
||
import torch
|
||
if hasattr(torch, 'serialization'):
|
||
# 导入 ultralytics 相关类以供 add_safe_globals 使用
|
||
from ultralytics.nn.tasks import DetectionModel
|
||
|
||
# 注册所有需要的类
|
||
torch.serialization.add_safe_globals([
|
||
DetectionModel, # ultralytics 模型类
|
||
torch.nn.modules.container.Sequential, # torch 序列类
|
||
torch.nn.modules.Conv2d, # 卷积层
|
||
torch.nn.modules.batchnorm.BatchNorm2d, # 批归一化
|
||
torch.nn.modules.activation.ReLU, # 激活函数
|
||
torch.nn.modules.Linear, # 线性层
|
||
torch.nn.modules.Dropout, # Dropout 层
|
||
torch.nn.modules.Upsample, # 上采样
|
||
torch.nn.modules.PixelShuffle, # 像素重排
|
||
])
|
||
logger.info("✅ 检测到 PyTorch 2.6+,应用 ultralytics 兼容性修复")
|
||
except (ImportError, AttributeError, NameError) as e:
|
||
logger.warning(f"⚠️ PyTorch 兼容性修复失败: {e}")
|
||
|
||
from ultralytics import YOLO
|
||
from typing import Dict, List, Optional, Union
|
||
|
||
class ModelService:
|
||
def __init__(self):
|
||
self.models: Dict[str, Union[YOLO, object]] = {}
|
||
# 基础路径:从 apps/server/services/model_service.py 到 jc-video-web 根目录
|
||
base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||
|
||
self.model_configs = {
|
||
'fire_detection': {
|
||
'path': os.path.join(base_dir, 'models', 'fire_detection', 'best.pt'),
|
||
'type': 'yolov10',
|
||
'classes': ['Fire', 'Smoke'],
|
||
'labels': {'Fire': '火焰', 'Smoke': '烟雾'},
|
||
'size': '61MB',
|
||
'description': '基于YOLOv10的火灾烟雾检测模型',
|
||
'name': '火灾检测'
|
||
},
|
||
'helmet_detection': {
|
||
'path': os.path.join(base_dir, 'models', 'helmet_detection', 'yolov8n.pt'),
|
||
'type': 'yolov8',
|
||
'classes': ['person', 'helmet'],
|
||
'labels': {'person': '人员', 'helmet': '安全帽'},
|
||
'size': '6MB',
|
||
'description': '基于YOLOv8的安全帽检测模型',
|
||
'name': '安全帽检测'
|
||
},
|
||
'crowd_detection': {
|
||
'path': os.path.join(base_dir, 'models', 'crowd_detection', 'yolov8l.pt'),
|
||
'type': 'yolov8',
|
||
'classes': ['person'],
|
||
'labels': {'person': '人员'},
|
||
'size': '100MB',
|
||
'description': '基于YOLOv8的人群聚集检测模型',
|
||
'name': '人群检测'
|
||
},
|
||
'smoking_detection': {
|
||
'path': os.path.join(base_dir, 'models', 'smoking_detection', 'smoking_yolov8n.pt'),
|
||
'type': 'yolov8',
|
||
'classes': ['cigarette', 'smoke'],
|
||
'labels': {'cigarette': '香烟', 'smoke': '烟雾'},
|
||
'size': '6MB',
|
||
'description': '基于YOLOv8的抽烟检测模型',
|
||
'name': '抽烟检测 (YOLOv8)'
|
||
},
|
||
'smoking_detection_paddle': {
|
||
'path': os.path.join(base_dir, 'models', 'smoking_detection_paddle', 'model.pdmodel'),
|
||
'type': 'paddle',
|
||
'classes': ['cigarette'],
|
||
'labels': {'cigarette': '香烟'},
|
||
'size': '27MB',
|
||
'description': '基于PaddlePaddle PP-YOLOE-s的抽烟检测模型(更高准确率)',
|
||
'name': '抽烟检测 (Paddle)'
|
||
},
|
||
'loitering_detection': {
|
||
'path': os.path.join(base_dir, 'models', 'loitering_detection', 'yolov8n.pt'),
|
||
'type': 'yolov8',
|
||
'classes': ['person'],
|
||
'labels': {'person': '人员'},
|
||
'size': '6MB',
|
||
'description': '基于YOLOv8的徘徊检测模型',
|
||
'name': '徘徊检测'
|
||
},
|
||
'vehicle_detection': {
|
||
'path': os.path.join(base_dir, 'models', 'vehicle_detection_paddle', 'mot_ppyoloe_l_36e_ppvehicle', 'model.pdmodel'),
|
||
'type': 'paddle',
|
||
'classes': ['vehicle'],
|
||
'labels': {'vehicle': '车辆'},
|
||
'size': '181MB',
|
||
'description': '基于PaddlePaddle PP-YOLOE-l的车辆检测和跟踪模型',
|
||
'name': '车辆检测 (Paddle)'
|
||
},
|
||
'illegal_parking_detection': {
|
||
'path': os.path.join(base_dir, 'models', 'vehicle_detection_paddle', 'mot_ppyoloe_l_36e_ppvehicle', 'model.pdmodel'),
|
||
'type': 'paddle',
|
||
'classes': ['vehicle'],
|
||
'labels': {'vehicle': '车辆'},
|
||
'size': '200MB',
|
||
'description': '基于PaddlePaddle PP-YOLOE-l的违停检测模型,支持车牌识别',
|
||
'name': '违停检测 (Paddle)'
|
||
},
|
||
'action_detection': {
|
||
'path': 'docker_api',
|
||
'type': 'docker_api',
|
||
'classes': ['fight', 'normal'],
|
||
'labels': {'fight': '打架', 'normal': '正常'},
|
||
'size': 'Docker',
|
||
'description': '基于PaddleVideo ppTSM的打架检测模型(通过Docker API调用)',
|
||
'name': '打架检测 (Docker)'
|
||
}
|
||
}
|
||
|
||
def get_available_models(self) -> List[Dict]:
|
||
available_models = []
|
||
for model_id, config in self.model_configs.items():
|
||
model_path = config['path']
|
||
|
||
# 检查模型是否存在
|
||
model_exists = False
|
||
if config['type'] == 'docker_api':
|
||
# Docker API 类型的模型不需要检查文件,总是可用
|
||
model_exists = True
|
||
elif config['type'] == 'paddle':
|
||
model_dir = os.path.dirname(model_path)
|
||
required_files = ['model.pdmodel', 'model.pdiparams', 'infer_cfg.yml']
|
||
model_exists = all(
|
||
os.path.exists(os.path.join(model_dir, f))
|
||
for f in required_files
|
||
)
|
||
else:
|
||
model_exists = os.path.exists(model_path)
|
||
|
||
if model_exists:
|
||
available_models.append({
|
||
'id': model_id,
|
||
'name': config['name'],
|
||
'description': config['description'],
|
||
'classes': config['classes'],
|
||
'labels': config['labels'],
|
||
'size': config['size'],
|
||
'type': config['type']
|
||
})
|
||
else:
|
||
logger.warning(f"模型文件不存在: {model_path}")
|
||
return available_models
|
||
|
||
async def load_model(self, model_id: str) -> Optional[Union[YOLO, object]]:
|
||
if model_id not in self.model_configs:
|
||
logger.error(f"未知模型ID: {model_id}")
|
||
return None
|
||
|
||
if model_id in self.models:
|
||
return self.models[model_id]
|
||
|
||
config = self.model_configs[model_id]
|
||
|
||
# 处理 Docker API 模型
|
||
if config['type'] == 'docker_api':
|
||
try:
|
||
if model_id == 'action_detection':
|
||
from .action_detection_service import ActionDetectionModel
|
||
logger.info(f"正在加载 Docker API 行为识别服务: {model_id}")
|
||
model = ActionDetectionModel()
|
||
else:
|
||
logger.error(f"未知的 Docker API 模型类型: {model_id}")
|
||
return None
|
||
|
||
self.models[model_id] = model
|
||
logger.info(f"Docker API 服务加载成功: {model_id}")
|
||
return model
|
||
except Exception as e:
|
||
logger.error(f"Docker API 服务加载失败: {model_id}, 错误: {e}")
|
||
return None
|
||
|
||
# 处理 PaddleDetection 模型
|
||
if config['type'] == 'paddle':
|
||
try:
|
||
if model_id == 'smoking_detection_paddle':
|
||
from .paddle_detection_service import SmokingDetectionModel
|
||
logger.info(f"正在加载 PaddlePaddle 抽烟检测服务: {model_id}")
|
||
model = SmokingDetectionModel()
|
||
elif model_id in ['vehicle_detection', 'illegal_parking_detection']:
|
||
from .vehicle_detection_service import VehicleDetectionModel
|
||
logger.info(f"正在加载 PaddlePaddle 车辆检测服务: {model_id}")
|
||
model = VehicleDetectionModel()
|
||
else:
|
||
logger.error(f"未知的 Paddle 模型类型: {model_id}")
|
||
return None
|
||
|
||
self.models[model_id] = model
|
||
logger.info(f"PaddlePaddle 服务加载成功: {model_id}")
|
||
return model
|
||
except Exception as e:
|
||
logger.error(f"PaddlePaddle 服务加载失败: {model_id}, 错误: {e}")
|
||
return None
|
||
|
||
# 处理 YOLO 模型
|
||
model_path = config['path']
|
||
|
||
if not os.path.exists(model_path):
|
||
logger.warning(f"模型文件不存在: {model_path},跳过加载 {model_id}")
|
||
return None
|
||
|
||
try:
|
||
logger.info(f"正在加载 YOLO 模型: {model_id} from {model_path}")
|
||
model = YOLO(model_path)
|
||
|
||
self.models[model_id] = model
|
||
logger.info(f"YOLO 模型加载成功: {model_id}")
|
||
return model
|
||
except Exception as e:
|
||
logger.error(f"YOLO 模型加载失败: {model_id}, 错误: {e}")
|
||
logger.error(f"模型路径: {model_path}")
|
||
import traceback
|
||
logger.error(f"详细错误: {traceback.format_exc()}")
|
||
return None
|
||
|
||
def get_model(self, model_id: str) -> Optional[Union[YOLO, object]]:
|
||
return self.models.get(model_id)
|
||
|
||
async def unload_model(self, model_id: str) -> bool:
|
||
if model_id in self.models:
|
||
del self.models[model_id]
|
||
logger.info(f"模型已卸载: {model_id}")
|
||
return True
|
||
return False
|