feat: 新增车辆检测Paddle模型及相关服务,优化依赖与代码兼容性
1. 新增3套PaddlePaddle车辆检测相关模型文件 2. 新增车辆检测服务类与违停检测功能 3. 更新服务依赖并添加环境初始化脚本与文档 4. 修复YOLO检测tensor转换兼容问题 5. 新增PyTorch版本兼容性修复逻辑 6. 扩展模型服务支持Paddle模型加载
This commit is contained in:
585
apps/server/services/vehicle_detection_service.py
Normal file
585
apps/server/services/vehicle_detection_service.py
Normal file
@@ -0,0 +1,585 @@
|
||||
"""
|
||||
车辆检测服务适配器
|
||||
支持车辆检测、跟踪、车牌识别和违停检测功能
|
||||
"""
|
||||
|
||||
# 禁用 PIR API 以支持旧版模型格式(必须在任何导入之前设置)
|
||||
import os
|
||||
os.environ['FLAGS_enable_pir_api'] = '0'
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import sys
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VehicleTrackingInfo:
|
||||
"""车辆跟踪信息"""
|
||||
track_id: int
|
||||
bbox: List[float]
|
||||
center: Tuple[float, float]
|
||||
first_seen: float
|
||||
last_seen: float
|
||||
plate_number: Optional[str] = None
|
||||
is_illegal_parking: bool = False
|
||||
trajectory: List[Tuple[float, float]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.trajectory is None:
|
||||
self.trajectory = []
|
||||
|
||||
|
||||
class VehicleDetectionService:
|
||||
"""车辆检测服务(本地模式)"""
|
||||
|
||||
def __init__(self):
|
||||
self.model_name = "vehicle_detection"
|
||||
self.threshold = 0.1
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# 本地环境配置
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
self.paddle_dir = os.path.join(project_root, "third-party", "paddle-inference")
|
||||
self.model_dir = os.path.join(project_root, "models", "vehicle_detection_paddle")
|
||||
|
||||
# 模型路径配置
|
||||
self.mot_model_dir = os.path.join(self.model_dir, "mot_ppyoloe_l_36e_ppvehicle")
|
||||
self.plate_det_model_dir = os.path.join(self.model_dir, "ch_PP-OCRv3_det_infer")
|
||||
self.plate_rec_model_dir = os.path.join(self.model_dir, "ch_PP-OCRv3_rec_infer")
|
||||
|
||||
# 检测器实例(延迟加载)
|
||||
self._detector = None
|
||||
self._detector_initialized = False
|
||||
|
||||
# 车辆跟踪信息
|
||||
self.vehicle_tracks: Dict[int, VehicleTrackingInfo] = {}
|
||||
self.track_id_counter = 0
|
||||
|
||||
# 违停检测配置
|
||||
self.illegal_parking_time = 5.0 # 默认5秒
|
||||
self.illegal_parking_region = None # 违停区域多边形
|
||||
|
||||
self.available = True
|
||||
logger.info(f"车辆检测服务初始化完成")
|
||||
logger.info(f"车辆检测模型目录: {self.mot_model_dir}")
|
||||
logger.info(f"车牌检测模型目录: {self.plate_det_model_dir}")
|
||||
logger.info(f"车牌识别模型目录: {self.plate_rec_model_dir}")
|
||||
|
||||
# 禁用 PIR API 以支持旧版模型格式
|
||||
os.environ['FLAGS_enable_pir_api'] = '0'
|
||||
|
||||
try:
|
||||
self._initialize_environment()
|
||||
except Exception as e:
|
||||
logger.error(f"环境初始化失败: {e}")
|
||||
self.available = False
|
||||
|
||||
def _initialize_environment(self):
|
||||
"""初始化本地 PaddlePaddle 环境"""
|
||||
try:
|
||||
# 添加 PaddleDetection 部署路径
|
||||
paddle_detection_path = self.paddle_dir
|
||||
if paddle_detection_path not in sys.path:
|
||||
sys.path.insert(0, paddle_detection_path)
|
||||
logger.info(f"✅ 添加 PaddleDetection 路径: {paddle_detection_path}")
|
||||
|
||||
# 检查模型目录是否存在
|
||||
required_models = {
|
||||
'MOT': self.mot_model_dir,
|
||||
'Plate Detection': self.plate_det_model_dir,
|
||||
'Plate Recognition': self.plate_rec_model_dir
|
||||
}
|
||||
|
||||
for model_name, model_path in required_models.items():
|
||||
if not os.path.exists(model_path):
|
||||
raise Exception(f"{model_name} 模型目录不存在: {model_path}")
|
||||
|
||||
required_files = ['inference.pdmodel', 'inference.pdiparams', 'inference.pdiparams.info']
|
||||
if model_name == 'MOT':
|
||||
required_files = ['model.pdmodel', 'model.pdiparams', 'infer_cfg.yml']
|
||||
|
||||
for file in required_files:
|
||||
file_path = os.path.join(model_path, file)
|
||||
if not os.path.exists(file_path):
|
||||
raise Exception(f"{model_name} 模型文件不存在: {file}")
|
||||
|
||||
logger.info("✅ 环境检查通过")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"环境初始化失败: {e}")
|
||||
raise
|
||||
|
||||
def _get_detector(self):
|
||||
"""获取检测器实例(单例模式)"""
|
||||
if self._detector is None or not self._detector_initialized:
|
||||
try:
|
||||
# 设置环境变量以支持旧版模型格式
|
||||
os.environ['FLAGS_enable_pir_api'] = '0'
|
||||
|
||||
# 添加 PaddleDetection 路径
|
||||
if self.paddle_dir not in sys.path:
|
||||
sys.path.insert(0, self.paddle_dir)
|
||||
|
||||
# 导入 PaddleDetection 模块
|
||||
from infer import Detector, PredictConfig
|
||||
|
||||
# 创建检测器(使用MOT模型)
|
||||
self._detector = Detector(
|
||||
model_dir=self.mot_model_dir,
|
||||
device='CPU',
|
||||
run_mode='paddle',
|
||||
batch_size=1,
|
||||
output_dir='output',
|
||||
threshold=self.threshold
|
||||
)
|
||||
|
||||
self._detector_initialized = True
|
||||
logger.info("✅ 车辆检测器初始化成功")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检测器初始化失败: {e}")
|
||||
raise
|
||||
|
||||
return self._detector
|
||||
|
||||
def detect_image(self, image: np.ndarray, threshold: float = None) -> Dict:
|
||||
"""
|
||||
检测图片中的车辆
|
||||
|
||||
Args:
|
||||
image: OpenCV 图片 (BGR格式)
|
||||
threshold: 置信度阈值
|
||||
|
||||
Returns:
|
||||
检测结果字典
|
||||
"""
|
||||
if threshold is None:
|
||||
threshold = self.threshold
|
||||
|
||||
if not self.available:
|
||||
return {
|
||||
'success': False,
|
||||
'message': '车辆检测服务不可用',
|
||||
'detections': [],
|
||||
'stats': None
|
||||
}
|
||||
|
||||
try:
|
||||
with self._lock:
|
||||
start_time = time.time()
|
||||
|
||||
# 确保检测器已初始化
|
||||
detector = self._get_detector()
|
||||
|
||||
# 准备输入图片
|
||||
if not isinstance(image, np.ndarray):
|
||||
raise Exception(f"不支持的图片类型: {type(image)}")
|
||||
|
||||
if len(image.shape) == 2:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
|
||||
elif image.shape[2] == 4:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR)
|
||||
|
||||
# 执行推理
|
||||
inference_start = time.time()
|
||||
|
||||
results = detector.predict_image(
|
||||
[image],
|
||||
visual=False,
|
||||
save_results=False
|
||||
)
|
||||
|
||||
inference_time = time.time() - inference_start
|
||||
logger.info(f"推理耗时: {inference_time:.3f}s")
|
||||
|
||||
# 解析检测结果
|
||||
detections = self._parse_detection_results(results, threshold)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
logger.info(f"检测总耗时: {total_time:.3f}s")
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'message': '检测完成',
|
||||
'detections': detections,
|
||||
'stats': {
|
||||
'total_detections': len(detections),
|
||||
'model_used': 'mot_ppyoloe_l_36e_ppvehicle',
|
||||
'threshold': threshold,
|
||||
'processing_time': round(total_time, 3),
|
||||
'inference_time': round(inference_time, 3)
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"检测失败: {e}")
|
||||
logger.error(f"错误堆栈: {traceback.format_exc()}")
|
||||
|
||||
self._detector_initialized = False
|
||||
|
||||
return {
|
||||
'success': False,
|
||||
'message': f'检测失败: {e}',
|
||||
'detections': [],
|
||||
'stats': None
|
||||
}
|
||||
|
||||
def _parse_detection_results(self, results: Dict, threshold: float) -> List[Dict]:
|
||||
"""解析 PaddleDetection 返回的检测结果"""
|
||||
detections = []
|
||||
|
||||
try:
|
||||
if results and 'boxes' in results:
|
||||
boxes = results['boxes']
|
||||
|
||||
if boxes is not None and len(boxes) > 0:
|
||||
for box in boxes:
|
||||
if len(box) >= 6:
|
||||
class_id = int(box[0])
|
||||
confidence = float(box[1])
|
||||
x1, y1, x2, y2 = float(box[2]), float(box[3]), float(box[4]), float(box[5])
|
||||
|
||||
# 计算中心点
|
||||
center_x = (x1 + x2) / 2
|
||||
center_y = (y1 + y2) / 2
|
||||
|
||||
# 过滤低置信度检测
|
||||
if confidence >= threshold:
|
||||
detections.append({
|
||||
'class': 'vehicle',
|
||||
'label': '车辆',
|
||||
'confidence': round(confidence, 3),
|
||||
'bbox': [int(x1), int(y1), int(x2), int(y2)],
|
||||
'center': [round(center_x, 2), round(center_y, 2)]
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析检测结果失败: {e}")
|
||||
|
||||
return detections
|
||||
|
||||
def detect_illegal_parking(self, image: np.ndarray, threshold: float = None,
|
||||
illegal_parking_time: float = 5.0,
|
||||
region_polygon: List[Tuple[int, int]] = None) -> Dict:
|
||||
"""
|
||||
检测违停车辆
|
||||
|
||||
Args:
|
||||
image: OpenCV 图片
|
||||
threshold: 置信度阈值
|
||||
illegal_parking_time: 违停时间阈值(秒)
|
||||
region_polygon: 违停区域多边形点集 [(x1,y1), (x2,y2), ...]
|
||||
|
||||
Returns:
|
||||
违停检测结果
|
||||
"""
|
||||
if threshold is None:
|
||||
threshold = self.threshold
|
||||
|
||||
# 更新违停配置
|
||||
self.illegal_parking_time = illegal_parking_time
|
||||
self.illegal_parking_region = region_polygon
|
||||
|
||||
# 基础车辆检测
|
||||
detection_result = self.detect_image(image, threshold)
|
||||
|
||||
if not detection_result['success']:
|
||||
return {
|
||||
'success': False,
|
||||
'message': detection_result['message'],
|
||||
'illegal_parking': [],
|
||||
'vehicles': []
|
||||
}
|
||||
|
||||
current_time = time.time()
|
||||
current_detections = detection_result['detections']
|
||||
|
||||
# 更新车辆跟踪信息
|
||||
illegal_parking_vehicles = []
|
||||
|
||||
for detection in current_detections:
|
||||
bbox = detection['bbox']
|
||||
center = detection['center']
|
||||
|
||||
# 简单的跟踪(基于位置匹配)
|
||||
matched_track_id = self._match_vehicle_to_track(center, bbox)
|
||||
|
||||
if matched_track_id is None:
|
||||
# 新车辆
|
||||
self.track_id_counter += 1
|
||||
matched_track_id = self.track_id_counter
|
||||
self.vehicle_tracks[matched_track_id] = VehicleTrackingInfo(
|
||||
track_id=matched_track_id,
|
||||
bbox=bbox,
|
||||
center=center,
|
||||
first_seen=current_time,
|
||||
last_seen=current_time,
|
||||
trajectory=[center]
|
||||
)
|
||||
else:
|
||||
# 更新现有车辆
|
||||
track_info = self.vehicle_tracks[matched_track_id]
|
||||
track_info.bbox = bbox
|
||||
track_info.center = center
|
||||
track_info.last_seen = current_time
|
||||
track_info.trajectory.append(center)
|
||||
|
||||
# 检查违停条件
|
||||
if self._check_illegal_parking(track_info, region_polygon):
|
||||
track_info.is_illegal_parking = True
|
||||
illegal_parking_vehicles.append({
|
||||
'track_id': matched_track_id,
|
||||
'bbox': bbox,
|
||||
'center': center,
|
||||
'parking_duration': round(current_time - track_info.first_seen, 2),
|
||||
'plate_number': track_info.plate_number
|
||||
})
|
||||
|
||||
# 清理长时间未出现的车辆
|
||||
self._cleanup_old_tracks(current_time)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'message': '违停检测完成',
|
||||
'illegal_parking': illegal_parking_vehicles,
|
||||
'total_vehicles': len(current_detections),
|
||||
'stats': detection_result['stats']
|
||||
}
|
||||
|
||||
def _match_vehicle_to_track(self, center: Tuple[float, float],
|
||||
bbox: List[float]) -> Optional[int]:
|
||||
"""将检测到的车辆匹配到已有轨迹"""
|
||||
x, y = center
|
||||
|
||||
for track_id, track_info in self.vehicle_tracks.items():
|
||||
track_x, track_y = track_info.center
|
||||
|
||||
# 计算距离
|
||||
distance = np.sqrt((x - track_x) ** 2 + (y - track_y) ** 2)
|
||||
|
||||
# 距离阈值(基于检测框大小)
|
||||
bbox_width = bbox[2] - bbox[0]
|
||||
bbox_height = bbox[3] - bbox[1]
|
||||
max_dim = max(bbox_width, bbox_height)
|
||||
|
||||
if distance < max_dim * 0.5: # 距离小于检测框最大尺寸的一半
|
||||
return track_id
|
||||
|
||||
return None
|
||||
|
||||
def _check_illegal_parking(self, track_info: VehicleTrackingInfo,
|
||||
region_polygon: List[Tuple[int, int]] = None) -> bool:
|
||||
"""检查是否违停"""
|
||||
current_time = time.time()
|
||||
parking_duration = current_time - track_info.first_seen
|
||||
|
||||
# 检查时间是否超过阈值
|
||||
if parking_duration < self.illegal_parking_time:
|
||||
return False
|
||||
|
||||
# 检查是否在违停区域内
|
||||
if region_polygon is None:
|
||||
return False
|
||||
|
||||
# 检查车辆中心是否在多边形内
|
||||
return self._point_in_polygon(track_info.center, region_polygon)
|
||||
|
||||
def _point_in_polygon(self, point: Tuple[float, float],
|
||||
polygon: List[Tuple[int, int]]) -> bool:
|
||||
"""判断点是否在多边形内(射线法)"""
|
||||
x, y = point
|
||||
n = len(polygon)
|
||||
inside = False
|
||||
|
||||
p1x, p1y = polygon[0]
|
||||
for i in range(n + 1):
|
||||
p2x, p2y = polygon[i % n]
|
||||
if y > min(p1y, p2y):
|
||||
if y <= max(p1y, p2y):
|
||||
if x <= max(p1x, p2x):
|
||||
if p1y != p2y:
|
||||
xinters = (y - p1y) * (p2x - p1x) / (p2y - p1y) + p1x
|
||||
if p1x == p2x or x <= xinters:
|
||||
inside = not inside
|
||||
p1x, p1y = p2x, p2y
|
||||
|
||||
return inside
|
||||
|
||||
def _cleanup_old_tracks(self, current_time: float):
|
||||
"""清理长时间未出现的车辆轨迹"""
|
||||
timeout = 10.0 # 10秒未出现则删除
|
||||
|
||||
tracks_to_remove = []
|
||||
for track_id, track_info in self.vehicle_tracks.items():
|
||||
if current_time - track_info.last_seen > timeout:
|
||||
tracks_to_remove.append(track_id)
|
||||
|
||||
for track_id in tracks_to_remove:
|
||||
del self.vehicle_tracks[track_id]
|
||||
logger.debug(f"清理车辆轨迹: {track_id}")
|
||||
|
||||
def get_performance_info(self) -> Dict:
|
||||
"""获取性能信息"""
|
||||
return {
|
||||
'mode': 'local',
|
||||
'environment': 'PaddlePaddle',
|
||||
'model_dir': self.model_dir,
|
||||
'mot_model_dir': self.mot_model_dir,
|
||||
'plate_det_model_dir': self.plate_det_model_dir,
|
||||
'plate_rec_model_dir': self.plate_rec_model_dir,
|
||||
'detector_loaded': self._detector_initialized,
|
||||
'available': self.available,
|
||||
'active_tracks': len(self.vehicle_tracks)
|
||||
}
|
||||
|
||||
|
||||
# 兼容性包装,保持与 YOLO 模型相同的接口
|
||||
class VehicleDetectionModel:
|
||||
"""车辆检测模型包装器,兼容 YOLO 接口"""
|
||||
|
||||
def __init__(self):
|
||||
self.service = VehicleDetectionService()
|
||||
self.names = {0: 'vehicle'}
|
||||
|
||||
def __call__(self, image, conf=0.1, iou=0.45, verbose=False):
|
||||
"""
|
||||
模拟 YOLO 模型的调用接口
|
||||
"""
|
||||
result = self.service.detect_image(image, threshold=conf)
|
||||
return [PaddleDetectionResult(result, self.names)]
|
||||
|
||||
def detect_illegal_parking(self, image, conf=0.1, illegal_parking_time=5.0,
|
||||
region_polygon=None):
|
||||
"""违停检测接口"""
|
||||
return self.service.detect_illegal_parking(
|
||||
image, conf, illegal_parking_time, region_polygon
|
||||
)
|
||||
|
||||
|
||||
class PaddleDetectionResult:
|
||||
"""模拟 YOLO 检测结果对象"""
|
||||
|
||||
def __init__(self, detection_result: Dict, names: Dict):
|
||||
self.detection_result = detection_result
|
||||
self.names = names
|
||||
self.boxes = self._create_boxes()
|
||||
|
||||
def _create_boxes(self):
|
||||
"""创建模拟的 boxes 对象"""
|
||||
detections = self.detection_result.get('detections', [])
|
||||
|
||||
if not detections:
|
||||
return MockBoxes([])
|
||||
|
||||
xyxy = []
|
||||
conf = []
|
||||
cls = []
|
||||
|
||||
for det in detections:
|
||||
xyxy.append(det['bbox'])
|
||||
conf.append(det['confidence'])
|
||||
cls.append(0)
|
||||
|
||||
return MockBoxes(xyxy, conf, cls)
|
||||
|
||||
|
||||
class MockBoxes:
|
||||
"""模拟 YOLO boxes 对象"""
|
||||
|
||||
def __init__(self, xyxy_list, conf_list=None, cls_list=None):
|
||||
try:
|
||||
import torch
|
||||
use_torch = True
|
||||
except ImportError:
|
||||
use_torch = False
|
||||
|
||||
if xyxy_list and len(xyxy_list) > 0:
|
||||
if use_torch:
|
||||
self.xyxy = torch.tensor(xyxy_list, dtype=torch.float32)
|
||||
self.conf = torch.tensor(conf_list, dtype=torch.float32).reshape(-1, 1)
|
||||
self.cls = torch.tensor(cls_list, dtype=torch.int64).reshape(-1, 1)
|
||||
else:
|
||||
self.xyxy = np.array(xyxy_list, dtype=np.float32)
|
||||
self.conf = np.array(conf_list, dtype=np.float32).reshape(-1, 1)
|
||||
self.cls = np.array(cls_list, dtype=np.int64).reshape(-1, 1)
|
||||
else:
|
||||
if use_torch:
|
||||
self.xyxy = torch.empty((0, 4), dtype=torch.float32)
|
||||
self.conf = torch.empty((0, 1), dtype=torch.float32)
|
||||
self.cls = torch.empty((0, 1), dtype=torch.int64)
|
||||
else:
|
||||
self.xyxy = np.array([]).reshape(0, 4)
|
||||
self.conf = np.array([]).reshape(0, 1)
|
||||
self.cls = np.array([], dtype=np.int64).reshape(0, 1)
|
||||
|
||||
self._use_torch = use_torch
|
||||
|
||||
def __iter__(self):
|
||||
for i in range(len(self.xyxy)):
|
||||
yield MockBox(
|
||||
self.xyxy[i],
|
||||
self.conf[i][0] if len(self.conf) > i else 0.0,
|
||||
self.cls[i][0] if len(self.cls) > i else 0
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.xyxy)
|
||||
|
||||
def cpu(self):
|
||||
return self
|
||||
|
||||
def numpy(self):
|
||||
if self._use_torch:
|
||||
if len(self.xyxy) > 0:
|
||||
return (
|
||||
self.xyxy.numpy(),
|
||||
self.conf.numpy(),
|
||||
self.cls.numpy()
|
||||
)
|
||||
else:
|
||||
return (
|
||||
np.array([]).reshape(0, 4),
|
||||
np.array([]).reshape(0, 1),
|
||||
np.array([], dtype=np.int64).reshape(0, 1)
|
||||
)
|
||||
else:
|
||||
return (
|
||||
self.xyxy,
|
||||
self.conf,
|
||||
self.cls
|
||||
)
|
||||
|
||||
|
||||
class MockBox:
|
||||
"""模拟单个 YOLO box 对象"""
|
||||
|
||||
def __init__(self, xyxy, conf, cls):
|
||||
try:
|
||||
import torch
|
||||
use_torch = True
|
||||
except ImportError:
|
||||
use_torch = False
|
||||
|
||||
if use_torch:
|
||||
if isinstance(xyxy, torch.Tensor):
|
||||
self.xyxy = xyxy
|
||||
else:
|
||||
self.xyxy = torch.tensor(xyxy, dtype=torch.float32)
|
||||
else:
|
||||
if isinstance(xyxy, np.ndarray):
|
||||
self.xyxy = xyxy
|
||||
else:
|
||||
self.xyxy = np.array(xyxy, dtype=np.float32)
|
||||
|
||||
self.conf = conf
|
||||
self.cls = cls
|
||||
Reference in New Issue
Block a user