feat: 新增人员徘徊/静止行为分析功能
本次提交实现了完整的人员行为分析系统,包括: 1. 新增基于位置和跟踪ID的两种行为检测算法 2. 新增徘徊检测服务与行为处理器模块 3. 前后端集成算法配置界面与告警展示 4. 支持图片和视频流场景下的行为分析 5. 新增算法配置接口与文档说明 具体改动: - 新增loitering_detection模型目录与算法实现 - 新增AlgorithmConfig组件实现可视化配置 - 扩展图片/视频检测接口支持算法参数传递 - 新增行为告警推送与前端展示页面 - 优化检测服务,集成行为分析逻辑 - 移除冗余日志输出,完善代码注释
This commit is contained in:
163
models/README.md
163
models/README.md
@@ -38,6 +38,169 @@ cp /path/to/behavior_detection/Loitering-Detection/yolov8n.pt models/loitering_d
|
||||
bash scripts/setup-models.sh
|
||||
```
|
||||
|
||||
## 算法目录规范
|
||||
|
||||
当需要在检测模型基础上添加自定义算法逻辑时,请在对应模型目录下创建以下子目录:
|
||||
|
||||
### `algorithms/` - 独立算法模块
|
||||
|
||||
用于存放独立的算法实现,如密度估计、流动分析、行为识别等。
|
||||
|
||||
```
|
||||
crowd_detection/
|
||||
├── yolov8l.pt
|
||||
└── algorithms/
|
||||
├── __init__.py
|
||||
├── density_estimator.py # 人群密度估计算法
|
||||
├── flow_analysis.py # 人群流动分析算法
|
||||
├── anomaly_detector.py # 异常行为检测算法
|
||||
└── crowd_counting.py # 人群计数算法
|
||||
```
|
||||
|
||||
**示例代码:**
|
||||
```python
|
||||
# crowd_detection/algorithms/density_estimator.py
|
||||
import numpy as np
|
||||
|
||||
class DensityEstimator:
|
||||
"""人群密度估计器"""
|
||||
|
||||
def __init__(self, grid_size=10):
|
||||
self.grid_size = grid_size
|
||||
|
||||
def estimate(self, detections, image_shape):
|
||||
"""
|
||||
根据检测结果估计人群密度
|
||||
|
||||
Args:
|
||||
detections: YOLO检测结果 [(x1,y1,x2,y2,conf,cls), ...]
|
||||
image_shape: (height, width)
|
||||
|
||||
Returns:
|
||||
density_map: 密度热力图
|
||||
"""
|
||||
h, w = image_shape
|
||||
grid_h, grid_w = h // self.grid_size, w // self.grid_size
|
||||
density_map = np.zeros((grid_h, grid_w))
|
||||
|
||||
for det in detections:
|
||||
cx, cy = int((det[0] + det[2]) / 2), int((det[1] + det[3]) / 2)
|
||||
grid_x, grid_y = cx // self.grid_size, cy // self.grid_size
|
||||
if 0 <= grid_x < grid_w and 0 <= grid_y < grid_h:
|
||||
density_map[grid_y, grid_x] += 1
|
||||
|
||||
return density_map
|
||||
|
||||
def is_crowded(self, density_map, threshold=5):
|
||||
"""判断是否拥挤"""
|
||||
return np.max(density_map) > threshold
|
||||
```
|
||||
|
||||
### `processors/` - 检测结果二次处理
|
||||
|
||||
用于对模型检测结果进行后处理,如过滤、分析、告警判断等。
|
||||
|
||||
```
|
||||
crowd_detection/
|
||||
├── yolov8l.pt
|
||||
└── processors/
|
||||
├── __init__.py
|
||||
├── post_processor.py # 检测结果后处理
|
||||
├── crowd_analyzer.py # 人群分析器
|
||||
└── alert_rules.py # 告警规则判断
|
||||
```
|
||||
|
||||
**示例代码:**
|
||||
```python
|
||||
# crowd_detection/processors/alert_rules.py
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
class CrowdAlertRules:
|
||||
"""人群检测告警规则"""
|
||||
|
||||
def __init__(self):
|
||||
self.alert_history = []
|
||||
self.cooldown_minutes = 5
|
||||
|
||||
def check_crowd_gathering(self, person_count, density, duration_seconds):
|
||||
"""
|
||||
检查是否触发人群聚集告警
|
||||
|
||||
Args:
|
||||
person_count: 检测到的人数
|
||||
density: 人群密度值
|
||||
duration_seconds: 持续时长
|
||||
|
||||
Returns:
|
||||
alert: 告警信息或None
|
||||
"""
|
||||
# 规则:人数>20 且 密度>0.5 且 持续>30秒
|
||||
if person_count > 20 and density > 0.5 and duration_seconds > 30:
|
||||
if self._can_trigger_alert("crowd_gathering"):
|
||||
alert = {
|
||||
"type": "crowd_gathering",
|
||||
"level": "high" if person_count > 50 else "medium",
|
||||
"message": f"检测到人群聚集,人数: {person_count}",
|
||||
"timestamp": datetime.now()
|
||||
}
|
||||
self._record_alert("crowd_gathering")
|
||||
return alert
|
||||
return None
|
||||
|
||||
def check_intrusion(self, detections, restricted_zones):
|
||||
"""
|
||||
检查是否触发区域入侵告警
|
||||
|
||||
Args:
|
||||
detections: 检测结果
|
||||
restricted_zones: 限制区域列表 [(x1,y1,x2,y2), ...]
|
||||
"""
|
||||
alerts = []
|
||||
for det in detections:
|
||||
person_box = (det[0], det[1], det[2], det[3])
|
||||
for zone in restricted_zones:
|
||||
if self._is_intersect(person_box, zone):
|
||||
alerts.append({
|
||||
"type": "zone_intrusion",
|
||||
"level": "high",
|
||||
"message": "检测到人员进入限制区域"
|
||||
})
|
||||
return alerts
|
||||
|
||||
def _can_trigger_alert(self, alert_type):
|
||||
"""检查是否可以通过冷却期触发告警"""
|
||||
cutoff_time = datetime.now() - timedelta(minutes=self.cooldown_minutes)
|
||||
recent_alerts = [
|
||||
a for a in self.alert_history
|
||||
if a["type"] == alert_type and a["timestamp"] > cutoff_time
|
||||
]
|
||||
return len(recent_alerts) == 0
|
||||
|
||||
def _record_alert(self, alert_type):
|
||||
"""记录告警历史"""
|
||||
self.alert_history.append({
|
||||
"type": alert_type,
|
||||
"timestamp": datetime.now()
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
def _is_intersect(box1, box2):
|
||||
"""判断两个框是否相交"""
|
||||
x1 = max(box1[0], box2[0])
|
||||
y1 = max(box1[1], box2[1])
|
||||
x2 = min(box1[2], box2[2])
|
||||
y2 = min(box1[3], box2[3])
|
||||
return x1 < x2 and y1 < y2
|
||||
```
|
||||
|
||||
### 目录选择建议
|
||||
|
||||
| 场景 | 推荐目录 | 说明 |
|
||||
|------|----------|------|
|
||||
| 实现新的检测算法(如密度估计、行为识别) | `algorithms/` | 独立的算法逻辑,可复用 |
|
||||
| 对检测结果进行过滤、分析 | `processors/` | 针对业务场景的后处理 |
|
||||
| 简单的工具函数 | `utils/` | 辅助函数,无状态逻辑 |
|
||||
|
||||
## 注意
|
||||
|
||||
模型文件较大,未包含在 Git 仓库中。请从原始位置复制或创建符号链接。
|
||||
|
||||
9
models/loitering_detection/algorithms/__init__.py
Normal file
9
models/loitering_detection/algorithms/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
徘徊检测算法模块
|
||||
包含基于位置和基于跟踪ID的检测算法
|
||||
"""
|
||||
|
||||
from .stationary_detector import PositionBasedStationaryDetector
|
||||
from .loitering_detector import LoiteringDetector
|
||||
|
||||
__all__ = ['PositionBasedStationaryDetector', 'LoiteringDetector']
|
||||
251
models/loitering_detection/algorithms/loitering_detector.py
Normal file
251
models/loitering_detection/algorithms/loitering_detector.py
Normal file
@@ -0,0 +1,251 @@
|
||||
"""
|
||||
基于跟踪ID的徘徊检测算法
|
||||
依赖跟踪ID,适用于跟踪稳定的场景
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
@dataclass
|
||||
class PersonTrack:
|
||||
"""人员跟踪记录"""
|
||||
person_id: int
|
||||
first_seen: float
|
||||
last_seen: float
|
||||
positions: List[Tuple[int, int]] = field(default_factory=list)
|
||||
last_position: Optional[Tuple[int, int]] = None
|
||||
stationary_start: Optional[float] = None
|
||||
total_duration: float = 0.0
|
||||
stationary_duration: float = 0.0
|
||||
|
||||
|
||||
class LoiteringDetector:
|
||||
"""
|
||||
徘徊检测器(基于跟踪ID)
|
||||
|
||||
特点:
|
||||
- 依赖跟踪 ID,需要稳定的跟踪器
|
||||
- 可以检测长时间停留(徘徊)
|
||||
- 可以检测静止不动(静止)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
loitering_threshold: float = 300.0, # 徘徊阈值(秒),默认5分钟
|
||||
stationary_threshold: float = 2.0, # 静止阈值(秒)
|
||||
movement_threshold: float = 5.0, # 移动阈值(像素)
|
||||
cleanup_interval: float = 10.0 # 清理间隔(秒)
|
||||
):
|
||||
self.loitering_threshold = loitering_threshold
|
||||
self.stationary_threshold = stationary_threshold
|
||||
self.movement_threshold = movement_threshold
|
||||
self.cleanup_interval = cleanup_interval
|
||||
|
||||
# 跟踪记录: {person_id: PersonTrack}
|
||||
self._tracks: Dict[int, PersonTrack] = {}
|
||||
self._last_cleanup = time.time()
|
||||
|
||||
def _cleanup_old_tracks(self, max_age: float = 60.0) -> int:
|
||||
"""清理长时间未更新的跟踪记录"""
|
||||
current_time = time.time()
|
||||
to_remove = [
|
||||
pid for pid, track in self._tracks.items()
|
||||
if current_time - track.last_seen > max_age
|
||||
]
|
||||
|
||||
for pid in to_remove:
|
||||
del self._tracks[pid]
|
||||
|
||||
return len(to_remove)
|
||||
|
||||
def update(
|
||||
self,
|
||||
person_id: int,
|
||||
position: Tuple[int, int]
|
||||
) -> Tuple[bool, float, bool, float]:
|
||||
"""
|
||||
更新人员位置
|
||||
|
||||
Args:
|
||||
person_id: 人员ID
|
||||
position: (x, y) 中心点坐标
|
||||
|
||||
Returns:
|
||||
is_loitering: 是否徘徊超过阈值
|
||||
loitering_duration: 徘徊时长(秒)
|
||||
is_stationary: 是否静止超过阈值
|
||||
stationary_duration: 静止时长(秒)
|
||||
"""
|
||||
current_time = time.time()
|
||||
|
||||
# 定期清理
|
||||
if current_time - self._last_cleanup > self.cleanup_interval:
|
||||
self._cleanup_old_tracks()
|
||||
self._last_cleanup = current_time
|
||||
|
||||
# 获取或创建跟踪记录
|
||||
if person_id not in self._tracks:
|
||||
self._tracks[person_id] = PersonTrack(
|
||||
person_id=person_id,
|
||||
first_seen=current_time,
|
||||
last_seen=current_time,
|
||||
last_position=position
|
||||
)
|
||||
return False, 0.0, False, 0.0
|
||||
|
||||
track = self._tracks[person_id]
|
||||
track.last_seen = current_time
|
||||
track.positions.append(position)
|
||||
|
||||
# 计算总停留时长
|
||||
track.total_duration = current_time - track.first_seen
|
||||
|
||||
# 检查是否移动
|
||||
is_moving = False
|
||||
if track.last_position is not None:
|
||||
distance = ((position[0] - track.last_position[0]) ** 2 +
|
||||
(position[1] - track.last_position[1]) ** 2) ** 0.5
|
||||
is_moving = distance > self.movement_threshold
|
||||
|
||||
track.last_position = position
|
||||
|
||||
# 更新静止状态
|
||||
if is_moving:
|
||||
# 如果移动了,重置静止计时
|
||||
track.stationary_start = None
|
||||
track.stationary_duration = 0.0
|
||||
else:
|
||||
# 如果没移动,更新静止时长
|
||||
if track.stationary_start is None:
|
||||
track.stationary_start = current_time
|
||||
track.stationary_duration = current_time - track.stationary_start
|
||||
|
||||
# 判断是否徘徊/静止
|
||||
is_loitering = track.total_duration > self.loitering_threshold
|
||||
is_stationary = track.stationary_duration > self.stationary_threshold
|
||||
|
||||
return (
|
||||
is_loitering,
|
||||
track.total_duration,
|
||||
is_stationary,
|
||||
track.stationary_duration
|
||||
)
|
||||
|
||||
def detect(
|
||||
self,
|
||||
detections: List[Dict],
|
||||
id_key: str = 'track_id'
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
批量检测徘徊状态
|
||||
|
||||
Args:
|
||||
detections: 检测结果列表,每项包含 'bbox' 和 track_id
|
||||
id_key: 跟踪ID的字段名
|
||||
|
||||
Returns:
|
||||
添加 'loitering_info' 字段的检测结果
|
||||
"""
|
||||
results = []
|
||||
|
||||
for det in detections:
|
||||
person_id = det.get(id_key)
|
||||
if person_id is None:
|
||||
results.append(det)
|
||||
continue
|
||||
|
||||
x1, y1, x2, y2 = det['bbox']
|
||||
center = ((x1 + x2) // 2, (y1 + y2) // 2)
|
||||
|
||||
is_loitering, loitering_duration, is_stationary, stationary_duration = \
|
||||
self.update(person_id, center)
|
||||
|
||||
det_copy = det.copy()
|
||||
det_copy['loitering_info'] = {
|
||||
'person_id': person_id,
|
||||
'is_loitering': is_loitering,
|
||||
'loitering_duration': round(loitering_duration, 2),
|
||||
'is_stationary': is_stationary,
|
||||
'stationary_duration': round(stationary_duration, 2),
|
||||
'loitering_threshold': self.loitering_threshold,
|
||||
'stationary_threshold': self.stationary_threshold
|
||||
}
|
||||
results.append(det_copy)
|
||||
|
||||
return results
|
||||
|
||||
def get_all_loitering(
|
||||
self,
|
||||
threshold: Optional[float] = None
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
获取所有徘徊超过阈值的人员
|
||||
|
||||
Args:
|
||||
threshold: 徘徊阈值(秒),默认使用初始化时的阈值
|
||||
|
||||
Returns:
|
||||
list: [{person_id, duration, positions}, ...]
|
||||
"""
|
||||
threshold = threshold or self.loitering_threshold
|
||||
|
||||
result = []
|
||||
for person_id, track in self._tracks.items():
|
||||
if track.total_duration > threshold:
|
||||
result.append({
|
||||
'person_id': person_id,
|
||||
'duration': track.total_duration,
|
||||
'positions': track.positions.copy(),
|
||||
'is_stationary': track.stationary_duration > self.stationary_threshold,
|
||||
'stationary_duration': track.stationary_duration
|
||||
})
|
||||
|
||||
# 按时长排序
|
||||
result.sort(key=lambda x: x['duration'], reverse=True)
|
||||
return result
|
||||
|
||||
def get_all_stationary(
|
||||
self,
|
||||
threshold: Optional[float] = None
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
获取所有静止超过阈值的人员
|
||||
|
||||
Args:
|
||||
threshold: 静止阈值(秒),默认使用初始化时的阈值
|
||||
|
||||
Returns:
|
||||
list: [{person_id, duration, position}, ...]
|
||||
"""
|
||||
threshold = threshold or self.stationary_threshold
|
||||
|
||||
result = []
|
||||
for person_id, track in self._tracks.items():
|
||||
if track.stationary_duration > threshold:
|
||||
result.append({
|
||||
'person_id': person_id,
|
||||
'duration': track.stationary_duration,
|
||||
'position': track.last_position,
|
||||
'total_duration': track.total_duration
|
||||
})
|
||||
|
||||
result.sort(key=lambda x: x['duration'], reverse=True)
|
||||
return result
|
||||
|
||||
def reset(self):
|
||||
"""重置所有跟踪数据"""
|
||||
self._tracks.clear()
|
||||
self._last_cleanup = time.time()
|
||||
|
||||
def get_stats(self) -> Dict:
|
||||
"""获取统计信息"""
|
||||
return {
|
||||
'total_tracks': len(self._tracks),
|
||||
'loitering_count': len(self.get_all_loitering()),
|
||||
'stationary_count': len(self.get_all_stationary()),
|
||||
'loitering_threshold': self.loitering_threshold,
|
||||
'stationary_threshold': self.stationary_threshold
|
||||
}
|
||||
236
models/loitering_detection/algorithms/stationary_detector.py
Normal file
236
models/loitering_detection/algorithms/stationary_detector.py
Normal file
@@ -0,0 +1,236 @@
|
||||
"""
|
||||
基于位置的静止人员检测算法
|
||||
不依赖跟踪 ID,而是根据位置来关联人员
|
||||
适用于跟踪不稳定但人员相对静止的场景
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
@dataclass
|
||||
class PositionRecord:
|
||||
"""位置记录"""
|
||||
first_seen: float
|
||||
last_seen: float
|
||||
center: Tuple[int, int]
|
||||
box: Tuple[int, int, int, int]
|
||||
duration: float = 0.0
|
||||
|
||||
|
||||
class PositionBasedStationaryDetector:
|
||||
"""
|
||||
基于位置的静止检测器
|
||||
|
||||
特点:
|
||||
- 不依赖跟踪 ID,直接用位置关联人员
|
||||
- 适用于 SORT 等跟踪器不稳定的场景
|
||||
- 使用网格化位置 + 距离容差进行匹配
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stationary_threshold: float = 10.0, # 静止阈值(秒)
|
||||
position_tolerance: int = 50, # 位置容差(像素)
|
||||
cleanup_interval: float = 5.0 # 清理间隔(秒)
|
||||
):
|
||||
self.stationary_threshold = stationary_threshold
|
||||
self.position_tolerance = position_tolerance
|
||||
self.cleanup_interval = cleanup_interval
|
||||
|
||||
# 位置历史记录: {position_key: PositionRecord}
|
||||
self._position_history: Dict[Tuple[int, int], PositionRecord] = {}
|
||||
self._last_cleanup = time.time()
|
||||
|
||||
def _get_position_key(self, center: Tuple[int, int]) -> Tuple[int, int]:
|
||||
"""
|
||||
将连续坐标转换为离散的位置键
|
||||
用于将相近位置归为一类
|
||||
"""
|
||||
x, y = center
|
||||
grid_x = int(x / self.position_tolerance)
|
||||
grid_y = int(y / self.position_tolerance)
|
||||
return (grid_x, grid_y)
|
||||
|
||||
def _find_matching_position(
|
||||
self,
|
||||
center: Tuple[int, int]
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
查找与当前位置匹配的历史位置
|
||||
返回匹配的位置键,如果没有则返回 None
|
||||
"""
|
||||
current_key = self._get_position_key(center)
|
||||
|
||||
# 首先检查精确匹配
|
||||
if current_key in self._position_history:
|
||||
hist_center = self._position_history[current_key].center
|
||||
distance = ((center[0] - hist_center[0]) ** 2 +
|
||||
(center[1] - hist_center[1]) ** 2) ** 0.5
|
||||
if distance < self.position_tolerance:
|
||||
return current_key
|
||||
|
||||
# 检查相邻网格
|
||||
for dx in [-1, 0, 1]:
|
||||
for dy in [-1, 0, 1]:
|
||||
if dx == 0 and dy == 0:
|
||||
continue
|
||||
neighbor_key = (current_key[0] + dx, current_key[1] + dy)
|
||||
if neighbor_key in self._position_history:
|
||||
hist_center = self._position_history[neighbor_key].center
|
||||
distance = ((center[0] - hist_center[0]) ** 2 +
|
||||
(center[1] - hist_center[1]) ** 2) ** 0.5
|
||||
if distance < self.position_tolerance:
|
||||
return neighbor_key
|
||||
|
||||
return None
|
||||
|
||||
def update(
|
||||
self,
|
||||
center: Tuple[int, int],
|
||||
box: Tuple[int, int, int, int]
|
||||
) -> Tuple[str, float, bool]:
|
||||
"""
|
||||
更新位置信息
|
||||
|
||||
Args:
|
||||
center: (x, y) 中心点坐标
|
||||
box: (x1, y1, x2, y2) 边界框
|
||||
|
||||
Returns:
|
||||
position_id: 位置 ID(用于关联)
|
||||
stationary_duration: 静止时长(秒)
|
||||
is_stationary: 是否静止超过阈值
|
||||
"""
|
||||
current_time = time.time()
|
||||
|
||||
# 定期清理旧记录
|
||||
if current_time - self._last_cleanup > self.cleanup_interval:
|
||||
self.cleanup_old_positions()
|
||||
self._last_cleanup = current_time
|
||||
|
||||
# 查找匹配的历史位置
|
||||
matching_key = self._find_matching_position(center)
|
||||
|
||||
if matching_key is not None:
|
||||
# 更新已有位置
|
||||
record = self._position_history[matching_key]
|
||||
record.last_seen = current_time
|
||||
|
||||
# 平滑更新中心位置(使用移动平均)
|
||||
old_center = record.center
|
||||
record.center = (
|
||||
int(0.7 * old_center[0] + 0.3 * center[0]),
|
||||
int(0.7 * old_center[1] + 0.3 * center[1])
|
||||
)
|
||||
record.box = box
|
||||
|
||||
duration = current_time - record.first_seen
|
||||
record.duration = duration
|
||||
|
||||
is_stationary = duration > self.stationary_threshold
|
||||
position_id = f"pos_{matching_key[0]}_{matching_key[1]}"
|
||||
|
||||
return position_id, duration, is_stationary
|
||||
else:
|
||||
# 创建新位置记录
|
||||
new_key = self._get_position_key(center)
|
||||
self._position_history[new_key] = PositionRecord(
|
||||
first_seen=current_time,
|
||||
last_seen=current_time,
|
||||
center=center,
|
||||
box=box,
|
||||
duration=0.0
|
||||
)
|
||||
new_id = f"pos_{new_key[0]}_{new_key[1]}"
|
||||
return new_id, 0.0, False
|
||||
|
||||
def cleanup_old_positions(self, max_age: float = 5.0) -> int:
|
||||
"""
|
||||
清理长时间未更新的位置记录
|
||||
|
||||
Args:
|
||||
max_age: 最大保留时间(秒)
|
||||
|
||||
Returns:
|
||||
清理的记录数量
|
||||
"""
|
||||
current_time = time.time()
|
||||
to_remove = [
|
||||
key for key, data in self._position_history.items()
|
||||
if current_time - data.last_seen > max_age
|
||||
]
|
||||
|
||||
for key in to_remove:
|
||||
del self._position_history[key]
|
||||
|
||||
return len(to_remove)
|
||||
|
||||
def get_all_stationary(
|
||||
self,
|
||||
threshold: Optional[float] = None
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
获取所有静止超过阈值的位置
|
||||
|
||||
Args:
|
||||
threshold: 静止阈值(秒),默认使用初始化时的阈值
|
||||
|
||||
Returns:
|
||||
list: [{position_id, duration, center, box}, ...]
|
||||
"""
|
||||
threshold = threshold or self.stationary_threshold
|
||||
|
||||
result = []
|
||||
for key, data in self._position_history.items():
|
||||
if data.duration > threshold:
|
||||
result.append({
|
||||
'position_id': f"pos_{key[0]}_{key[1]}",
|
||||
'duration': data.duration,
|
||||
'center': data.center,
|
||||
'box': data.box
|
||||
})
|
||||
|
||||
# 按时长排序
|
||||
result.sort(key=lambda x: x['duration'], reverse=True)
|
||||
return result
|
||||
|
||||
def reset(self):
|
||||
"""重置所有跟踪数据"""
|
||||
self._position_history.clear()
|
||||
self._last_cleanup = time.time()
|
||||
|
||||
def detect(
|
||||
self,
|
||||
detections: List[Dict]
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
批量检测静止状态
|
||||
|
||||
Args:
|
||||
detections: 检测结果列表,每项包含 'bbox': [x1, y1, x2, y2]
|
||||
|
||||
Returns:
|
||||
添加 'stationary_info' 字段的检测结果
|
||||
"""
|
||||
results = []
|
||||
|
||||
for det in detections:
|
||||
x1, y1, x2, y2 = det['bbox']
|
||||
center = ((x1 + x2) // 2, (y1 + y2) // 2)
|
||||
box = (x1, y1, x2, y2)
|
||||
|
||||
position_id, duration, is_stationary = self.update(center, box)
|
||||
|
||||
det_copy = det.copy()
|
||||
det_copy['stationary_info'] = {
|
||||
'position_id': position_id,
|
||||
'duration': round(duration, 2),
|
||||
'is_stationary': is_stationary,
|
||||
'threshold': self.stationary_threshold
|
||||
}
|
||||
results.append(det_copy)
|
||||
|
||||
return results
|
||||
8
models/loitering_detection/processors/__init__.py
Normal file
8
models/loitering_detection/processors/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
徘徊检测处理器模块
|
||||
用于对检测结果进行后处理
|
||||
"""
|
||||
|
||||
from .behavior_processor import BehaviorProcessor
|
||||
|
||||
__all__ = ['BehaviorProcessor']
|
||||
201
models/loitering_detection/processors/behavior_processor.py
Normal file
201
models/loitering_detection/processors/behavior_processor.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""
|
||||
行为检测处理器
|
||||
集成基于位置和基于跟踪ID的检测算法
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
# 添加算法模块路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
from algorithms import PositionBasedStationaryDetector, LoiteringDetector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BehaviorAlert:
|
||||
"""行为告警"""
|
||||
alert_type: str # 'stationary', 'loitering'
|
||||
level: str # 'low', 'medium', 'high'
|
||||
message: str
|
||||
person_id: Optional[str] = None
|
||||
position_id: Optional[str] = None
|
||||
duration: float = 0.0
|
||||
bbox: Optional[Tuple[int, int, int, int]] = None
|
||||
|
||||
|
||||
class BehaviorProcessor:
|
||||
"""
|
||||
行为检测处理器
|
||||
|
||||
整合两种检测方式:
|
||||
1. 基于位置的静止检测(无需跟踪ID)
|
||||
2. 基于跟踪ID的徘徊检测(需要跟踪ID)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# 静止检测参数
|
||||
stationary_threshold: float = 10.0,
|
||||
position_tolerance: int = 50,
|
||||
|
||||
# 徘徊检测参数
|
||||
loitering_threshold: float = 300.0,
|
||||
movement_threshold: float = 5.0,
|
||||
|
||||
# 告警参数
|
||||
enable_stationary_alert: bool = True,
|
||||
enable_loitering_alert: bool = True,
|
||||
stationary_alert_threshold: float = 10.0, # 超过此时间产生告警
|
||||
loitering_alert_threshold: float = 300.0 # 超过此时间产生告警
|
||||
):
|
||||
# 初始化检测器
|
||||
self.stationary_detector = PositionBasedStationaryDetector(
|
||||
stationary_threshold=stationary_threshold,
|
||||
position_tolerance=position_tolerance
|
||||
)
|
||||
|
||||
self.loitering_detector = LoiteringDetector(
|
||||
loitering_threshold=loitering_threshold,
|
||||
stationary_threshold=stationary_threshold,
|
||||
movement_threshold=movement_threshold
|
||||
)
|
||||
|
||||
# 配置
|
||||
self.enable_stationary_alert = enable_stationary_alert
|
||||
self.enable_loitering_alert = enable_loitering_alert
|
||||
self.stationary_alert_threshold = stationary_alert_threshold
|
||||
self.loitering_alert_threshold = loitering_alert_threshold
|
||||
|
||||
def process(
|
||||
self,
|
||||
detections: List[Dict],
|
||||
use_tracking: bool = False,
|
||||
track_id_key: str = 'track_id'
|
||||
) -> Dict:
|
||||
"""
|
||||
处理检测结果,检测行为
|
||||
|
||||
Args:
|
||||
detections: 检测结果列表
|
||||
use_tracking: 是否使用跟踪ID(如果有的话)
|
||||
track_id_key: 跟踪ID字段名
|
||||
|
||||
Returns:
|
||||
{
|
||||
'detections': 添加行为信息的检测结果,
|
||||
'alerts': 触发的告警列表,
|
||||
'stats': 统计信息
|
||||
}
|
||||
"""
|
||||
logger.info(f"[BehaviorProcessor] 开始处理 {len(detections)} 个检测结果")
|
||||
logger.info(f"[BehaviorProcessor] 配置: stationary={self.enable_stationary_alert}, loitering={self.enable_loitering_alert}")
|
||||
|
||||
alerts = []
|
||||
|
||||
# 1. 始终进行基于位置的静止检测
|
||||
logger.info(f"[BehaviorProcessor] 调用静止检测器...")
|
||||
detections = self.stationary_detector.detect(detections)
|
||||
logger.info(f"[BehaviorProcessor] 静止检测完成,检测到 {len(detections)} 个结果")
|
||||
|
||||
# 检查静止告警
|
||||
stationary_alerts = 0
|
||||
if self.enable_stationary_alert:
|
||||
for det in detections:
|
||||
info = det.get('stationary_info', {})
|
||||
if info.get('is_stationary') and info.get('duration', 0) >= self.stationary_alert_threshold:
|
||||
alert = BehaviorAlert(
|
||||
alert_type='stationary',
|
||||
level='medium' if info['duration'] < 30 else 'high',
|
||||
message=f"人员静止停留 {int(info['duration'])} 秒",
|
||||
position_id=info.get('position_id'),
|
||||
duration=info['duration'],
|
||||
bbox=tuple(det['bbox'])
|
||||
)
|
||||
alerts.append(alert)
|
||||
stationary_alerts += 1
|
||||
logger.info(f"[BehaviorProcessor] 静止告警: {stationary_alerts} 个")
|
||||
|
||||
# 2. 如果有跟踪ID,进行徘徊检测
|
||||
logger.info(f"[BehaviorProcessor] use_tracking={use_tracking}")
|
||||
if use_tracking:
|
||||
detections = self.loitering_detector.detect(detections, id_key=track_id_key)
|
||||
|
||||
# 检查徘徊告警
|
||||
if self.enable_loitering_alert:
|
||||
for det in detections:
|
||||
info = det.get('loitering_info', {})
|
||||
if info.get('is_loitering') and info.get('loitering_duration', 0) >= self.loitering_alert_threshold:
|
||||
alert = BehaviorAlert(
|
||||
alert_type='loitering',
|
||||
level='high',
|
||||
message=f"人员徘徊 {int(info['loitering_duration'] // 60)} 分钟",
|
||||
person_id=str(info.get('person_id')),
|
||||
duration=info['loitering_duration'],
|
||||
bbox=tuple(det['bbox'])
|
||||
)
|
||||
alerts.append(alert)
|
||||
|
||||
# 统计信息
|
||||
stats = {
|
||||
'total_detections': len(detections),
|
||||
'stationary_count': len(self.stationary_detector.get_all_stationary()),
|
||||
'alert_count': len(alerts)
|
||||
}
|
||||
|
||||
if use_tracking:
|
||||
stats.update({
|
||||
'loitering_count': len(self.loitering_detector.get_all_loitering()),
|
||||
'tracking_count': self.loitering_detector.get_stats()['total_tracks']
|
||||
})
|
||||
|
||||
logger.info(f"[BehaviorProcessor] 处理完成: {stats}")
|
||||
|
||||
return {
|
||||
'detections': detections,
|
||||
'alerts': [self._alert_to_dict(a) for a in alerts],
|
||||
'stats': stats
|
||||
}
|
||||
|
||||
def _alert_to_dict(self, alert: BehaviorAlert) -> Dict:
|
||||
"""将告警对象转换为字典"""
|
||||
return {
|
||||
'type': alert.alert_type,
|
||||
'level': alert.level,
|
||||
'message': alert.message,
|
||||
'person_id': alert.person_id,
|
||||
'position_id': alert.position_id,
|
||||
'duration': round(alert.duration, 2),
|
||||
'bbox': alert.bbox
|
||||
}
|
||||
|
||||
def get_stationary_persons(self) -> List[Dict]:
|
||||
"""获取所有静止人员"""
|
||||
return self.stationary_detector.get_all_stationary()
|
||||
|
||||
def get_loitering_persons(self) -> List[Dict]:
|
||||
"""获取所有徘徊人员"""
|
||||
return self.loitering_detector.get_all_loitering()
|
||||
|
||||
def reset(self):
|
||||
"""重置所有检测器"""
|
||||
self.stationary_detector.reset()
|
||||
self.loitering_detector.reset()
|
||||
|
||||
def get_config(self) -> Dict:
|
||||
"""获取当前配置"""
|
||||
return {
|
||||
'stationary_threshold': self.stationary_detector.stationary_threshold,
|
||||
'position_tolerance': self.stationary_detector.position_tolerance,
|
||||
'loitering_threshold': self.loitering_detector.loitering_threshold,
|
||||
'movement_threshold': self.loitering_detector.movement_threshold,
|
||||
'enable_stationary_alert': self.enable_stationary_alert,
|
||||
'enable_loitering_alert': self.enable_loitering_alert,
|
||||
'stationary_alert_threshold': self.stationary_alert_threshold,
|
||||
'loitering_alert_threshold': self.loitering_alert_threshold
|
||||
}
|
||||
Reference in New Issue
Block a user