feat: 新增人员徘徊/静止行为分析功能

本次提交实现了完整的人员行为分析系统,包括:
1. 新增基于位置和跟踪ID的两种行为检测算法
2. 新增徘徊检测服务与行为处理器模块
3. 前后端集成算法配置界面与告警展示
4. 支持图片和视频流场景下的行为分析
5. 新增算法配置接口与文档说明

具体改动:
- 新增loitering_detection模型目录与算法实现
- 新增AlgorithmConfig组件实现可视化配置
- 扩展图片/视频检测接口支持算法参数传递
- 新增行为告警推送与前端展示页面
- 优化检测服务,集成行为分析逻辑
- 移除冗余日志输出,完善代码注释
This commit is contained in:
wwh
2026-05-19 09:17:09 +08:00
parent 2691761f01
commit 7aa71c5f83
15 changed files with 1937 additions and 76 deletions

View File

@@ -2,24 +2,50 @@ import cv2
import numpy as np import numpy as np
import base64 import base64
import logging import logging
import json
from typing import Optional
from fastapi import APIRouter, UploadFile, File, Form, Query from fastapi import APIRouter, UploadFile, File, Form, Query
from models.schemas import ImageDetectionResult from models.schemas import ImageDetectionResult
router = APIRouter() router = APIRouter()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@router.post("/detect/image", response_model=ImageDetectionResult) @router.post("/detect/image", response_model=ImageDetectionResult)
async def detect_image( async def detect_image(
file: UploadFile = File(...), file: UploadFile = File(...),
model_id: str = Query("fire_detection"), model_id: str = Query("fire_detection"),
confidence: float = Query(0.5), confidence: float = Query(0.5),
iou: float = Query(0.45) iou: float = Query(0.45),
algorithm_config: Optional[str] = Query(None, description="算法配置JSON字符串")
): ):
"""
图片检测接口
Args:
algorithm_config: 算法配置JSON例如
{
"enable_stationary_detection": true,
"enable_loitering_detection": false,
"stationary_threshold": 10.0,
"position_tolerance": 50,
"loitering_threshold": 300.0,
"movement_threshold": 5.0
}
"""
from main import model_service from main import model_service
from services.detection_service import DetectionService from services.detection_service import DetectionService
detection_service = DetectionService(model_service) detection_service = DetectionService(model_service)
# 解析算法配置
algo_config = None
if algorithm_config:
try:
algo_config = json.loads(algorithm_config)
except json.JSONDecodeError as e:
logger.warning(f"算法配置解析失败: {e}")
try: try:
contents = await file.read() contents = await file.read()
nparr = np.frombuffer(contents, np.uint8) nparr = np.frombuffer(contents, np.uint8)
@@ -32,10 +58,14 @@ async def detect_image(
data={} data={}
) )
result = await detection_service.detect_image(frame, model_id, confidence, iou) result = await detection_service.detect_image(
frame, model_id, confidence, iou, algorithm_config=algo_config
)
if result['success']: if result['success']:
annotated_frame = detection_service.draw_detections(frame, result['detections']) annotated_frame = detection_service.draw_detections(
frame, result['detections'], algorithm_config=algo_config
)
# 将标注后的图片转换为 base64 # 将标注后的图片转换为 base64
_, buffer = cv2.imencode('.jpg', annotated_frame) _, buffer = cv2.imencode('.jpg', annotated_frame)
@@ -47,7 +77,9 @@ async def detect_image(
data={ data={
"detections": result['detections'], "detections": result['detections'],
"image_base64": img_base64, "image_base64": img_base64,
"stats": result['stats'] "stats": result['stats'],
"alerts": result.get('alerts', []),
"behavior_stats": result.get('behavior_stats', {})
} }
) )
else: else:
@@ -64,3 +96,66 @@ async def detect_image(
message=f"检测失败: {str(e)}", message=f"检测失败: {str(e)}",
data={} data={}
) )
@router.get("/algorithms/config")
async def get_algorithm_config():
"""获取算法配置选项"""
return {
"algorithms": [
{
"id": "stationary_detection",
"name": "静止检测",
"description": "检测人员在同一位置静止停留",
"params": [
{
"name": "stationary_threshold",
"label": "静止阈值",
"type": "number",
"default": 10.0,
"min": 1.0,
"max": 300.0,
"unit": "",
"description": "超过此时间视为静止"
},
{
"name": "position_tolerance",
"label": "位置容差",
"type": "number",
"default": 50,
"min": 10,
"max": 200,
"unit": "像素",
"description": "位置匹配容差范围"
}
]
},
{
"id": "loitering_detection",
"name": "徘徊检测",
"description": "检测人员长时间停留需要跟踪ID",
"params": [
{
"name": "loitering_threshold",
"label": "徘徊阈值",
"type": "number",
"default": 300.0,
"min": 60.0,
"max": 1800.0,
"unit": "",
"description": "超过此时间视为徘徊"
},
{
"name": "movement_threshold",
"label": "移动阈值",
"type": "number",
"default": 5.0,
"min": 1.0,
"max": 50.0,
"unit": "像素",
"description": "小于此移动视为静止"
}
]
}
]
}

View File

@@ -249,11 +249,21 @@ class CameraService:
logger.info(f"发送检测结果: {len(result['detections'])} 个目标, {result['stats']}") logger.info(f"发送检测结果: {len(result['detections'])} 个目标, {result['stats']}")
await websocket.send_json({ detection_message = {
'type': 'detection', 'type': 'detection',
'detections': result['detections'], 'detections': result['detections'],
'stats': result['stats'] 'stats': result['stats']
}) }
# 包含行为告警信息
if 'alerts' in result and result['alerts']:
detection_message['alerts'] = result['alerts']
logger.info(f"发送告警: {len(result['alerts'])}")
if 'behavior_stats' in result:
detection_message['behavior_stats'] = result['behavior_stats']
await websocket.send_json(detection_message)
_, buffer = cv2.imencode('.jpg', frame, [cv2.IMWRITE_JPEG_QUALITY, 80]) _, buffer = cv2.imencode('.jpg', frame, [cv2.IMWRITE_JPEG_QUALITY, 80])
import base64 import base64

View File

