Initial commit: Video detection platform with YOLO models
Features: - Fire detection (YOLOv10) - Helmet detection (YOLOv8) - Crowd detection (YOLOv8) - Smoking detection (YOLOv8) - Loitering detection (YOLOv8) Tech Stack: - Frontend: Vue 3 + Vite + Element Plus - Backend: FastAPI + WebSocket - Monorepo: pnpm workspace + Turbo - Docker support included
This commit is contained in:
0
apps/server/api/__init__.py
Normal file
0
apps/server/api/__init__.py
Normal file
67
apps/server/api/detection.py
Normal file
67
apps/server/api/detection.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import base64
|
||||
import logging
|
||||
from fastapi import APIRouter, UploadFile, File, Form, Query
|
||||
from models.schemas import ImageDetectionResult
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@router.post("/detect/image", response_model=ImageDetectionResult)
|
||||
async def detect_image(
|
||||
file: UploadFile = File(...),
|
||||
model_id: str = Query("fire_detection"),
|
||||
confidence: float = Query(0.5),
|
||||
iou: float = Query(0.45)
|
||||
):
|
||||
from main import model_service
|
||||
from services.detection_service import DetectionService
|
||||
|
||||
detection_service = DetectionService(model_service)
|
||||
|
||||
try:
|
||||
contents = await file.read()
|
||||
nparr = np.frombuffer(contents, np.uint8)
|
||||
frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
|
||||
if frame is None:
|
||||
return ImageDetectionResult(
|
||||
success=False,
|
||||
message="无法读取图片",
|
||||
data={}
|
||||
)
|
||||
|
||||
result = await detection_service.detect_image(frame, model_id, confidence, iou)
|
||||
|
||||
if result['success']:
|
||||
annotated_frame = detection_service.draw_detections(frame, result['detections'])
|
||||
|
||||
import uuid
|
||||
result_filename = f"result_{uuid.uuid4().hex[:8]}.jpg"
|
||||
result_path = f"static/results/{result_filename}"
|
||||
cv2.imwrite(result_path, annotated_frame)
|
||||
|
||||
return ImageDetectionResult(
|
||||
success=True,
|
||||
message="检测完成",
|
||||
data={
|
||||
"detections": result['detections'],
|
||||
"image_url": f"/static/results/{result_filename}",
|
||||
"stats": result['stats']
|
||||
}
|
||||
)
|
||||
else:
|
||||
return ImageDetectionResult(
|
||||
success=False,
|
||||
message=result['message'],
|
||||
data={}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"图片检测失败: {e}")
|
||||
return ImageDetectionResult(
|
||||
success=False,
|
||||
message=f"检测失败: {str(e)}",
|
||||
data={}
|
||||
)
|
||||
28
apps/server/api/models.py
Normal file
28
apps/server/api/models.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from fastapi import APIRouter
|
||||
from models.schemas import ModelInfo
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/models", response_model=list[ModelInfo])
|
||||
async def get_models():
|
||||
from main import model_service
|
||||
models = model_service.get_available_models()
|
||||
return models
|
||||
|
||||
@router.get("/models/{model_id}", response_model=ModelInfo)
|
||||
async def get_model(model_id: str):
|
||||
from main import model_service
|
||||
models = model_service.get_available_models()
|
||||
for model in models:
|
||||
if model['id'] == model_id:
|
||||
return model
|
||||
return None
|
||||
|
||||
@router.post("/models/{model_id}/load")
|
||||
async def load_model(model_id: str):
|
||||
from main import model_service
|
||||
model = await model_service.load_model(model_id)
|
||||
if model:
|
||||
return {"success": True, "message": f"模型加载成功: {model_id}"}
|
||||
else:
|
||||
return {"success": False, "message": f"模型加载失败: {model_id}"}
|
||||
126
apps/server/main.py
Normal file
126
apps/server/main.py
Normal file
@@ -0,0 +1,126 @@
|
||||
from fastapi import FastAPI, UploadFile, File, WebSocket, WebSocketDisconnect
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from contextlib import asynccontextmanager
|
||||
import uvicorn
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import logging
|
||||
|
||||
from api import detection, models
|
||||
from services.model_service import ModelService
|
||||
from services.camera_service import CameraService
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
model_service = ModelService()
|
||||
camera_service = None
|
||||
|
||||
|
||||
def setup_signal_handlers():
|
||||
"""设置信号处理器,确保进程异常退出时能清理资源"""
|
||||
def signal_handler(signum, frame):
|
||||
sig_name = signal.Signals(signum).name
|
||||
logger.info(f"收到信号 {sig_name},正在清理资源...")
|
||||
|
||||
# 强制释放摄像头资源
|
||||
import subprocess
|
||||
try:
|
||||
# 查找并终止占用摄像头的Python进程(除了当前进程)
|
||||
current_pid = os.getpid()
|
||||
result = subprocess.run(
|
||||
['lsof', '+D', '/dev', '-a', '-c', 'python'],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
if result.returncode == 0:
|
||||
for line in result.stdout.split('\n'):
|
||||
if '/dev/video' in line:
|
||||
parts = line.split()
|
||||
if len(parts) >= 2:
|
||||
try:
|
||||
pid = int(parts[1])
|
||||
if pid != current_pid:
|
||||
logger.info(f"终止占用摄像头的进程: {pid}")
|
||||
os.kill(pid, signal.SIGTERM)
|
||||
except (ValueError, ProcessLookupError):
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"清理摄像头资源失败: {e}")
|
||||
|
||||
sys.exit(0)
|
||||
|
||||
# 注册信号处理器
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
if hasattr(signal, 'SIGQUIT'):
|
||||
signal.signal(signal.SIGQUIT, signal_handler)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global camera_service
|
||||
|
||||
camera_service = CameraService(model_service)
|
||||
yield
|
||||
|
||||
# 关闭时清理资源
|
||||
logger.info("正在关闭服务,清理资源...")
|
||||
if camera_service:
|
||||
await camera_service.stop()
|
||||
|
||||
app = FastAPI(
|
||||
title="视频模型检测平台",
|
||||
description="基于YOLO的实时视频检测平台",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.mount("/static", StaticFiles(directory="static"), name="static")
|
||||
app.include_router(detection.router, prefix="/api")
|
||||
app.include_router(models.router, prefix="/api")
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "视频模型检测平台 API", "version": "1.0.0"}
|
||||
|
||||
@app.get("/api/health")
|
||||
async def health():
|
||||
return {"status": "healthy"}
|
||||
|
||||
@app.websocket("/ws/camera")
|
||||
async def camera_websocket_endpoint(websocket: WebSocket):
|
||||
await camera_service.handle_connection(websocket)
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.makedirs("static/uploads", exist_ok=True)
|
||||
os.makedirs("static/results", exist_ok=True)
|
||||
os.makedirs("static/temp", exist_ok=True)
|
||||
|
||||
# 设置信号处理器
|
||||
setup_signal_handlers()
|
||||
|
||||
# 检测是否处于uvicorn重载模式的子进程中
|
||||
is_reload_worker = os.environ.get('UVICORN_RELOAD') == 'true'
|
||||
|
||||
if is_reload_worker:
|
||||
logger.info("检测到uvicorn重载子进程,跳过摄像头预清理")
|
||||
|
||||
uvicorn.run(
|
||||
"main:app",
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
reload=True,
|
||||
reload_dirs=["./"],
|
||||
reload_includes=["*.py"]
|
||||
)
|
||||
0
apps/server/models/__init__.py
Normal file
0
apps/server/models/__init__.py
Normal file
39
apps/server/models/schemas.py
Normal file
39
apps/server/models/schemas.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional, Dict, Any
|
||||
from enum import Enum
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
classes: List[str]
|
||||
labels: Dict[str, str]
|
||||
size: str
|
||||
type: str
|
||||
|
||||
class Detection(BaseModel):
|
||||
class_name: str
|
||||
label: str
|
||||
confidence: float
|
||||
bbox: List[int]
|
||||
|
||||
class DetectionStats(BaseModel):
|
||||
total_detections: int
|
||||
avg_confidence: float
|
||||
processing_time: float
|
||||
model_used: str
|
||||
|
||||
class ImageDetectionResult(BaseModel):
|
||||
success: bool
|
||||
message: str
|
||||
data: Dict[str, Any]
|
||||
|
||||
class VideoDetectionRequest(BaseModel):
|
||||
model_id: str
|
||||
confidence: float = Field(default=0.5, ge=0.1, le=1.0)
|
||||
iou: float = Field(default=0.45, ge=0.1, le=0.9)
|
||||
|
||||
class DetectionConfig(BaseModel):
|
||||
model_id: str
|
||||
confidence: float = Field(default=0.5, ge=0.1, le=1.0)
|
||||
iou: float = Field(default=0.45, ge=0.1, le=0.9)
|
||||
357
apps/server/models/smoking_yolo_adapter.py
Normal file
357
apps/server/models/smoking_yolo_adapter.py
Normal file
@@ -0,0 +1,357 @@
|
||||
"""
|
||||
YOLO 格式的抽烟检测模型适配器
|
||||
将 PaddleDetection 模型包装为 YOLO 接口
|
||||
"""
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import subprocess
|
||||
import tempfile
|
||||
import logging
|
||||
from typing import List, Dict, Union
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SmokingDetectionYOLO:
|
||||
"""
|
||||
模拟 YOLO 接口的抽烟检测模型
|
||||
底层使用 PaddleDetection Docker 容器
|
||||
"""
|
||||
|
||||
def __init__(self, model_path=None):
|
||||
"""
|
||||
初始化模型
|
||||
|
||||
Args:
|
||||
model_path: 模型路径(可选,仅用于兼容性)
|
||||
"""
|
||||
self.model_name = "smoking_detection"
|
||||
self.docker_image = "smoking-detection:test"
|
||||
self.model_dir = "output_inference/ppyoloe_crn_s_80e_smoking_visdrone"
|
||||
self.threshold = 0.1
|
||||
|
||||
# YOLO 兼容属性
|
||||
self.names = {0: 'cigarette'}
|
||||
self.model = self # 自我引用,保持与 YOLO 相同的接口
|
||||
|
||||
# 检查 Docker
|
||||
self._check_docker()
|
||||
|
||||
logger.info(f"抽烟检测模型初始化完成,Docker可用: {self.available}")
|
||||
|
||||
def _check_docker(self):
|
||||
"""检查 Docker 环境"""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["docker", "info"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
self.available = result.returncode == 0
|
||||
|
||||
if self.available:
|
||||
# 检查镜像
|
||||
result = subprocess.run(
|
||||
["docker", "image", "inspect", self.docker_image],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
self.available = result.returncode == 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Docker 检查失败: {e}")
|
||||
self.available = False
|
||||
|
||||
def __call__(self, source, conf=0.1, iou=0.45, verbose=False, stream=False):
|
||||
"""
|
||||
模拟 YOLO 模型的调用接口
|
||||
|
||||
Args:
|
||||
source: 图片路径、OpenCV 图片、或图片列表
|
||||
conf: 置信度阈值
|
||||
iou: IoU 阈值(PaddleDetection 不支持,仅用于兼容)
|
||||
verbose: 是否输出详细信息
|
||||
stream: 是否流式输出(仅用于兼容)
|
||||
|
||||
Returns:
|
||||
YOLOResult 对象列表
|
||||
"""
|
||||
if not self.available:
|
||||
logger.error("Docker 不可用,无法运行检测")
|
||||
return [YOLOResult([])]
|
||||
|
||||
# 处理不同类型的输入
|
||||
if isinstance(source, str):
|
||||
# 图片路径
|
||||
image = cv2.imread(source)
|
||||
if image is None:
|
||||
logger.error(f"无法读取图片: {source}")
|
||||
return [YOLOResult([])]
|
||||
return self._detect_single(image, conf, verbose)
|
||||
|
||||
elif isinstance(source, np.ndarray):
|
||||
# OpenCV 图片
|
||||
return self._detect_single(source, conf, verbose)
|
||||
|
||||
elif isinstance(source, list):
|
||||
# 图片列表
|
||||
results = []
|
||||
for img in source:
|
||||
if isinstance(img, str):
|
||||
img = cv2.imread(img)
|
||||
if img is not None:
|
||||
results.extend(self._detect_single(img, conf, verbose))
|
||||
return results
|
||||
|
||||
else:
|
||||
logger.error(f"不支持的输入类型: {type(source)}")
|
||||
return [YOLOResult([])]
|
||||
|
||||
def _detect_single(self, image: np.ndarray, conf: float, verbose: bool) -> List['YOLOResult']:
|
||||
"""检测单张图片"""
|
||||
|
||||
try:
|
||||
# 创建临时文件
|
||||
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as f:
|
||||
temp_input = f.name
|
||||
|
||||
# 保存输入图片
|
||||
cv2.imwrite(temp_input, image)
|
||||
|
||||
if verbose:
|
||||
logger.info(f"正在检测: {temp_input}")
|
||||
|
||||
# 构建 Docker 命令
|
||||
cmd = [
|
||||
"docker", "run", "--rm",
|
||||
"-v", f"{temp_input}:/workspace/input.jpg",
|
||||
self.docker_image,
|
||||
"python", "deploy/python/infer.py",
|
||||
f"--model_dir={self.model_dir}",
|
||||
"--image_file=/workspace/input.jpg",
|
||||
"--device=CPU",
|
||||
"--output_dir=/workspace",
|
||||
f"--threshold={conf}"
|
||||
]
|
||||
|
||||
# 执行检测
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60
|
||||
)
|
||||
|
||||
if verbose:
|
||||
logger.info(f"检测完成,返回码: {result.returncode}")
|
||||
|
||||
# 解析结果
|
||||
detections = self._parse_output(result.stdout)
|
||||
|
||||
if verbose and detections:
|
||||
logger.info(f"检测到 {len(detections)} 个目标")
|
||||
|
||||
# 清理临时文件
|
||||
try:
|
||||
os.remove(temp_input)
|
||||
except:
|
||||
pass
|
||||
|
||||
return [YOLOResult(detections)]
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error("检测超时")
|
||||
return [YOLOResult([])]
|
||||
except Exception as e:
|
||||
logger.error(f"检测失败: {e}")
|
||||
return [YOLOResult([])]
|
||||
|
||||
def _parse_output(self, output: str) -> List[Dict]:
|
||||
"""解析检测输出"""
|
||||
detections = []
|
||||
import re
|
||||
|
||||
# 使用正则表达式匹配检测行
|
||||
# 格式: class_id:0, confidence:0.8921, left_top:[268.66,231.64],right_bottom:[351.87,258.66]
|
||||
pattern = r'class_id:\d+,\s*confidence:([\d.]+),\s*left_top:\[([\d.]+),\s*([\d.]+)\],\s*right_bottom:\[([\d.]+),\s*([\d.]+)\]'
|
||||
|
||||
for line in output.split('\n'):
|
||||
match = re.search(pattern, line)
|
||||
if match:
|
||||
try:
|
||||
confidence = float(match.group(1))
|
||||
x1 = float(match.group(2))
|
||||
y1 = float(match.group(3))
|
||||
x2 = float(match.group(4))
|
||||
y2 = float(match.group(5))
|
||||
|
||||
detections.append({
|
||||
'bbox': [int(x1), int(y1), int(x2), int(y2)],
|
||||
'confidence': confidence,
|
||||
'class': 0,
|
||||
'name': 'cigarette'
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"解析检测结果失败: {e}, line: {line}")
|
||||
continue
|
||||
|
||||
return detections
|
||||
|
||||
def predict(self, source, **kwargs):
|
||||
"""兼容 predict 方法"""
|
||||
return self.__call__(source, **kwargs)
|
||||
|
||||
|
||||
class YOLOResult:
|
||||
"""
|
||||
模拟 YOLO 检测结果对象
|
||||
提供与 ultralytics YOLO 结果相同的接口
|
||||
"""
|
||||
|
||||
def __init__(self, detections: List[Dict]):
|
||||
self.detections = detections
|
||||
self.names = {0: 'cigarette'}
|
||||
|
||||
# 创建 boxes 对象
|
||||
self.boxes = Boxes(detections)
|
||||
|
||||
# 其他 YOLO 结果属性
|
||||
self.probs = None
|
||||
self.keypoints = None
|
||||
self.obb = None
|
||||
self.speed = {'preprocess': 0, 'inference': 0, 'postprocess': 0}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.detections)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""支持索引访问"""
|
||||
if idx < len(self.detections):
|
||||
return YOLOResult([self.detections[idx]])
|
||||
return YOLOResult([])
|
||||
|
||||
def plot(self, **kwargs):
|
||||
"""绘制检测结果(兼容方法)"""
|
||||
return None
|
||||
|
||||
|
||||
class Boxes:
|
||||
"""
|
||||
模拟 YOLO boxes 对象
|
||||
提供 xyxy, conf, cls 等属性
|
||||
"""
|
||||
|
||||
def __init__(self, detections: List[Dict]):
|
||||
self.detections = detections
|
||||
|
||||
# 尝试使用 torch,如果没有则使用 numpy
|
||||
try:
|
||||
import torch
|
||||
|
||||
if detections:
|
||||
xyxy_list = [[d['bbox'][0], d['bbox'][1], d['bbox'][2], d['bbox'][3]] for d in detections]
|
||||
conf_list = [[d['confidence']] for d in detections]
|
||||
cls_list = [[d['class']] for d in detections]
|
||||
|
||||
self.xyxy = torch.tensor(xyxy_list, dtype=torch.float32)
|
||||
self.conf = torch.tensor(conf_list, dtype=torch.float32)
|
||||
self.cls = torch.tensor(cls_list, dtype=torch.int64)
|
||||
self.id = None
|
||||
else:
|
||||
self.xyxy = torch.empty((0, 4))
|
||||
self.conf = torch.empty((0, 1))
|
||||
self.cls = torch.empty((0, 1), dtype=torch.int64)
|
||||
self.id = None
|
||||
|
||||
except ImportError:
|
||||
# 如果没有 torch,使用 numpy
|
||||
import numpy as np
|
||||
|
||||
if detections:
|
||||
xyxy_list = [[d['bbox'][0], d['bbox'][1], d['bbox'][2], d['bbox'][3]] for d in detections]
|
||||
conf_list = [[d['confidence']] for d in detections]
|
||||
cls_list = [[d['class']] for d in detections]
|
||||
|
||||
self.xyxy = np.array(xyxy_list, dtype=np.float32)
|
||||
self.conf = np.array(conf_list, dtype=np.float32)
|
||||
self.cls = np.array(cls_list, dtype=np.int64)
|
||||
self.id = None
|
||||
else:
|
||||
self.xyxy = np.empty((0, 4), dtype=np.float32)
|
||||
self.conf = np.empty((0, 1), dtype=np.float32)
|
||||
self.cls = np.empty((0, 1), dtype=np.int64)
|
||||
self.id = None
|
||||
|
||||
def __len__(self):
|
||||
return len(self.detections)
|
||||
|
||||
def __iter__(self):
|
||||
"""使 Boxes 可迭代"""
|
||||
for i in range(len(self.detections)):
|
||||
yield Box(self, i)
|
||||
|
||||
def cpu(self):
|
||||
"""兼容方法"""
|
||||
return self
|
||||
|
||||
def numpy(self):
|
||||
"""转换为 numpy"""
|
||||
if hasattr(self.xyxy, 'numpy'):
|
||||
return type('Boxes', (), {
|
||||
'xyxy': self.xyxy.numpy(),
|
||||
'conf': self.conf.numpy(),
|
||||
'cls': self.cls.numpy(),
|
||||
'id': self.id
|
||||
})()
|
||||
return self
|
||||
|
||||
|
||||
class Box:
|
||||
"""
|
||||
模拟单个检测框对象
|
||||
"""
|
||||
|
||||
def __init__(self, boxes: Boxes, index: int):
|
||||
self._boxes = boxes
|
||||
self._index = index
|
||||
|
||||
@property
|
||||
def xyxy(self):
|
||||
"""返回 xyxy 坐标 (1, 4) 形状"""
|
||||
import torch
|
||||
import numpy as np
|
||||
coords = self._boxes.xyxy[self._index]
|
||||
if isinstance(coords, torch.Tensor):
|
||||
return coords.unsqueeze(0)
|
||||
else:
|
||||
return np.array([coords])
|
||||
|
||||
@property
|
||||
def conf(self):
|
||||
"""返回置信度 (1,) 形状 - 与 YOLO 兼容"""
|
||||
import torch
|
||||
import numpy as np
|
||||
conf_val = self._boxes.conf[self._index]
|
||||
# 返回 (1,) 形状,与 YOLO 一致
|
||||
if isinstance(conf_val, torch.Tensor):
|
||||
return conf_val.view(1)
|
||||
else:
|
||||
return np.array([conf_val])
|
||||
|
||||
@property
|
||||
def cls(self):
|
||||
"""返回类别 (1,) 形状 - 与 YOLO 兼容"""
|
||||
import torch
|
||||
import numpy as np
|
||||
cls_val = self._boxes.cls[self._index]
|
||||
# 返回 (1,) 形状,与 YOLO 一致
|
||||
if isinstance(cls_val, torch.Tensor):
|
||||
return cls_val.view(1)
|
||||
else:
|
||||
return np.array([cls_val])
|
||||
359
apps/server/models/smoking_yolo_adapter_fast.py
Normal file
359
apps/server/models/smoking_yolo_adapter_fast.py
Normal file
@@ -0,0 +1,359 @@
|
||||
"""
|
||||
YOLO 格式的抽烟检测模型适配器(快速版)
|
||||
使用 HTTP API 与常驻 Docker 容器通信
|
||||
"""
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import subprocess
|
||||
import tempfile
|
||||
import logging
|
||||
import time
|
||||
import json
|
||||
import requests
|
||||
from typing import List, Dict, Optional
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SmokingDetectionYOLO:
|
||||
"""
|
||||
模拟 YOLO 接口的抽烟检测模型(快速版)
|
||||
使用 HTTP API 与常驻 Docker 容器通信
|
||||
"""
|
||||
|
||||
_container_name = "smoking-detection-daemon"
|
||||
_initialized = False
|
||||
_server_url = "http://localhost:8080"
|
||||
|
||||
def __init__(self, model_path=None):
|
||||
self.model_name = "smoking_detection"
|
||||
self.docker_image = "smoking-detection:test"
|
||||
self.model_dir = "output_inference/ppyoloe_crn_s_80e_smoking_visdrone"
|
||||
self.threshold = 0.1
|
||||
|
||||
# YOLO 兼容属性
|
||||
self.names = {0: 'cigarette'}
|
||||
self.model = self
|
||||
|
||||
# 检查 Docker 并启动常驻服务器
|
||||
self._check_docker()
|
||||
if self.available:
|
||||
self._start_daemon()
|
||||
|
||||
logger.info(f"抽烟检测模型快速版初始化完成,Docker可用: {self.available}")
|
||||
|
||||
def _check_docker(self):
|
||||
"""检查 Docker 环境"""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["docker", "info"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
self.available = result.returncode == 0
|
||||
|
||||
if self.available:
|
||||
result = subprocess.run(
|
||||
["docker", "image", "inspect", self.docker_image],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
self.available = result.returncode == 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Docker 检查失败: {e}")
|
||||
self.available = False
|
||||
|
||||
def _start_daemon(self):
|
||||
"""启动常驻服务器"""
|
||||
try:
|
||||
# 检查容器是否已在运行
|
||||
result = subprocess.run(
|
||||
["docker", "ps", "-q", "-f", f"name={self._container_name}"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
|
||||
if result.stdout.strip():
|
||||
logger.info(f"常驻服务器已在运行: {self._container_name}")
|
||||
SmokingDetectionYOLO._initialized = True
|
||||
return
|
||||
|
||||
# 检查容器是否存在但已停止
|
||||
result = subprocess.run(
|
||||
["docker", "ps", "-aq", "-f", f"name={self._container_name}"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
|
||||
if result.stdout.strip():
|
||||
# 删除旧容器
|
||||
logger.info("删除旧容器")
|
||||
subprocess.run(
|
||||
["docker", "rm", "-f", self._container_name],
|
||||
capture_output=True,
|
||||
timeout=10
|
||||
)
|
||||
|
||||
# 创建新容器并启动常驻服务器
|
||||
logger.info("启动常驻服务器...")
|
||||
subprocess.run(
|
||||
[
|
||||
"docker", "run", "-d",
|
||||
"--name", self._container_name,
|
||||
"-p", "8080:8080",
|
||||
"-v", "/tmp:/workspace/input",
|
||||
"-v", "/Users/wwh/project/video-model/PaddlePaddle/PaddleDetection-release-2.9/smoking_server_daemon.py:/workspace/PaddleDetection/smoking_server_daemon.py",
|
||||
"-w", "/workspace/PaddleDetection",
|
||||
self.docker_image,
|
||||
"python", "smoking_server_daemon.py"
|
||||
],
|
||||
capture_output=True,
|
||||
timeout=10
|
||||
)
|
||||
|
||||
# 等待服务器启动
|
||||
logger.info("等待服务器启动...")
|
||||
time.sleep(5)
|
||||
|
||||
SmokingDetectionYOLO._initialized = True
|
||||
logger.info("常驻服务器启动成功")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"启动常驻服务器失败: {e}")
|
||||
SmokingDetectionYOLO._initialized = False
|
||||
|
||||
def __call__(self, source, conf=0.1, iou=0.45, verbose=False, stream=False):
|
||||
"""模拟 YOLO 模型的调用接口"""
|
||||
if not self.available:
|
||||
logger.error("Docker 不可用,无法运行检测")
|
||||
return [YOLOResult([])]
|
||||
|
||||
# 处理不同类型的输入
|
||||
if isinstance(source, str):
|
||||
image = cv2.imread(source)
|
||||
if image is None:
|
||||
logger.error(f"无法读取图片: {source}")
|
||||
return [YOLOResult([])]
|
||||
return self._detect_single(image, conf, verbose)
|
||||
|
||||
elif isinstance(source, np.ndarray):
|
||||
return self._detect_single(source, conf, verbose)
|
||||
|
||||
elif isinstance(source, list):
|
||||
results = []
|
||||
for img in source:
|
||||
if isinstance(img, str):
|
||||
img = cv2.imread(img)
|
||||
if img is not None:
|
||||
results.extend(self._detect_single(img, conf, verbose))
|
||||
return results
|
||||
|
||||
else:
|
||||
logger.error(f"不支持的输入类型: {type(source)}")
|
||||
return [YOLOResult([])]
|
||||
|
||||
def _detect_single(self, image: np.ndarray, conf: float, verbose: bool) -> List['YOLOResult']:
|
||||
"""检测单张图片(使用 HTTP API)"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 创建临时文件
|
||||
input_filename = f"smoking_fast_{int(time.time()*1000)}.jpg"
|
||||
temp_input = f"/tmp/{input_filename}"
|
||||
|
||||
# 保存输入图片
|
||||
cv2.imwrite(temp_input, image)
|
||||
|
||||
if verbose:
|
||||
logger.info(f"正在检测: {temp_input}")
|
||||
|
||||
# 发送 HTTP 请求
|
||||
request = {
|
||||
'image_path': f'/workspace/input/{input_filename}',
|
||||
'threshold': conf
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"{self._server_url}/detect",
|
||||
json=request,
|
||||
timeout=30
|
||||
)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
if verbose:
|
||||
logger.info(f"检测完成,耗时: {elapsed:.2f}秒")
|
||||
|
||||
# 解析结果
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
if data.get('success'):
|
||||
detections = data.get('detections', [])
|
||||
else:
|
||||
logger.error(f"检测失败: {data.get('error')}")
|
||||
detections = []
|
||||
else:
|
||||
logger.error(f"HTTP 错误: {response.status_code}")
|
||||
detections = []
|
||||
|
||||
# 清理临时文件
|
||||
try:
|
||||
os.remove(temp_input)
|
||||
except:
|
||||
pass
|
||||
|
||||
return [YOLOResult(detections)]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检测失败: {e}")
|
||||
return [YOLOResult([])]
|
||||
|
||||
def predict(self, source, **kwargs):
|
||||
"""兼容 predict 方法"""
|
||||
return self.__call__(source, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def stop_daemon(cls):
|
||||
"""停止常驻服务器"""
|
||||
try:
|
||||
subprocess.run(
|
||||
["docker", "stop", cls._container_name],
|
||||
capture_output=True,
|
||||
timeout=10
|
||||
)
|
||||
logger.info("常驻服务器已停止")
|
||||
except Exception as e:
|
||||
logger.error(f"停止常驻服务器失败: {e}")
|
||||
|
||||
|
||||
# YOLOResult, Boxes, Box 类(与之前相同)
|
||||
class YOLOResult:
|
||||
"""模拟 YOLO 检测结果对象"""
|
||||
|
||||
def __init__(self, detections: List[Dict]):
|
||||
self.detections = detections
|
||||
self.names = {0: 'cigarette'}
|
||||
self.boxes = Boxes(detections)
|
||||
self.probs = None
|
||||
self.keypoints = None
|
||||
self.obb = None
|
||||
self.speed = {'preprocess': 0, 'inference': 0, 'postprocess': 0}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.detections)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if idx < len(self.detections):
|
||||
return YOLOResult([self.detections[idx]])
|
||||
return YOLOResult([])
|
||||
|
||||
def plot(self, **kwargs):
|
||||
return None
|
||||
|
||||
|
||||
class Boxes:
|
||||
"""模拟 YOLO boxes 对象"""
|
||||
|
||||
def __init__(self, detections: List[Dict]):
|
||||
self.detections = detections
|
||||
|
||||
try:
|
||||
import torch
|
||||
|
||||
if detections:
|
||||
xyxy_list = [[d['bbox'][0], d['bbox'][1], d['bbox'][2], d['bbox'][3]] for d in detections]
|
||||
conf_list = [[d['confidence']] for d in detections]
|
||||
cls_list = [[d['class']] for d in detections]
|
||||
|
||||
self.xyxy = torch.tensor(xyxy_list, dtype=torch.float32)
|
||||
self.conf = torch.tensor(conf_list, dtype=torch.float32)
|
||||
self.cls = torch.tensor(cls_list, dtype=torch.int64)
|
||||
self.id = None
|
||||
else:
|
||||
self.xyxy = torch.empty((0, 4))
|
||||
self.conf = torch.empty((0, 1))
|
||||
self.cls = torch.empty((0, 1), dtype=torch.int64)
|
||||
self.id = None
|
||||
|
||||
except ImportError:
|
||||
import numpy as np
|
||||
|
||||
if detections:
|
||||
xyxy_list = [[d['bbox'][0], d['bbox'][1], d['bbox'][2], d['bbox'][3]] for d in detections]
|
||||
conf_list = [[d['confidence']] for d in detections]
|
||||
cls_list = [[d['class']] for d in detections]
|
||||
|
||||
self.xyxy = np.array(xyxy_list, dtype=np.float32)
|
||||
self.conf = np.array(conf_list, dtype=np.float32)
|
||||
self.cls = np.array(cls_list, dtype=np.int64)
|
||||
self.id = None
|
||||
else:
|
||||
self.xyxy = np.empty((0, 4), dtype=np.float32)
|
||||
self.conf = np.empty((0, 1), dtype=np.float32)
|
||||
self.cls = np.empty((0, 1), dtype=np.int64)
|
||||
self.id = None
|
||||
|
||||
def __len__(self):
|
||||
return len(self.detections)
|
||||
|
||||
def __iter__(self):
|
||||
for i in range(len(self.detections)):
|
||||
yield Box(self, i)
|
||||
|
||||
def cpu(self):
|
||||
return self
|
||||
|
||||
def numpy(self):
|
||||
if hasattr(self.xyxy, 'numpy'):
|
||||
return type('Boxes', (), {
|
||||
'xyxy': self.xyxy.numpy(),
|
||||
'conf': self.conf.numpy(),
|
||||
'cls': self.cls.numpy(),
|
||||
'id': self.id
|
||||
})()
|
||||
return self
|
||||
|
||||
|
||||
class Box:
|
||||
"""模拟单个检测框对象"""
|
||||
|
||||
def __init__(self, boxes: Boxes, index: int):
|
||||
self._boxes = boxes
|
||||
self._index = index
|
||||
|
||||
@property
|
||||
def xyxy(self):
|
||||
import torch
|
||||
import numpy as np
|
||||
coords = self._boxes.xyxy[self._index]
|
||||
if isinstance(coords, torch.Tensor):
|
||||
return coords.unsqueeze(0)
|
||||
else:
|
||||
return np.array([coords])
|
||||
|
||||
@property
|
||||
def conf(self):
|
||||
import torch
|
||||
import numpy as np
|
||||
conf_val = self._boxes.conf[self._index]
|
||||
if isinstance(conf_val, torch.Tensor):
|
||||
return conf_val.view(1)
|
||||
else:
|
||||
return np.array([conf_val])
|
||||
|
||||
@property
|
||||
def cls(self):
|
||||
import torch
|
||||
import numpy as np
|
||||
cls_val = self._boxes.cls[self._index]
|
||||
if isinstance(cls_val, torch.Tensor):
|
||||
return cls_val.view(1)
|
||||
else:
|
||||
return np.array([cls_val])
|
||||
399
apps/server/models/smoking_yolo_adapter_optimized.py
Normal file
399
apps/server/models/smoking_yolo_adapter_optimized.py
Normal file
@@ -0,0 +1,399 @@
|
||||
"""
|
||||
YOLO 格式的抽烟检测模型适配器(优化版)
|
||||
使用后台 Docker 容器,避免重复启动
|
||||
"""
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import subprocess
|
||||
import tempfile
|
||||
import logging
|
||||
import time
|
||||
from typing import List, Dict, Optional
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SmokingDetectionYOLO:
|
||||
"""
|
||||
模拟 YOLO 接口的抽烟检测模型(优化版)
|
||||
使用后台 Docker 容器,避免每次检测都启动新容器
|
||||
"""
|
||||
|
||||
_container_name = "smoking-detection-server"
|
||||
_container_started = False
|
||||
|
||||
def __init__(self, model_path=None):
|
||||
self.model_name = "smoking_detection"
|
||||
self.docker_image = "smoking-detection:test"
|
||||
self.model_dir = "output_inference/ppyoloe_crn_s_80e_smoking_visdrone"
|
||||
self.threshold = 0.1
|
||||
|
||||
# YOLO 兼容属性
|
||||
self.names = {0: 'cigarette'}
|
||||
self.model = self
|
||||
|
||||
# 检查 Docker 并启动后台容器
|
||||
self._check_docker()
|
||||
if self.available:
|
||||
self._start_background_container()
|
||||
|
||||
logger.info(f"抽烟检测模型初始化完成,Docker可用: {self.available}")
|
||||
|
||||
def _check_docker(self):
|
||||
"""检查 Docker 环境"""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["docker", "info"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
self.available = result.returncode == 0
|
||||
|
||||
if self.available:
|
||||
result = subprocess.run(
|
||||
["docker", "image", "inspect", self.docker_image],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
self.available = result.returncode == 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Docker 检查失败: {e}")
|
||||
self.available = False
|
||||
|
||||
def _start_background_container(self):
|
||||
"""启动后台容器"""
|
||||
try:
|
||||
# 检查容器是否已在运行
|
||||
result = subprocess.run(
|
||||
["docker", "ps", "-q", "-f", f"name={self._container_name}"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
|
||||
if result.stdout.strip():
|
||||
logger.info(f"后台容器已在运行: {self._container_name}")
|
||||
SmokingDetectionYOLO._container_started = True
|
||||
return
|
||||
|
||||
# 检查容器是否存在但已停止
|
||||
result = subprocess.run(
|
||||
["docker", "ps", "-aq", "-f", f"name={self._container_name}"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
|
||||
if result.stdout.strip():
|
||||
# 启动已存在的容器
|
||||
logger.info(f"启动已存在的容器: {self._container_name}")
|
||||
subprocess.run(
|
||||
["docker", "start", self._container_name],
|
||||
capture_output=True,
|
||||
timeout=10
|
||||
)
|
||||
else:
|
||||
# 创建新容器
|
||||
logger.info(f"创建后台容器: {self._container_name}")
|
||||
subprocess.run(
|
||||
[
|
||||
"docker", "run", "-d",
|
||||
"--name", self._container_name,
|
||||
"-v", "/tmp:/workspace/input",
|
||||
"-v", "/tmp:/workspace/output",
|
||||
self.docker_image,
|
||||
"tail", "-f", "/dev/null"
|
||||
],
|
||||
capture_output=True,
|
||||
timeout=10
|
||||
)
|
||||
|
||||
SmokingDetectionYOLO._container_started = True
|
||||
logger.info("后台容器启动成功")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"启动后台容器失败: {e}")
|
||||
SmokingDetectionYOLO._container_started = False
|
||||
|
||||
def __call__(self, source, conf=0.1, iou=0.45, verbose=False, stream=False):
|
||||
"""模拟 YOLO 模型的调用接口"""
|
||||
if not self.available:
|
||||
logger.error("Docker 不可用,无法运行检测")
|
||||
return [YOLOResult([])]
|
||||
|
||||
# 处理不同类型的输入
|
||||
if isinstance(source, str):
|
||||
image = cv2.imread(source)
|
||||
if image is None:
|
||||
logger.error(f"无法读取图片: {source}")
|
||||
return [YOLOResult([])]
|
||||
return self._detect_single(image, conf, verbose)
|
||||
|
||||
elif isinstance(source, np.ndarray):
|
||||
return self._detect_single(source, conf, verbose)
|
||||
|
||||
elif isinstance(source, list):
|
||||
results = []
|
||||
for img in source:
|
||||
if isinstance(img, str):
|
||||
img = cv2.imread(img)
|
||||
if img is not None:
|
||||
results.extend(self._detect_single(img, conf, verbose))
|
||||
return results
|
||||
|
||||
else:
|
||||
logger.error(f"不支持的输入类型: {type(source)}")
|
||||
return [YOLOResult([])]
|
||||
|
||||
def _detect_single(self, image: np.ndarray, conf: float, verbose: bool) -> List['YOLOResult']:
|
||||
"""检测单张图片(使用后台容器)"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 创建临时文件
|
||||
input_filename = f"smoking_input_{int(time.time()*1000)}.jpg"
|
||||
output_filename = f"smoking_output_{int(time.time()*1000)}.jpg"
|
||||
temp_input = f"/tmp/{input_filename}"
|
||||
temp_output = f"/tmp/{output_filename}"
|
||||
|
||||
# 保存输入图片
|
||||
cv2.imwrite(temp_input, image)
|
||||
|
||||
if verbose:
|
||||
logger.info(f"正在检测: {temp_input}")
|
||||
|
||||
# 使用后台容器执行检测
|
||||
if SmokingDetectionYOLO._container_started:
|
||||
# 使用 exec 在运行中的容器内执行
|
||||
cmd = [
|
||||
"docker", "exec",
|
||||
self._container_name,
|
||||
"python", "deploy/python/infer.py",
|
||||
f"--model_dir={self.model_dir}",
|
||||
f"--image_file=/workspace/input/{input_filename}",
|
||||
"--device=CPU",
|
||||
f"--output_dir=/workspace/output",
|
||||
f"--threshold={conf}"
|
||||
]
|
||||
else:
|
||||
# 回退到原来的方式
|
||||
cmd = [
|
||||
"docker", "run", "--rm",
|
||||
"-v", f"{temp_input}:/workspace/input.jpg",
|
||||
self.docker_image,
|
||||
"python", "deploy/python/infer.py",
|
||||
f"--model_dir={self.model_dir}",
|
||||
"--image_file=/workspace/input.jpg",
|
||||
"--device=CPU",
|
||||
"--output_dir=/workspace",
|
||||
f"--threshold={conf}"
|
||||
]
|
||||
|
||||
# 执行检测
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60
|
||||
)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
if verbose:
|
||||
logger.info(f"检测完成,耗时: {elapsed:.2f}秒")
|
||||
|
||||
# 解析结果
|
||||
detections = self._parse_output(result.stdout)
|
||||
|
||||
# 清理临时文件
|
||||
try:
|
||||
os.remove(temp_input)
|
||||
except:
|
||||
pass
|
||||
|
||||
return [YOLOResult(detections)]
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error("检测超时")
|
||||
return [YOLOResult([])]
|
||||
except Exception as e:
|
||||
logger.error(f"检测失败: {e}")
|
||||
return [YOLOResult([])]
|
||||
|
||||
def _parse_output(self, output: str) -> List[Dict]:
|
||||
"""解析检测输出"""
|
||||
detections = []
|
||||
import re
|
||||
|
||||
pattern = r'class_id:\d+,\s*confidence:([\d.]+),\s*left_top:\[([\d.]+),\s*([\d.]+)\],\s*right_bottom:\[([\d.]+),\s*([\d.]+)\]'
|
||||
|
||||
for line in output.split('\n'):
|
||||
match = re.search(pattern, line)
|
||||
if match:
|
||||
try:
|
||||
confidence = float(match.group(1))
|
||||
x1 = float(match.group(2))
|
||||
y1 = float(match.group(3))
|
||||
x2 = float(match.group(4))
|
||||
y2 = float(match.group(5))
|
||||
|
||||
detections.append({
|
||||
'bbox': [int(x1), int(y1), int(x2), int(y2)],
|
||||
'confidence': confidence,
|
||||
'class': 0,
|
||||
'name': 'cigarette'
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"解析检测结果失败: {e}, line: {line}")
|
||||
continue
|
||||
|
||||
return detections
|
||||
|
||||
def predict(self, source, **kwargs):
|
||||
"""兼容 predict 方法"""
|
||||
return self.__call__(source, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def stop_background_container(cls):
|
||||
"""停止后台容器"""
|
||||
try:
|
||||
subprocess.run(
|
||||
["docker", "stop", cls._container_name],
|
||||
capture_output=True,
|
||||
timeout=10
|
||||
)
|
||||
logger.info("后台容器已停止")
|
||||
except Exception as e:
|
||||
logger.error(f"停止后台容器失败: {e}")
|
||||
|
||||
|
||||
# 复制 YOLOResult, Boxes, Box 类(与原版相同)
|
||||
class YOLOResult:
|
||||
"""模拟 YOLO 检测结果对象"""
|
||||
|
||||
def __init__(self, detections: List[Dict]):
|
||||
self.detections = detections
|
||||
self.names = {0: 'cigarette'}
|
||||
self.boxes = Boxes(detections)
|
||||
self.probs = None
|
||||
self.keypoints = None
|
||||
self.obb = None
|
||||
self.speed = {'preprocess': 0, 'inference': 0, 'postprocess': 0}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.detections)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if idx < len(self.detections):
|
||||
return YOLOResult([self.detections[idx]])
|
||||
return YOLOResult([])
|
||||
|
||||
def plot(self, **kwargs):
|
||||
return None
|
||||
|
||||
|
||||
class Boxes:
|
||||
"""模拟 YOLO boxes 对象"""
|
||||
|
||||
def __init__(self, detections: List[Dict]):
|
||||
self.detections = detections
|
||||
|
||||
try:
|
||||
import torch
|
||||
|
||||
if detections:
|
||||
xyxy_list = [[d['bbox'][0], d['bbox'][1], d['bbox'][2], d['bbox'][3]] for d in detections]
|
||||
conf_list = [[d['confidence']] for d in detections]
|
||||
cls_list = [[d['class']] for d in detections]
|
||||
|
||||
self.xyxy = torch.tensor(xyxy_list, dtype=torch.float32)
|
||||
self.conf = torch.tensor(conf_list, dtype=torch.float32)
|
||||
self.cls = torch.tensor(cls_list, dtype=torch.int64)
|
||||
self.id = None
|
||||
else:
|
||||
self.xyxy = torch.empty((0, 4))
|
||||
self.conf = torch.empty((0, 1))
|
||||
self.cls = torch.empty((0, 1), dtype=torch.int64)
|
||||
self.id = None
|
||||
|
||||
except ImportError:
|
||||
import numpy as np
|
||||
|
||||
if detections:
|
||||
xyxy_list = [[d['bbox'][0], d['bbox'][1], d['bbox'][2], d['bbox'][3]] for d in detections]
|
||||
conf_list = [[d['confidence']] for d in detections]
|
||||
cls_list = [[d['class']] for d in detections]
|
||||
|
||||
self.xyxy = np.array(xyxy_list, dtype=np.float32)
|
||||
self.conf = np.array(conf_list, dtype=np.float32)
|
||||
self.cls = np.array(cls_list, dtype=np.int64)
|
||||
self.id = None
|
||||
else:
|
||||
self.xyxy = np.empty((0, 4), dtype=np.float32)
|
||||
self.conf = np.empty((0, 1), dtype=np.float32)
|
||||
self.cls = np.empty((0, 1), dtype=np.int64)
|
||||
self.id = None
|
||||
|
||||
def __len__(self):
|
||||
return len(self.detections)
|
||||
|
||||
def __iter__(self):
|
||||
for i in range(len(self.detections)):
|
||||
yield Box(self, i)
|
||||
|
||||
def cpu(self):
|
||||
return self
|
||||
|
||||
def numpy(self):
|
||||
if hasattr(self.xyxy, 'numpy'):
|
||||
return type('Boxes', (), {
|
||||
'xyxy': self.xyxy.numpy(),
|
||||
'conf': self.conf.numpy(),
|
||||
'cls': self.cls.numpy(),
|
||||
'id': self.id
|
||||
})()
|
||||
return self
|
||||
|
||||
|
||||
class Box:
|
||||
"""模拟单个检测框对象"""
|
||||
|
||||
def __init__(self, boxes: Boxes, index: int):
|
||||
self._boxes = boxes
|
||||
self._index = index
|
||||
|
||||
@property
|
||||
def xyxy(self):
|
||||
import torch
|
||||
import numpy as np
|
||||
coords = self._boxes.xyxy[self._index]
|
||||
if isinstance(coords, torch.Tensor):
|
||||
return coords.unsqueeze(0)
|
||||
else:
|
||||
return np.array([coords])
|
||||
|
||||
@property
|
||||
def conf(self):
|
||||
import torch
|
||||
import numpy as np
|
||||
conf_val = self._boxes.conf[self._index]
|
||||
if isinstance(conf_val, torch.Tensor):
|
||||
return conf_val.view(1)
|
||||
else:
|
||||
return np.array([conf_val])
|
||||
|
||||
@property
|
||||
def cls(self):
|
||||
import torch
|
||||
import numpy as np
|
||||
cls_val = self._boxes.cls[self._index]
|
||||
if isinstance(cls_val, torch.Tensor):
|
||||
return cls_val.view(1)
|
||||
else:
|
||||
return np.array([cls_val])
|
||||
379
apps/server/models/smoking_yolo_adapter_v2.py
Normal file
379
apps/server/models/smoking_yolo_adapter_v2.py
Normal file
@@ -0,0 +1,379 @@
|
||||
"""
|
||||
YOLO 格式的抽烟检测模型适配器(V2 - 常驻进程版)
|
||||
使用 Docker 容器内的常驻 Python 进程,避免每次检测都启动新进程
|
||||
"""
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import subprocess
|
||||
import tempfile
|
||||
import logging
|
||||
import time
|
||||
import json
|
||||
from typing import List, Dict, Optional
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SmokingDetectionYOLO:
|
||||
"""
|
||||
模拟 YOLO 接口的抽烟检测模型(V2 - 常驻进程版)
|
||||
使用 Docker 容器内的常驻 Python 进程
|
||||
"""
|
||||
|
||||
_container_name = "smoking-detection-v2"
|
||||
_process = None
|
||||
_initialized = False
|
||||
|
||||
def __init__(self, model_path=None):
|
||||
self.model_name = "smoking_detection"
|
||||
self.docker_image = "smoking-detection:test"
|
||||
self.model_dir = "output_inference/ppyoloe_crn_s_80e_smoking_visdrone"
|
||||
self.threshold = 0.1
|
||||
|
||||
# YOLO 兼容属性
|
||||
self.names = {0: 'cigarette'}
|
||||
self.model = self
|
||||
|
||||
# 检查 Docker 并启动常驻进程
|
||||
self._check_docker()
|
||||
if self.available:
|
||||
self._start_server()
|
||||
|
||||
logger.info(f"抽烟检测模型 V2 初始化完成,Docker可用: {self.available}")
|
||||
|
||||
def _check_docker(self):
|
||||
"""检查 Docker 环境"""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["docker", "info"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
self.available = result.returncode == 0
|
||||
|
||||
if self.available:
|
||||
result = subprocess.run(
|
||||
["docker", "image", "inspect", self.docker_image],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
self.available = result.returncode == 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Docker 检查失败: {e}")
|
||||
self.available = False
|
||||
|
||||
def _start_server(self):
|
||||
"""启动常驻服务器进程"""
|
||||
try:
|
||||
# 检查是否已有进程在运行
|
||||
if SmokingDetectionYOLO._process is not None:
|
||||
logger.info("常驻进程已在运行")
|
||||
return
|
||||
|
||||
# 检查容器是否已存在
|
||||
result = subprocess.run(
|
||||
["docker", "ps", "-aq", "-f", f"name={self._container_name}"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
|
||||
if result.stdout.strip():
|
||||
# 删除旧容器
|
||||
logger.info("删除旧容器")
|
||||
subprocess.run(
|
||||
["docker", "rm", "-f", self._container_name],
|
||||
capture_output=True,
|
||||
timeout=10
|
||||
)
|
||||
|
||||
# 启动新容器并运行服务器
|
||||
logger.info("启动常驻服务器...")
|
||||
|
||||
# 获取 smoking_server.py 的绝对路径
|
||||
server_script_path = "/Users/wwh/project/video-model/PaddlePaddle/PaddleDetection-release-2.9/smoking_server.py"
|
||||
|
||||
# 使用 Popen 保持进程运行,挂载 server 脚本
|
||||
SmokingDetectionYOLO._process = subprocess.Popen(
|
||||
[
|
||||
"docker", "run", "-i", "--rm",
|
||||
"--name", self._container_name,
|
||||
"-v", "/tmp:/workspace/input",
|
||||
"-v", f"{server_script_path}:/workspace/PaddleDetection/smoking_server.py",
|
||||
"-w", "/workspace/PaddleDetection",
|
||||
self.docker_image,
|
||||
"python", "smoking_server.py"
|
||||
],
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
bufsize=1
|
||||
)
|
||||
|
||||
# 等待服务器启动(读取模型加载完成的消息)
|
||||
logger.info("等待服务器启动...")
|
||||
start_wait = time.time()
|
||||
while time.time() - start_wait < 30: # 最多等待30秒
|
||||
if SmokingDetectionYOLO._process.poll() is not None:
|
||||
# 进程已退出
|
||||
stderr = SmokingDetectionYOLO._process.stderr.read()
|
||||
logger.error(f"服务器启动失败: {stderr}")
|
||||
SmokingDetectionYOLO._process = None
|
||||
return
|
||||
|
||||
# 尝试读取 stderr 看是否加载完成
|
||||
import select
|
||||
if SmokingDetectionYOLO._process.stderr:
|
||||
ready, _, _ = select.select([SmokingDetectionYOLO._process.stderr], [], [], 0.5)
|
||||
if ready:
|
||||
line = SmokingDetectionYOLO._process.stderr.readline()
|
||||
if line:
|
||||
logger.info(f"Server: {line.strip()}")
|
||||
if "模型加载完成" in line:
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
# 检查进程是否还在运行
|
||||
if SmokingDetectionYOLO._process.poll() is None:
|
||||
SmokingDetectionYOLO._initialized = True
|
||||
logger.info("常驻服务器启动成功")
|
||||
else:
|
||||
stderr = SmokingDetectionYOLO._process.stderr.read()
|
||||
logger.error(f"服务器启动失败: {stderr}")
|
||||
SmokingDetectionYOLO._process = None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"启动常驻服务器失败: {e}")
|
||||
SmokingDetectionYOLO._process = None
|
||||
|
||||
def __call__(self, source, conf=0.1, iou=0.45, verbose=False, stream=False):
|
||||
"""模拟 YOLO 模型的调用接口"""
|
||||
if not self.available:
|
||||
logger.error("Docker 不可用,无法运行检测")
|
||||
return [YOLOResult([])]
|
||||
|
||||
if not SmokingDetectionYOLO._initialized:
|
||||
logger.error("常驻服务器未初始化")
|
||||
return [YOLOResult([])]
|
||||
|
||||
# 处理不同类型的输入
|
||||
if isinstance(source, str):
|
||||
image = cv2.imread(source)
|
||||
if image is None:
|
||||
logger.error(f"无法读取图片: {source}")
|
||||
return [YOLOResult([])]
|
||||
return self._detect_single(image, conf, verbose)
|
||||
|
||||
elif isinstance(source, np.ndarray):
|
||||
return self._detect_single(source, conf, verbose)
|
||||
|
||||
elif isinstance(source, list):
|
||||
results = []
|
||||
for img in source:
|
||||
if isinstance(img, str):
|
||||
img = cv2.imread(img)
|
||||
if img is not None:
|
||||
results.extend(self._detect_single(img, conf, verbose))
|
||||
return results
|
||||
|
||||
else:
|
||||
logger.error(f"不支持的输入类型: {type(source)}")
|
||||
return [YOLOResult([])]
|
||||
|
||||
def _detect_single(self, image: np.ndarray, conf: float, verbose: bool) -> List['YOLOResult']:
|
||||
"""检测单张图片(使用常驻进程)"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 创建临时文件
|
||||
input_filename = f"smoking_v2_{int(time.time()*1000)}.jpg"
|
||||
temp_input = f"/tmp/{input_filename}"
|
||||
|
||||
# 保存输入图片
|
||||
cv2.imwrite(temp_input, image)
|
||||
|
||||
if verbose:
|
||||
logger.info(f"正在检测: {temp_input}")
|
||||
|
||||
# 发送请求到常驻进程
|
||||
request = {
|
||||
'image_path': f'/workspace/input/{input_filename}',
|
||||
'threshold': conf
|
||||
}
|
||||
|
||||
SmokingDetectionYOLO._process.stdin.write(json.dumps(request) + '\n')
|
||||
SmokingDetectionYOLO._process.stdin.flush()
|
||||
|
||||
# 读取响应
|
||||
response_line = SmokingDetectionYOLO._process.stdout.readline()
|
||||
response = json.loads(response_line)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
if verbose:
|
||||
logger.info(f"检测完成,耗时: {elapsed:.2f}秒")
|
||||
|
||||
# 解析结果
|
||||
if response.get('success'):
|
||||
detections = response.get('detections', [])
|
||||
else:
|
||||
logger.error(f"检测失败: {response.get('error')}")
|
||||
detections = []
|
||||
|
||||
# 清理临时文件
|
||||
try:
|
||||
os.remove(temp_input)
|
||||
except:
|
||||
pass
|
||||
|
||||
return [YOLOResult(detections)]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检测失败: {e}")
|
||||
return [YOLOResult([])]
|
||||
|
||||
def predict(self, source, **kwargs):
|
||||
"""兼容 predict 方法"""
|
||||
return self.__call__(source, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def stop_server(cls):
|
||||
"""停止常驻服务器"""
|
||||
if cls._process is not None:
|
||||
cls._process.terminate()
|
||||
cls._process.wait()
|
||||
cls._process = None
|
||||
cls._initialized = False
|
||||
logger.info("常驻服务器已停止")
|
||||
|
||||
|
||||
# YOLOResult, Boxes, Box 类(与之前相同)
|
||||
class YOLOResult:
|
||||
"""模拟 YOLO 检测结果对象"""
|
||||
|
||||
def __init__(self, detections: List[Dict]):
|
||||
self.detections = detections
|
||||
self.names = {0: 'cigarette'}
|
||||
self.boxes = Boxes(detections)
|
||||
self.probs = None
|
||||
self.keypoints = None
|
||||
self.obb = None
|
||||
self.speed = {'preprocess': 0, 'inference': 0, 'postprocess': 0}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.detections)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if idx < len(self.detections):
|
||||
return YOLOResult([self.detections[idx]])
|
||||
return YOLOResult([])
|
||||
|
||||
def plot(self, **kwargs):
|
||||
return None
|
||||
|
||||
|
||||
class Boxes:
|
||||
"""模拟 YOLO boxes 对象"""
|
||||
|
||||
def __init__(self, detections: List[Dict]):
|
||||
self.detections = detections
|
||||
|
||||
try:
|
||||
import torch
|
||||
|
||||
if detections:
|
||||
xyxy_list = [[d['bbox'][0], d['bbox'][1], d['bbox'][2], d['bbox'][3]] for d in detections]
|
||||
conf_list = [[d['confidence']] for d in detections]
|
||||
cls_list = [[d['class']] for d in detections]
|
||||
|
||||
self.xyxy = torch.tensor(xyxy_list, dtype=torch.float32)
|
||||
self.conf = torch.tensor(conf_list, dtype=torch.float32)
|
||||
self.cls = torch.tensor(cls_list, dtype=torch.int64)
|
||||
self.id = None
|
||||
else:
|
||||
self.xyxy = torch.empty((0, 4))
|
||||
self.conf = torch.empty((0, 1))
|
||||
self.cls = torch.empty((0, 1), dtype=torch.int64)
|
||||
self.id = None
|
||||
|
||||
except ImportError:
|
||||
import numpy as np
|
||||
|
||||
if detections:
|
||||
xyxy_list = [[d['bbox'][0], d['bbox'][1], d['bbox'][2], d['bbox'][3]] for d in detections]
|
||||
conf_list = [[d['confidence']] for d in detections]
|
||||
cls_list = [[d['class']] for d in detections]
|
||||
|
||||
self.xyxy = np.array(xyxy_list, dtype=np.float32)
|
||||
self.conf = np.array(conf_list, dtype=np.float32)
|
||||
self.cls = np.array(cls_list, dtype=np.int64)
|
||||
self.id = None
|
||||
else:
|
||||
self.xyxy = np.empty((0, 4), dtype=np.float32)
|
||||
self.conf = np.empty((0, 1), dtype=np.float32)
|
||||
self.cls = np.empty((0, 1), dtype=np.int64)
|
||||
self.id = None
|
||||
|
||||
def __len__(self):
|
||||
return len(self.detections)
|
||||
|
||||
def __iter__(self):
|
||||
for i in range(len(self.detections)):
|
||||
yield Box(self, i)
|
||||
|
||||
def cpu(self):
|
||||
return self
|
||||
|
||||
def numpy(self):
|
||||
if hasattr(self.xyxy, 'numpy'):
|
||||
return type('Boxes', (), {
|
||||
'xyxy': self.xyxy.numpy(),
|
||||
'conf': self.conf.numpy(),
|
||||
'cls': self.cls.numpy(),
|
||||
'id': self.id
|
||||
})()
|
||||
return self
|
||||
|
||||
|
||||
class Box:
|
||||
"""模拟单个检测框对象"""
|
||||
|
||||
def __init__(self, boxes: Boxes, index: int):
|
||||
self._boxes = boxes
|
||||
self._index = index
|
||||
|
||||
@property
|
||||
def xyxy(self):
|
||||
import torch
|
||||
import numpy as np
|
||||
coords = self._boxes.xyxy[self._index]
|
||||
if isinstance(coords, torch.Tensor):
|
||||
return coords.unsqueeze(0)
|
||||
else:
|
||||
return np.array([coords])
|
||||
|
||||
@property
|
||||
def conf(self):
|
||||
import torch
|
||||
import numpy as np
|
||||
conf_val = self._boxes.conf[self._index]
|
||||
if isinstance(conf_val, torch.Tensor):
|
||||
return conf_val.view(1)
|
||||
else:
|
||||
return np.array([conf_val])
|
||||
|
||||
@property
|
||||
def cls(self):
|
||||
import torch
|
||||
import numpy as np
|
||||
cls_val = self._boxes.cls[self._index]
|
||||
if isinstance(cls_val, torch.Tensor):
|
||||
return cls_val.view(1)
|
||||
else:
|
||||
return np.array([cls_val])
|
||||
377
apps/server/models/smoking_yolo_adapter_v2_simple.py
Normal file
377
apps/server/models/smoking_yolo_adapter_v2_simple.py
Normal file
@@ -0,0 +1,377 @@
|
||||
"""
|
||||
YOLO 格式的抽烟检测模型适配器(V2 简化版)
|
||||
使用 Docker exec 在后台容器中执行检测
|
||||
"""
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import subprocess
|
||||
import tempfile
|
||||
import logging
|
||||
import time
|
||||
import json
|
||||
from typing import List, Dict, Optional
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SmokingDetectionYOLO:
|
||||
"""
|
||||
模拟 YOLO 接口的抽烟检测模型(V2 简化版)
|
||||
使用 Docker exec 在后台容器中执行检测
|
||||
"""
|
||||
|
||||
_container_name = "smoking-detection-server"
|
||||
_initialized = False
|
||||
|
||||
def __init__(self, model_path=None):
|
||||
self.model_name = "smoking_detection"
|
||||
self.docker_image = "smoking-detection:test"
|
||||
self.model_dir = "output_inference/ppyoloe_crn_s_80e_smoking_visdrone"
|
||||
self.threshold = 0.1
|
||||
|
||||
# YOLO 兼容属性
|
||||
self.names = {0: 'cigarette'}
|
||||
self.model = self
|
||||
|
||||
# 检查 Docker 并启动后台容器
|
||||
self._check_docker()
|
||||
if self.available:
|
||||
self._start_background_container()
|
||||
|
||||
logger.info(f"抽烟检测模型 V2 简化版初始化完成,Docker可用: {self.available}")
|
||||
|
||||
def _check_docker(self):
|
||||
"""检查 Docker 环境"""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["docker", "info"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
self.available = result.returncode == 0
|
||||
|
||||
if self.available:
|
||||
result = subprocess.run(
|
||||
["docker", "image", "inspect", self.docker_image],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
self.available = result.returncode == 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Docker 检查失败: {e}")
|
||||
self.available = False
|
||||
|
||||
def _start_background_container(self):
|
||||
"""启动后台容器"""
|
||||
try:
|
||||
# 检查容器是否已在运行
|
||||
result = subprocess.run(
|
||||
["docker", "ps", "-q", "-f", f"name={self._container_name}"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
|
||||
if result.stdout.strip():
|
||||
logger.info(f"后台容器已在运行: {self._container_name}")
|
||||
SmokingDetectionYOLO._initialized = True
|
||||
return
|
||||
|
||||
# 检查容器是否存在但已停止
|
||||
result = subprocess.run(
|
||||
["docker", "ps", "-aq", "-f", f"name={self._container_name}"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
|
||||
if result.stdout.strip():
|
||||
# 启动已存在的容器
|
||||
logger.info(f"启动已存在的容器: {self._container_name}")
|
||||
subprocess.run(
|
||||
["docker", "start", self._container_name],
|
||||
capture_output=True,
|
||||
timeout=10
|
||||
)
|
||||
else:
|
||||
# 创建新容器(保持运行)
|
||||
logger.info(f"创建后台容器: {self._container_name}")
|
||||
subprocess.run(
|
||||
[
|
||||
"docker", "run", "-d",
|
||||
"--name", self._container_name,
|
||||
"-v", "/tmp:/workspace/input",
|
||||
"-v", "/Users/wwh/project/video-model/PaddlePaddle/PaddleDetection-release-2.9/smoking_server.py:/workspace/PaddleDetection/smoking_server.py",
|
||||
"-w", "/workspace/PaddleDetection",
|
||||
self.docker_image,
|
||||
"tail", "-f", "/dev/null"
|
||||
],
|
||||
capture_output=True,
|
||||
timeout=10
|
||||
)
|
||||
|
||||
SmokingDetectionYOLO._initialized = True
|
||||
logger.info("后台容器启动成功")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"启动后台容器失败: {e}")
|
||||
SmokingDetectionYOLO._initialized = False
|
||||
|
||||
def __call__(self, source, conf=0.1, iou=0.45, verbose=False, stream=False):
|
||||
"""模拟 YOLO 模型的调用接口"""
|
||||
if not self.available:
|
||||
logger.error("Docker 不可用,无法运行检测")
|
||||
return [YOLOResult([])]
|
||||
|
||||
# 处理不同类型的输入
|
||||
if isinstance(source, str):
|
||||
image = cv2.imread(source)
|
||||
if image is None:
|
||||
logger.error(f"无法读取图片: {source}")
|
||||
return [YOLOResult([])]
|
||||
return self._detect_single(image, conf, verbose)
|
||||
|
||||
elif isinstance(source, np.ndarray):
|
||||
return self._detect_single(source, conf, verbose)
|
||||
|
||||
elif isinstance(source, list):
|
||||
results = []
|
||||
for img in source:
|
||||
if isinstance(img, str):
|
||||
img = cv2.imread(img)
|
||||
if img is not None:
|
||||
results.extend(self._detect_single(img, conf, verbose))
|
||||
return results
|
||||
|
||||
else:
|
||||
logger.error(f"不支持的输入类型: {type(source)}")
|
||||
return [YOLOResult([])]
|
||||
|
||||
def _detect_single(self, image: np.ndarray, conf: float, verbose: bool) -> List['YOLOResult']:
|
||||
"""检测单张图片(使用 docker exec)"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 创建临时文件
|
||||
input_filename = f"smoking_v2_{int(time.time()*1000)}.jpg"
|
||||
temp_input = f"/tmp/{input_filename}"
|
||||
|
||||
# 保存输入图片
|
||||
cv2.imwrite(temp_input, image)
|
||||
|
||||
if verbose:
|
||||
logger.info(f"正在检测: {temp_input}")
|
||||
|
||||
# 使用 docker exec 执行检测
|
||||
cmd = [
|
||||
"docker", "exec", "-i",
|
||||
self._container_name,
|
||||
"python", "smoking_server.py"
|
||||
]
|
||||
|
||||
# 发送请求
|
||||
request = {
|
||||
'image_path': f'/workspace/input/{input_filename}',
|
||||
'threshold': conf
|
||||
}
|
||||
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
input=json.dumps(request) + '\n',
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60
|
||||
)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
if verbose:
|
||||
logger.info(f"检测完成,耗时: {elapsed:.2f}秒")
|
||||
|
||||
# 解析结果
|
||||
try:
|
||||
# 找到最后一行 JSON 输出
|
||||
lines = result.stdout.strip().split('\n')
|
||||
json_line = None
|
||||
for line in reversed(lines):
|
||||
line = line.strip()
|
||||
if line.startswith('{'):
|
||||
json_line = line
|
||||
break
|
||||
|
||||
if json_line:
|
||||
response = json.loads(json_line)
|
||||
if response.get('success'):
|
||||
detections = response.get('detections', [])
|
||||
else:
|
||||
logger.error(f"检测失败: {response.get('error')}")
|
||||
detections = []
|
||||
else:
|
||||
logger.error(f"无法解析输出: {result.stdout}")
|
||||
detections = []
|
||||
except Exception as e:
|
||||
logger.error(f"解析结果失败: {e}, stdout: {result.stdout}")
|
||||
detections = []
|
||||
|
||||
# 清理临时文件
|
||||
try:
|
||||
os.remove(temp_input)
|
||||
except:
|
||||
pass
|
||||
|
||||
return [YOLOResult(detections)]
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error("检测超时")
|
||||
return [YOLOResult([])]
|
||||
except Exception as e:
|
||||
logger.error(f"检测失败: {e}")
|
||||
return [YOLOResult([])]
|
||||
|
||||
def predict(self, source, **kwargs):
|
||||
"""兼容 predict 方法"""
|
||||
return self.__call__(source, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def stop_server(cls):
|
||||
"""停止后台容器"""
|
||||
try:
|
||||
subprocess.run(
|
||||
["docker", "stop", cls._container_name],
|
||||
capture_output=True,
|
||||
timeout=10
|
||||
)
|
||||
logger.info("后台容器已停止")
|
||||
except Exception as e:
|
||||
logger.error(f"停止后台容器失败: {e}")
|
||||
|
||||
|
||||
# YOLOResult, Boxes, Box 类(与之前相同)
|
||||
class YOLOResult:
|
||||
"""模拟 YOLO 检测结果对象"""
|
||||
|
||||
def __init__(self, detections: List[Dict]):
|
||||
self.detections = detections
|
||||
self.names = {0: 'cigarette'}
|
||||
self.boxes = Boxes(detections)
|
||||
self.probs = None
|
||||
self.keypoints = None
|
||||
self.obb = None
|
||||
self.speed = {'preprocess': 0, 'inference': 0, 'postprocess': 0}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.detections)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if idx < len(self.detections):
|
||||
return YOLOResult([self.detections[idx]])
|
||||
return YOLOResult([])
|
||||
|
||||
def plot(self, **kwargs):
|
||||
return None
|
||||
|
||||
|
||||
class Boxes:
|
||||
"""模拟 YOLO boxes 对象"""
|
||||
|
||||
def __init__(self, detections: List[Dict]):
|
||||
self.detections = detections
|
||||
|
||||
try:
|
||||
import torch
|
||||
|
||||
if detections:
|
||||
xyxy_list = [[d['bbox'][0], d['bbox'][1], d['bbox'][2], d['bbox'][3]] for d in detections]
|
||||
conf_list = [[d['confidence']] for d in detections]
|
||||
cls_list = [[d['class']] for d in detections]
|
||||
|
||||
self.xyxy = torch.tensor(xyxy_list, dtype=torch.float32)
|
||||
self.conf = torch.tensor(conf_list, dtype=torch.float32)
|
||||
self.cls = torch.tensor(cls_list, dtype=torch.int64)
|
||||
self.id = None
|
||||
else:
|
||||
self.xyxy = torch.empty((0, 4))
|
||||
self.conf = torch.empty((0, 1))
|
||||
self.cls = torch.empty((0, 1), dtype=torch.int64)
|
||||
self.id = None
|
||||
|
||||
except ImportError:
|
||||
import numpy as np
|
||||
|
||||
if detections:
|
||||
xyxy_list = [[d['bbox'][0], d['bbox'][1], d['bbox'][2], d['bbox'][3]] for d in detections]
|
||||
conf_list = [[d['confidence']] for d in detections]
|
||||
cls_list = [[d['class']] for d in detections]
|
||||
|
||||
self.xyxy = np.array(xyxy_list, dtype=np.float32)
|
||||
self.conf = np.array(conf_list, dtype=np.float32)
|
||||
self.cls = np.array(cls_list, dtype=np.int64)
|
||||
self.id = None
|
||||
else:
|
||||
self.xyxy = np.empty((0, 4), dtype=np.float32)
|
||||
self.conf = np.empty((0, 1), dtype=np.float32)
|
||||
self.cls = np.empty((0, 1), dtype=np.int64)
|
||||
self.id = None
|
||||
|
||||
def __len__(self):
|
||||
return len(self.detections)
|
||||
|
||||
def __iter__(self):
|
||||
for i in range(len(self.detections)):
|
||||
yield Box(self, i)
|
||||
|
||||
def cpu(self):
|
||||
return self
|
||||
|
||||
def numpy(self):
|
||||
if hasattr(self.xyxy, 'numpy'):
|
||||
return type('Boxes', (), {
|
||||
'xyxy': self.xyxy.numpy(),
|
||||
'conf': self.conf.numpy(),
|
||||
'cls': self.cls.numpy(),
|
||||
'id': self.id
|
||||
})()
|
||||
return self
|
||||
|
||||
|
||||
class Box:
|
||||
"""模拟单个检测框对象"""
|
||||
|
||||
def __init__(self, boxes: Boxes, index: int):
|
||||
self._boxes = boxes
|
||||
self._index = index
|
||||
|
||||
@property
|
||||
def xyxy(self):
|
||||
import torch
|
||||
import numpy as np
|
||||
coords = self._boxes.xyxy[self._index]
|
||||
if isinstance(coords, torch.Tensor):
|
||||
return coords.unsqueeze(0)
|
||||
else:
|
||||
return np.array([coords])
|
||||
|
||||
@property
|
||||
def conf(self):
|
||||
import torch
|
||||
import numpy as np
|
||||
conf_val = self._boxes.conf[self._index]
|
||||
if isinstance(conf_val, torch.Tensor):
|
||||
return conf_val.view(1)
|
||||
else:
|
||||
return np.array([conf_val])
|
||||
|
||||
@property
|
||||
def cls(self):
|
||||
import torch
|
||||
import numpy as np
|
||||
cls_val = self._boxes.cls[self._index]
|
||||
if isinstance(cls_val, torch.Tensor):
|
||||
return cls_val.view(1)
|
||||
else:
|
||||
return np.array([cls_val])
|
||||
17
apps/server/package.json
Normal file
17
apps/server/package.json
Normal file
@@ -0,0 +1,17 @@
|
||||
{
|
||||
"name": "server",
|
||||
"version": "1.0.0",
|
||||
"description": "视频模型检测平台后端服务",
|
||||
"scripts": {
|
||||
"dev": "python main.py",
|
||||
"start": "uvicorn main:app --host 0.0.0.0 --port 8000",
|
||||
"lint": "ruff check .",
|
||||
"test": "pytest tests/",
|
||||
"clean": "rm -rf __pycache__ .pytest_cache"
|
||||
},
|
||||
"dependencies": {},
|
||||
"devDependencies": {},
|
||||
"engines": {
|
||||
"python": ">=3.9"
|
||||
}
|
||||
}
|
||||
12
apps/server/requirements.txt
Normal file
12
apps/server/requirements.txt
Normal file
@@ -0,0 +1,12 @@
|
||||
fastapi>=0.104.0
|
||||
uvicorn[standard]>=0.24.0
|
||||
python-multipart>=0.0.6
|
||||
pydantic>=2.0.0
|
||||
python-dotenv>=1.0.0
|
||||
aiofiles>=23.2.0
|
||||
opencv-python>=4.8.0
|
||||
pillow>=10.0.0
|
||||
ultralytics>=8.0.0
|
||||
numpy>=1.24.0
|
||||
torch>=2.0.0
|
||||
websockets>=12.0.0
|
||||
0
apps/server/services/__init__.py
Normal file
0
apps/server/services/__init__.py
Normal file
351
apps/server/services/camera_service.py
Normal file
351
apps/server/services/camera_service.py
Normal file
@@ -0,0 +1,351 @@
|
||||
import cv2
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import platform
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
from typing import Dict, Optional
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CameraService:
|
||||
def __init__(self, model_service):
|
||||
self.model_service = model_service
|
||||
self.active_connections: Dict[str, WebSocket] = {}
|
||||
self.camera_captures: Dict[str, cv2.VideoCapture] = {}
|
||||
self.running = False
|
||||
self.camera_configs: Dict[str, Dict] = {}
|
||||
self._locks: Dict[str, asyncio.Lock] = {}
|
||||
self._stop_events: Dict[str, asyncio.Event] = {}
|
||||
|
||||
@staticmethod
|
||||
def force_release_cameras():
|
||||
"""强制释放被占用的摄像头资源
|
||||
|
||||
在以下场景调用:
|
||||
1. 服务启动前 - 清理之前异常退出残留的占用
|
||||
2. 服务关闭时 - 确保资源被释放
|
||||
3. 信号处理时 - 异常退出前的清理
|
||||
"""
|
||||
logger.info("强制释放摄像头资源...")
|
||||
|
||||
# 1. 尝试释放当前进程中的摄像头
|
||||
try:
|
||||
# 在macOS上,摄像头设备通常是 /dev/video* 或 AVFoundation 设备
|
||||
# 尝试打开并立即释放来清理状态
|
||||
for i in range(10): # 检查前10个可能的摄像头索引
|
||||
try:
|
||||
cap = cv2.VideoCapture(i)
|
||||
if cap.isOpened():
|
||||
cap.release()
|
||||
logger.info(f"已释放摄像头索引 {i}")
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"释放当前进程摄像头失败: {e}")
|
||||
|
||||
# 2. 终止占用摄像头的其他Python进程
|
||||
current_pid = os.getpid()
|
||||
system = platform.system()
|
||||
|
||||
try:
|
||||
if system == "Darwin": # macOS
|
||||
# 使用 lsof 查找占用摄像头的进程
|
||||
result = subprocess.run(
|
||||
['lsof', '-c', 'python'],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
if result.returncode == 0:
|
||||
pids_to_kill = set()
|
||||
for line in result.stdout.split('\n'):
|
||||
# 查找包含摄像头设备 (/dev/video* 或 V4L 相关) 的行
|
||||
if any(x in line for x in ['/dev/video', 'Camera', 'V4L']):
|
||||
parts = line.split()
|
||||
if len(parts) >= 2:
|
||||
try:
|
||||
pid = int(parts[1])
|
||||
if pid != current_pid:
|
||||
pids_to_kill.add(pid)
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
# 终止找到的进程
|
||||
for pid in pids_to_kill:
|
||||
try:
|
||||
logger.info(f"终止占用摄像头的进程: {pid}")
|
||||
os.kill(pid, signal.SIGTERM)
|
||||
# 给进程一点时间优雅退出
|
||||
import time
|
||||
time.sleep(0.5)
|
||||
# 检查是否还在运行,如果是则强制终止
|
||||
try:
|
||||
os.kill(pid, 0) # 检查进程是否存在
|
||||
logger.warning(f"进程 {pid} 未响应 SIGTERM,使用 SIGKILL")
|
||||
os.kill(pid, signal.SIGKILL)
|
||||
except ProcessLookupError:
|
||||
pass # 进程已退出
|
||||
except ProcessLookupError:
|
||||
pass # 进程已不存在
|
||||
except Exception as e:
|
||||
logger.error(f"终止进程 {pid} 失败: {e}")
|
||||
|
||||
elif system == "Linux":
|
||||
# Linux 系统使用 fuser 或 lsof
|
||||
for device in ['/dev/video0', '/dev/video1', '/dev/video2']:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
['fuser', '-k', device],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
if result.returncode == 0:
|
||||
logger.info(f"已释放设备 {device}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"终止占用进程失败: {e}")
|
||||
|
||||
# 3. 给系统一点时间完成资源释放
|
||||
import time
|
||||
time.sleep(0.5)
|
||||
logger.info("摄像头资源释放完成")
|
||||
|
||||
async def handle_connection(self, websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
connection_id = id(websocket)
|
||||
self.active_connections[connection_id] = websocket
|
||||
|
||||
logger.info(f"新连接: {connection_id}")
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
data = await websocket.receive_text()
|
||||
message = json.loads(data)
|
||||
|
||||
if message.get('action') == 'start':
|
||||
await self.start_camera(connection_id, websocket, message.get('config'))
|
||||
elif message.get('action') == 'stop':
|
||||
await self.stop_camera(connection_id)
|
||||
elif message.get('action') == 'update_config':
|
||||
await self.update_config(connection_id, message.get('config', {}))
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket连接断开: {connection_id}")
|
||||
break
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"连接断开: {connection_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"连接错误: {connection_id}, {e}")
|
||||
finally:
|
||||
await self.stop_camera(connection_id)
|
||||
if connection_id in self.active_connections:
|
||||
del self.active_connections[connection_id]
|
||||
logger.info(f"连接清理完成: {connection_id}")
|
||||
|
||||
async def start_camera(self, connection_id: str, websocket: WebSocket, config: Dict = None):
|
||||
try:
|
||||
camera_source = 0 # 默认本地摄像头
|
||||
|
||||
if not config:
|
||||
config = {
|
||||
'model_id': 'fire_detection',
|
||||
'confidence': 0.5,
|
||||
'iou': 0.45,
|
||||
'camera_source': 0 # 默认本地摄像头
|
||||
}
|
||||
|
||||
# 支持多种摄像头源
|
||||
if 'camera_source' in config:
|
||||
camera_source = config['camera_source']
|
||||
|
||||
# 如果该连接已有摄像头在运行,先停止它
|
||||
if connection_id in self.camera_captures:
|
||||
await self.stop_camera(connection_id)
|
||||
|
||||
# 初始化锁和停止事件
|
||||
self._locks[connection_id] = asyncio.Lock()
|
||||
self._stop_events[connection_id] = asyncio.Event()
|
||||
|
||||
# 尝试打开摄像头
|
||||
cap = cv2.VideoCapture(camera_source)
|
||||
if not cap.isOpened():
|
||||
await websocket.send_json({
|
||||
'type': 'error',
|
||||
'message': f'无法打开摄像头源: {camera_source}'
|
||||
})
|
||||
return
|
||||
|
||||
# 只对本地摄像头设置分辨率
|
||||
if isinstance(camera_source, int):
|
||||
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280)
|
||||
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 720)
|
||||
|
||||
self.camera_captures[connection_id] = cap
|
||||
self.camera_configs[connection_id] = config
|
||||
|
||||
await websocket.send_json({
|
||||
'type': 'camera_started',
|
||||
'message': f'摄像头已启动: {camera_source}'
|
||||
})
|
||||
|
||||
await self.process_frames(connection_id, websocket, cap)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"启动摄像头失败: {e}")
|
||||
await websocket.send_json({
|
||||
'type': 'error',
|
||||
'message': f'启动摄像头失败: {str(e)}'
|
||||
})
|
||||
|
||||
async def process_frames(self, connection_id: str, websocket: WebSocket, cap: cv2.VideoCapture):
|
||||
from .detection_service import DetectionService
|
||||
detection_service = DetectionService(self.model_service)
|
||||
|
||||
try:
|
||||
frame_count = 0
|
||||
stop_event = self._stop_events.get(connection_id)
|
||||
|
||||
while connection_id in self.active_connections:
|
||||
# 检查是否收到停止信号
|
||||
if stop_event and stop_event.is_set():
|
||||
logger.info(f"收到停止信号,结束帧处理: {connection_id}")
|
||||
break
|
||||
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
|
||||
config = self.camera_configs.get(connection_id, {
|
||||
'model_id': 'fire_detection',
|
||||
'confidence': 0.5,
|
||||
'iou': 0.45
|
||||
})
|
||||
|
||||
model_id = config.get('model_id', 'fire_detection')
|
||||
confidence = config.get('confidence', 0.5)
|
||||
iou = config.get('iou', 0.45)
|
||||
draw = True
|
||||
|
||||
processed_frame, result = await detection_service.detect_frame(
|
||||
frame,
|
||||
model_id=model_id,
|
||||
confidence=confidence,
|
||||
iou=iou,
|
||||
draw=draw
|
||||
)
|
||||
|
||||
if result['success']:
|
||||
frame_count += 1
|
||||
|
||||
logger.info(f"发送检测结果: {len(result['detections'])} 个目标, {result['stats']}")
|
||||
|
||||
await websocket.send_json({
|
||||
'type': 'detection',
|
||||
'detections': result['detections'],
|
||||
'stats': result['stats']
|
||||
})
|
||||
|
||||
_, buffer = cv2.imencode('.jpg', frame, [cv2.IMWRITE_JPEG_QUALITY, 80])
|
||||
import base64
|
||||
original_frame_data = base64.b64encode(buffer).decode('utf-8')
|
||||
|
||||
_, buffer = cv2.imencode('.jpg', processed_frame, [cv2.IMWRITE_JPEG_QUALITY, 80])
|
||||
frame_data = base64.b64encode(buffer).decode('utf-8')
|
||||
|
||||
try:
|
||||
await websocket.send_json({
|
||||
'type': 'original_frame',
|
||||
'frame': original_frame_data
|
||||
})
|
||||
|
||||
await websocket.send_json({
|
||||
'type': 'annotated_frame',
|
||||
'frame': frame_data
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"发送帧数据失败: {e}")
|
||||
break
|
||||
|
||||
await asyncio.sleep(0.03)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理帧错误: {e}")
|
||||
finally:
|
||||
# 只在摄像头仍存在于字典中时才释放(避免重复释放)
|
||||
if connection_id in self.camera_captures:
|
||||
cap.release()
|
||||
del self.camera_captures[connection_id]
|
||||
logger.info(f"帧处理结束,摄像头已释放: {connection_id}")
|
||||
|
||||
async def stop_camera(self, connection_id: str):
|
||||
# 使用锁确保同一时间只有一个协程在操作该连接的摄像头资源
|
||||
lock = self._locks.get(connection_id)
|
||||
if lock:
|
||||
async with lock:
|
||||
await self._do_stop_camera(connection_id)
|
||||
else:
|
||||
await self._do_stop_camera(connection_id)
|
||||
|
||||
async def _do_stop_camera(self, connection_id: str):
|
||||
"""实际执行停止摄像头的操作(内部方法,应在获取锁后调用)"""
|
||||
# 设置停止事件,通知帧处理循环退出
|
||||
if connection_id in self._stop_events:
|
||||
self._stop_events[connection_id].set()
|
||||
|
||||
if connection_id in self.camera_captures:
|
||||
cap = self.camera_captures[connection_id]
|
||||
cap.release()
|
||||
del self.camera_captures[connection_id]
|
||||
|
||||
if connection_id in self.active_connections:
|
||||
try:
|
||||
await self.active_connections[connection_id].send_json({
|
||||
'type': 'camera_stopped',
|
||||
'message': '摄像头已停止'
|
||||
})
|
||||
except:
|
||||
pass
|
||||
|
||||
logger.info(f"摄像头已停止: {connection_id}")
|
||||
|
||||
# 清理锁和事件
|
||||
if connection_id in self._locks:
|
||||
del self._locks[connection_id]
|
||||
if connection_id in self._stop_events:
|
||||
del self._stop_events[connection_id]
|
||||
|
||||
async def update_config(self, connection_id: str, config: Dict):
|
||||
if connection_id in self.camera_configs:
|
||||
self.camera_configs[connection_id].update(config)
|
||||
|
||||
model_id = self.camera_configs[connection_id].get('model_id', 'fire_detection')
|
||||
confidence = self.camera_configs[connection_id].get('confidence', 0.5)
|
||||
iou = self.camera_configs[connection_id].get('iou', 0.45)
|
||||
|
||||
logger.info(f"配置更新: model_id={model_id}, confidence={confidence}, iou={iou}")
|
||||
|
||||
if connection_id in self.active_connections:
|
||||
try:
|
||||
await self.active_connections[connection_id].send_json({
|
||||
'type': 'config_updated',
|
||||
'message': '配置已更新'
|
||||
})
|
||||
except:
|
||||
pass
|
||||
|
||||
async def stop(self):
|
||||
for connection_id in list(self.camera_captures.keys()):
|
||||
await self.stop_camera(connection_id)
|
||||
|
||||
self.running = False
|
||||
logger.info("摄像头服务已停止")
|
||||
199
apps/server/services/detection_service.py
Normal file
199
apps/server/services/detection_service.py
Normal file
@@ -0,0 +1,199 @@
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import time
|
||||
import uuid
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DetectionService:
|
||||
def __init__(self, model_service):
|
||||
self.model_service = model_service
|
||||
self.base_dir = os.path.dirname(os.path.dirname(__file__))
|
||||
self.results_dir = os.path.join(self.base_dir, "static", "results")
|
||||
self.temp_dir = os.path.join(self.base_dir, "static", "temp")
|
||||
|
||||
os.makedirs(self.results_dir, exist_ok=True)
|
||||
os.makedirs(self.temp_dir, exist_ok=True)
|
||||
|
||||
def draw_detections(self, frame: np.ndarray, detections: List[Dict], fps: float = 0) -> np.ndarray:
|
||||
try:
|
||||
img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
pil_img = Image.fromarray(img_rgb)
|
||||
draw = ImageDraw.Draw(pil_img)
|
||||
|
||||
try:
|
||||
font = ImageFont.truetype("/System/Library/Fonts/PingFang.ttc", 20)
|
||||
font_large = ImageFont.truetype("/System/Library/Fonts/PingFang.ttc", 24)
|
||||
except:
|
||||
font = ImageFont.load_default()
|
||||
font_large = font
|
||||
|
||||
class_colors = {
|
||||
'Fire': (255, 0, 0),
|
||||
'Smoke': (128, 128, 128),
|
||||
'person': (0, 255, 0),
|
||||
'helmet': (255, 255, 0),
|
||||
'no_helmet': (255, 0, 255),
|
||||
'cigarette': (0, 165, 255) # 橙色,用于抽烟检测
|
||||
}
|
||||
|
||||
for det in detections:
|
||||
x1, y1, x2, y2 = det['bbox']
|
||||
class_name = det['class']
|
||||
conf = det['confidence']
|
||||
label = det['label']
|
||||
|
||||
color = class_colors.get(class_name, (0, 255, 0))
|
||||
|
||||
draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
|
||||
|
||||
label_text = f"{label} {conf:.2f}"
|
||||
bbox = draw.textbbox((0, 0), label_text, font=font)
|
||||
text_w = bbox[2] - bbox[0]
|
||||
text_h = bbox[3] - bbox[1]
|
||||
draw.rectangle([x1, y1 - text_h - 4, x1 + text_w + 4, y1], fill=color)
|
||||
draw.text((x1 + 2, y1 - text_h - 2), label_text, fill=(255, 255, 255), font=font)
|
||||
|
||||
if fps > 0:
|
||||
fps_text = f"FPS: {fps:.1f} | Detections: {len(detections)}"
|
||||
draw.text((10, 10), fps_text, fill=(0, 255, 0), font=font)
|
||||
|
||||
return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
|
||||
except Exception as e:
|
||||
logger.error(f"绘制检测结果失败: {e}")
|
||||
return frame
|
||||
|
||||
async def detect_image(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
model_id: str,
|
||||
confidence: float = 0.5,
|
||||
iou: float = 0.45
|
||||
) -> Dict:
|
||||
start_time = time.time()
|
||||
|
||||
model = await self.model_service.load_model(model_id)
|
||||
if not model:
|
||||
return {
|
||||
'success': False,
|
||||
'message': f'模型加载失败: {model_id}',
|
||||
'detections': [],
|
||||
'stats': None
|
||||
}
|
||||
|
||||
try:
|
||||
results = model(image, conf=confidence, iou=iou, verbose=False)
|
||||
|
||||
detections = []
|
||||
for result in results:
|
||||
boxes = result.boxes
|
||||
for box in boxes:
|
||||
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
|
||||
conf = float(box.conf[0].cpu().numpy())
|
||||
cls = int(box.cls[0].cpu().numpy())
|
||||
class_name = result.names[cls]
|
||||
|
||||
label_map = self.model_service.model_configs[model_id]['labels']
|
||||
label = label_map.get(class_name, class_name)
|
||||
|
||||
detections.append({
|
||||
'class': class_name,
|
||||
'label': label,
|
||||
'confidence': round(conf, 3),
|
||||
'bbox': [int(x1), int(y1), int(x2), int(y2)]
|
||||
})
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
avg_confidence = sum(d['confidence'] for d in detections) / len(detections) if detections else 0
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'message': '检测完成',
|
||||
'detections': detections,
|
||||
'stats': {
|
||||
'total_detections': len(detections),
|
||||
'avg_confidence': round(avg_confidence, 3),
|
||||
'processing_time': round(processing_time, 3),
|
||||
'model_used': model_id
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"图片检测失败: {e}")
|
||||
return {
|
||||
'success': False,
|
||||
'message': f'检测失败: {str(e)}',
|
||||
'detections': [],
|
||||
'stats': None
|
||||
}
|
||||
|
||||
async def detect_frame(
|
||||
self,
|
||||
frame: np.ndarray,
|
||||
model_id: str,
|
||||
confidence: float = 0.5,
|
||||
iou: float = 0.45,
|
||||
draw: bool = True
|
||||
) -> tuple:
|
||||
start_time = time.time()
|
||||
|
||||
model = await self.model_service.load_model(model_id)
|
||||
if not model:
|
||||
return frame, {
|
||||
'success': False,
|
||||
'detections': [],
|
||||
'stats': None
|
||||
}
|
||||
|
||||
try:
|
||||
results = model(frame, conf=confidence, iou=iou, verbose=False)
|
||||
|
||||
detections = []
|
||||
for result in results:
|
||||
boxes = result.boxes
|
||||
for box in boxes:
|
||||
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
|
||||
conf = float(box.conf[0].cpu().numpy())
|
||||
cls = int(box.cls[0].cpu().numpy())
|
||||
class_name = result.names[cls]
|
||||
|
||||
label_map = self.model_service.model_configs[model_id]['labels']
|
||||
label = label_map.get(class_name, class_name)
|
||||
|
||||
detections.append({
|
||||
'class': class_name,
|
||||
'label': label,
|
||||
'confidence': round(conf, 3),
|
||||
'bbox': [int(x1), int(y1), int(x2), int(y2)]
|
||||
})
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
fps = 1.0 / processing_time if processing_time > 0 else 0
|
||||
avg_confidence = sum(d['confidence'] for d in detections) / len(detections) if detections else 0
|
||||
|
||||
result_data = {
|
||||
'success': True,
|
||||
'detections': detections,
|
||||
'stats': {
|
||||
'total_detections': len(detections),
|
||||
'avg_confidence': round(avg_confidence, 3),
|
||||
'processing_time': round(processing_time, 3),
|
||||
'fps': round(fps, 2),
|
||||
'model_used': model_id
|
||||
}
|
||||
}
|
||||
|
||||
if draw:
|
||||
frame = self.draw_detections(frame, detections, fps)
|
||||
|
||||
return frame, result_data
|
||||
except Exception as e:
|
||||
logger.error(f"帧检测失败: {e}")
|
||||
return frame, {
|
||||
'success': False,
|
||||
'detections': [],
|
||||
'stats': None
|
||||
}
|
||||
115
apps/server/services/model_service.py
Normal file
115
apps/server/services/model_service.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import os
|
||||
import logging
|
||||
from ultralytics import YOLO
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ModelService:
|
||||
def __init__(self):
|
||||
self.models: Dict[str, YOLO] = {}
|
||||
# 基础路径:从 apps/server/services/model_service.py 到 jc-video-web 根目录
|
||||
base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
|
||||
self.model_configs = {
|
||||
'fire_detection': {
|
||||
'path': os.path.join(base_dir, 'models', 'fire_detection', 'best.pt'),
|
||||
'type': 'yolov10',
|
||||
'classes': ['Fire', 'Smoke'],
|
||||
'labels': {'Fire': '火焰', 'Smoke': '烟雾'},
|
||||
'size': '61MB',
|
||||
'description': '基于YOLOv10的火灾烟雾检测模型',
|
||||
'name': '火灾检测'
|
||||
},
|
||||
'helmet_detection': {
|
||||
'path': os.path.join(base_dir, 'models', 'helmet_detection', 'yolov8n.pt'),
|
||||
'type': 'yolov8',
|
||||
'classes': ['person', 'helmet'],
|
||||
'labels': {'person': '人员', 'helmet': '安全帽'},
|
||||
'size': '6MB',
|
||||
'description': '基于YOLOv8的安全帽检测模型',
|
||||
'name': '安全帽检测'
|
||||
},
|
||||
'crowd_detection': {
|
||||
'path': os.path.join(base_dir, 'models', 'crowd_detection', 'yolov8l.pt'),
|
||||
'type': 'yolov8',
|
||||
'classes': ['person'],
|
||||
'labels': {'person': '人员'},
|
||||
'size': '100MB',
|
||||
'description': '基于YOLOv8的人群聚集检测模型',
|
||||
'name': '人群检测'
|
||||
},
|
||||
'smoking_detection': {
|
||||
'path': os.path.join(base_dir, 'models', 'smoking_detection', 'smoking_yolov8n.pt'),
|
||||
'type': 'yolov8',
|
||||
'classes': ['cigarette', 'smoke'],
|
||||
'labels': {'cigarette': '香烟', 'smoke': '烟雾'},
|
||||
'size': '6MB',
|
||||
'description': '基于YOLOv8的抽烟检测模型',
|
||||
'name': '抽烟检测'
|
||||
},
|
||||
'loitering_detection': {
|
||||
'path': os.path.join(base_dir, 'models', 'loitering_detection', 'yolov8n.pt'),
|
||||
'type': 'yolov8',
|
||||
'classes': ['person'],
|
||||
'labels': {'person': '人员'},
|
||||
'size': '6MB',
|
||||
'description': '基于YOLOv8的徘徊检测模型',
|
||||
'name': '徘徊检测'
|
||||
}
|
||||
}
|
||||
|
||||
def get_available_models(self) -> List[Dict]:
|
||||
available_models = []
|
||||
for model_id, config in self.model_configs.items():
|
||||
if os.path.exists(config['path']):
|
||||
available_models.append({
|
||||
'id': model_id,
|
||||
'name': config['name'],
|
||||
'description': config['description'],
|
||||
'classes': config['classes'],
|
||||
'labels': config['labels'],
|
||||
'size': config['size'],
|
||||
'type': config['type']
|
||||
})
|
||||
else:
|
||||
logger.warning(f"模型文件不存在: {config['path']}")
|
||||
return available_models
|
||||
|
||||
async def load_model(self, model_id: str) -> Optional[YOLO]:
|
||||
if model_id not in self.model_configs:
|
||||
logger.error(f"未知模型ID: {model_id}")
|
||||
return None
|
||||
|
||||
if model_id in self.models:
|
||||
logger.info(f"模型已加载: {model_id}")
|
||||
return self.models[model_id]
|
||||
|
||||
config = self.model_configs[model_id]
|
||||
|
||||
# 处理 YOLO 模型
|
||||
model_path = config['path']
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
logger.error(f"模型文件不存在: {model_path}")
|
||||
return None
|
||||
|
||||
try:
|
||||
logger.info(f"正在加载模型: {model_id} from {model_path}")
|
||||
model = YOLO(model_path)
|
||||
self.models[model_id] = model
|
||||
logger.info(f"模型加载成功: {model_id}")
|
||||
return model
|
||||
except Exception as e:
|
||||
logger.error(f"模型加载失败: {model_id}, 错误: {e}")
|
||||
return None
|
||||
|
||||
def get_model(self, model_id: str) -> Optional[YOLO]:
|
||||
return self.models.get(model_id)
|
||||
|
||||
async def unload_model(self, model_id: str) -> bool:
|
||||
if model_id in self.models:
|
||||
del self.models[model_id]
|
||||
logger.info(f"模型已卸载: {model_id}")
|
||||
return True
|
||||
return False
|
||||
147
apps/server/services/model_service_updated.py
Normal file
147
apps/server/services/model_service_updated.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import os
|
||||
import logging
|
||||
from ultralytics import YOLO
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ModelService:
|
||||
def __init__(self):
|
||||
self.models: Dict[str, YOLO] = {}
|
||||
self.model_configs = {
|
||||
'fire_detection': {
|
||||
'path': os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
|
||||
'fire_detection', 'models', 'best.pt'),
|
||||
'type': 'yolov10',
|
||||
'classes': ['Fire', 'Smoke'],
|
||||
'labels': {'Fire': '火焰', 'Smoke': '烟雾'},
|
||||
'size': '61MB',
|
||||
'description': '基于YOLOv10的火灾烟雾检测模型',
|
||||
'name': '火灾检测'
|
||||
},
|
||||
'helmet_detection': {
|
||||
'path': os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
|
||||
'yolov', 'yolov8n.pt'),
|
||||
'type': 'yolov8',
|
||||
'classes': ['person', 'helmet'],
|
||||
'labels': {'person': '人员', 'helmet': '安全帽'},
|
||||
'size': '6MB',
|
||||
'description': '基于YOLOv8的安全帽检测模型',
|
||||
'name': '安全帽检测'
|
||||
},
|
||||
'crowd_detection': {
|
||||
'path': os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
|
||||
'behavior_detection', 'Crowd-Gathering', 'models', 'yolov8l.pt'),
|
||||
'type': 'yolov8',
|
||||
'classes': ['person'],
|
||||
'labels': {'person': '人员'},
|
||||
'size': '100MB',
|
||||
'description': '基于YOLOv8的人群聚集检测模型',
|
||||
'name': '人群检测'
|
||||
},
|
||||
'smoking_detection': {
|
||||
'path': 'PADDLE_DETECTION', # 特殊标记,表示使用 PaddleDetection
|
||||
'type': 'paddle',
|
||||
'classes': ['cigarette'],
|
||||
'labels': {'cigarette': '香烟'},
|
||||
'size': '27MB',
|
||||
'description': '基于PP-YOLOE的抽烟检测模型',
|
||||
'name': '抽烟检测',
|
||||
'docker_image': 'smoking-detection:test',
|
||||
'model_dir': 'output_inference/ppyoloe_crn_s_80e_smoking_visdrone'
|
||||
}
|
||||
}
|
||||
|
||||
def get_available_models(self) -> List[Dict]:
|
||||
available_models = []
|
||||
for model_id, config in self.model_configs.items():
|
||||
# 对于 PaddleDetection 模型,不需要检查本地文件
|
||||
if config.get('type') == 'paddle':
|
||||
# 检查 Docker 是否可用
|
||||
import subprocess
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["docker", "image", "inspect", config['docker_image']],
|
||||
capture_output=True,
|
||||
timeout=5
|
||||
)
|
||||
if result.returncode == 0:
|
||||
available_models.append({
|
||||
'id': model_id,
|
||||
'name': config['name'],
|
||||
'description': config['description'],
|
||||
'classes': config['classes'],
|
||||
'labels': config['labels'],
|
||||
'size': config['size'],
|
||||
'type': config['type']
|
||||
})
|
||||
else:
|
||||
logger.warning(f"PaddleDetection Docker 镜像不可用: {config['docker_image']}")
|
||||
except Exception as e:
|
||||
logger.warning(f"检查 PaddleDetection Docker 镜像失败: {e}")
|
||||
elif os.path.exists(config['path']):
|
||||
available_models.append({
|
||||
'id': model_id,
|
||||
'name': config['name'],
|
||||
'description': config['description'],
|
||||
'classes': config['classes'],
|
||||
'labels': config['labels'],
|
||||
'size': config['size'],
|
||||
'type': config['type']
|
||||
})
|
||||
else:
|
||||
logger.warning(f"模型文件不存在: {config['path']}")
|
||||
return available_models
|
||||
|
||||
async def load_model(self, model_id: str):
|
||||
if model_id not in self.model_configs:
|
||||
logger.error(f"未知模型ID: {model_id}")
|
||||
return None
|
||||
|
||||
# 如果已经加载,直接返回
|
||||
if model_id in self.models:
|
||||
logger.info(f"模型已加载: {model_id}")
|
||||
return self.models[model_id]
|
||||
|
||||
config = self.model_configs[model_id]
|
||||
|
||||
# 处理 PaddleDetection 模型
|
||||
if config.get('type') == 'paddle':
|
||||
try:
|
||||
from .paddle_detection_service import SmokingDetectionModel
|
||||
|
||||
logger.info(f"正在加载 PaddleDetection 抽烟检测模型...")
|
||||
model = SmokingDetectionModel()
|
||||
self.models[model_id] = model
|
||||
logger.info(f"PaddleDetection 模型加载成功: {model_id}")
|
||||
return model
|
||||
except Exception as e:
|
||||
logger.error(f"PaddleDetection 模型加载失败: {e}")
|
||||
return None
|
||||
|
||||
# 处理 YOLO 模型
|
||||
model_path = config['path']
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
logger.error(f"模型文件不存在: {model_path}")
|
||||
return None
|
||||
|
||||
try:
|
||||
logger.info(f"正在加载模型: {model_id} from {model_path}")
|
||||
model = YOLO(model_path)
|
||||
self.models[model_id] = model
|
||||
logger.info(f"模型加载成功: {model_id}")
|
||||
return model
|
||||
except Exception as e:
|
||||
logger.error(f"模型加载失败: {model_id}, 错误: {e}")
|
||||
return None
|
||||
|
||||
def get_model(self, model_id: str):
|
||||
return self.models.get(model_id)
|
||||
|
||||
async def unload_model(self, model_id: str) -> bool:
|
||||
if model_id in self.models:
|
||||
del self.models[model_id]
|
||||
logger.info(f"模型已卸载: {model_id}")
|
||||
return True
|
||||
return False
|
||||
274
apps/server/services/paddle_detection_service.py
Normal file
274
apps/server/services/paddle_detection_service.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""
|
||||
PaddleDetection 抽烟检测服务适配器
|
||||
通过 Docker 调用 Paddle 模型
|
||||
"""
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import subprocess
|
||||
import tempfile
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PaddleDetectionService:
|
||||
"""PaddleDetection 服务适配器"""
|
||||
|
||||
def __init__(self):
|
||||
self.model_name = "smoking_detection"
|
||||
self.docker_image = "smoking-detection:test"
|
||||
self.model_dir = "output_inference/ppyoloe_crn_s_80e_smoking_visdrone"
|
||||
self.threshold = 0.1 # 抽烟检测需要较低的阈值
|
||||
|
||||
# 检查 Docker 和镜像
|
||||
self._check_docker()
|
||||
|
||||
def _check_docker(self):
|
||||
"""检查 Docker 环境"""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["docker", "info"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
if result.returncode != 0:
|
||||
logger.error("Docker 未运行")
|
||||
self.available = False
|
||||
return
|
||||
|
||||
# 检查镜像
|
||||
result = subprocess.run(
|
||||
["docker", "image", "inspect", self.docker_image],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
self.available = result.returncode == 0
|
||||
|
||||
if self.available:
|
||||
logger.info(f"PaddleDetection 服务已就绪: {self.docker_image}")
|
||||
else:
|
||||
logger.error(f"Docker 镜像不存在: {self.docker_image}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Docker 检查失败: {e}")
|
||||
self.available = False
|
||||
|
||||
def detect_image(self, image: np.ndarray) -> Dict:
|
||||
"""
|
||||
检测图片中的抽烟行为
|
||||
|
||||
Args:
|
||||
image: OpenCV 图片 (BGR格式)
|
||||
|
||||
Returns:
|
||||
检测结果字典
|
||||
"""
|
||||
if not self.available:
|
||||
return {
|
||||
'success': False,
|
||||
'message': 'PaddleDetection 服务不可用',
|
||||
'detections': [],
|
||||
'stats': None
|
||||
}
|
||||
|
||||
try:
|
||||
# 创建临时文件
|
||||
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as f:
|
||||
temp_input = f.name
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as f:
|
||||
temp_output = f.name
|
||||
|
||||
# 保存输入图片
|
||||
cv2.imwrite(temp_input, image)
|
||||
|
||||
# 构建 Docker 命令
|
||||
cmd = [
|
||||
"docker", "run", "--rm",
|
||||
"-v", f"{temp_input}:/workspace/input.jpg",
|
||||
"-v", f"{os.path.dirname(temp_output)}:/workspace/output",
|
||||
self.docker_image,
|
||||
"python", "deploy/python/infer.py",
|
||||
f"--model_dir={self.model_dir}",
|
||||
"--image_file=/workspace/input.jpg",
|
||||
"--device=CPU",
|
||||
"--output_dir=/workspace/output",
|
||||
f"--threshold={self.threshold}"
|
||||
]
|
||||
|
||||
# 执行检测
|
||||
logger.info(f"执行抽烟检测: {temp_input}")
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60
|
||||
)
|
||||
|
||||
# 解析结果
|
||||
detections = self._parse_detection_output(result.stdout)
|
||||
|
||||
# 读取输出图片
|
||||
output_image = None
|
||||
output_path = temp_output.replace('.jpg', '') + '_result.jpg'
|
||||
if os.path.exists(output_path):
|
||||
output_image = cv2.imread(output_path)
|
||||
|
||||
# 清理临时文件
|
||||
self._cleanup_temp_files([temp_input, temp_output, output_path])
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'message': '检测完成',
|
||||
'detections': detections,
|
||||
'output_image': output_image,
|
||||
'stats': {
|
||||
'total_detections': len(detections),
|
||||
'model_used': 'ppyoloe_crn_s_80e_smoking_visdrone',
|
||||
'threshold': self.threshold
|
||||
}
|
||||
}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error("检测超时")
|
||||
return {
|
||||
'success': False,
|
||||
'message': '检测超时',
|
||||
'detections': [],
|
||||
'stats': None
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"检测失败: {e}")
|
||||
return {
|
||||
'success': False,
|
||||
'message': f'检测失败: {str(e)}',
|
||||
'detections': [],
|
||||
'stats': None
|
||||
}
|
||||
|
||||
def _parse_detection_output(self, output: str) -> List[Dict]:
|
||||
"""解析检测输出"""
|
||||
detections = []
|
||||
|
||||
# 查找检测结果行
|
||||
for line in output.split('\n'):
|
||||
if 'class_id:' in line and 'confidence:' in line:
|
||||
try:
|
||||
# 解析: class_id:0, confidence:0.8921, left_top:[268.66,231.64],right_bottom:[351.87,258.66]
|
||||
parts = line.split(',')
|
||||
|
||||
# 提取置信度
|
||||
conf_part = [p for p in parts if 'confidence:' in p][0]
|
||||
confidence = float(conf_part.split(':')[1])
|
||||
|
||||
# 提取坐标
|
||||
left_top_part = [p for p in parts if 'left_top:' in p][0]
|
||||
right_bottom_part = [p for p in parts if 'right_bottom:' in p][0]
|
||||
|
||||
# 解析坐标
|
||||
left_top = eval(left_top_part.split(':')[1])
|
||||
right_bottom = eval(right_bottom_part.split(':')[1])
|
||||
|
||||
x1, y1 = left_top
|
||||
x2, y2 = right_bottom
|
||||
|
||||
detections.append({
|
||||
'class': 'cigarette',
|
||||
'label': '香烟',
|
||||
'confidence': round(confidence, 3),
|
||||
'bbox': [int(x1), int(y1), int(x2), int(y2)]
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"解析检测结果失败: {e}")
|
||||
continue
|
||||
|
||||
return detections
|
||||
|
||||
def _cleanup_temp_files(self, files: List[str]):
|
||||
"""清理临时文件"""
|
||||
for f in files:
|
||||
try:
|
||||
if os.path.exists(f):
|
||||
os.remove(f)
|
||||
except Exception as e:
|
||||
logger.warning(f"清理临时文件失败: {f}, {e}")
|
||||
|
||||
|
||||
# 兼容性包装,保持与 YOLO 模型相同的接口
|
||||
class SmokingDetectionModel:
|
||||
"""抽烟检测模型包装器,兼容 YOLO 接口"""
|
||||
|
||||
def __init__(self):
|
||||
self.service = PaddleDetectionService()
|
||||
self.names = {0: 'cigarette'}
|
||||
|
||||
def __call__(self, image, conf=0.1, iou=0.45, verbose=False):
|
||||
"""
|
||||
模拟 YOLO 模型的调用接口
|
||||
|
||||
Args:
|
||||
image: OpenCV 图片
|
||||
conf: 置信度阈值
|
||||
iou: IoU 阈值
|
||||
verbose: 是否输出详细信息
|
||||
|
||||
Returns:
|
||||
模拟 YOLO 结果的对象
|
||||
"""
|
||||
result = self.service.detect_image(image)
|
||||
|
||||
# 创建模拟的 YOLO 结果对象
|
||||
return [PaddleDetectionResult(result, self.names)]
|
||||
|
||||
|
||||
class PaddleDetectionResult:
|
||||
"""模拟 YOLO 检测结果对象"""
|
||||
|
||||
def __init__(self, detection_result: Dict, names: Dict):
|
||||
self.detection_result = detection_result
|
||||
self.names = names
|
||||
|
||||
# 创建模拟的 boxes 对象
|
||||
self.boxes = self._create_boxes()
|
||||
|
||||
def _create_boxes(self):
|
||||
"""创建模拟的 boxes 对象"""
|
||||
detections = self.detection_result.get('detections', [])
|
||||
|
||||
if not detections:
|
||||
return MockBoxes([])
|
||||
|
||||
# 转换为 YOLO 格式
|
||||
xyxy = []
|
||||
conf = []
|
||||
cls = []
|
||||
|
||||
for det in detections:
|
||||
xyxy.append(det['bbox'])
|
||||
conf.append(det['confidence'])
|
||||
cls.append(0) # cigarette 类别
|
||||
|
||||
return MockBoxes(xyxy, conf, cls)
|
||||
|
||||
|
||||
class MockBoxes:
|
||||
"""模拟 YOLO boxes 对象"""
|
||||
|
||||
def __init__(self, xyxy_list, conf_list=None, cls_list=None):
|
||||
import torch
|
||||
|
||||
if xyxy_list:
|
||||
self.xyxy = torch.tensor(xyxy_list, dtype=torch.float32)
|
||||
self.conf = torch.tensor(conf_list, dtype=torch.float32).reshape(-1, 1)
|
||||
self.cls = torch.tensor(cls_list, dtype=torch.int64).reshape(-1, 1)
|
||||
else:
|
||||
self.xyxy = torch.empty((0, 4))
|
||||
self.conf = torch.empty((0, 1))
|
||||
self.cls = torch.empty((0, 1), dtype=torch.int64)
|
||||
Reference in New Issue
Block a user