From bf12a29acdae937b1936fe342eedef1e69726e32 Mon Sep 17 00:00:00 2001 From: wuzhuorong <973204353@qq.com> Date: Fri, 12 Jun 2026 13:58:15 +0800 Subject: [PATCH] =?UTF-8?q?feat(server):=20=E6=96=B0=E5=A2=9EByteTrack?= =?UTF-8?q?=E7=9B=AE=E6=A0=87=E8=B7=9F=E8=B8=AA=E6=9C=8D=E5=8A=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/server/services/tracking_service.py | 385 +++++++++++++++++++++++ 1 file changed, 385 insertions(+) create mode 100644 apps/server/services/tracking_service.py diff --git a/apps/server/services/tracking_service.py b/apps/server/services/tracking_service.py new file mode 100644 index 0000000..d7f5461 --- /dev/null +++ b/apps/server/services/tracking_service.py @@ -0,0 +1,385 @@ +"""目标跟踪服务 (MVP-2 / D19) + +基于 ByteTrack 思想的简化跟踪器,纯 Python 实现,无需额外依赖。 + +核心算法 (ByteTrack 论文要点): + +1. 按置信度将检测分为 high (>= high_thresh) / low (< high_thresh) +2. 先用 high 与现有 tracks 进行 IOU 匹配 (匈牙利算法 / 贪心) +3. 未匹配的 tracks 再与 low 检测匹配 (拯救低置信度真实目标) +4. 仍未匹配的 high 检测创建新 track +5. 未匹配的 tracks 进入 Lost 状态,超过 max_lost_frames 移除 + +简化点 (相对原版): + +- 不使用 Kalman Filter,只用上一帧 bbox 与当前检测 IOU 匹配 +- 匹配算法用贪心 (按 IOU 降序) 替代匈牙利 +- 不支持外观特征 (ReID) + +集成方式:: + + tracker = ByteTracker() + for frame_detections in stream: + tracked = tracker.update(frame_detections) + # tracked 中每个 UnifiedDetection 会被填充 track_id +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from enum import Enum +from typing import Dict, List, Optional, Tuple + +from models.event_schemas import BBox, UnifiedDetection + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Track 状态 +# --------------------------------------------------------------------------- + + +class TrackState(str, Enum): + """跟踪状态。""" + + NEW = "new" + TRACKED = "tracked" + LOST = "lost" + REMOVED = "removed" + + +@dataclass +class Track: + """单个目标跟踪。""" + + track_id: int + bbox: BBox + confidence: float + class_name: str + state: TrackState = TrackState.NEW + age: int = 0 # 总存活帧数 + lost_frames: int = 0 # 连续丢失帧数 + last_update_frame: int = 0 + + +# --------------------------------------------------------------------------- +# 工具函数: IOU +# --------------------------------------------------------------------------- + + +def compute_iou(b1: BBox, b2: BBox) -> float: + """计算两个 BBox 的 IOU。""" + + inter_x1 = max(b1.x1, b2.x1) + inter_y1 = max(b1.y1, b2.y1) + inter_x2 = min(b1.x2, b2.x2) + inter_y2 = min(b1.y2, b2.y2) + + inter_w = max(0, inter_x2 - inter_x1) + inter_h = max(0, inter_y2 - inter_y1) + inter = inter_w * inter_h + + if inter == 0: + return 0.0 + + area1 = b1.area + area2 = b2.area + union = area1 + area2 - inter + return inter / union if union > 0 else 0.0 + + +def greedy_match( + cost_matrix: List[List[float]], + threshold: float, +) -> List[Tuple[int, int]]: + """贪心匹配: 按 IOU 降序选择不冲突的最佳匹配。 + + Args: + cost_matrix: cost_matrix[i][j] = IOU 越大越好 + threshold: 低于该阈值不匹配 + + Returns: + 匹配对列表 [(row_idx, col_idx), ...] + """ + + matches: List[Tuple[int, int]] = [] + if not cost_matrix or not cost_matrix[0]: + return matches + + candidates: List[Tuple[float, int, int]] = [] + for i, row in enumerate(cost_matrix): + for j, score in enumerate(row): + if score >= threshold: + candidates.append((score, i, j)) + + candidates.sort(reverse=True) + used_rows: set = set() + used_cols: set = set() + for score, i, j in candidates: + if i in used_rows or j in used_cols: + continue + matches.append((i, j)) + used_rows.add(i) + used_cols.add(j) + return matches + + +# --------------------------------------------------------------------------- +# ByteTracker +# --------------------------------------------------------------------------- + + +class ByteTracker: + """简化版 ByteTrack 跟踪器。 + + Args: + track_thresh: 跟踪基础置信度阈值,低于此值不参与跟踪 + high_thresh: 高置信度阈值,分桶用 + match_thresh: IOU 匹配阈值 + max_lost_frames: 跟踪丢失最大帧数,超过移除 + min_box_area: 最小框面积,低于此值忽略 + """ + + def __init__( + self, + track_thresh: float = 0.5, + high_thresh: float = 0.6, + match_thresh: float = 0.8, + max_lost_frames: int = 30, + min_box_area: float = 10.0, + ) -> None: + self.track_thresh = track_thresh + self.high_thresh = high_thresh + self.match_thresh = match_thresh + self.max_lost_frames = max_lost_frames + self.min_box_area = min_box_area + + self._tracks: Dict[int, Track] = {} + self._next_id: int = 1 + self._frame_count: int = 0 + + # ------------------------------------------------------------------ + # 主入口 + # ------------------------------------------------------------------ + + def update( + self, + detections: List[UnifiedDetection], + ) -> List[UnifiedDetection]: + """对一帧检测结果执行跟踪更新,填充 track_id。 + + Returns: + 带 track_id 的检测结果 (顺序保持不变) + """ + + self._frame_count += 1 + + # 1. 过滤无效检测 (面积太小) + valid: List[Tuple[int, UnifiedDetection]] = [] + for idx, det in enumerate(detections): + if det.bbox.area < self.min_box_area: + continue + valid.append((idx, det)) + + # 2. 按置信度分桶 + high_dets: List[Tuple[int, UnifiedDetection]] = [] + low_dets: List[Tuple[int, UnifiedDetection]] = [] + for idx, det in valid: + if det.confidence >= self.high_thresh: + high_dets.append((idx, det)) + elif det.confidence >= self.track_thresh: + low_dets.append((idx, det)) + # < track_thresh 直接忽略 + + # 3. 获取当前 tracked + lost 的 tracks 列表 + active_tracks: List[Track] = [ + t for t in self._tracks.values() + if t.state in (TrackState.TRACKED, TrackState.LOST, TrackState.NEW) + ] + + # 4. 第一轮: 与 high 检测匹配 + matches_high, unmatched_tracks, unmatched_high = self._match( + active_tracks, [d for _, d in high_dets], self.match_thresh + ) + for track_idx, det_idx in matches_high: + track = active_tracks[track_idx] + orig_idx, det = high_dets[det_idx] + self._update_track(track, det) + det.track_id = track.track_id + + # 5. 第二轮: 未匹配 tracks 与 low 检测匹配 + remaining_tracks = [active_tracks[i] for i in unmatched_tracks] + matches_low, still_unmatched, unmatched_low = self._match( + remaining_tracks, [d for _, d in low_dets], 0.5 # 低置信度用更宽松阈值 + ) + for track_idx, det_idx in matches_low: + track = remaining_tracks[track_idx] + orig_idx, det = low_dets[det_idx] + self._update_track(track, det) + det.track_id = track.track_id + + # 6. 未匹配的 high 检测创建新 track + for det_idx in unmatched_high: + orig_idx, det = high_dets[det_idx] + new_track = self._create_track(det) + det.track_id = new_track.track_id + + # 7. 未匹配的 tracks 增加 lost_frames + unmatched_original_tracks = [remaining_tracks[i] for i in still_unmatched] + for track in unmatched_original_tracks: + track.lost_frames += 1 + track.state = TrackState.LOST + if track.lost_frames > self.max_lost_frames: + track.state = TrackState.REMOVED + + # 8. 清理已移除的 tracks + self._tracks = { + tid: t for tid, t in self._tracks.items() + if t.state != TrackState.REMOVED + } + + return detections + + # ------------------------------------------------------------------ + # 内部 + # ------------------------------------------------------------------ + + def _match( + self, + tracks: List[Track], + detections: List[UnifiedDetection], + threshold: float, + ) -> Tuple[List[Tuple[int, int]], List[int], List[int]]: + """计算 tracks 与 detections 的 IOU 匹配。 + + Returns: + (matches, unmatched_track_indices, unmatched_det_indices) + """ + + if not tracks or not detections: + return [], list(range(len(tracks))), list(range(len(detections))) + + # 构建 IOU 矩阵 + iou_matrix: List[List[float]] = [] + for t in tracks: + row = [compute_iou(t.bbox, d.bbox) for d in detections] + iou_matrix.append(row) + + matches = greedy_match(iou_matrix, threshold) + matched_tracks = {i for i, _ in matches} + matched_dets = {j for _, j in matches} + unmatched_tracks = [i for i in range(len(tracks)) if i not in matched_tracks] + unmatched_dets = [j for j in range(len(detections)) if j not in matched_dets] + return matches, unmatched_tracks, unmatched_dets + + def _create_track(self, det: UnifiedDetection) -> Track: + track_id = self._next_id + self._next_id += 1 + track = Track( + track_id=track_id, + bbox=det.bbox, + confidence=det.confidence, + class_name=det.class_name, + state=TrackState.TRACKED, + age=1, + lost_frames=0, + last_update_frame=self._frame_count, + ) + self._tracks[track_id] = track + return track + + def _update_track(self, track: Track, det: UnifiedDetection) -> None: + track.bbox = det.bbox + track.confidence = det.confidence + track.age += 1 + track.lost_frames = 0 + track.state = TrackState.TRACKED + track.last_update_frame = self._frame_count + + # ------------------------------------------------------------------ + # 状态 + # ------------------------------------------------------------------ + + @property + def active_tracks(self) -> List[Track]: + return [t for t in self._tracks.values() if t.state == TrackState.TRACKED] + + @property + def lost_tracks(self) -> List[Track]: + return [t for t in self._tracks.values() if t.state == TrackState.LOST] + + def reset(self) -> None: + self._tracks.clear() + self._next_id = 1 + self._frame_count = 0 + + +# --------------------------------------------------------------------------- +# TrackingService: 多流跟踪器管理 +# --------------------------------------------------------------------------- + + +class TrackingService: + """多流目标跟踪服务。 + + 为每个 source_id (摄像头/视频流) 维护独立的 ByteTracker 实例。 + """ + + def __init__( + self, + track_thresh: float = 0.5, + high_thresh: float = 0.6, + match_thresh: float = 0.8, + max_lost_frames: int = 30, + min_box_area: float = 10.0, + ) -> None: + self.track_thresh = track_thresh + self.high_thresh = high_thresh + self.match_thresh = match_thresh + self.max_lost_frames = max_lost_frames + self.min_box_area = min_box_area + self._trackers: Dict[str, ByteTracker] = {} + + def get_or_create(self, source_id: str) -> ByteTracker: + if source_id not in self._trackers: + self._trackers[source_id] = ByteTracker( + track_thresh=self.track_thresh, + high_thresh=self.high_thresh, + match_thresh=self.match_thresh, + max_lost_frames=self.max_lost_frames, + min_box_area=self.min_box_area, + ) + return self._trackers[source_id] + + def update( + self, + source_id: str, + detections: List[UnifiedDetection], + ) -> List[UnifiedDetection]: + tracker = self.get_or_create(source_id) + return tracker.update(detections) + + def reset(self, source_id: Optional[str] = None) -> None: + if source_id is None: + self._trackers.clear() + elif source_id in self._trackers: + self._trackers[source_id].reset() + + def remove(self, source_id: str) -> None: + self._trackers.pop(source_id, None) + + @property + def source_count(self) -> int: + return len(self._trackers) + + +__all__ = [ + "ByteTracker", + "TrackingService", + "Track", + "TrackState", + "compute_iou", + "greedy_match", +]