@@ -7,6 +7,8 @@ import logging
from typing import Dict, List, Optional from typing import Dict, List, Optional
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
from .loitering_service import get_loitering_service
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DetectionService: class DetectionService:
@@ -18,64 +20,20 @@ class DetectionService:
os.makedirs(self.results_dir, exist_ok=True) os.makedirs(self.results_dir, exist_ok=True)
os.makedirs(self.temp_dir, exist_ok=True) os.makedirs(self.temp_dir, exist_ok=True)
def draw_detections(self, frame: np.ndarray, detections: List[Dict], fps: float = 0) -> np.ndarray: # 初始化徘徊检测服务(懒加载,实际初始化在第一次使用时)
try: self.loitering_service = get_loitering_service()
img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_img = Image.fromarray(img_rgb)
draw = ImageDraw.Draw(pil_img)
try:
font = ImageFont.truetype("/System/Library/Fonts/PingFang.ttc", 20)
font_large = ImageFont.truetype("/System/Library/Fonts/PingFang.ttc", 24)
except:
font = ImageFont.load_default()
font_large = font
class_colors = {
'Fire': (255, 0, 0),
'Smoke': (128, 128, 128),
'person': (0, 255, 0),
'helmet': (255, 255, 0),
'no_helmet': (255, 0, 255),
'cigarette': (0, 165, 255) # 橙色,用于抽烟检测
}
for det in detections:
x1, y1, x2, y2 = det['bbox']
class_name = det['class']
conf = det['confidence']
label = det['label']
color = class_colors.get(class_name, (0, 255, 0))
draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
label_text = f"{label} {conf:.2f}"
bbox = draw.textbbox((0, 0), label_text, font=font)
text_w = bbox[2] - bbox[0]
text_h = bbox[3] - bbox[1]
draw.rectangle([x1, y1 - text_h - 4, x1 + text_w + 4, y1], fill=color)
draw.text((x1 + 2, y1 - text_h - 2), label_text, fill=(255, 255, 255), font=font)
if fps > 0:
fps_text = f"FPS: {fps:.1f} | Detections: {len(detections)}"
draw.text((10, 10), fps_text, fill=(0, 255, 0), font=font)
return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
except Exception as e:
logger.error(f"绘制检测结果失败: {e}")
return frame
async def detect_image( async def detect_image(
self, self,
image: np.ndarray, image: np.ndarray,
model_id: str, model_id: str,
confidence: float = 0.5, confidence: float = 0.5,
iou: float = 0.45 iou: float = 0.45,
algorithm_config: Optional[Dict] = None
) -> Dict: ) -> Dict:
start_time = time.time() start_time = time.time()
model = await self.model_service.load_model(model_id) model = await self.model_service.load_model(model_id)
if not model: if not model:
return { return {
@@ -84,10 +42,10 @@ class DetectionService:
'detections': [], 'detections': [],
'stats': None 'stats': None
} }
try: try:
results = model(image, conf=confidence, iou=iou, verbose=False) results = model(image, conf=confidence, iou=iou, verbose=False)
detections = [] detections = []
for result in results: for result in results:
boxes = result.boxes boxes = result.boxes
@@ -96,21 +54,21 @@ class DetectionService:
conf = float(box.conf[0].cpu().numpy()) conf = float(box.conf[0].cpu().numpy())
cls = int(box.cls[0].cpu().numpy()) cls = int(box.cls[0].cpu().numpy())
class_name = result.names[cls] class_name = result.names[cls]
label_map = self.model_service.model_configs[model_id]['labels'] label_map = self.model_service.model_configs[model_id]['labels']
label = label_map.get(class_name, class_name) label = label_map.get(class_name, class_name)
detections.append({ detections.append({
'class': class_name, 'class': class_name,
'label': label, 'label': label,
'confidence': round(conf, 3), 'confidence': round(conf, 3),
'bbox': [int(x1), int(y1), int(x2), int(y2)] 'bbox': [int(x1), int(y1), int(x2), int(y2)]
}) })
processing_time = time.time() - start_time processing_time = time.time() - start_time
avg_confidence = sum(d['confidence'] for d in detections) / len(detections) if detections else 0 avg_confidence = sum(d['confidence'] for d in detections) / len(detections) if detections else 0
return { result_data = {
'success': True, 'success': True,
'message': '检测完成', 'message': '检测完成',
'detections': detections, 'detections': detections,
@@ -121,6 +79,14 @@ class DetectionService:
'model_used': model_id 'model_used': model_id
} }
} }
# 如果启用了行为检测算法
if algorithm_config and detections:
result_data = self._apply_behavior_analysis(
result_data, algorithm_config
)
return result_data
except Exception as e: except Exception as e:
logger.error(f"图片检测失败: {e}") logger.error(f"图片检测失败: {e}")
return { return {
@@ -186,9 +152,40 @@ class DetectionService:
} }
} }
# 如果是人员检测模型,进行行为分析
logger.info(f"[DetectionService] 模型: {model_id}, 检测目标: {len(detections)}")
if model_id == 'loitering_detection' and detections:
logger.info("[DetectionService] 调用行为分析...")
# 确保服务已初始化
if not self.loitering_service.is_initialized:
logger.info("[DetectionService] 初始化徘徊检测服务...")
self.loitering_service.initialize(
# 检测阈值(用于判断是否静止/徘徊)
stationary_threshold=10.0,
position_tolerance=50,
loitering_threshold=300.0,
movement_threshold=5.0,
# 告警阈值(用于触发告警,应该比检测阈值高)
stationary_alert_threshold=30.0,
loitering_alert_threshold=600.0,
# 启用告警
enable_stationary_alert=True,
enable_loitering_alert=True
)
behavior_result = self.loitering_service.process_detections(
detections,
use_tracking=False # 可以改为 True 如果使用跟踪
)
detections = behavior_result['detections']
result_data['alerts'] = behavior_result['alerts']
result_data['behavior_stats'] = behavior_result['stats']
logger.info(f"[DetectionService] 行为分析完成: alerts={len(behavior_result['alerts'])}, stats={behavior_result['stats']}")
if draw: if draw:
frame = self.draw_detections(frame, detections, fps) frame = self.draw_detections(frame, detections, fps)
return frame, result_data return frame, result_data
except Exception as e: except Exception as e:
logger.error(f"帧检测失败: {e}") logger.error(f"帧检测失败: {e}")
@@ -197,3 +194,139 @@ class DetectionService:
'detections': [], 'detections': [],
'stats': None 'stats': None
} }
def _apply_behavior_analysis(
self,
result_data: Dict,
algorithm_config: Dict
) -> Dict:
"""
应用行为分析算法
Args:
result_data: 检测结果
algorithm_config: 算法配置
{
"enable_stationary_detection": true,
"enable_loitering_detection": false,
"stationary_threshold": 10.0,
"position_tolerance": 50,
...
}
Returns:
添加行为分析结果的检测结果
"""
detections = result_data['detections']
# 检查是否需要行为分析
enable_stationary = algorithm_config.get('enable_stationary_detection', False)
enable_loitering = algorithm_config.get('enable_loitering_detection', False)
if not enable_stationary and not enable_loitering:
return result_data
try:
# 使用前端传入的配置初始化服务
self.loitering_service.initialize(
stationary_threshold=algorithm_config.get('stationary_threshold', 10.0),
position_tolerance=algorithm_config.get('position_tolerance', 50),
loitering_threshold=algorithm_config.get('loitering_threshold', 300.0),
movement_threshold=algorithm_config.get('movement_threshold', 5.0),
enable_stationary_alert=enable_stationary,
enable_loitering_alert=enable_loitering
)
# 处理检测
behavior_result = self.loitering_service.process_detections(
detections,
use_tracking=enable_loitering # 只有启用徘徊检测时才使用跟踪
)
result_data['detections'] = behavior_result['detections']
result_data['alerts'] = behavior_result['alerts']
result_data['behavior_stats'] = behavior_result['stats']
except Exception as e:
logger.error(f"行为分析失败: {e}")
result_data['behavior_error'] = str(e)
return result_data
def draw_detections(
self,
frame: np.ndarray,
detections: List[Dict],
fps: float = 0,
algorithm_config: Optional[Dict] = None
) -> np.ndarray:
"""
绘制检测结果和行为告警
Args:
frame: 图像帧
detections: 检测结果列表(可能包含 stationary_info/loitering_info
fps: 帧率
algorithm_config: 算法配置(已废弃,保留用于向后兼容)
"""
try:
img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_img = Image.fromarray(img_rgb)
draw = ImageDraw.Draw(pil_img)
try:
font = ImageFont.truetype("/System/Library/Fonts/PingFang.ttc", 20)
font_large = ImageFont.truetype("/System/Library/Fonts/PingFang.ttc", 24)
except:
font = ImageFont.load_default()
font_large = font
class_colors = {
'Fire': (255, 0, 0),
'Smoke': (128, 128, 128),
'person': (0, 255, 0),
'helmet': (255, 255, 0),
'no_helmet': (255, 0, 255),
'cigarette': (0, 165, 255)
}
for det in detections:
x1, y1, x2, y2 = det['bbox']
class_name = det['class']
conf = det['confidence']
label = det['label']
# 根据是否有行为告警选择颜色
color = class_colors.get(class_name, (0, 255, 0))
# 检查行为告警
if algorithm_config:
if 'stationary_info' in det:
info = det['stationary_info']
if info.get('is_stationary'):
color = (0, 0, 255) # 红色警告
label = f"静止{int(info['duration'])}s"
if 'loitering_info' in det:
info = det['loitering_info']
if info.get('is_loitering'):
color = (255, 0, 0) # 蓝色警告
label = f"徘徊{int(info['loitering_duration']//60)}min"
draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
label_text = f"{label} {conf:.2f}"
bbox = draw.textbbox((0, 0), label_text, font=font)
text_w = bbox[2] - bbox[0]
text_h = bbox[3] - bbox[1]
draw.rectangle([x1, y1 - text_h - 4, x1 + text_w + 4, y1], fill=color)
draw.text((x1 + 2, y1 - text_h - 2), label_text, fill=(255, 255, 255), font=font)
if fps > 0:
fps_text = f"FPS: {fps:.1f} | Detections: {len(detections)}"
draw.text((10, 10), fps_text, fill=(0, 255, 0), font=font)
return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
except Exception as e:
logger.error(f"绘制检测结果失败: {e}")
return frame

