打架斗殴模型集成
This commit is contained in:
265
apps/server/services/action_detection_service.py
Normal file
265
apps/server/services/action_detection_service.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""
|
||||
PaddleVideo 行为识别服务适配器(Docker API 调用方式)
|
||||
通过 HTTP 请求调用运行在 Docker 中的 ppTSM 行为识别模型
|
||||
"""
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import logging
|
||||
import httpx
|
||||
import base64
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 获取外部项目路径
|
||||
BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
EXTERNAL_PADDLE_PATH = os.path.join(BASE_DIR, 'external', 'video-recognition-system', 'PaddlePaddle')
|
||||
|
||||
|
||||
class ActionDetectionService:
|
||||
"""行为识别服务(调用外部 Docker 服务)"""
|
||||
|
||||
def __init__(self):
|
||||
# 从环境变量获取 Docker 服务地址
|
||||
self.api_base_url = os.environ.get(
|
||||
'ACTION_DETECTION_API_URL',
|
||||
'http://localhost:8081' # 统一使用 8081 端口
|
||||
)
|
||||
self.timeout = int(os.environ.get('ACTION_DETECTION_TIMEOUT', '30'))
|
||||
|
||||
# 类别定义(根据你的 ppTSM 模型)
|
||||
self.classes = ['fight', 'normal']
|
||||
self.labels = {0: 'fight', 1: 'normal'}
|
||||
|
||||
# 外部项目路径
|
||||
self.external_paddle_path = EXTERNAL_PADDLE_PATH
|
||||
self.model_path = os.path.join(self.external_paddle_path, 'PaddleVideo', 'inference', 'ppTSM')
|
||||
|
||||
# 健康检查
|
||||
self.available = self._check_service_health()
|
||||
|
||||
logger.info(f"行为识别服务初始化完成")
|
||||
logger.info(f"API地址: {self.api_base_url}")
|
||||
logger.info(f"外部项目路径: {self.external_paddle_path}")
|
||||
logger.info(f"服务可用: {self.available}")
|
||||
|
||||
def _check_service_health(self) -> bool:
|
||||
"""检查 Docker 服务是否可用"""
|
||||
try:
|
||||
with httpx.Client(timeout=5) as client:
|
||||
response = client.get(f"{self.api_base_url}/health")
|
||||
return response.status_code == 200
|
||||
except Exception as e:
|
||||
logger.warning(f"服务健康检查失败: {e}")
|
||||
return False
|
||||
|
||||
def detect_image(self, image: np.ndarray, threshold: float = 0.5) -> Dict:
|
||||
"""
|
||||
调用 Docker 服务进行行为识别
|
||||
|
||||
Args:
|
||||
image: OpenCV 图片 (BGR格式)
|
||||
threshold: 置信度阈值
|
||||
|
||||
Returns:
|
||||
检测结果字典
|
||||
"""
|
||||
if not self.available:
|
||||
return {
|
||||
'success': False,
|
||||
'message': '行为识别服务不可用',
|
||||
'detections': [],
|
||||
'stats': None
|
||||
}
|
||||
|
||||
try:
|
||||
# 将图片转换为 base64
|
||||
_, img_encoded = cv2.imencode('.jpg', image)
|
||||
img_base64 = base64.b64encode(img_encoded).decode('utf-8')
|
||||
|
||||
# 构建请求数据
|
||||
payload = {
|
||||
'image': img_base64,
|
||||
'threshold': threshold
|
||||
}
|
||||
|
||||
# 调用 Docker 服务
|
||||
with httpx.Client(timeout=self.timeout) as client:
|
||||
response = client.post(
|
||||
f"{self.api_base_url}/api/detect",
|
||||
json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
return {
|
||||
'success': False,
|
||||
'message': f"API调用失败: {response.status_code}",
|
||||
'detections': [],
|
||||
'stats': None
|
||||
}
|
||||
|
||||
# 解析响应
|
||||
result = response.json()
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'message': '检测完成',
|
||||
'detections': result.get('detections', []),
|
||||
'stats': result.get('stats', None)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检测失败: {e}")
|
||||
return {
|
||||
'success': False,
|
||||
'message': f'检测失败: {e}',
|
||||
'detections': [],
|
||||
'stats': None
|
||||
}
|
||||
|
||||
|
||||
class ActionDetectionModel:
|
||||
"""行为识别模型包装器,兼容 YOLO 接口"""
|
||||
|
||||
def __init__(self):
|
||||
self.service = ActionDetectionService()
|
||||
self.names = {0: 'fight', 1: 'normal'}
|
||||
|
||||
def __call__(self, image, conf=0.5, iou=0.45, verbose=False):
|
||||
"""
|
||||
模拟 YOLO 模型的调用接口
|
||||
|
||||
Args:
|
||||
image: OpenCV 图片
|
||||
conf: 置信度阈值
|
||||
iou: IoU 阈值
|
||||
verbose: 是否输出详细信息
|
||||
|
||||
Returns:
|
||||
模拟 YOLO 结果的对象
|
||||
"""
|
||||
result = self.service.detect_image(image, threshold=conf)
|
||||
return [ActionDetectionResult(result, self.names)]
|
||||
|
||||
|
||||
class ActionDetectionResult:
|
||||
"""模拟 YOLO 检测结果对象"""
|
||||
|
||||
def __init__(self, detection_result: Dict, names: Dict):
|
||||
self.detection_result = detection_result
|
||||
self.names = names
|
||||
self.boxes = self._create_boxes()
|
||||
|
||||
def _create_boxes(self):
|
||||
"""创建模拟的 boxes 对象"""
|
||||
detections = self.detection_result.get('detections', [])
|
||||
|
||||
if not detections:
|
||||
return MockBoxes([])
|
||||
|
||||
xyxy = []
|
||||
conf = []
|
||||
cls = []
|
||||
|
||||
for det in detections:
|
||||
if 'bbox' in det:
|
||||
xyxy.append(det['bbox'])
|
||||
conf.append(det.get('confidence', 0.0))
|
||||
cls.append(det.get('class_id', 0))
|
||||
|
||||
return MockBoxes(xyxy, conf, cls)
|
||||
|
||||
|
||||
class MockBoxes:
|
||||
"""模拟 YOLO boxes 对象"""
|
||||
|
||||
def __init__(self, xyxy_list, conf_list=None, cls_list=None):
|
||||
try:
|
||||
import torch
|
||||
use_torch = True
|
||||
except ImportError:
|
||||
use_torch = False
|
||||
|
||||
if xyxy_list and len(xyxy_list) > 0:
|
||||
if use_torch:
|
||||
self.xyxy = torch.tensor(xyxy_list, dtype=torch.float32)
|
||||
self.conf = torch.tensor(conf_list, dtype=torch.float32).reshape(-1, 1)
|
||||
self.cls = torch.tensor(cls_list, dtype=torch.int64).reshape(-1, 1)
|
||||
else:
|
||||
self.xyxy = np.array(xyxy_list, dtype=np.float32)
|
||||
self.conf = np.array(conf_list, dtype=np.float32).reshape(-1, 1)
|
||||
self.cls = np.array(cls_list, dtype=np.int64).reshape(-1, 1)
|
||||
else:
|
||||
if use_torch:
|
||||
self.xyxy = torch.empty((0, 4), dtype=torch.float32)
|
||||
self.conf = torch.empty((0, 1), dtype=torch.float32)
|
||||
self.cls = torch.empty((0, 1), dtype=torch.int64)
|
||||
else:
|
||||
self.xyxy = np.array([]).reshape(0, 4)
|
||||
self.conf = np.array([]).reshape(0, 1)
|
||||
self.cls = np.array([]).reshape(0, 1)
|
||||
|
||||
self._use_torch = use_torch
|
||||
|
||||
def __iter__(self):
|
||||
for i in range(len(self.xyxy)):
|
||||
yield MockBox(
|
||||
self.xyxy[i],
|
||||
self.conf[i][0] if len(self.conf) > i else 0.0,
|
||||
self.cls[i][0] if len(self.cls) > i else 0
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.xyxy)
|
||||
|
||||
def cpu(self):
|
||||
return self
|
||||
|
||||
def numpy(self):
|
||||
if self._use_torch:
|
||||
if len(self.xyxy) > 0:
|
||||
return (
|
||||
self.xyxy.numpy(),
|
||||
self.conf.numpy(),
|
||||
self.cls.numpy()
|
||||
)
|
||||
else:
|
||||
return (
|
||||
np.array([]).reshape(0, 4),
|
||||
np.array([]).reshape(0, 1),
|
||||
np.array([], dtype=np.int64).reshape(0, 1)
|
||||
)
|
||||
else:
|
||||
return (
|
||||
self.xyxy,
|
||||
self.conf,
|
||||
self.cls
|
||||
)
|
||||
|
||||
|
||||
class MockBox:
|
||||
"""模拟单个 YOLO box 对象"""
|
||||
|
||||
def __init__(self, xyxy, conf, cls):
|
||||
try:
|
||||
import torch
|
||||
use_torch = True
|
||||
except ImportError:
|
||||
use_torch = False
|
||||
|
||||
if use_torch:
|
||||
if isinstance(xyxy, torch.Tensor):
|
||||
self.xyxy = xyxy
|
||||
else:
|
||||
self.xyxy = torch.tensor(xyxy, dtype=torch.float32)
|
||||
else:
|
||||
if isinstance(xyxy, np.ndarray):
|
||||
self.xyxy = xyxy
|
||||
else:
|
||||
self.xyxy = np.array(xyxy, dtype=np.float32)
|
||||
|
||||
self.conf = conf
|
||||
self.cls = cls
|
||||
Reference in New Issue
Block a user