Files
jc-video-recognize/apps/server/services/model_service.py
wwh 8fb58c75fe Initial commit: Video detection platform with YOLO models
Features:
- Fire detection (YOLOv10)
- Helmet detection (YOLOv8)
- Crowd detection (YOLOv8)
- Smoking detection (YOLOv8)
- Loitering detection (YOLOv8)

Tech Stack:
- Frontend: Vue 3 + Vite + Element Plus
- Backend: FastAPI + WebSocket
- Monorepo: pnpm workspace + Turbo
- Docker support included
2026-05-18 10:54:10 +08:00

116 lines
4.5 KiB
Python

import os
import logging
from ultralytics import YOLO
from typing import Dict, List, Optional
logger = logging.getLogger(__name__)
class ModelService:
def __init__(self):
self.models: Dict[str, YOLO] = {}
# 基础路径:从 apps/server/services/model_service.py 到 jc-video-web 根目录
base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
self.model_configs = {
'fire_detection': {
'path': os.path.join(base_dir, 'models', 'fire_detection', 'best.pt'),
'type': 'yolov10',
'classes': ['Fire', 'Smoke'],
'labels': {'Fire': '火焰', 'Smoke': '烟雾'},
'size': '61MB',
'description': '基于YOLOv10的火灾烟雾检测模型',
'name': '火灾检测'
},
'helmet_detection': {
'path': os.path.join(base_dir, 'models', 'helmet_detection', 'yolov8n.pt'),
'type': 'yolov8',
'classes': ['person', 'helmet'],
'labels': {'person': '人员', 'helmet': '安全帽'},
'size': '6MB',
'description': '基于YOLOv8的安全帽检测模型',
'name': '安全帽检测'
},
'crowd_detection': {
'path': os.path.join(base_dir, 'models', 'crowd_detection', 'yolov8l.pt'),
'type': 'yolov8',
'classes': ['person'],
'labels': {'person': '人员'},
'size': '100MB',
'description': '基于YOLOv8的人群聚集检测模型',
'name': '人群检测'
},
'smoking_detection': {
'path': os.path.join(base_dir, 'models', 'smoking_detection', 'smoking_yolov8n.pt'),
'type': 'yolov8',
'classes': ['cigarette', 'smoke'],
'labels': {'cigarette': '香烟', 'smoke': '烟雾'},
'size': '6MB',
'description': '基于YOLOv8的抽烟检测模型',
'name': '抽烟检测'
},
'loitering_detection': {
'path': os.path.join(base_dir, 'models', 'loitering_detection', 'yolov8n.pt'),
'type': 'yolov8',
'classes': ['person'],
'labels': {'person': '人员'},
'size': '6MB',
'description': '基于YOLOv8的徘徊检测模型',
'name': '徘徊检测'
}
}
def get_available_models(self) -> List[Dict]:
available_models = []
for model_id, config in self.model_configs.items():
if os.path.exists(config['path']):
available_models.append({
'id': model_id,
'name': config['name'],
'description': config['description'],
'classes': config['classes'],
'labels': config['labels'],
'size': config['size'],
'type': config['type']
})
else:
logger.warning(f"模型文件不存在: {config['path']}")
return available_models
async def load_model(self, model_id: str) -> Optional[YOLO]:
if model_id not in self.model_configs:
logger.error(f"未知模型ID: {model_id}")
return None
if model_id in self.models:
logger.info(f"模型已加载: {model_id}")
return self.models[model_id]
config = self.model_configs[model_id]
# 处理 YOLO 模型
model_path = config['path']
if not os.path.exists(model_path):
logger.error(f"模型文件不存在: {model_path}")
return None
try:
logger.info(f"正在加载模型: {model_id} from {model_path}")
model = YOLO(model_path)
self.models[model_id] = model
logger.info(f"模型加载成功: {model_id}")
return model
except Exception as e:
logger.error(f"模型加载失败: {model_id}, 错误: {e}")
return None
def get_model(self, model_id: str) -> Optional[YOLO]:
return self.models.get(model_id)
async def unload_model(self, model_id: str) -> bool:
if model_id in self.models:
del self.models[model_id]
logger.info(f"模型已卸载: {model_id}")
return True
return False