View File

@@ -0,0 +1,168 @@
"""
徘徊检测服务
集成行为检测算法到后端服务
"""
import sys
import os
from typing import Dict, List, Optional
import logging
# 添加算法模块路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..', 'models', 'loitering_detection'))
from processors import BehaviorProcessor
logger = logging.getLogger(__name__)
class LoiteringService:
"""
徘徊检测服务
为视频流检测提供行为分析功能:
- 静止检测(基于位置,无需跟踪)
- 徘徊检测基于跟踪ID
"""
def __init__(self):
self.processor = None
self.is_initialized = False
def initialize(
self,
stationary_threshold: float = 10.0,
position_tolerance: int = 50,
loitering_threshold: float = 300.0,
movement_threshold: float = 5.0,
enable_stationary_alert: bool = True,
enable_loitering_alert: bool = True,
stationary_alert_threshold: Optional[float] = None,
loitering_alert_threshold: Optional[float] = None
):
"""
初始化服务
Args:
stationary_threshold: 静止检测阈值(秒)- 用于判断是否静止
position_tolerance: 位置容差(像素)
loitering_threshold: 徘徊检测阈值(秒)- 用于判断是否徘徊
movement_threshold: 移动阈值(像素)
enable_stationary_alert: 是否启用静止告警
enable_loitering_alert: 是否启用徘徊告警
stationary_alert_threshold: 静止告警阈值(秒)- 超过此时间产生告警,默认等于 stationary_threshold
loitering_alert_threshold: 徘徊告警阈值(秒)- 超过此时间产生告警,默认等于 loitering_threshold
"""
try:
self.processor = BehaviorProcessor(
stationary_threshold=stationary_threshold,
position_tolerance=position_tolerance,
loitering_threshold=loitering_threshold,
movement_threshold=movement_threshold,
enable_stationary_alert=enable_stationary_alert,
enable_loitering_alert=enable_loitering_alert,
stationary_alert_threshold=stationary_alert_threshold if stationary_alert_threshold is not None else stationary_threshold,
loitering_alert_threshold=loitering_alert_threshold if loitering_alert_threshold is not None else loitering_threshold
)
self.is_initialized = True
logger.info(f"徘徊检测服务初始化成功: 静止阈值={stationary_threshold}s, 告警阈值={stationary_alert_threshold or stationary_threshold}s")
except Exception as e:
logger.error(f"徘徊检测服务初始化失败: {e}")
self.is_initialized = False
def process_detections(
self,
detections: List[Dict],
use_tracking: bool = False,
track_id_key: str = 'track_id'
) -> Dict:
"""
处理检测结果
Args:
detections: YOLO检测结果列表
use_tracking: 是否使用跟踪ID
track_id_key: 跟踪ID字段名
Returns:
{
'detections': 添加行为信息的检测结果,
'alerts': 触发的告警列表,
'stats': 统计信息
}
"""
if not self.is_initialized or not self.processor:
return {
'detections': detections,
'alerts': [],
'stats': {'error': '服务未初始化'}
}
try:
return self.processor.process(
detections=detections,
use_tracking=use_tracking,
track_id_key=track_id_key
)
except Exception as e:
logger.error(f"处理检测结果失败: {e}")
return {
'detections': detections,
'alerts': [],
'stats': {'error': str(e)}
}
def get_stationary_persons(self) -> List[Dict]:
"""获取所有静止人员"""
if not self.is_initialized or not self.processor:
return []
return self.processor.get_stationary_persons()
def get_loitering_persons(self) -> List[Dict]:
"""获取所有徘徊人员"""
if not self.is_initialized or not self.processor:
return []
return self.processor.get_loitering_persons()
def reset(self):
"""重置检测器"""
if self.processor:
self.processor.reset()
logger.info("徘徊检测器已重置")
def get_config(self) -> Dict:
"""获取当前配置"""
if not self.is_initialized or not self.processor:
return {'error': '服务未初始化'}
return self.processor.get_config()
def get_stats(self) -> Dict:
"""获取统计信息"""
if not self.is_initialized or not self.processor:
return {'error': '服务未初始化'}
stats = {
'stationary_count': len(self.get_stationary_persons()),
'loitering_count': len(self.get_loitering_persons()),
'config': self.get_config()
}
return stats
# 全局服务实例
_loitering_service: Optional[LoiteringService] = None
def get_loitering_service() -> LoiteringService:
"""获取全局徘徊检测服务实例"""
global _loitering_service
if _loitering_service is None:
_loitering_service = LoiteringService()
return _loitering_service
def initialize_loitering_service(**kwargs):
"""初始化全局徘徊检测服务"""
service = get_loitering_service()
service.initialize(**kwargs)
return service

View File

@@ -82,7 +82,6 @@ class ModelService:
return None return None
if model_id in self.models: if model_id in self.models:
logger.info(f"模型已加载: {model_id}")
return self.models[model_id] return self.models[model_id]
config = self.model_configs[model_id] config = self.model_configs[model_id]

View File

@@ -9,12 +9,22 @@ export const detectionApi = {
getModels() { getModels() {
return api.get('/models') return api.get('/models')
}, },
detectImage(formData) { getAlgorithmConfig() {
return api.get('/algorithms/config')
},
detectImage(formData, algorithmConfig = null) {
const params = {}
if (algorithmConfig) {
params.algorithm_config = JSON.stringify(algorithmConfig)
}
return api.post('/detect/image', formData, { return api.post('/detect/image', formData, {
headers: { headers: {
'Content-Type': 'multipart/form-data' 'Content-Type': 'multipart/form-data'
} },
params
}) })
} }
} }

View File

