feat: 新增人员徘徊/静止行为分析功能

本次提交实现了完整的人员行为分析系统,包括:
1. 新增基于位置和跟踪ID的两种行为检测算法
2. 新增徘徊检测服务与行为处理器模块
3. 前后端集成算法配置界面与告警展示
4. 支持图片和视频流场景下的行为分析
5. 新增算法配置接口与文档说明

具体改动:
- 新增loitering_detection模型目录与算法实现
- 新增AlgorithmConfig组件实现可视化配置
- 扩展图片/视频检测接口支持算法参数传递
- 新增行为告警推送与前端展示页面
- 优化检测服务,集成行为分析逻辑
- 移除冗余日志输出,完善代码注释
This commit is contained in:
wwh
2026-05-19 09:17:09 +08:00
parent 2691761f01
commit 7aa71c5f83
15 changed files with 1937 additions and 76 deletions

View File

@@ -2,24 +2,50 @@ import cv2
import numpy as np
import base64
import logging
import json
from typing import Optional
from fastapi import APIRouter, UploadFile, File, Form, Query
from models.schemas import ImageDetectionResult
router = APIRouter()
logger = logging.getLogger(__name__)
@router.post("/detect/image", response_model=ImageDetectionResult)
async def detect_image(
file: UploadFile = File(...),
model_id: str = Query("fire_detection"),
confidence: float = Query(0.5),
iou: float = Query(0.45)
iou: float = Query(0.45),
algorithm_config: Optional[str] = Query(None, description="算法配置JSON字符串")
):
"""
图片检测接口
Args:
algorithm_config: 算法配置JSON例如
{
"enable_stationary_detection": true,
"enable_loitering_detection": false,
"stationary_threshold": 10.0,
"position_tolerance": 50,
"loitering_threshold": 300.0,
"movement_threshold": 5.0
}
"""
from main import model_service
from services.detection_service import DetectionService
detection_service = DetectionService(model_service)
# 解析算法配置
algo_config = None
if algorithm_config:
try:
algo_config = json.loads(algorithm_config)
except json.JSONDecodeError as e:
logger.warning(f"算法配置解析失败: {e}")
try:
contents = await file.read()
nparr = np.frombuffer(contents, np.uint8)
@@ -32,10 +58,14 @@ async def detect_image(
data={}
)
result = await detection_service.detect_image(frame, model_id, confidence, iou)
result = await detection_service.detect_image(
frame, model_id, confidence, iou, algorithm_config=algo_config
)
if result['success']:
annotated_frame = detection_service.draw_detections(frame, result['detections'])
annotated_frame = detection_service.draw_detections(
frame, result['detections'], algorithm_config=algo_config
)
# 将标注后的图片转换为 base64
_, buffer = cv2.imencode('.jpg', annotated_frame)
@@ -47,7 +77,9 @@ async def detect_image(
data={
"detections": result['detections'],
"image_base64": img_base64,
"stats": result['stats']
"stats": result['stats'],
"alerts": result.get('alerts', []),
"behavior_stats": result.get('behavior_stats', {})
}
)
else:
@@ -64,3 +96,66 @@ async def detect_image(
message=f"检测失败: {str(e)}",
data={}
)
@router.get("/algorithms/config")
async def get_algorithm_config():
"""获取算法配置选项"""
return {
"algorithms": [
{
"id": "stationary_detection",
"name": "静止检测",
"description": "检测人员在同一位置静止停留",
"params": [
{
"name": "stationary_threshold",
"label": "静止阈值",
"type": "number",
"default": 10.0,
"min": 1.0,
"max": 300.0,
"unit": "",
"description": "超过此时间视为静止"
},
{
"name": "position_tolerance",
"label": "位置容差",
"type": "number",
"default": 50,
"min": 10,
"max": 200,
"unit": "像素",
"description": "位置匹配容差范围"
}
]
},
{
"id": "loitering_detection",
"name": "徘徊检测",
"description": "检测人员长时间停留需要跟踪ID",
"params": [
{
"name": "loitering_threshold",
"label": "徘徊阈值",
"type": "number",
"default": 300.0,
"min": 60.0,
"max": 1800.0,
"unit": "",
"description": "超过此时间视为徘徊"
},
{
"name": "movement_threshold",
"label": "移动阈值",
"type": "number",
"default": 5.0,
"min": 1.0,
"max": 50.0,
"unit": "像素",
"description": "小于此移动视为静止"
}
]
}
]
}

View File

@@ -249,11 +249,21 @@ class CameraService:
logger.info(f"发送检测结果: {len(result['detections'])} 个目标, {result['stats']}")
await websocket.send_json({
detection_message = {
'type': 'detection',
'detections': result['detections'],
'stats': result['stats']
})
}
# 包含行为告警信息
if 'alerts' in result and result['alerts']:
detection_message['alerts'] = result['alerts']
logger.info(f"发送告警: {len(result['alerts'])}")
if 'behavior_stats' in result:
detection_message['behavior_stats'] = result['behavior_stats']
await websocket.send_json(detection_message)
_, buffer = cv2.imencode('.jpg', frame, [cv2.IMWRITE_JPEG_QUALITY, 80])
import base64

View File

