""" 基于跟踪ID的徘徊检测算法 依赖跟踪ID,适用于跟踪稳定的场景 """ import time from typing import Dict, List, Tuple, Optional from dataclasses import dataclass, field from collections import defaultdict @dataclass class PersonTrack: """人员跟踪记录""" person_id: int first_seen: float last_seen: float positions: List[Tuple[int, int]] = field(default_factory=list) last_position: Optional[Tuple[int, int]] = None stationary_start: Optional[float] = None total_duration: float = 0.0 stationary_duration: float = 0.0 class LoiteringDetector: """ 徘徊检测器(基于跟踪ID) 特点: - 依赖跟踪 ID,需要稳定的跟踪器 - 可以检测长时间停留(徘徊) - 可以检测静止不动(静止) """ def __init__( self, loitering_threshold: float = 300.0, # 徘徊阈值(秒),默认5分钟 stationary_threshold: float = 2.0, # 静止阈值(秒) movement_threshold: float = 5.0, # 移动阈值(像素) cleanup_interval: float = 10.0 # 清理间隔(秒) ): self.loitering_threshold = loitering_threshold self.stationary_threshold = stationary_threshold self.movement_threshold = movement_threshold self.cleanup_interval = cleanup_interval # 跟踪记录: {person_id: PersonTrack} self._tracks: Dict[int, PersonTrack] = {} self._last_cleanup = time.time() def _cleanup_old_tracks(self, max_age: float = 60.0) -> int: """清理长时间未更新的跟踪记录""" current_time = time.time() to_remove = [ pid for pid, track in self._tracks.items() if current_time - track.last_seen > max_age ] for pid in to_remove: del self._tracks[pid] return len(to_remove) def update( self, person_id: int, position: Tuple[int, int] ) -> Tuple[bool, float, bool, float]: """ 更新人员位置 Args: person_id: 人员ID position: (x, y) 中心点坐标 Returns: is_loitering: 是否徘徊超过阈值 loitering_duration: 徘徊时长(秒) is_stationary: 是否静止超过阈值 stationary_duration: 静止时长(秒) """ current_time = time.time() # 定期清理 if current_time - self._last_cleanup > self.cleanup_interval: self._cleanup_old_tracks() self._last_cleanup = current_time # 获取或创建跟踪记录 if person_id not in self._tracks: self._tracks[person_id] = PersonTrack( person_id=person_id, first_seen=current_time, last_seen=current_time, last_position=position ) return False, 0.0, False, 0.0 track = self._tracks[person_id] track.last_seen = current_time track.positions.append(position) # 计算总停留时长 track.total_duration = current_time - track.first_seen # 检查是否移动 is_moving = False if track.last_position is not None: distance = ((position[0] - track.last_position[0]) ** 2 + (position[1] - track.last_position[1]) ** 2) ** 0.5 is_moving = distance > self.movement_threshold track.last_position = position # 更新静止状态 if is_moving: # 如果移动了,重置静止计时 track.stationary_start = None track.stationary_duration = 0.0 else: # 如果没移动,更新静止时长 if track.stationary_start is None: track.stationary_start = current_time track.stationary_duration = current_time - track.stationary_start # 判断是否徘徊/静止 is_loitering = track.total_duration > self.loitering_threshold is_stationary = track.stationary_duration > self.stationary_threshold return ( is_loitering, track.total_duration, is_stationary, track.stationary_duration ) def detect( self, detections: List[Dict], id_key: str = 'track_id' ) -> List[Dict]: """ 批量检测徘徊状态 Args: detections: 检测结果列表,每项包含 'bbox' 和 track_id id_key: 跟踪ID的字段名 Returns: 添加 'loitering_info' 字段的检测结果 """ results = [] for det in detections: person_id = det.get(id_key) if person_id is None: results.append(det) continue x1, y1, x2, y2 = det['bbox'] center = ((x1 + x2) // 2, (y1 + y2) // 2) is_loitering, loitering_duration, is_stationary, stationary_duration = \ self.update(person_id, center) det_copy = det.copy() det_copy['loitering_info'] = { 'person_id': person_id, 'is_loitering': is_loitering, 'loitering_duration': round(loitering_duration, 2), 'is_stationary': is_stationary, 'stationary_duration': round(stationary_duration, 2), 'loitering_threshold': self.loitering_threshold, 'stationary_threshold': self.stationary_threshold } results.append(det_copy) return results def get_all_loitering( self, threshold: Optional[float] = None ) -> List[Dict]: """ 获取所有徘徊超过阈值的人员 Args: threshold: 徘徊阈值(秒),默认使用初始化时的阈值 Returns: list: [{person_id, duration, positions}, ...] """ threshold = threshold or self.loitering_threshold result = [] for person_id, track in self._tracks.items(): if track.total_duration > threshold: result.append({ 'person_id': person_id, 'duration': track.total_duration, 'positions': track.positions.copy(), 'is_stationary': track.stationary_duration > self.stationary_threshold, 'stationary_duration': track.stationary_duration }) # 按时长排序 result.sort(key=lambda x: x['duration'], reverse=True) return result def get_all_stationary( self, threshold: Optional[float] = None ) -> List[Dict]: """ 获取所有静止超过阈值的人员 Args: threshold: 静止阈值(秒),默认使用初始化时的阈值 Returns: list: [{person_id, duration, position}, ...] """ threshold = threshold or self.stationary_threshold result = [] for person_id, track in self._tracks.items(): if track.stationary_duration > threshold: result.append({ 'person_id': person_id, 'duration': track.stationary_duration, 'position': track.last_position, 'total_duration': track.total_duration }) result.sort(key=lambda x: x['duration'], reverse=True) return result def reset(self): """重置所有跟踪数据""" self._tracks.clear() self._last_cleanup = time.time() def get_stats(self) -> Dict: """获取统计信息""" return { 'total_tracks': len(self._tracks), 'loitering_count': len(self.get_all_loitering()), 'stationary_count': len(self.get_all_stationary()), 'loitering_threshold': self.loitering_threshold, 'stationary_threshold': self.stationary_threshold }