@@ -0,0 +1,326 @@
<template>
<div v-if="showConfig" class="algorithm-config">
<el-divider content-position="left">
<el-icon><Cpu /></el-icon>
<span style="margin-left: 8px;">行为分析算法</span>
</el-divider>
<div v-if="loading" class="loading-wrapper">
<el-skeleton :rows="3" animated />
</div>
<div v-else-if="algorithms.length === 0" class="empty-config">
<el-empty description="暂无可配置算法" :image-size="60" />
</div>
<div v-else class="algorithm-list">
<div
v-for="algo in algorithms"
:key="algo.id"
class="algorithm-item"
>
<div class="algorithm-header">
<el-switch
v-model="config[algo.id].enabled"
@change="onConfigChange"
:active-text="algo.name"
/>
<el-tooltip :content="algo.description" placement="top">
<el-icon class="info-icon"><InfoFilled /></el-icon>
</el-tooltip>
</div>
<div v-if="config[algo.id].enabled" class="algorithm-params">
<div
v-for="param in algo.params"
:key="param.name"
class="param-item"
>
<div class="param-label">
{{ param.label }}
<el-tooltip :content="param.description" placement="top">
<el-icon class="help-icon"><QuestionFilled /></el-icon>
</el-tooltip>
</div>
<div class="param-control">
<el-slider
v-if="param.type === 'number'"
v-model="config[algo.id].params[param.name]"
:min="param.min"
:max="param.max"
:step="param.name.includes('threshold') ? 1 : 1"
@change="onConfigChange"
show-input
:show-input-controls="false"
input-size="small"
/>
</div>
<div class="param-unit">{{ param.unit }}</div>
</div>
</div>
</div>
</div>
<div class="config-actions">
<el-button size="small" @click="resetConfig" :icon="RefreshRight">
重置
</el-button>
<el-button type="primary" size="small" @click="applyConfig" :icon="Check">
应用
</el-button>
</div>
</div>
</template>
<script setup>
import { ref, reactive, onMounted, watch, computed } from 'vue'
import { ElMessage } from 'element-plus'
import {
Cpu,
InfoFilled,
QuestionFilled,
RefreshRight,
Check
} from '@element-plus/icons-vue'
import { detectionApi } from '@/api/detection'
const props = defineProps({
modelValue: {
type: Object,
default: () => ({})
},
modelId: {
type: String,
default: ''
}
})
// 支持行为分析的模型列表
const SUPPORTED_MODELS = [
'loitering_detection', // 徘徊检测
'crowd_detection', // 人群检测
'person_detection' // 人员检测
]
// 是否显示配置
const showConfig = computed(() => {
return SUPPORTED_MODELS.some(model => props.modelId.includes(model))
})
const emit = defineEmits(['update:modelValue', 'change'])
const loading = ref(false)
const algorithms = ref([])
const config = reactive({})
// 获取算法配置
const fetchAlgorithmConfig = async () => {
loading.value = true
try {
const response = await detectionApi.getAlgorithmConfig()
algorithms.value = response.data.algorithms || []
// 初始化配置
algorithms.value.forEach(algo => {
if (!config[algo.id]) {
config[algo.id] = {
enabled: false,
params: {}
}
}
// 设置默认参数
algo.params.forEach(param => {
if (config[algo.id].params[param.name] === undefined) {
config[algo.id].params[param.name] = param.default
}
})
})
} catch (error) {
console.error('获取算法配置失败:', error)
} finally {
loading.value = false
}
}
// 生成后端需要的配置格式
const generateConfig = () => {
const result = {}
algorithms.value.forEach(algo => {
const algoConfig = config[algo.id]
if (algoConfig && algoConfig.enabled) {
// 根据算法类型设置启用标志
if (algo.id === 'stationary_detection') {
result.enable_stationary_detection = true
} else if (algo.id === 'loitering_detection') {
result.enable_loitering_detection = true
}
// 添加参数
Object.entries(algoConfig.params).forEach(([key, value]) => {
result[key] = value
})
}
})
return result
}
// 配置变化
const onConfigChange = () => {
const backendConfig = generateConfig()
emit('update:modelValue', backendConfig)
emit('change', backendConfig)
}
// 重置配置
const resetConfig = () => {
algorithms.value.forEach(algo => {
config[algo.id] = {
enabled: false,
params: {}
}
algo.params.forEach(param => {
config[algo.id].params[param.name] = param.default
})
})
onConfigChange()
ElMessage.success('配置已重置')
}
// 应用配置
const applyConfig = () => {
onConfigChange()
ElMessage.success('配置已应用')
}
// 监听外部配置变化
watch(() => props.modelValue, (newVal) => {
if (newVal && Object.keys(newVal).length > 0) {
// 根据外部配置更新内部状态
if (newVal.enable_stationary_detection) {
config['stationary_detection'].enabled = true
}
if (newVal.enable_loitering_detection) {
config['loitering_detection'].enabled = true
}
// 更新参数
Object.entries(newVal).forEach(([key, value]) => {
algorithms.value.forEach(algo => {
if (config[algo.id].params[key] !== undefined) {
config[algo.id].params[key] = value
}
})
})
}
}, { deep: true })
onMounted(() => {
fetchAlgorithmConfig()
})
</script>
<style scoped>
.algorithm-config {
margin-top: 16px;
}
.loading-wrapper {
padding: 20px;
}
.empty-config {
padding: 20px;
}
.algorithm-list {
display: flex;
flex-direction: column;
gap: 16px;
}
.algorithm-item {
border: 1px solid #e4e7ed;
border-radius: 8px;
padding: 12px;
background: #fafafa;
}
.algorithm-header {
display: flex;
align-items: center;
gap: 8px;
}
.algorithm-header :deep(.el-switch__label) {
font-weight: 500;
}
.info-icon {
color: #909399;
cursor: help;
font-size: 14px;
}
.algorithm-params {
margin-top: 12px;
padding-top: 12px;
border-top: 1px dashed #dcdfe6;
}
.param-item {
margin-bottom: 12px;
}
.param-item:last-child {
margin-bottom: 0;
}
.param-label {
font-size: 13px;
color: #606266;
margin-bottom: 8px;
display: flex;
align-items: center;
gap: 4px;
}
.help-icon {
color: #c0c4cc;
cursor: help;
font-size: 12px;
}
.param-control {
display: flex;
align-items: center;
gap: 8px;
}
.param-control :deep(.el-slider) {
flex: 1;
}
.param-control :deep(.el-slider__input) {
width: 60px;
}
.param-unit {
font-size: 12px;
color: #909399;
min-width: 40px;
}
.config-actions {
display: flex;
justify-content: flex-end;
gap: 8px;
margin-top: 16px;
padding-top: 16px;
border-top: 1px solid #e4e7ed;
}
</style>

View File