@@ -7,6 +7,8 @@ import logging
from typing import Dict, List, Optional
from PIL import Image, ImageDraw, ImageFont
from .loitering_service import get_loitering_service
logger = logging.getLogger(__name__)
class DetectionService:
@@ -19,60 +21,16 @@ class DetectionService:
os.makedirs(self.results_dir, exist_ok=True)
os.makedirs(self.temp_dir, exist_ok=True)
def draw_detections(self, frame: np.ndarray, detections: List[Dict], fps: float = 0) -> np.ndarray:
try:
img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_img = Image.fromarray(img_rgb)
draw = ImageDraw.Draw(pil_img)
try:
font = ImageFont.truetype("/System/Library/Fonts/PingFang.ttc", 20)
font_large = ImageFont.truetype("/System/Library/Fonts/PingFang.ttc", 24)
except:
font = ImageFont.load_default()
font_large = font
class_colors = {
'Fire': (255, 0, 0),
'Smoke': (128, 128, 128),
'person': (0, 255, 0),
'helmet': (255, 255, 0),
'no_helmet': (255, 0, 255),
'cigarette': (0, 165, 255) # 橙色,用于抽烟检测
}
for det in detections:
x1, y1, x2, y2 = det['bbox']
class_name = det['class']
conf = det['confidence']
label = det['label']
color = class_colors.get(class_name, (0, 255, 0))
draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
label_text = f"{label} {conf:.2f}"
bbox = draw.textbbox((0, 0), label_text, font=font)
text_w = bbox[2] - bbox[0]
text_h = bbox[3] - bbox[1]
draw.rectangle([x1, y1 - text_h - 4, x1 + text_w + 4, y1], fill=color)
draw.text((x1 + 2, y1 - text_h - 2), label_text, fill=(255, 255, 255), font=font)
if fps > 0:
fps_text = f"FPS: {fps:.1f} | Detections: {len(detections)}"
draw.text((10, 10), fps_text, fill=(0, 255, 0), font=font)
return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
except Exception as e:
logger.error(f"绘制检测结果失败: {e}")
return frame
# 初始化徘徊检测服务(懒加载,实际初始化在第一次使用时)
self.loitering_service = get_loitering_service()
async def detect_image(
self,
image: np.ndarray,
model_id: str,
confidence: float = 0.5,
iou: float = 0.45
iou: float = 0.45,
algorithm_config: Optional[Dict] = None
) -> Dict:
start_time = time.time()
@@ -110,7 +68,7 @@ class DetectionService:
processing_time = time.time() - start_time
avg_confidence = sum(d['confidence'] for d in detections) / len(detections) if detections else 0
return {
result_data = {
'success': True,
'message': '检测完成',
'detections': detections,
@@ -121,6 +79,14 @@ class DetectionService:
'model_used': model_id
}
}
# 如果启用了行为检测算法
if algorithm_config and detections:
result_data = self._apply_behavior_analysis(
result_data, algorithm_config
)
return result_data
except Exception as e:
logger.error(f"图片检测失败: {e}")
return {
@@ -186,6 +152,37 @@ class DetectionService:
}
}
# 如果是人员检测模型,进行行为分析
logger.info(f"[DetectionService] 模型: {model_id}, 检测目标: {len(detections)}")
if model_id == 'loitering_detection' and detections:
logger.info("[DetectionService] 调用行为分析...")
# 确保服务已初始化
if not self.loitering_service.is_initialized:
logger.info("[DetectionService] 初始化徘徊检测服务...")
self.loitering_service.initialize(
# 检测阈值(用于判断是否静止/徘徊)
stationary_threshold=10.0,
position_tolerance=50,
loitering_threshold=300.0,
movement_threshold=5.0,
# 告警阈值(用于触发告警,应该比检测阈值高)
stationary_alert_threshold=30.0,
loitering_alert_threshold=600.0,
# 启用告警
enable_stationary_alert=True,
enable_loitering_alert=True
)
behavior_result = self.loitering_service.process_detections(
detections,
use_tracking=False # 可以改为 True 如果使用跟踪
)
detections = behavior_result['detections']
result_data['alerts'] = behavior_result['alerts']
result_data['behavior_stats'] = behavior_result['stats']
logger.info(f"[DetectionService] 行为分析完成: alerts={len(behavior_result['alerts'])}, stats={behavior_result['stats']}")
if draw:
frame = self.draw_detections(frame, detections, fps)
@@ -197,3 +194,139 @@ class DetectionService:
'detections': [],
'stats': None
}
def _apply_behavior_analysis(
self,
result_data: Dict,
algorithm_config: Dict
) -> Dict:
"""
应用行为分析算法
Args:
result_data: 检测结果
algorithm_config: 算法配置
{
"enable_stationary_detection": true,
"enable_loitering_detection": false,
"stationary_threshold": 10.0,
"position_tolerance": 50,
...
}
Returns:
添加行为分析结果的检测结果
"""
detections = result_data['detections']
# 检查是否需要行为分析
enable_stationary = algorithm_config.get('enable_stationary_detection', False)
enable_loitering = algorithm_config.get('enable_loitering_detection', False)
if not enable_stationary and not enable_loitering:
return result_data
try:
# 使用前端传入的配置初始化服务
self.loitering_service.initialize(
stationary_threshold=algorithm_config.get('stationary_threshold', 10.0),
position_tolerance=algorithm_config.get('position_tolerance', 50),
loitering_threshold=algorithm_config.get('loitering_threshold', 300.0),
movement_threshold=algorithm_config.get('movement_threshold', 5.0),
enable_stationary_alert=enable_stationary,
enable_loitering_alert=enable_loitering
)
# 处理检测
behavior_result = self.loitering_service.process_detections(
detections,
use_tracking=enable_loitering # 只有启用徘徊检测时才使用跟踪
)
result_data['detections'] = behavior_result['detections']
result_data['alerts'] = behavior_result['alerts']
result_data['behavior_stats'] = behavior_result['stats']
except Exception as e:
logger.error(f"行为分析失败: {e}")
result_data['behavior_error'] = str(e)
return result_data
def draw_detections(
self,
frame: np.ndarray,
detections: List[Dict],
fps: float = 0,
algorithm_config: Optional[Dict] = None
) -> np.ndarray:
"""
绘制检测结果和行为告警
Args:
frame: 图像帧
detections: 检测结果列表(可能包含 stationary_info/loitering_info
fps: 帧率
algorithm_config: 算法配置(已废弃,保留用于向后兼容)
"""
try:
img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_img = Image.fromarray(img_rgb)
draw = ImageDraw.Draw(pil_img)
try:
font = ImageFont.truetype("/System/Library/Fonts/PingFang.ttc", 20)
font_large = ImageFont.truetype("/System/Library/Fonts/PingFang.ttc", 24)
except:
font = ImageFont.load_default()
font_large = font
class_colors = {
'Fire': (255, 0, 0),
'Smoke': (128, 128, 128),
'person': (0, 255, 0),
'helmet': (255, 255, 0),
'no_helmet': (255, 0, 255),
'cigarette': (0, 165, 255)
}
for det in detections:
x1, y1, x2, y2 = det['bbox']
class_name = det['class']
conf = det['confidence']
label = det['label']
# 根据是否有行为告警选择颜色
color = class_colors.get(class_name, (0, 255, 0))
# 检查行为告警
if algorithm_config:
if 'stationary_info' in det:
info = det['stationary_info']
if info.get('is_stationary'):
color = (0, 0, 255) # 红色警告
label = f"静止{int(info['duration'])}s"
if 'loitering_info' in det:
info = det['loitering_info']
if info.get('is_loitering'):
color = (255, 0, 0) # 蓝色警告
label = f"徘徊{int(info['loitering_duration']//60)}min"
draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
label_text = f"{label} {conf:.2f}"
bbox = draw.textbbox((0, 0), label_text, font=font)
text_w = bbox[2] - bbox[0]
text_h = bbox[3] - bbox[1]
draw.rectangle([x1, y1 - text_h - 4, x1 + text_w + 4, y1], fill=color)
draw.text((x1 + 2, y1 - text_h - 2), label_text, fill=(255, 255, 255), font=font)
if fps > 0:
fps_text = f"FPS: {fps:.1f} | Detections: {len(detections)}"
draw.text((10, 10), fps_text, fill=(0, 255, 0), font=font)
return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
except Exception as e:
logger.error(f"绘制检测结果失败: {e}")
return frame

View File

