"""目标跟踪服务 (MVP-2 / D19) 基于 ByteTrack 思想的简化跟踪器,纯 Python 实现,无需额外依赖。 核心算法 (ByteTrack 论文要点): 1. 按置信度将检测分为 high (>= high_thresh) / low (< high_thresh) 2. 先用 high 与现有 tracks 进行 IOU 匹配 (匈牙利算法 / 贪心) 3. 未匹配的 tracks 再与 low 检测匹配 (拯救低置信度真实目标) 4. 仍未匹配的 high 检测创建新 track 5. 未匹配的 tracks 进入 Lost 状态,超过 max_lost_frames 移除 简化点 (相对原版): - 不使用 Kalman Filter,只用上一帧 bbox 与当前检测 IOU 匹配 - 匹配算法用贪心 (按 IOU 降序) 替代匈牙利 - 不支持外观特征 (ReID) 集成方式:: tracker = ByteTracker() for frame_detections in stream: tracked = tracker.update(frame_detections) # tracked 中每个 UnifiedDetection 会被填充 track_id """ from __future__ import annotations import logging from dataclasses import dataclass, field from enum import Enum from typing import Dict, List, Optional, Tuple from models.event_schemas import BBox, UnifiedDetection logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Track 状态 # --------------------------------------------------------------------------- class TrackState(str, Enum): """跟踪状态。""" NEW = "new" TRACKED = "tracked" LOST = "lost" REMOVED = "removed" @dataclass class Track: """单个目标跟踪。""" track_id: int bbox: BBox confidence: float class_name: str state: TrackState = TrackState.NEW age: int = 0 # 总存活帧数 lost_frames: int = 0 # 连续丢失帧数 last_update_frame: int = 0 # --------------------------------------------------------------------------- # 工具函数: IOU # --------------------------------------------------------------------------- def compute_iou(b1: BBox, b2: BBox) -> float: """计算两个 BBox 的 IOU。""" inter_x1 = max(b1.x1, b2.x1) inter_y1 = max(b1.y1, b2.y1) inter_x2 = min(b1.x2, b2.x2) inter_y2 = min(b1.y2, b2.y2) inter_w = max(0, inter_x2 - inter_x1) inter_h = max(0, inter_y2 - inter_y1) inter = inter_w * inter_h if inter == 0: return 0.0 area1 = b1.area area2 = b2.area union = area1 + area2 - inter return inter / union if union > 0 else 0.0 def greedy_match( cost_matrix: List[List[float]], threshold: float, ) -> List[Tuple[int, int]]: """贪心匹配: 按 IOU 降序选择不冲突的最佳匹配。 Args: cost_matrix: cost_matrix[i][j] = IOU 越大越好 threshold: 低于该阈值不匹配 Returns: 匹配对列表 [(row_idx, col_idx), ...] """ matches: List[Tuple[int, int]] = [] if not cost_matrix or not cost_matrix[0]: return matches candidates: List[Tuple[float, int, int]] = [] for i, row in enumerate(cost_matrix): for j, score in enumerate(row): if score >= threshold: candidates.append((score, i, j)) candidates.sort(reverse=True) used_rows: set = set() used_cols: set = set() for score, i, j in candidates: if i in used_rows or j in used_cols: continue matches.append((i, j)) used_rows.add(i) used_cols.add(j) return matches # --------------------------------------------------------------------------- # ByteTracker # --------------------------------------------------------------------------- class ByteTracker: """简化版 ByteTrack 跟踪器。 Args: track_thresh: 跟踪基础置信度阈值,低于此值不参与跟踪 high_thresh: 高置信度阈值,分桶用 match_thresh: IOU 匹配阈值 max_lost_frames: 跟踪丢失最大帧数,超过移除 min_box_area: 最小框面积,低于此值忽略 """ def __init__( self, track_thresh: float = 0.5, high_thresh: float = 0.6, match_thresh: float = 0.8, max_lost_frames: int = 30, min_box_area: float = 10.0, ) -> None: self.track_thresh = track_thresh self.high_thresh = high_thresh self.match_thresh = match_thresh self.max_lost_frames = max_lost_frames self.min_box_area = min_box_area self._tracks: Dict[int, Track] = {} self._next_id: int = 1 self._frame_count: int = 0 # ------------------------------------------------------------------ # 主入口 # ------------------------------------------------------------------ def update( self, detections: List[UnifiedDetection], ) -> List[UnifiedDetection]: """对一帧检测结果执行跟踪更新,填充 track_id。 Returns: 带 track_id 的检测结果 (顺序保持不变) """ self._frame_count += 1 # 1. 过滤无效检测 (面积太小) valid: List[Tuple[int, UnifiedDetection]] = [] for idx, det in enumerate(detections): if det.bbox.area < self.min_box_area: continue valid.append((idx, det)) # 2. 按置信度分桶 high_dets: List[Tuple[int, UnifiedDetection]] = [] low_dets: List[Tuple[int, UnifiedDetection]] = [] for idx, det in valid: if det.confidence >= self.high_thresh: high_dets.append((idx, det)) elif det.confidence >= self.track_thresh: low_dets.append((idx, det)) # < track_thresh 直接忽略 # 3. 获取当前 tracked + lost 的 tracks 列表 active_tracks: List[Track] = [ t for t in self._tracks.values() if t.state in (TrackState.TRACKED, TrackState.LOST, TrackState.NEW) ] # 4. 第一轮: 与 high 检测匹配 matches_high, unmatched_tracks, unmatched_high = self._match( active_tracks, [d for _, d in high_dets], self.match_thresh ) for track_idx, det_idx in matches_high: track = active_tracks[track_idx] orig_idx, det = high_dets[det_idx] self._update_track(track, det) det.track_id = track.track_id # 5. 第二轮: 未匹配 tracks 与 low 检测匹配 remaining_tracks = [active_tracks[i] for i in unmatched_tracks] matches_low, still_unmatched, unmatched_low = self._match( remaining_tracks, [d for _, d in low_dets], 0.5 # 低置信度用更宽松阈值 ) for track_idx, det_idx in matches_low: track = remaining_tracks[track_idx] orig_idx, det = low_dets[det_idx] self._update_track(track, det) det.track_id = track.track_id # 6. 未匹配的 high 检测创建新 track for det_idx in unmatched_high: orig_idx, det = high_dets[det_idx] new_track = self._create_track(det) det.track_id = new_track.track_id # 7. 未匹配的 tracks 增加 lost_frames unmatched_original_tracks = [remaining_tracks[i] for i in still_unmatched] for track in unmatched_original_tracks: track.lost_frames += 1 track.state = TrackState.LOST if track.lost_frames > self.max_lost_frames: track.state = TrackState.REMOVED # 8. 清理已移除的 tracks self._tracks = { tid: t for tid, t in self._tracks.items() if t.state != TrackState.REMOVED } return detections # ------------------------------------------------------------------ # 内部 # ------------------------------------------------------------------ def _match( self, tracks: List[Track], detections: List[UnifiedDetection], threshold: float, ) -> Tuple[List[Tuple[int, int]], List[int], List[int]]: """计算 tracks 与 detections 的 IOU 匹配。 Returns: (matches, unmatched_track_indices, unmatched_det_indices) """ if not tracks or not detections: return [], list(range(len(tracks))), list(range(len(detections))) # 构建 IOU 矩阵 iou_matrix: List[List[float]] = [] for t in tracks: row = [compute_iou(t.bbox, d.bbox) for d in detections] iou_matrix.append(row) matches = greedy_match(iou_matrix, threshold) matched_tracks = {i for i, _ in matches} matched_dets = {j for _, j in matches} unmatched_tracks = [i for i in range(len(tracks)) if i not in matched_tracks] unmatched_dets = [j for j in range(len(detections)) if j not in matched_dets] return matches, unmatched_tracks, unmatched_dets def _create_track(self, det: UnifiedDetection) -> Track: track_id = self._next_id self._next_id += 1 track = Track( track_id=track_id, bbox=det.bbox, confidence=det.confidence, class_name=det.class_name, state=TrackState.TRACKED, age=1, lost_frames=0, last_update_frame=self._frame_count, ) self._tracks[track_id] = track return track def _update_track(self, track: Track, det: UnifiedDetection) -> None: track.bbox = det.bbox track.confidence = det.confidence track.age += 1 track.lost_frames = 0 track.state = TrackState.TRACKED track.last_update_frame = self._frame_count # ------------------------------------------------------------------ # 状态 # ------------------------------------------------------------------ @property def active_tracks(self) -> List[Track]: return [t for t in self._tracks.values() if t.state == TrackState.TRACKED] @property def lost_tracks(self) -> List[Track]: return [t for t in self._tracks.values() if t.state == TrackState.LOST] def reset(self) -> None: self._tracks.clear() self._next_id = 1 self._frame_count = 0 # --------------------------------------------------------------------------- # TrackingService: 多流跟踪器管理 # --------------------------------------------------------------------------- class TrackingService: """多流目标跟踪服务。 为每个 source_id (摄像头/视频流) 维护独立的 ByteTracker 实例。 """ def __init__( self, track_thresh: float = 0.5, high_thresh: float = 0.6, match_thresh: float = 0.8, max_lost_frames: int = 30, min_box_area: float = 10.0, ) -> None: self.track_thresh = track_thresh self.high_thresh = high_thresh self.match_thresh = match_thresh self.max_lost_frames = max_lost_frames self.min_box_area = min_box_area self._trackers: Dict[str, ByteTracker] = {} def get_or_create(self, source_id: str) -> ByteTracker: if source_id not in self._trackers: self._trackers[source_id] = ByteTracker( track_thresh=self.track_thresh, high_thresh=self.high_thresh, match_thresh=self.match_thresh, max_lost_frames=self.max_lost_frames, min_box_area=self.min_box_area, ) return self._trackers[source_id] def update( self, source_id: str, detections: List[UnifiedDetection], ) -> List[UnifiedDetection]: tracker = self.get_or_create(source_id) return tracker.update(detections) def reset(self, source_id: Optional[str] = None) -> None: if source_id is None: self._trackers.clear() elif source_id in self._trackers: self._trackers[source_id].reset() def remove(self, source_id: str) -> None: self._trackers.pop(source_id, None) @property def source_count(self) -> int: return len(self._trackers) __all__ = [ "ByteTracker", "TrackingService", "Track", "TrackState", "compute_iou", "greedy_match", ]