@@ -93,6 +93,13 @@
<div class="slider-value">{{ config.iou.toFixed(2) }}</div> <div class="slider-value">{{ config.iou.toFixed(2) }}</div>
</el-form-item> </el-form-item>
<!-- 算法配置仅对人员检测模型显示 -->
<AlgorithmConfig
v-model="config.algorithmConfig"
@change="onAlgorithmChange"
:model-id="config.model"
/>
</el-form> </el-form>
</el-card> </el-card>
</div> </div>
@@ -225,6 +232,7 @@ import {
QuestionFilled QuestionFilled
} from '@element-plus/icons-vue' } from '@element-plus/icons-vue'
import { detectionApi } from '@/api/detection' import { detectionApi } from '@/api/detection'
import AlgorithmConfig from './AlgorithmConfig.vue'
const props = defineProps({ const props = defineProps({
models: { models: {
@@ -236,7 +244,8 @@ const props = defineProps({
const config = ref({ const config = ref({
model: props.models.length > 0 ? props.models[0].id : 'fire_detection', model: props.models.length > 0 ? props.models[0].id : 'fire_detection',
confidence: 0.5, confidence: 0.5,
iou: 0.45 iou: 0.45,
algorithmConfig: {}
}) })
// 可拖拽调整宽度相关 // 可拖拽调整宽度相关
@@ -271,7 +280,20 @@ const originalImage = ref('')
const resultImage = ref('') const resultImage = ref('')
const detections = ref([]) const detections = ref([])
const stats = ref(null) const stats = ref(null)
const uploadUrl = computed(() => `/api/detect/image?model_id=${config.value.model}&confidence=${config.value.confidence}&iou=${config.value.iou}`) const uploadUrl = computed(() => {
const params = new URLSearchParams({
model_id: config.value.model,
confidence: config.value.confidence,
iou: config.value.iou
})
// 添加算法配置
if (config.value.algorithmConfig && Object.keys(config.value.algorithmConfig).length > 0) {
params.append('algorithm_config', JSON.stringify(config.value.algorithmConfig))
}
return `/api/detect/image?${params.toString()}`
})
const formatConfidence = (value) => { const formatConfidence = (value) => {
return `置信度: ${value.toFixed(2)}` return `置信度: ${value.toFixed(2)}`
@@ -303,6 +325,22 @@ const handleUploadSuccess = (response) => {
} }
detections.value = response.data.detections || [] detections.value = response.data.detections || []
stats.value = response.data.stats stats.value = response.data.stats
// 处理告警信息
if (response.data.alerts && response.data.alerts.length > 0) {
alerts.value = response.data.alerts
console.log('收到告警:', response.data.alerts)
// 显示告警通知
response.data.alerts.forEach(alert => {
ElMessage({
message: `行为告警: ${alert.type} - ${alert.message}`,
type: 'warning',
duration: 3000
})
})
}
ElMessage.success('检测完成') ElMessage.success('检测完成')
} else { } else {
ElMessage.error(response.message) ElMessage.error(response.message)
@@ -320,6 +358,10 @@ const modelName = computed(() => {
const model = props.models.find(m => m.id === config.value.model) const model = props.models.find(m => m.id === config.value.model)
return model ? model.name : config.value.model return model ? model.name : config.value.model
}) })
const onAlgorithmChange = (algoConfig) => {
config.value.algorithmConfig = algoConfig
}
</script> </script>
<style scoped> <style scoped>
@@ -624,6 +666,78 @@ const modelName = computed(() => {
color: #409eff; color: #409eff;
} }
/* 告警卡片 */
.alerts-card {
margin-bottom: 20px;
border: 2px solid #f56c6c;
}
.alerts-card .card-header {
display: flex;
justify-content: space-between;
align-items: center;
}
.alert-count {
margin-left: 8px;
}
.alerts-container {
max-height: 400px;
overflow-y: auto;
padding: 16px;
}
.alert-item {
background: #fef0f0;
border-left: 4px solid #f56c6c;
padding: 12px;
border-radius: 4px;
margin-bottom: 12px;
}
.alert-header {
display: flex;
align-items: center;
gap: 12px;
margin-bottom: 8px;
}
.alert-time {
font-size: 12px;
color: #909399;
}
.alert-detail {
display: flex;
align-items: center;
gap: 12px;
margin-bottom: 6px;
}
.alert-message {
font-size: 14px;
color: #f56c6c;
font-weight: 500;
}
.alert-duration {
font-size: 13px;
color: #606266;
background: #fff;
padding: 2px 8px;
border-radius: 4px;
}
.alert-bbox {
font-size: 12px;
color: #606266;
background: #fff;
padding: 4px 8px;
border-radius: 4px;
display: inline-block;
}
/* 响应式布局 */ /* 响应式布局 */
@media (max-width: 768px) { @media (max-width: 768px) {
.image-detection-container { .image-detection-container {

View File

@@ -95,6 +95,13 @@
/> />
<div class="slider-value">{{ config.iou.toFixed(2) }}</div> <div class="slider-value">{{ config.iou.toFixed(2) }}</div>
</el-form-item> </el-form-item>
<!-- 算法配置仅对人员检测模型显示 -->
<AlgorithmConfig
v-model="config.algorithmConfig"
@change="onAlgorithmChange"
:model-id="config.model"
/>
</el-form> </el-form>
</el-card> </el-card>
</div> </div>
@@ -186,6 +193,40 @@
</el-col> </el-col>
</el-row> </el-row>
<!-- 行为告警 -->
<el-card v-if="alerts && alerts.length > 0" class="alerts-card" shadow="hover">
<template #header>
<div class="card-header">
<div class="header-left">
<el-icon class="header-icon"><Warning /></el-icon>
<span>行为告警</span>
<el-tag size="small" type="danger" class="alert-count">{{ alerts.length }} </el-tag>
</div>
</div>
</template>
<div class="alerts-container">
<div
v-for="(alert, index) in alerts"
:key="index"
class="alert-item"
>
<div class="alert-header">
<el-tag :type="alert.type === 'stationary' ? 'warning' : 'danger'" size="small">
{{ alert.type === 'stationary' ? '静止' : '徘徊' }}
</el-tag>
<span class="alert-time">{{ new Date(alert.timestamp * 1000).toLocaleTimeString('zh-CN') }}</span>
</div>
<div class="alert-detail">
<span class="alert-message">{{ alert.message }}</span>
<span v-if="alert.duration" class="alert-duration">持续: {{ alert.duration.toFixed(1) }}s</span>
</div>
<div v-if="alert.bbox" class="alert-bbox">
<code class="bbox-code">[{{ alert.bbox.join(', ') }}]</code>
</div>
</div>
</div>
</el-card>
<!-- 检测详情 --> <!-- 检测详情 -->
<el-card v-if="detections.length > 0" class="details-card" shadow="hover"> <el-card v-if="detections.length > 0" class="details-card" shadow="hover">
<template #header> <template #header>
@@ -273,6 +314,7 @@ import {
Timer, Timer,
Delete Delete
} from '@element-plus/icons-vue' } from '@element-plus/icons-vue'
import AlgorithmConfig from './AlgorithmConfig.vue'
const props = defineProps({ const props = defineProps({
models: { models: {
@@ -284,7 +326,8 @@ const props = defineProps({
const config = ref({ const config = ref({
model: props.models.length > 0 ? props.models[0].id : 'fire_detection', model: props.models.length > 0 ? props.models[0].id : 'fire_detection',
confidence: 0.5, confidence: 0.5,
iou: 0.45 iou: 0.45,
algorithmConfig: {}
}) })
// 可拖拽调整宽度相关 // 可拖拽调整宽度相关
@@ -321,6 +364,7 @@ const originalCameraFrame = ref('')
const resultCameraFrame = ref('') const resultCameraFrame = ref('')
const detections = ref([]) const detections = ref([])
const stats = ref(null) const stats = ref(null)
const alerts = ref([])
const websocket = ref(null) const websocket = ref(null)
// 检测日志 // 检测日志
@@ -387,14 +431,21 @@ const startCamera = async () => {
cameraConnected.value = true cameraConnected.value = true
cameraStarting.value = false cameraStarting.value = false
websocket.value.send(JSON.stringify({ const startConfig = {
action: 'start', action: 'start',
config: { config: {
model_id: config.value.model, model_id: config.value.model,
confidence: config.value.confidence, confidence: config.value.confidence,
iou: config.value.iou iou: config.value.iou
} }
})) }
// 添加算法配置
if (config.value.algorithmConfig && Object.keys(config.value.algorithmConfig).length > 0) {
startConfig.config.algorithm_config = config.value.algorithmConfig
}
websocket.value.send(JSON.stringify(startConfig))
} }
websocket.value.onmessage = (event) => { websocket.value.onmessage = (event) => {
@@ -450,14 +501,29 @@ const stopCamera = () => {
const updateCameraConfig = () => { const updateCameraConfig = () => {
if (websocket.value && cameraConnected.value) { if (websocket.value && cameraConnected.value) {
websocket.value.send(JSON.stringify({ const wsConfig = {
action: 'update_config', action: 'update_config',
config: { config: {
model_id: config.value.model, model_id: config.value.model,
confidence: config.value.confidence, confidence: config.value.confidence,
iou: config.value.iou iou: config.value.iou
} }
})) }
// 添加算法配置
if (config.value.algorithmConfig && Object.keys(config.value.algorithmConfig).length > 0) {
wsConfig.config.algorithm_config = config.value.algorithmConfig
}
websocket.value.send(JSON.stringify(wsConfig))
}
}
const onAlgorithmChange = (algoConfig) => {
config.value.algorithmConfig = algoConfig
// 如果摄像头已连接,实时更新配置
if (websocket.value && cameraConnected.value) {
updateCameraConfig()
} }
} }
@@ -847,6 +913,78 @@ onUnmounted(() => {
padding: 12px; padding: 12px;
} }
/* 告警卡片 */
.alerts-card {
margin-bottom: 20px;
border: 2px solid #f56c6c;
}
.alerts-card .card-header {
display: flex;
justify-content: space-between;
align-items: center;
}
.alert-count {
margin-left: 8px;
}
.alerts-container {
max-height: 400px;
overflow-y: auto;
padding: 16px;
}
.alert-item {
background: #fef0f0;
border-left: 4px solid #f56c6c;
padding: 12px;
border-radius: 4px;
margin-bottom: 12px;
}
.alert-header {
display: flex;
align-items: center;
gap: 12px;
margin-bottom: 8px;
}
.alert-time {
font-size: 12px;
color: #909399;
}
.alert-detail {
display: flex;
align-items: center;
gap: 12px;
margin-bottom: 6px;
}
.alert-message {
font-size: 14px;
color: #f56c6c;
font-weight: 500;
}
.alert-duration {
font-size: 13px;
color: #606266;
background: #fff;
padding: 2px 8px;
border-radius: 4px;
}
.alert-bbox {
font-size: 12px;
color: #606266;
background: #fff;
padding: 4px 8px;
border-radius: 4px;
display: inline-block;
}
/* 响应式布局 */ /* 响应式布局 */
@media (max-width: 768px) { @media (max-width: 768px) {
.video-detection-container { .video-detection-container {

View File

@@ -38,6 +38,169 @@ cp /path/to/behavior_detection/Loitering-Detection/yolov8n.pt models/loitering_d
bash scripts/setup-models.sh bash scripts/setup-models.sh
``` ```
## 算法目录规范
当需要在检测模型基础上添加自定义算法逻辑时,请在对应模型目录下创建以下子目录:
### `algorithms/` - 独立算法模块
用于存放独立的算法实现,如密度估计、流动分析、行为识别等。
```
crowd_detection/
├── yolov8l.pt
└── algorithms/
├── __init__.py
├── density_estimator.py # 人群密度估计算法
├── flow_analysis.py # 人群流动分析算法
├── anomaly_detector.py # 异常行为检测算法
└── crowd_counting.py # 人群计数算法
```
**示例代码:**
```python
# crowd_detection/algorithms/density_estimator.py
import numpy as np
class DensityEstimator:
"""人群密度估计器"""
def __init__(self, grid_size=10):
self.grid_size = grid_size
def estimate(self, detections, image_shape):
"""
根据检测结果估计人群密度
Args:
detections: YOLO检测结果 [(x1,y1,x2,y2,conf,cls), ...]
image_shape: (height, width)
Returns:
density_map: 密度热力图
"""
h, w = image_shape
grid_h, grid_w = h // self.grid_size, w // self.grid_size
density_map = np.zeros((grid_h, grid_w))
for det in detections:
cx, cy = int((det[0] + det[2]) / 2), int((det[1] + det[3]) / 2)
grid_x, grid_y = cx // self.grid_size, cy // self.grid_size
if 0 <= grid_x < grid_w and 0 <= grid_y < grid_h:
density_map[grid_y, grid_x] += 1
return density_map
def is_crowded(self, density_map, threshold=5):
"""判断是否拥挤"""
return np.max(density_map) > threshold
```
### `processors/` - 检测结果二次处理
用于对模型检测结果进行后处理,如过滤、分析、告警判断等。
```
crowd_detection/
├── yolov8l.pt
└── processors/
├── __init__.py
├── post_processor.py # 检测结果后处理
├── crowd_analyzer.py # 人群分析器
└── alert_rules.py # 告警规则判断
```
**示例代码:**
```python
# crowd_detection/processors/alert_rules.py
from datetime import datetime, timedelta
class CrowdAlertRules:
"""人群检测告警规则"""
def __init__(self):
self.alert_history = []
self.cooldown_minutes = 5
def check_crowd_gathering(self, person_count, density, duration_seconds):
"""
检查是否触发人群聚集告警
Args:
person_count: 检测到的人数
density: 人群密度值
duration_seconds: 持续时长
Returns:
alert: 告警信息或None
"""
# 规则:人数>20 且 密度>0.5 且 持续>30秒
if person_count > 20 and density > 0.5 and duration_seconds > 30:
if self._can_trigger_alert("crowd_gathering"):
alert = {
"type": "crowd_gathering",
"level": "high" if person_count > 50 else "medium",
"message": f"检测到人群聚集,人数: {person_count}",
"timestamp": datetime.now()
}
self._record_alert("crowd_gathering")
return alert
return None
def check_intrusion(self, detections, restricted_zones):
"""
检查是否触发区域入侵告警
Args:
detections: 检测结果
restricted_zones: 限制区域列表 [(x1,y1,x2,y2), ...]
"""
alerts = []
for det in detections:
person_box = (det[0], det[1], det[2], det[3])
for zone in restricted_zones:
if self._is_intersect(person_box, zone):
alerts.append({
"type": "zone_intrusion",
"level": "high",
"message": "检测到人员进入限制区域"
})
return alerts
def _can_trigger_alert(self, alert_type):
"""检查是否可以通过冷却期触发告警"""
cutoff_time = datetime.now() - timedelta(minutes=self.cooldown_minutes)
recent_alerts = [
a for a in self.alert_history
if a["type"] == alert_type and a["timestamp"] > cutoff_time
]
return len(recent_alerts) == 0
def _record_alert(self, alert_type):
"""记录告警历史"""
self.alert_history.append({
"type": alert_type,
"timestamp": datetime.now()
})
@staticmethod
def _is_intersect(box1, box2):
"""判断两个框是否相交"""
x1 = max(box1[0], box2[0])
y1 = max(box1[1], box2[1])
x2 = min(box1[2], box2[2])
y2 = min(box1[3], box2[3])
return x1 < x2 and y1 < y2
```
### 目录选择建议
| 场景 | 推荐目录 | 说明 |
|------|----------|------|
| 实现新的检测算法(如密度估计、行为识别) | `algorithms/` | 独立的算法逻辑,可复用 |
| 对检测结果进行过滤、分析 | `processors/` | 针对业务场景的后处理 |
| 简单的工具函数 | `utils/` | 辅助函数,无状态逻辑 |
## 注意 ## 注意
模型文件较大,未包含在 Git 仓库中。请从原始位置复制或创建符号链接。 模型文件较大,未包含在 Git 仓库中。请从原始位置复制或创建符号链接。

View File

@@ -0,0 +1,9 @@
"""
徘徊检测算法模块
包含基于位置和基于跟踪ID的检测算法
"""
from .stationary_detector import PositionBasedStationaryDetector
from .loitering_detector import LoiteringDetector
__all__ = ['PositionBasedStationaryDetector', 'LoiteringDetector']

View File

@@ -0,0 +1,251 @@
"""
基于跟踪ID的徘徊检测算法
依赖跟踪ID适用于跟踪稳定的场景
"""
import time
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass, field
from collections import defaultdict
@dataclass
class PersonTrack:
"""人员跟踪记录"""
person_id: int
first_seen: float
last_seen: float
positions: List[Tuple[int, int]] = field(default_factory=list)
last_position: Optional[Tuple[int, int]] = None
stationary_start: Optional[float] = None
total_duration: float = 0.0
stationary_duration: float = 0.0
class LoiteringDetector:
"""
徘徊检测器基于跟踪ID
特点:
- 依赖跟踪 ID需要稳定的跟踪器
- 可以检测长时间停留(徘徊)
- 可以检测静止不动(静止)
"""
def __init__(
self,
loitering_threshold: float = 300.0, # 徘徊阈值默认5分钟
stationary_threshold: float = 2.0, # 静止阈值(秒)
movement_threshold: float = 5.0, # 移动阈值(像素)
cleanup_interval: float = 10.0 # 清理间隔(秒)
):
self.loitering_threshold = loitering_threshold
self.stationary_threshold = stationary_threshold
self.movement_threshold = movement_threshold
self.cleanup_interval = cleanup_interval
# 跟踪记录: {person_id: PersonTrack}
self._tracks: Dict[int, PersonTrack] = {}
self._last_cleanup = time.time()
def _cleanup_old_tracks(self, max_age: float = 60.0) -> int:
"""清理长时间未更新的跟踪记录"""
current_time = time.time()
to_remove = [
pid for pid, track in self._tracks.items()
if current_time - track.last_seen > max_age
]
for pid in to_remove:
del self._tracks[pid]
return len(to_remove)
def update(
self,
person_id: int,
position: Tuple[int, int]
) -> Tuple[bool, float, bool, float]:
"""
更新人员位置
Args:
person_id: 人员ID
position: (x, y) 中心点坐标
Returns:
is_loitering: 是否徘徊超过阈值
loitering_duration: 徘徊时长(秒)
is_stationary: 是否静止超过阈值
stationary_duration: 静止时长(秒)
"""
current_time = time.time()
# 定期清理
if current_time - self._last_cleanup > self.cleanup_interval:
self._cleanup_old_tracks()
self._last_cleanup = current_time
# 获取或创建跟踪记录
if person_id not in self._tracks:
self._tracks[person_id] = PersonTrack(
person_id=person_id,
first_seen=current_time,
last_seen=current_time,
last_position=position
)
return False, 0.0, False, 0.0
track = self._tracks[person_id]
track.last_seen = current_time
track.positions.append(position)
# 计算总停留时长
track.total_duration = current_time - track.first_seen
# 检查是否移动
is_moving = False
if track.last_position is not None:
distance = ((position[0] - track.last_position[0]) ** 2 +
(position[1] - track.last_position[1]) ** 2) ** 0.5
is_moving = distance > self.movement_threshold
track.last_position = position
# 更新静止状态
if is_moving:
# 如果移动了,重置静止计时
track.stationary_start = None
track.stationary_duration = 0.0
else:
# 如果没移动,更新静止时长
if track.stationary_start is None:
track.stationary_start = current_time
track.stationary_duration = current_time - track.stationary_start
# 判断是否徘徊/静止
is_loitering = track.total_duration > self.loitering_threshold
is_stationary = track.stationary_duration > self.stationary_threshold
return (
is_loitering,
track.total_duration,
is_stationary,
track.stationary_duration
)
def detect(
self,
detections: List[Dict],
id_key: str = 'track_id'
) -> List[Dict]:
"""
批量检测徘徊状态
Args:
detections: 检测结果列表,每项包含 'bbox' 和 track_id
id_key: 跟踪ID的字段名
Returns:
添加 'loitering_info' 字段的检测结果
"""
results = []
for det in detections:
person_id = det.get(id_key)
if person_id is None:
results.append(det)
continue
x1, y1, x2, y2 = det['bbox']
center = ((x1 + x2) // 2, (y1 + y2) // 2)
is_loitering, loitering_duration, is_stationary, stationary_duration = \
self.update(person_id, center)
det_copy = det.copy()
det_copy['loitering_info'] = {
'person_id': person_id,
'is_loitering': is_loitering,
'loitering_duration': round(loitering_duration, 2),
'is_stationary': is_stationary,
'stationary_duration': round(stationary_duration, 2),
'loitering_threshold': self.loitering_threshold,
'stationary_threshold': self.stationary_threshold
}
results.append(det_copy)
return results
def get_all_loitering(
self,
threshold: Optional[float] = None
) -> List[Dict]:
"""
获取所有徘徊超过阈值的人员
Args:
threshold: 徘徊阈值(秒),默认使用初始化时的阈值
Returns:
list: [{person_id, duration, positions}, ...]
"""
threshold = threshold or self.loitering_threshold
result = []
for person_id, track in self._tracks.items():
if track.total_duration > threshold:
result.append({
'person_id': person_id,
'duration': track.total_duration,
'positions': track.positions.copy(),
'is_stationary': track.stationary_duration > self.stationary_threshold,
'stationary_duration': track.stationary_duration
})
# 按时长排序
result.sort(key=lambda x: x['duration'], reverse=True)
return result
def get_all_stationary(
self,
threshold: Optional[float] = None
) -> List[Dict]:
"""
获取所有静止超过阈值的人员
Args:
threshold: 静止阈值(秒),默认使用初始化时的阈值
Returns:
list: [{person_id, duration, position}, ...]
"""
threshold = threshold or self.stationary_threshold
result = []
for person_id, track in self._tracks.items():
if track.stationary_duration > threshold:
result.append({
'person_id': person_id,
'duration': track.stationary_duration,
'position': track.last_position,
'total_duration': track.total_duration
})
result.sort(key=lambda x: x['duration'], reverse=True)
return result
def reset(self):
"""重置所有跟踪数据"""
self._tracks.clear()
self._last_cleanup = time.time()
def get_stats(self) -> Dict:
"""获取统计信息"""
return {
'total_tracks': len(self._tracks),
'loitering_count': len(self.get_all_loitering()),
'stationary_count': len(self.get_all_stationary()),
'loitering_threshold': self.loitering_threshold,
'stationary_threshold': self.stationary_threshold
}

View File

@@ -0,0 +1,236 @@
"""
基于位置的静止人员检测算法
不依赖跟踪 ID而是根据位置来关联人员
适用于跟踪不稳定但人员相对静止的场景
"""
import time
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass, field
from collections import defaultdict
@dataclass
class PositionRecord:
"""位置记录"""
first_seen: float
last_seen: float
center: Tuple[int, int]
box: Tuple[int, int, int, int]
duration: float = 0.0
class PositionBasedStationaryDetector:
"""
基于位置的静止检测器
特点:
- 不依赖跟踪 ID直接用位置关联人员
- 适用于 SORT 等跟踪器不稳定的场景
- 使用网格化位置 + 距离容差进行匹配
"""
def __init__(
self,
stationary_threshold: float = 10.0, # 静止阈值(秒)
position_tolerance: int = 50, # 位置容差(像素)
cleanup_interval: float = 5.0 # 清理间隔(秒)
):
self.stationary_threshold = stationary_threshold
self.position_tolerance = position_tolerance
self.cleanup_interval = cleanup_interval
# 位置历史记录: {position_key: PositionRecord}
self._position_history: Dict[Tuple[int, int], PositionRecord] = {}
self._last_cleanup = time.time()
def _get_position_key(self, center: Tuple[int, int]) -> Tuple[int, int]:
"""
将连续坐标转换为离散的位置键
用于将相近位置归为一类
"""
x, y = center
grid_x = int(x / self.position_tolerance)
grid_y = int(y / self.position_tolerance)
return (grid_x, grid_y)
def _find_matching_position(
self,
center: Tuple[int, int]
) -> Optional[Tuple[int, int]]:
"""
查找与当前位置匹配的历史位置
返回匹配的位置键,如果没有则返回 None
"""
current_key = self._get_position_key(center)
# 首先检查精确匹配
if current_key in self._position_history:
hist_center = self._position_history[current_key].center
distance = ((center[0] - hist_center[0]) ** 2 +
(center[1] - hist_center[1]) ** 2) ** 0.5
if distance < self.position_tolerance:
return current_key
# 检查相邻网格
for dx in [-1, 0, 1]:
for dy in [-1, 0, 1]:
if dx == 0 and dy == 0:
continue
neighbor_key = (current_key[0] + dx, current_key[1] + dy)
if neighbor_key in self._position_history:
hist_center = self._position_history[neighbor_key].center
distance = ((center[0] - hist_center[0]) ** 2 +
(center[1] - hist_center[1]) ** 2) ** 0.5
if distance < self.position_tolerance:
return neighbor_key
return None
def update(
self,
center: Tuple[int, int],
box: Tuple[int, int, int, int]
) -> Tuple[str, float, bool]:
"""
更新位置信息
Args:
center: (x, y) 中心点坐标
box: (x1, y1, x2, y2) 边界框
Returns:
position_id: 位置 ID用于关联
stationary_duration: 静止时长(秒)
is_stationary: 是否静止超过阈值
"""
current_time = time.time()
# 定期清理旧记录
if current_time - self._last_cleanup > self.cleanup_interval:
self.cleanup_old_positions()
self._last_cleanup = current_time
# 查找匹配的历史位置
matching_key = self._find_matching_position(center)
if matching_key is not None:
# 更新已有位置
record = self._position_history[matching_key]
record.last_seen = current_time
# 平滑更新中心位置(使用移动平均)
old_center = record.center
record.center = (
int(0.7 * old_center[0] + 0.3 * center[0]),
int(0.7 * old_center[1] + 0.3 * center[1])
)
record.box = box
duration = current_time - record.first_seen
record.duration = duration
is_stationary = duration > self.stationary_threshold
position_id = f"pos_{matching_key[0]}_{matching_key[1]}"
return position_id, duration, is_stationary
else:
# 创建新位置记录
new_key = self._get_position_key(center)
self._position_history[new_key] = PositionRecord(
first_seen=current_time,
last_seen=current_time,
center=center,
box=box,
duration=0.0
)
new_id = f"pos_{new_key[0]}_{new_key[1]}"
return new_id, 0.0, False
def cleanup_old_positions(self, max_age: float = 5.0) -> int:
"""
清理长时间未更新的位置记录
Args:
max_age: 最大保留时间(秒)
Returns:
清理的记录数量
"""
current_time = time.time()
to_remove = [
key for key, data in self._position_history.items()
if current_time - data.last_seen > max_age
]
for key in to_remove:
del self._position_history[key]
return len(to_remove)
def get_all_stationary(
self,
threshold: Optional[float] = None
) -> List[Dict]:
"""
获取所有静止超过阈值的位置
Args:
threshold: 静止阈值(秒),默认使用初始化时的阈值
Returns:
list: [{position_id, duration, center, box}, ...]
"""
threshold = threshold or self.stationary_threshold
result = []
for key, data in self._position_history.items():
if data.duration > threshold:
result.append({
'position_id': f"pos_{key[0]}_{key[1]}",
'duration': data.duration,
'center': data.center,
'box': data.box
})
# 按时长排序
result.sort(key=lambda x: x['duration'], reverse=True)
return result
def reset(self):
"""重置所有跟踪数据"""
self._position_history.clear()
self._last_cleanup = time.time()
def detect(
self,
detections: List[Dict]
) -> List[Dict]:
"""
批量检测静止状态
Args:
detections: 检测结果列表,每项包含 'bbox': [x1, y1, x2, y2]
Returns:
添加 'stationary_info' 字段的检测结果
"""
results = []
for det in detections:
x1, y1, x2, y2 = det['bbox']
center = ((x1 + x2) // 2, (y1 + y2) // 2)
box = (x1, y1, x2, y2)
position_id, duration, is_stationary = self.update(center, box)
det_copy = det.copy()
det_copy['stationary_info'] = {
'position_id': position_id,
'duration': round(duration, 2),
'is_stationary': is_stationary,
'threshold': self.stationary_threshold
}
results.append(det_copy)
return results

View File

@@ -0,0 +1,8 @@
"""
徘徊检测处理器模块
用于对检测结果进行后处理
"""
from .behavior_processor import BehaviorProcessor
__all__ = ['BehaviorProcessor']

View File

@@ -0,0 +1,201 @@
"""
行为检测处理器
集成基于位置和基于跟踪ID的检测算法
"""
import sys
import os
import logging
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
# 添加算法模块路径
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from algorithms import PositionBasedStationaryDetector, LoiteringDetector
logger = logging.getLogger(__name__)
@dataclass
class BehaviorAlert:
"""行为告警"""
alert_type: str # 'stationary', 'loitering'
level: str # 'low', 'medium', 'high'
message: str
person_id: Optional[str] = None
position_id: Optional[str] = None
duration: float = 0.0
bbox: Optional[Tuple[int, int, int, int]] = None
class BehaviorProcessor:
"""
行为检测处理器
整合两种检测方式:
1. 基于位置的静止检测无需跟踪ID
2. 基于跟踪ID的徘徊检测需要跟踪ID
"""
def __init__(
self,
# 静止检测参数
stationary_threshold: float = 10.0,
position_tolerance: int = 50,
# 徘徊检测参数
loitering_threshold: float = 300.0,
movement_threshold: float = 5.0,
# 告警参数
enable_stationary_alert: bool = True,
enable_loitering_alert: bool = True,
stationary_alert_threshold: float = 10.0, # 超过此时间产生告警
loitering_alert_threshold: float = 300.0 # 超过此时间产生告警
):
# 初始化检测器
self.stationary_detector = PositionBasedStationaryDetector(
stationary_threshold=stationary_threshold,
position_tolerance=position_tolerance
)
self.loitering_detector = LoiteringDetector(
loitering_threshold=loitering_threshold,
stationary_threshold=stationary_threshold,
movement_threshold=movement_threshold
)
# 配置
self.enable_stationary_alert = enable_stationary_alert
self.enable_loitering_alert = enable_loitering_alert
self.stationary_alert_threshold = stationary_alert_threshold
self.loitering_alert_threshold = loitering_alert_threshold
def process(
self,
detections: List[Dict],
use_tracking: bool = False,
track_id_key: str = 'track_id'
) -> Dict:
"""
处理检测结果,检测行为
Args:
detections: 检测结果列表
use_tracking: 是否使用跟踪ID如果有的话
track_id_key: 跟踪ID字段名
Returns:
{
'detections': 添加行为信息的检测结果,
'alerts': 触发的告警列表,
'stats': 统计信息
}
"""
logger.info(f"[BehaviorProcessor] 开始处理 {len(detections)} 个检测结果")
logger.info(f"[BehaviorProcessor] 配置: stationary={self.enable_stationary_alert}, loitering={self.enable_loitering_alert}")
alerts = []
# 1. 始终进行基于位置的静止检测
logger.info(f"[BehaviorProcessor] 调用静止检测器...")
detections = self.stationary_detector.detect(detections)
logger.info(f"[BehaviorProcessor] 静止检测完成,检测到 {len(detections)} 个结果")
# 检查静止告警
stationary_alerts = 0
if self.enable_stationary_alert:
for det in detections:
info = det.get('stationary_info', {})
if info.get('is_stationary') and info.get('duration', 0) >= self.stationary_alert_threshold:
alert = BehaviorAlert(
alert_type='stationary',
level='medium' if info['duration'] < 30 else 'high',
message=f"人员静止停留 {int(info['duration'])}",
position_id=info.get('position_id'),
duration=info['duration'],
bbox=tuple(det['bbox'])
)
alerts.append(alert)
stationary_alerts += 1
logger.info(f"[BehaviorProcessor] 静止告警: {stationary_alerts}")
# 2. 如果有跟踪ID进行徘徊检测
logger.info(f"[BehaviorProcessor] use_tracking={use_tracking}")
if use_tracking:
detections = self.loitering_detector.detect(detections, id_key=track_id_key)
# 检查徘徊告警
if self.enable_loitering_alert:
for det in detections:
info = det.get('loitering_info', {})
if info.get('is_loitering') and info.get('loitering_duration', 0) >= self.loitering_alert_threshold:
alert = BehaviorAlert(
alert_type='loitering',
level='high',
message=f"人员徘徊 {int(info['loitering_duration'] // 60)} 分钟",
person_id=str(info.get('person_id')),
duration=info['loitering_duration'],
bbox=tuple(det['bbox'])
)
alerts.append(alert)
# 统计信息
stats = {
'total_detections': len(detections),
'stationary_count': len(self.stationary_detector.get_all_stationary()),
'alert_count': len(alerts)
}
if use_tracking:
stats.update({
'loitering_count': len(self.loitering_detector.get_all_loitering()),
'tracking_count': self.loitering_detector.get_stats()['total_tracks']
})
logger.info(f"[BehaviorProcessor] 处理完成: {stats}")
return {
'detections': detections,
'alerts': [self._alert_to_dict(a) for a in alerts],
'stats': stats
}
def _alert_to_dict(self, alert: BehaviorAlert) -> Dict:
"""将告警对象转换为字典"""
return {
'type': alert.alert_type,
'level': alert.level,
'message': alert.message,
'person_id': alert.person_id,
'position_id': alert.position_id,
'duration': round(alert.duration, 2),
'bbox': alert.bbox
}
def get_stationary_persons(self) -> List[Dict]:
"""获取所有静止人员"""
return self.stationary_detector.get_all_stationary()
def get_loitering_persons(self) -> List[Dict]:
"""获取所有徘徊人员"""
return self.loitering_detector.get_all_loitering()
def reset(self):
"""重置所有检测器"""
self.stationary_detector.reset()
self.loitering_detector.reset()
def get_config(self) -> Dict:
"""获取当前配置"""
return {
'stationary_threshold': self.stationary_detector.stationary_threshold,
'position_tolerance': self.stationary_detector.position_tolerance,
'loitering_threshold': self.loitering_detector.loitering_threshold,
'movement_threshold': self.loitering_detector.movement_threshold,
'enable_stationary_alert': self.enable_stationary_alert,
'enable_loitering_alert': self.enable_loitering_alert,
'stationary_alert_threshold': self.stationary_alert_threshold,
'loitering_alert_threshold': self.loitering_alert_threshold
}