@@ -0,0 +1,168 @@
"""
徘徊检测服务
集成行为检测算法到后端服务
"""
import sys
import os
from typing import Dict, List, Optional
import logging
# 添加算法模块路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..', 'models', 'loitering_detection'))
from processors import BehaviorProcessor
logger = logging.getLogger(__name__)
class LoiteringService:
"""
徘徊检测服务
为视频流检测提供行为分析功能:
- 静止检测(基于位置,无需跟踪)
- 徘徊检测基于跟踪ID
"""
def __init__(self):
self.processor = None
self.is_initialized = False
def initialize(
self,
stationary_threshold: float = 10.0,
position_tolerance: int = 50,
loitering_threshold: float = 300.0,
movement_threshold: float = 5.0,
enable_stationary_alert: bool = True,
enable_loitering_alert: bool = True,
stationary_alert_threshold: Optional[float] = None,
loitering_alert_threshold: Optional[float] = None
):
"""
初始化服务
Args:
stationary_threshold: 静止检测阈值(秒)- 用于判断是否静止
position_tolerance: 位置容差(像素)
loitering_threshold: 徘徊检测阈值(秒)- 用于判断是否徘徊
movement_threshold: 移动阈值(像素)
enable_stationary_alert: 是否启用静止告警
enable_loitering_alert: 是否启用徘徊告警
stationary_alert_threshold: 静止告警阈值(秒)- 超过此时间产生告警,默认等于 stationary_threshold
loitering_alert_threshold: 徘徊告警阈值(秒)- 超过此时间产生告警,默认等于 loitering_threshold
"""
try:
self.processor = BehaviorProcessor(
stationary_threshold=stationary_threshold,
position_tolerance=position_tolerance,
loitering_threshold=loitering_threshold,
movement_threshold=movement_threshold,
enable_stationary_alert=enable_stationary_alert,
enable_loitering_alert=enable_loitering_alert,
stationary_alert_threshold=stationary_alert_threshold if stationary_alert_threshold is not None else stationary_threshold,
loitering_alert_threshold=loitering_alert_threshold if loitering_alert_threshold is not None else loitering_threshold
)
self.is_initialized = True
logger.info(f"徘徊检测服务初始化成功: 静止阈值={stationary_threshold}s, 告警阈值={stationary_alert_threshold or stationary_threshold}s")
except Exception as e:
logger.error(f"徘徊检测服务初始化失败: {e}")
self.is_initialized = False
def process_detections(
self,
detections: List[Dict],
use_tracking: bool = False,
track_id_key: str = 'track_id'
) -> Dict:
"""
处理检测结果
Args:
detections: YOLO检测结果列表
use_tracking: 是否使用跟踪ID
track_id_key: 跟踪ID字段名
Returns:
{
'detections': 添加行为信息的检测结果,
'alerts': 触发的告警列表,
'stats': 统计信息
}
"""
if not self.is_initialized or not self.processor:
return {
'detections': detections,
'alerts': [],
'stats': {'error': '服务未初始化'}
}
try:
return self.processor.process(
detections=detections,
use_tracking=use_tracking,
track_id_key=track_id_key
)
except Exception as e:
logger.error(f"处理检测结果失败: {e}")
return {
'detections': detections,
'alerts': [],
'stats': {'error': str(e)}
}
def get_stationary_persons(self) -> List[Dict]:
"""获取所有静止人员"""
if not self.is_initialized or not self.processor:
return []
return self.processor.get_stationary_persons()
def get_loitering_persons(self) -> List[Dict]:
"""获取所有徘徊人员"""
if not self.is_initialized or not self.processor:
return []
return self.processor.get_loitering_persons()
def reset(self):
"""重置检测器"""
if self.processor:
self.processor.reset()
logger.info("徘徊检测器已重置")
def get_config(self) -> Dict:
"""获取当前配置"""
if not self.is_initialized or not self.processor:
return {'error': '服务未初始化'}
return self.processor.get_config()
def get_stats(self) -> Dict:
"""获取统计信息"""
if not self.is_initialized or not self.processor:
return {'error': '服务未初始化'}
stats = {
'stationary_count': len(self.get_stationary_persons()),
'loitering_count': len(self.get_loitering_persons()),
'config': self.get_config()
}
return stats
# 全局服务实例
_loitering_service: Optional[LoiteringService] = None
def get_loitering_service() -> LoiteringService:
"""获取全局徘徊检测服务实例"""
global _loitering_service
if _loitering_service is None:
_loitering_service = LoiteringService()
return _loitering_service
def initialize_loitering_service(**kwargs):
"""初始化全局徘徊检测服务"""
service = get_loitering_service()
service.initialize(**kwargs)
return service

View File

@@ -82,7 +82,6 @@ class ModelService:
return None
if model_id in self.models:
logger.info(f"模型已加载: {model_id}")
return self.models[model_id]
config = self.model_configs[model_id]

View File

@@ -10,11 +10,21 @@ export const detectionApi = {
return api.get('/models')
},
detectImage(formData) {
getAlgorithmConfig() {
return api.get('/algorithms/config')
},
detectImage(formData, algorithmConfig = null) {
const params = {}
if (algorithmConfig) {
params.algorithm_config = JSON.stringify(algorithmConfig)
}
return api.post('/detect/image', formData, {
headers: {
'Content-Type': 'multipart/form-data'
}
},
params
})
}
}

View File

@@ -0,0 +1,326 @@
<template>
<div v-if="showConfig" class="algorithm-config">
<el-divider content-position="left">
<el-icon><Cpu /></el-icon>
<span style="margin-left: 8px;">行为分析算法</span>
</el-divider>
<div v-if="loading" class="loading-wrapper">
<el-skeleton :rows="3" animated />
</div>
<div v-else-if="algorithms.length === 0" class="empty-config">
<el-empty description="暂无可配置算法" :image-size="60" />
</div>
<div v-else class="algorithm-list">
<div
v-for="algo in algorithms"
:key="algo.id"
class="algorithm-item"
>
<div class="algorithm-header">
<el-switch
v-model="config[algo.id].enabled"
@change="onConfigChange"
:active-text="algo.name"
/>
<el-tooltip :content="algo.description" placement="top">
<el-icon class="info-icon"><InfoFilled /></el-icon>
</el-tooltip>
</div>
<div v-if="config[algo.id].enabled" class="algorithm-params">
<div
v-for="param in algo.params"
:key="param.name"
class="param-item"
>
<div class="param-label">
{{ param.label }}
<el-tooltip :content="param.description" placement="top">
<el-icon class="help-icon"><QuestionFilled /></el-icon>
</el-tooltip>
</div>
<div class="param-control">
<el-slider
v-if="param.type === 'number'"
v-model="config[algo.id].params[param.name]"
:min="param.min"
:max="param.max"
:step="param.name.includes('threshold') ? 1 : 1"
@change="onConfigChange"
show-input
:show-input-controls="false"
input-size="small"
/>
</div>
<div class="param-unit">{{ param.unit }}</div>
</div>
</div>
</div>
</div>
<div class="config-actions">
<el-button size="small" @click="resetConfig" :icon="RefreshRight">
重置
</el-button>
<el-button type="primary" size="small" @click="applyConfig" :icon="Check">
应用
</el-button>
</div>
</div>
</template>
<script setup>
import { ref, reactive, onMounted, watch, computed } from 'vue'
import { ElMessage } from 'element-plus'
import {
Cpu,
InfoFilled,
QuestionFilled,
RefreshRight,
Check
} from '@element-plus/icons-vue'
import { detectionApi } from '@/api/detection'
const props = defineProps({
modelValue: {
type: Object,
default: () => ({})
},
modelId: {
type: String,
default: ''
}
})
// 支持行为分析的模型列表
const SUPPORTED_MODELS = [
'loitering_detection', // 徘徊检测
'crowd_detection', // 人群检测
'person_detection' // 人员检测
]
// 是否显示配置
const showConfig = computed(() => {
return SUPPORTED_MODELS.some(model => props.modelId.includes(model))
})
const emit = defineEmits(['update:modelValue', 'change'])
const loading = ref(false)
const algorithms = ref([])
const config = reactive({})
// 获取算法配置
const fetchAlgorithmConfig = async () => {
loading.value = true
try {
const response = await detectionApi.getAlgorithmConfig()
algorithms.value = response.data.algorithms || []
// 初始化配置
algorithms.value.forEach(algo => {
if (!config[algo.id]) {
config[algo.id] = {
enabled: false,
params: {}
}
}
// 设置默认参数
algo.params.forEach(param => {
if (config[algo.id].params[param.name] === undefined) {
config[algo.id].params[param.name] = param.default
}
})
})
} catch (error) {
console.error('获取算法配置失败:', error)
} finally {
loading.value = false
}
}
// 生成后端需要的配置格式
const generateConfig = () => {
const result = {}
algorithms.value.forEach(algo => {
const algoConfig = config[algo.id]
if (algoConfig && algoConfig.enabled) {
// 根据算法类型设置启用标志
if (algo.id === 'stationary_detection') {
result.enable_stationary_detection = true
} else if (algo.id === 'loitering_detection') {
result.enable_loitering_detection = true
}
// 添加参数
Object.entries(algoConfig.params).forEach(([key, value]) => {
result[key] = value
})
}
})
return result
}
// 配置变化
const onConfigChange = () => {
const backendConfig = generateConfig()
emit('update:modelValue', backendConfig)
emit('change', backendConfig)
}
// 重置配置
const resetConfig = () => {
algorithms.value.forEach(algo => {
config[algo.id] = {
enabled: false,
params: {}
}
algo.params.forEach(param => {
config[algo.id].params[param.name] = param.default
})
})
onConfigChange()
ElMessage.success('配置已重置')
}
// 应用配置
const applyConfig = () => {
onConfigChange()
ElMessage.success('配置已应用')
}
// 监听外部配置变化
watch(() => props.modelValue, (newVal) => {
if (newVal && Object.keys(newVal).length > 0) {
// 根据外部配置更新内部状态
if (newVal.enable_stationary_detection) {
config['stationary_detection'].enabled = true
}
if (newVal.enable_loitering_detection) {
config['loitering_detection'].enabled = true
}
// 更新参数
Object.entries(newVal).forEach(([key, value]) => {
algorithms.value.forEach(algo => {
if (config[algo.id].params[key] !== undefined) {
config[algo.id].params[key] = value
}
})
})
}
}, { deep: true })
onMounted(() => {
fetchAlgorithmConfig()
})
</script>
<style scoped>
.algorithm-config {
margin-top: 16px;
}
.loading-wrapper {
padding: 20px;
}
.empty-config {
padding: 20px;
}
.algorithm-list {
display: flex;
flex-direction: column;
gap: 16px;
}
.algorithm-item {
border: 1px solid #e4e7ed;
border-radius: 8px;
padding: 12px;
background: #fafafa;
}
.algorithm-header {
display: flex;
align-items: center;
gap: 8px;
}
.algorithm-header :deep(.el-switch__label) {
font-weight: 500;
}
.info-icon {
color: #909399;
cursor: help;
font-size: 14px;
}
.algorithm-params {
margin-top: 12px;
padding-top: 12px;
border-top: 1px dashed #dcdfe6;
}
.param-item {
margin-bottom: 12px;
}
.param-item:last-child {
margin-bottom: 0;
}
.param-label {
font-size: 13px;
color: #606266;
margin-bottom: 8px;
display: flex;
align-items: center;
gap: 4px;
}
.help-icon {
color: #c0c4cc;
cursor: help;
font-size: 12px;
}
.param-control {
display: flex;
align-items: center;
gap: 8px;
}
.param-control :deep(.el-slider) {
flex: 1;
}
.param-control :deep(.el-slider__input) {
width: 60px;
}
.param-unit {
font-size: 12px;
color: #909399;
min-width: 40px;
}
.config-actions {
display: flex;
justify-content: flex-end;
gap: 8px;
margin-top: 16px;
padding-top: 16px;
border-top: 1px solid #e4e7ed;
}
</style>

View File

@@ -93,6 +93,13 @@
<div class="slider-value">{{ config.iou.toFixed(2) }}</div>
</el-form-item>
<!-- 算法配置仅对人员检测模型显示 -->
<AlgorithmConfig
v-model="config.algorithmConfig"
@change="onAlgorithmChange"
:model-id="config.model"
/>
</el-form>
</el-card>
</div>
@@ -225,6 +232,7 @@ import {
QuestionFilled
} from '@element-plus/icons-vue'
import { detectionApi } from '@/api/detection'
import AlgorithmConfig from './AlgorithmConfig.vue'
const props = defineProps({
models: {
@@ -236,7 +244,8 @@ const props = defineProps({
const config = ref({
model: props.models.length > 0 ? props.models[0].id : 'fire_detection',
confidence: 0.5,
iou: 0.45
iou: 0.45,
algorithmConfig: {}
})
// 可拖拽调整宽度相关
@@ -271,7 +280,20 @@ const originalImage = ref('')
const resultImage = ref('')
const detections = ref([])
const stats = ref(null)
const uploadUrl = computed(() => `/api/detect/image?model_id=${config.value.model}&confidence=${config.value.confidence}&iou=${config.value.iou}`)
const uploadUrl = computed(() => {
const params = new URLSearchParams({
model_id: config.value.model,
confidence: config.value.confidence,
iou: config.value.iou
})
// 添加算法配置
if (config.value.algorithmConfig && Object.keys(config.value.algorithmConfig).length > 0) {
params.append('algorithm_config', JSON.stringify(config.value.algorithmConfig))
}
return `/api/detect/image?${params.toString()}`
})
const formatConfidence = (value) => {
return `置信度: ${value.toFixed(2)}`
@@ -303,6 +325,22 @@ const handleUploadSuccess = (response) => {
}
detections.value = response.data.detections || []
stats.value = response.data.stats
// 处理告警信息
if (response.data.alerts && response.data.alerts.length > 0) {
alerts.value = response.data.alerts
console.log('收到告警:', response.data.alerts)
// 显示告警通知
response.data.alerts.forEach(alert => {
ElMessage({
message: `行为告警: ${alert.type} - ${alert.message}`,
type: 'warning',
duration: 3000
})
})
}
ElMessage.success('检测完成')
} else {
ElMessage.error(response.message)
@@ -320,6 +358,10 @@ const modelName = computed(() => {
const model = props.models.find(m => m.id === config.value.model)
return model ? model.name : config.value.model
})
const onAlgorithmChange = (algoConfig) => {
config.value.algorithmConfig = algoConfig
}
</script>
<style scoped>
@@ -624,6 +666,78 @@ const modelName = computed(() => {
color: #409eff;
}
/* 告警卡片 */
.alerts-card {
margin-bottom: 20px;
border: 2px solid #f56c6c;
}
.alerts-card .card-header {
display: flex;
justify-content: space-between;
align-items: center;
}
.alert-count {
margin-left: 8px;
}
.alerts-container {
max-height: 400px;
overflow-y: auto;
padding: 16px;
}
.alert-item {
background: #fef0f0;
border-left: 4px solid #f56c6c;
padding: 12px;
border-radius: 4px;
margin-bottom: 12px;
}
.alert-header {
display: flex;
align-items: center;
gap: 12px;
margin-bottom: 8px;
}
.alert-time {
font-size: 12px;
color: #909399;
}
.alert-detail {
display: flex;
align-items: center;
gap: 12px;
margin-bottom: 6px;
}
.alert-message {
font-size: 14px;
color: #f56c6c;
font-weight: 500;
}
.alert-duration {
font-size: 13px;
color: #606266;
background: #fff;
padding: 2px 8px;
border-radius: 4px;
}
.alert-bbox {
font-size: 12px;
color: #606266;
background: #fff;
padding: 4px 8px;
border-radius: 4px;
display: inline-block;
}
/* 响应式布局 */
@media (max-width: 768px) {
.image-detection-container {

View File

@@ -95,6 +95,13 @@
/>
<div class="slider-value">{{ config.iou.toFixed(2) }}</div>
</el-form-item>
<!-- 算法配置仅对人员检测模型显示 -->
<AlgorithmConfig
v-model="config.algorithmConfig"
@change="onAlgorithmChange"
:model-id="config.model"
/>
</el-form>
</el-card>
</div>
@@ -186,6 +193,40 @@
</el-col>
</el-row>
<!-- 行为告警 -->
<el-card v-if="alerts && alerts.length > 0" class="alerts-card" shadow="hover">
<template #header>
<div class="card-header">
<div class="header-left">
<el-icon class="header-icon"><Warning /></el-icon>
<span>行为告警</span>
<el-tag size="small" type="danger" class="alert-count">{{ alerts.length }} </el-tag>
</div>
</div>
</template>
<div class="alerts-container">
<div
v-for="(alert, index) in alerts"
:key="index"
class="alert-item"
>
<div class="alert-header">
<el-tag :type="alert.type === 'stationary' ? 'warning' : 'danger'" size="small">
{{ alert.type === 'stationary' ? '静止' : '徘徊' }}
</el-tag>
<span class="alert-time">{{ new Date(alert.timestamp * 1000).toLocaleTimeString('zh-CN') }}</span>
</div>
<div class="alert-detail">
<span class="alert-message">{{ alert.message }}</span>
<span v-if="alert.duration" class="alert-duration">持续: {{ alert.duration.toFixed(1) }}s</span>
</div>
<div v-if="alert.bbox" class="alert-bbox">
<code class="bbox-code">[{{ alert.bbox.join(', ') }}]</code>
</div>
</div>
</div>
</el-card>
<!-- 检测详情 -->
<el-card v-if="detections.length > 0" class="details-card" shadow="hover">
<template #header>
@@ -273,6 +314,7 @@ import {
Timer,
Delete
} from '@element-plus/icons-vue'
import AlgorithmConfig from './AlgorithmConfig.vue'
const props = defineProps({
models: {
@@ -284,7 +326,8 @@ const props = defineProps({
const config = ref({
model: props.models.length > 0 ? props.models[0].id : 'fire_detection',
confidence: 0.5,
iou: 0.45
iou: 0.45,
algorithmConfig: {}
})
// 可拖拽调整宽度相关
@@ -321,6 +364,7 @@ const originalCameraFrame = ref('')
const resultCameraFrame = ref('')
const detections = ref([])
const stats = ref(null)
const alerts = ref([])
const websocket = ref(null)
// 检测日志
@@ -387,14 +431,21 @@ const startCamera = async () => {
cameraConnected.value = true
cameraStarting.value = false
websocket.value.send(JSON.stringify({
const startConfig = {
action: 'start',
config: {
model_id: config.value.model,
confidence: config.value.confidence,
iou: config.value.iou
}
}))
}
// 添加算法配置
if (config.value.algorithmConfig && Object.keys(config.value.algorithmConfig).length > 0) {
startConfig.config.algorithm_config = config.value.algorithmConfig
}
websocket.value.send(JSON.stringify(startConfig))
}
websocket.value.onmessage = (event) => {
@@ -450,14 +501,29 @@ const stopCamera = () => {
const updateCameraConfig = () => {
if (websocket.value && cameraConnected.value) {
websocket.value.send(JSON.stringify({
const wsConfig = {
action: 'update_config',
config: {
model_id: config.value.model,
confidence: config.value.confidence,
iou: config.value.iou
}
}))
}
// 添加算法配置
if (config.value.algorithmConfig && Object.keys(config.value.algorithmConfig).length > 0) {
wsConfig.config.algorithm_config = config.value.algorithmConfig
}
websocket.value.send(JSON.stringify(wsConfig))
}
}
const onAlgorithmChange = (algoConfig) => {
config.value.algorithmConfig = algoConfig
// 如果摄像头已连接,实时更新配置
if (websocket.value && cameraConnected.value) {
updateCameraConfig()
}
}
@@ -847,6 +913,78 @@ onUnmounted(() => {
padding: 12px;
}
/* 告警卡片 */
.alerts-card {
margin-bottom: 20px;
border: 2px solid #f56c6c;
}
.alerts-card .card-header {
display: flex;
justify-content: space-between;
align-items: center;
}
.alert-count {
margin-left: 8px;
}
.alerts-container {
max-height: 400px;
overflow-y: auto;
padding: 16px;
}
.alert-item {
background: #fef0f0;
border-left: 4px solid #f56c6c;
padding: 12px;
border-radius: 4px;
margin-bottom: 12px;
}
.alert-header {
display: flex;
align-items: center;
gap: 12px;
margin-bottom: 8px;
}
.alert-time {
font-size: 12px;
color: #909399;
}
.alert-detail {
display: flex;
align-items: center;
gap: 12px;
margin-bottom: 6px;
}
.alert-message {
font-size: 14px;
color: #f56c6c;
font-weight: 500;
}
.alert-duration {
font-size: 13px;
color: #606266;
background: #fff;
padding: 2px 8px;
border-radius: 4px;
}
.alert-bbox {
font-size: 12px;
color: #606266;
background: #fff;
padding: 4px 8px;
border-radius: 4px;
display: inline-block;
}
/* 响应式布局 */
@media (max-width: 768px) {
.video-detection-container {

View File

@@ -38,6 +38,169 @@ cp /path/to/behavior_detection/Loitering-Detection/yolov8n.pt models/loitering_d
bash scripts/setup-models.sh
```
## 算法目录规范
当需要在检测模型基础上添加自定义算法逻辑时,请在对应模型目录下创建以下子目录:
### `algorithms/` - 独立算法模块
用于存放独立的算法实现,如密度估计、流动分析、行为识别等。
```
crowd_detection/
├── yolov8l.pt
└── algorithms/
├── __init__.py
├── density_estimator.py # 人群密度估计算法
├── flow_analysis.py # 人群流动分析算法
├── anomaly_detector.py # 异常行为检测算法
└── crowd_counting.py # 人群计数算法
```
**示例代码:**
```python
# crowd_detection/algorithms/density_estimator.py
import numpy as np
class DensityEstimator:
"""人群密度估计器"""
def __init__(self, grid_size=10):
self.grid_size = grid_size
def estimate(self, detections, image_shape):
"""
根据检测结果估计人群密度
Args:
detections: YOLO检测结果 [(x1,y1,x2,y2,conf,cls), ...]
image_shape: (height, width)
Returns:
density_map: 密度热力图
"""
h, w = image_shape
grid_h, grid_w = h // self.grid_size, w // self.grid_size
density_map = np.zeros((grid_h, grid_w))
for det in detections:
cx, cy = int((det[0] + det[2]) / 2), int((det[1] + det[3]) / 2)
grid_x, grid_y = cx // self.grid_size, cy // self.grid_size
if 0 <= grid_x < grid_w and 0 <= grid_y < grid_h:
density_map[grid_y, grid_x] += 1
return density_map
def is_crowded(self, density_map, threshold=5):
"""判断是否拥挤"""
return np.max(density_map) > threshold
```
### `processors/` - 检测结果二次处理
用于对模型检测结果进行后处理,如过滤、分析、告警判断等。
```
crowd_detection/
├── yolov8l.pt
└── processors/
├── __init__.py
├── post_processor.py # 检测结果后处理
├── crowd_analyzer.py # 人群分析器
└── alert_rules.py # 告警规则判断
```
**示例代码:**
```python
# crowd_detection/processors/alert_rules.py
from datetime import datetime, timedelta
class CrowdAlertRules:
"""人群检测告警规则"""
def __init__(self):
self.alert_history = []
self.cooldown_minutes = 5
def check_crowd_gathering(self, person_count, density, duration_seconds):
"""
检查是否触发人群聚集告警
Args:
person_count: 检测到的人数
density: 人群密度值
duration_seconds: 持续时长
Returns:
alert: 告警信息或None
"""
# 规则:人数>20 且 密度>0.5 且 持续>30秒
if person_count > 20 and density > 0.5 and duration_seconds > 30:
if self._can_trigger_alert("crowd_gathering"):
alert = {
"type": "crowd_gathering",
"level": "high" if person_count > 50 else "medium",
"message": f"检测到人群聚集,人数: {person_count}",
"timestamp": datetime.now()
}
self._record_alert("crowd_gathering")
return alert
return None
def check_intrusion(self, detections, restricted_zones):
"""
检查是否触发区域入侵告警
Args:
detections: 检测结果
restricted_zones: 限制区域列表 [(x1,y1,x2,y2), ...]
"""
alerts = []
for det in detections:
person_box = (det[0], det[1], det[2], det[3])
for zone in restricted_zones:
if self._is_intersect(person_box, zone):
alerts.append({
"type": "zone_intrusion",
"level": "high",
"message": "检测到人员进入限制区域"
})
return alerts
def _can_trigger_alert(self, alert_type):
"""检查是否可以通过冷却期触发告警"""
cutoff_time = datetime.now() - timedelta(minutes=self.cooldown_minutes)
recent_alerts = [
a for a in self.alert_history
if a["type"] == alert_type and a["timestamp"] > cutoff_time
]
return len(recent_alerts) == 0
def _record_alert(self, alert_type):
"""记录告警历史"""
self.alert_history.append({
"type": alert_type,
"timestamp": datetime.now()
})
@staticmethod
def _is_intersect(box1, box2):
"""判断两个框是否相交"""
x1 = max(box1[0], box2[0])
y1 = max(box1[1], box2[1])
x2 = min(box1[2], box2[2])
y2 = min(box1[3], box2[3])
return x1 < x2 and y1 < y2
```
### 目录选择建议
| 场景 | 推荐目录 | 说明 |
|------|----------|------|
| 实现新的检测算法(如密度估计、行为识别) | `algorithms/` | 独立的算法逻辑,可复用 |
| 对检测结果进行过滤、分析 | `processors/` | 针对业务场景的后处理 |
| 简单的工具函数 | `utils/` | 辅助函数,无状态逻辑 |
## 注意
模型文件较大,未包含在 Git 仓库中。请从原始位置复制或创建符号链接。

View File

@@ -0,0 +1,9 @@
"""
徘徊检测算法模块
包含基于位置和基于跟踪ID的检测算法
"""
from .stationary_detector import PositionBasedStationaryDetector
from .loitering_detector import LoiteringDetector
__all__ = ['PositionBasedStationaryDetector', 'LoiteringDetector']

View File

@@ -0,0 +1,251 @@
"""
基于跟踪ID的徘徊检测算法
依赖跟踪ID适用于跟踪稳定的场景
"""
import time
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass, field
from collections import defaultdict
@dataclass
class PersonTrack:
"""人员跟踪记录"""
person_id: int
first_seen: float
last_seen: float
positions: List[Tuple[int, int]] = field(default_factory=list)
last_position: Optional[Tuple[int, int]] = None
stationary_start: Optional[float] = None
total_duration: float = 0.0
stationary_duration: float = 0.0
class LoiteringDetector:
"""
徘徊检测器基于跟踪ID
特点:
- 依赖跟踪 ID需要稳定的跟踪器
- 可以检测长时间停留(徘徊)
- 可以检测静止不动(静止)
"""
def __init__(
self,
loitering_threshold: float = 300.0, # 徘徊阈值默认5分钟
stationary_threshold: float = 2.0, # 静止阈值(秒)
movement_threshold: float = 5.0, # 移动阈值(像素)
cleanup_interval: float = 10.0 # 清理间隔(秒)
):
self.loitering_threshold = loitering_threshold
self.stationary_threshold = stationary_threshold
self.movement_threshold = movement_threshold
self.cleanup_interval = cleanup_interval
# 跟踪记录: {person_id: PersonTrack}
self._tracks: Dict[int, PersonTrack] = {}
self._last_cleanup = time.time()
def _cleanup_old_tracks(self, max_age: float = 60.0) -> int:
"""清理长时间未更新的跟踪记录"""
current_time = time.time()
to_remove = [
pid for pid, track in self._tracks.items()
if current_time - track.last_seen > max_age
]
for pid in to_remove:
del self._tracks[pid]
return len(to_remove)
def update(
self,
person_id: int,
position: Tuple[int, int]
) -> Tuple[bool, float, bool, float]:
"""
更新人员位置
Args:
person_id: 人员ID
position: (x, y) 中心点坐标
Returns:
is_loitering: 是否徘徊超过阈值
loitering_duration: 徘徊时长(秒)
is_stationary: 是否静止超过阈值
stationary_duration: 静止时长(秒)
"""
current_time = time.time()
# 定期清理
if current_time - self._last_cleanup > self.cleanup_interval:
self._cleanup_old_tracks()
self._last_cleanup = current_time
# 获取或创建跟踪记录
if person_id not in self._tracks:
self._tracks[person_id] = PersonTrack(
person_id=person_id,
first_seen=current_time,
last_seen=current_time,
last_position=position
)
return False, 0.0, False, 0.0
track = self._tracks[person_id]
track.last_seen = current_time
track.positions.append(position)
# 计算总停留时长
track.total_duration = current_time - track.first_seen
# 检查是否移动
is_moving = False
if track.last_position is not None:
distance = ((position[0] - track.last_position[0]) ** 2 +
(position[1] - track.last_position[1]) ** 2) ** 0.5
is_moving = distance > self.movement_threshold
track.last_position = position
# 更新静止状态
if is_moving:
# 如果移动了,重置静止计时
track.stationary_start = None
track.stationary_duration = 0.0
else:
# 如果没移动,更新静止时长
if track.stationary_start is None:
track.stationary_start = current_time
track.stationary_duration = current_time - track.stationary_start
# 判断是否徘徊/静止
is_loitering = track.total_duration > self.loitering_threshold
is_stationary = track.stationary_duration > self.stationary_threshold
return (
is_loitering,
track.total_duration,
is_stationary,
track.stationary_duration
)
def detect(
self,
detections: List[Dict],
id_key: str = 'track_id'
) -> List[Dict]:
"""
批量检测徘徊状态
Args:
detections: 检测结果列表,每项包含 'bbox' 和 track_id
id_key: 跟踪ID的字段名
Returns:
添加 'loitering_info' 字段的检测结果
"""
results = []
for det in detections:
person_id = det.get(id_key)
if person_id is None:
results.append(det)
continue
x1, y1, x2, y2 = det['bbox']
center = ((x1 + x2) // 2, (y1 + y2) // 2)
is_loitering, loitering_duration, is_stationary, stationary_duration = \
self.update(person_id, center)
det_copy = det.copy()
det_copy['loitering_info'] = {
'person_id': person_id,
'is_loitering': is_loitering,
'loitering_duration': round(loitering_duration, 2),
'is_stationary': is_stationary,
'stationary_duration': round(stationary_duration, 2),
'loitering_threshold': self.loitering_threshold,
'stationary_threshold': self.stationary_threshold
}
results.append(det_copy)
return results
def get_all_loitering(
self,
threshold: Optional[float] = None
) -> List[Dict]:
"""
获取所有徘徊超过阈值的人员
Args:
threshold: 徘徊阈值(秒),默认使用初始化时的阈值
Returns:
list: [{person_id, duration, positions}, ...]
"""
threshold = threshold or self.loitering_threshold
result = []
for person_id, track in self._tracks.items():
if track.total_duration > threshold:
result.append({
'person_id': person_id,
'duration': track.total_duration,
'positions': track.positions.copy(),
'is_stationary': track.stationary_duration > self.stationary_threshold,
'stationary_duration': track.stationary_duration
})
# 按时长排序
result.sort(key=lambda x: x['duration'], reverse=True)
return result
def get_all_stationary(
self,
threshold: Optional[float] = None
) -> List[Dict]:
"""
获取所有静止超过阈值的人员
Args:
threshold: 静止阈值(秒),默认使用初始化时的阈值
Returns:
list: [{person_id, duration, position}, ...]
"""
threshold = threshold or self.stationary_threshold
result = []
for person_id, track in self._tracks.items():
if track.stationary_duration > threshold:
result.append({
'person_id': person_id,
'duration': track.stationary_duration,
'position': track.last_position,
'total_duration': track.total_duration
})
result.sort(key=lambda x: x['duration'], reverse=True)
return result
def reset(self):
"""重置所有跟踪数据"""
self._tracks.clear()
self._last_cleanup = time.time()
def get_stats(self) -> Dict:
"""获取统计信息"""
return {
'total_tracks': len(self._tracks),
'loitering_count': len(self.get_all_loitering()),
'stationary_count': len(self.get_all_stationary()),
'loitering_threshold': self.loitering_threshold,
'stationary_threshold': self.stationary_threshold
}

View File

@@ -0,0 +1,236 @@
"""
基于位置的静止人员检测算法
不依赖跟踪 ID而是根据位置来关联人员
适用于跟踪不稳定但人员相对静止的场景
"""
import time
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass, field
from collections import defaultdict
@dataclass
class PositionRecord:
"""位置记录"""
first_seen: float
last_seen: float
center: Tuple[int, int]
box: Tuple[int, int, int, int]
duration: float = 0.0
class PositionBasedStationaryDetector:
"""
基于位置的静止检测器
特点:
- 不依赖跟踪 ID直接用位置关联人员
- 适用于 SORT 等跟踪器不稳定的场景
- 使用网格化位置 + 距离容差进行匹配
"""
def __init__(
self,
stationary_threshold: float = 10.0, # 静止阈值(秒)
position_tolerance: int = 50, # 位置容差(像素)
cleanup_interval: float = 5.0 # 清理间隔(秒)
):
self.stationary_threshold = stationary_threshold
self.position_tolerance = position_tolerance
self.cleanup_interval = cleanup_interval
# 位置历史记录: {position_key: PositionRecord}
self._position_history: Dict[Tuple[int, int], PositionRecord] = {}
self._last_cleanup = time.time()
def _get_position_key(self, center: Tuple[int, int]) -> Tuple[int, int]:
"""
将连续坐标转换为离散的位置键
用于将相近位置归为一类
"""
x, y = center
grid_x = int(x / self.position_tolerance)
grid_y = int(y / self.position_tolerance)
return (grid_x, grid_y)
def _find_matching_position(
self,
center: Tuple[int, int]
) -> Optional[Tuple[int, int]]:
"""
查找与当前位置匹配的历史位置
返回匹配的位置键,如果没有则返回 None
"""
current_key = self._get_position_key(center)
# 首先检查精确匹配
if current_key in self._position_history:
hist_center = self._position_history[current_key].center
distance = ((center[0] - hist_center[0]) ** 2 +
(center[1] - hist_center[1]) ** 2) ** 0.5
if distance < self.position_tolerance:
return current_key
# 检查相邻网格
for dx in [-1, 0, 1]:
for dy in [-1, 0, 1]:
if dx == 0 and dy == 0:
continue
neighbor_key = (current_key[0] + dx, current_key[1] + dy)
if neighbor_key in self._position_history:
hist_center = self._position_history[neighbor_key].center
distance = ((center[0] - hist_center[0]) ** 2 +
(center[1] - hist_center[1]) ** 2) ** 0.5
if distance < self.position_tolerance:
return neighbor_key
return None
def update(
self,
center: Tuple[int, int],
box: Tuple[int, int, int, int]
) -> Tuple[str, float, bool]:
"""
更新位置信息
Args:
center: (x, y) 中心点坐标
box: (x1, y1, x2, y2) 边界框
Returns:
position_id: 位置 ID用于关联
stationary_duration: 静止时长(秒)
is_stationary: 是否静止超过阈值
"""
current_time = time.time()
# 定期清理旧记录
if current_time - self._last_cleanup > self.cleanup_interval:
self.cleanup_old_positions()
self._last_cleanup = current_time
# 查找匹配的历史位置
matching_key = self._find_matching_position(center)
if matching_key is not None:
# 更新已有位置
record = self._position_history[matching_key]
record.last_seen = current_time
# 平滑更新中心位置(使用移动平均)
old_center = record.center
record.center = (
int(0.7 * old_center[0] + 0.3 * center[0]),
int(0.7 * old_center[1] + 0.3 * center[1])
)
record.box = box
duration = current_time - record.first_seen
record.duration = duration
is_stationary = duration > self.stationary_threshold
position_id = f"pos_{matching_key[0]}_{matching_key[1]}"
return position_id, duration, is_stationary
else:
# 创建新位置记录
new_key = self._get_position_key(center)
self._position_history[new_key] = PositionRecord(
first_seen=current_time,
last_seen=current_time,
center=center,
box=box,
duration=0.0
)
new_id = f"pos_{new_key[0]}_{new_key[1]}"
return new_id, 0.0, False
def cleanup_old_positions(self, max_age: float = 5.0) -> int:
"""
清理长时间未更新的位置记录
Args:
max_age: 最大保留时间(秒)
Returns:
清理的记录数量
"""
current_time = time.time()
to_remove = [
key for key, data in self._position_history.items()
if current_time - data.last_seen > max_age
]
for key in to_remove:
del self._position_history[key]
return len(to_remove)
def get_all_stationary(
self,
threshold: Optional[float] = None
) -> List[Dict]:
"""
获取所有静止超过阈值的位置
Args:
threshold: 静止阈值(秒),默认使用初始化时的阈值
Returns:
list: [{position_id, duration, center, box}, ...]
"""
threshold = threshold or self.stationary_threshold
result = []
for key, data in self._position_history.items():
if data.duration > threshold:
result.append({
'position_id': f"pos_{key[0]}_{key[1]}",
'duration': data.duration,
'center': data.center,
'box': data.box
})
# 按时长排序
result.sort(key=lambda x: x['duration'], reverse=True)
return result
def reset(self):
"""重置所有跟踪数据"""
self._position_history.clear()
self._last_cleanup = time.time()
def detect(
self,
detections: List[Dict]
) -> List[Dict]:
"""
批量检测静止状态
Args:
detections: 检测结果列表,每项包含 'bbox': [x1, y1, x2, y2]
Returns:
添加 'stationary_info' 字段的检测结果
"""
results = []
for det in detections:
x1, y1, x2, y2 = det['bbox']
center = ((x1 + x2) // 2, (y1 + y2) // 2)
box = (x1, y1, x2, y2)
position_id, duration, is_stationary = self.update(center, box)
det_copy = det.copy()
det_copy['stationary_info'] = {
'position_id': position_id,
'duration': round(duration, 2),
'is_stationary': is_stationary,
'threshold': self.stationary_threshold
}
results.append(det_copy)
return results

View File

@@ -0,0 +1,8 @@
"""
徘徊检测处理器模块
用于对检测结果进行后处理
"""
from .behavior_processor import BehaviorProcessor
__all__ = ['BehaviorProcessor']

View File

@@ -0,0 +1,201 @@
"""
行为检测处理器
集成基于位置和基于跟踪ID的检测算法
"""
import sys
import os
import logging
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
# 添加算法模块路径
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from algorithms import PositionBasedStationaryDetector, LoiteringDetector
logger = logging.getLogger(__name__)
@dataclass
class BehaviorAlert:
"""行为告警"""
alert_type: str # 'stationary', 'loitering'
level: str # 'low', 'medium', 'high'
message: str
person_id: Optional[str] = None
position_id: Optional[str] = None
duration: float = 0.0
bbox: Optional[Tuple[int, int, int, int]] = None
class BehaviorProcessor:
"""
行为检测处理器
整合两种检测方式:
1. 基于位置的静止检测无需跟踪ID
2. 基于跟踪ID的徘徊检测需要跟踪ID
"""
def __init__(
self,
# 静止检测参数
stationary_threshold: float = 10.0,
position_tolerance: int = 50,
# 徘徊检测参数
loitering_threshold: float = 300.0,
movement_threshold: float = 5.0,
# 告警参数
enable_stationary_alert: bool = True,
enable_loitering_alert: bool = True,
stationary_alert_threshold: float = 10.0, # 超过此时间产生告警
loitering_alert_threshold: float = 300.0 # 超过此时间产生告警
):
# 初始化检测器
self.stationary_detector = PositionBasedStationaryDetector(
stationary_threshold=stationary_threshold,
position_tolerance=position_tolerance
)
self.loitering_detector = LoiteringDetector(
loitering_threshold=loitering_threshold,
stationary_threshold=stationary_threshold,
movement_threshold=movement_threshold
)
# 配置
self.enable_stationary_alert = enable_stationary_alert
self.enable_loitering_alert = enable_loitering_alert
self.stationary_alert_threshold = stationary_alert_threshold
self.loitering_alert_threshold = loitering_alert_threshold
def process(
self,
detections: List[Dict],
use_tracking: bool = False,
track_id_key: str = 'track_id'
) -> Dict:
"""
处理检测结果,检测行为
Args:
detections: 检测结果列表
use_tracking: 是否使用跟踪ID如果有的话
track_id_key: 跟踪ID字段名
Returns:
{
'detections': 添加行为信息的检测结果,
'alerts': 触发的告警列表,
'stats': 统计信息
}
"""
logger.info(f"[BehaviorProcessor] 开始处理 {len(detections)} 个检测结果")
logger.info(f"[BehaviorProcessor] 配置: stationary={self.enable_stationary_alert}, loitering={self.enable_loitering_alert}")
alerts = []
# 1. 始终进行基于位置的静止检测
logger.info(f"[BehaviorProcessor] 调用静止检测器...")
detections = self.stationary_detector.detect(detections)
logger.info(f"[BehaviorProcessor] 静止检测完成,检测到 {len(detections)} 个结果")
# 检查静止告警
stationary_alerts = 0
if self.enable_stationary_alert:
for det in detections:
info = det.get('stationary_info', {})
if info.get('is_stationary') and info.get('duration', 0) >= self.stationary_alert_threshold:
alert = BehaviorAlert(
alert_type='stationary',
level='medium' if info['duration'] < 30 else 'high',
message=f"人员静止停留 {int(info['duration'])}",
position_id=info.get('position_id'),
duration=info['duration'],
bbox=tuple(det['bbox'])
)
alerts.append(alert)
stationary_alerts += 1
logger.info(f"[BehaviorProcessor] 静止告警: {stationary_alerts}")
# 2. 如果有跟踪ID进行徘徊检测
logger.info(f"[BehaviorProcessor] use_tracking={use_tracking}")
if use_tracking:
detections = self.loitering_detector.detect(detections, id_key=track_id_key)
# 检查徘徊告警
if self.enable_loitering_alert:
for det in detections:
info = det.get('loitering_info', {})
if info.get('is_loitering') and info.get('loitering_duration', 0) >= self.loitering_alert_threshold:
alert = BehaviorAlert(
alert_type='loitering',
level='high',
message=f"人员徘徊 {int(info['loitering_duration'] // 60)} 分钟",
person_id=str(info.get('person_id')),
duration=info['loitering_duration'],
bbox=tuple(det['bbox'])
)
alerts.append(alert)
# 统计信息
stats = {
'total_detections': len(detections),
'stationary_count': len(self.stationary_detector.get_all_stationary()),
'alert_count': len(alerts)
}
if use_tracking:
stats.update({
'loitering_count': len(self.loitering_detector.get_all_loitering()),
'tracking_count': self.loitering_detector.get_stats()['total_tracks']
})
logger.info(f"[BehaviorProcessor] 处理完成: {stats}")
return {
'detections': detections,
'alerts': [self._alert_to_dict(a) for a in alerts],
'stats': stats
}
def _alert_to_dict(self, alert: BehaviorAlert) -> Dict:
"""将告警对象转换为字典"""
return {
'type': alert.alert_type,
'level': alert.level,
'message': alert.message,
'person_id': alert.person_id,
'position_id': alert.position_id,
'duration': round(alert.duration, 2),
'bbox': alert.bbox
}
def get_stationary_persons(self) -> List[Dict]:
"""获取所有静止人员"""
return self.stationary_detector.get_all_stationary()
def get_loitering_persons(self) -> List[Dict]:
"""获取所有徘徊人员"""
return self.loitering_detector.get_all_loitering()
def reset(self):
"""重置所有检测器"""
self.stationary_detector.reset()
self.loitering_detector.reset()
def get_config(self) -> Dict:
"""获取当前配置"""
return {
'stationary_threshold': self.stationary_detector.stationary_threshold,
'position_tolerance': self.stationary_detector.position_tolerance,
'loitering_threshold': self.loitering_detector.loitering_threshold,
'movement_threshold': self.loitering_detector.movement_threshold,
'enable_stationary_alert': self.enable_stationary_alert,
'enable_loitering_alert': self.enable_loitering_alert,
'stationary_alert_threshold': self.stationary_alert_threshold,
'loitering_alert_threshold': self.loitering_alert_threshold
}