feat(server): 新增ByteTrack目标跟踪服务

This commit is contained in:
2026-06-12 13:58:15 +08:00
parent 2fcaf57478
commit bf12a29acd

View File

@@ -0,0 +1,385 @@
"""目标跟踪服务 (MVP-2 / D19)
基于 ByteTrack 思想的简化跟踪器,纯 Python 实现,无需额外依赖。
核心算法 (ByteTrack 论文要点):
1. 按置信度将检测分为 high (>= high_thresh) / low (< high_thresh)
2. 先用 high 与现有 tracks 进行 IOU 匹配 (匈牙利算法 / 贪心)
3. 未匹配的 tracks 再与 low 检测匹配 (拯救低置信度真实目标)
4. 仍未匹配的 high 检测创建新 track
5. 未匹配的 tracks 进入 Lost 状态,超过 max_lost_frames 移除
简化点 (相对原版):
- 不使用 Kalman Filter只用上一帧 bbox 与当前检测 IOU 匹配
- 匹配算法用贪心 (按 IOU 降序) 替代匈牙利
- 不支持外观特征 (ReID)
集成方式::
tracker = ByteTracker()
for frame_detections in stream:
tracked = tracker.update(frame_detections)
# tracked 中每个 UnifiedDetection 会被填充 track_id
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional, Tuple
from models.event_schemas import BBox, UnifiedDetection
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Track 状态
# ---------------------------------------------------------------------------
class TrackState(str, Enum):
"""跟踪状态。"""
NEW = "new"
TRACKED = "tracked"
LOST = "lost"
REMOVED = "removed"
@dataclass
class Track:
"""单个目标跟踪。"""
track_id: int
bbox: BBox
confidence: float
class_name: str
state: TrackState = TrackState.NEW
age: int = 0 # 总存活帧数
lost_frames: int = 0 # 连续丢失帧数
last_update_frame: int = 0
# ---------------------------------------------------------------------------
# 工具函数: IOU
# ---------------------------------------------------------------------------
def compute_iou(b1: BBox, b2: BBox) -> float:
"""计算两个 BBox 的 IOU。"""
inter_x1 = max(b1.x1, b2.x1)
inter_y1 = max(b1.y1, b2.y1)
inter_x2 = min(b1.x2, b2.x2)
inter_y2 = min(b1.y2, b2.y2)
inter_w = max(0, inter_x2 - inter_x1)
inter_h = max(0, inter_y2 - inter_y1)
inter = inter_w * inter_h
if inter == 0:
return 0.0
area1 = b1.area
area2 = b2.area
union = area1 + area2 - inter
return inter / union if union > 0 else 0.0
def greedy_match(
cost_matrix: List[List[float]],
threshold: float,
) -> List[Tuple[int, int]]:
"""贪心匹配: 按 IOU 降序选择不冲突的最佳匹配。
Args:
cost_matrix: cost_matrix[i][j] = IOU 越大越好
threshold: 低于该阈值不匹配
Returns:
匹配对列表 [(row_idx, col_idx), ...]
"""
matches: List[Tuple[int, int]] = []
if not cost_matrix or not cost_matrix[0]:
return matches
candidates: List[Tuple[float, int, int]] = []
for i, row in enumerate(cost_matrix):
for j, score in enumerate(row):
if score >= threshold:
candidates.append((score, i, j))
candidates.sort(reverse=True)
used_rows: set = set()
used_cols: set = set()
for score, i, j in candidates:
if i in used_rows or j in used_cols:
continue
matches.append((i, j))
used_rows.add(i)
used_cols.add(j)
return matches
# ---------------------------------------------------------------------------
# ByteTracker
# ---------------------------------------------------------------------------
class ByteTracker:
"""简化版 ByteTrack 跟踪器。
Args:
track_thresh: 跟踪基础置信度阈值,低于此值不参与跟踪
high_thresh: 高置信度阈值,分桶用
match_thresh: IOU 匹配阈值
max_lost_frames: 跟踪丢失最大帧数,超过移除
min_box_area: 最小框面积,低于此值忽略
"""
def __init__(
self,
track_thresh: float = 0.5,
high_thresh: float = 0.6,
match_thresh: float = 0.8,
max_lost_frames: int = 30,
min_box_area: float = 10.0,
) -> None:
self.track_thresh = track_thresh
self.high_thresh = high_thresh
self.match_thresh = match_thresh
self.max_lost_frames = max_lost_frames
self.min_box_area = min_box_area
self._tracks: Dict[int, Track] = {}
self._next_id: int = 1
self._frame_count: int = 0
# ------------------------------------------------------------------
# 主入口
# ------------------------------------------------------------------
def update(
self,
detections: List[UnifiedDetection],
) -> List[UnifiedDetection]:
"""对一帧检测结果执行跟踪更新,填充 track_id。
Returns:
带 track_id 的检测结果 (顺序保持不变)
"""
self._frame_count += 1
# 1. 过滤无效检测 (面积太小)
valid: List[Tuple[int, UnifiedDetection]] = []
for idx, det in enumerate(detections):
if det.bbox.area < self.min_box_area:
continue
valid.append((idx, det))
# 2. 按置信度分桶
high_dets: List[Tuple[int, UnifiedDetection]] = []
low_dets: List[Tuple[int, UnifiedDetection]] = []
for idx, det in valid:
if det.confidence >= self.high_thresh:
high_dets.append((idx, det))
elif det.confidence >= self.track_thresh:
low_dets.append((idx, det))
# < track_thresh 直接忽略
# 3. 获取当前 tracked + lost 的 tracks 列表
active_tracks: List[Track] = [
t for t in self._tracks.values()
if t.state in (TrackState.TRACKED, TrackState.LOST, TrackState.NEW)
]
# 4. 第一轮: 与 high 检测匹配
matches_high, unmatched_tracks, unmatched_high = self._match(
active_tracks, [d for _, d in high_dets], self.match_thresh
)
for track_idx, det_idx in matches_high:
track = active_tracks[track_idx]
orig_idx, det = high_dets[det_idx]
self._update_track(track, det)
det.track_id = track.track_id
# 5. 第二轮: 未匹配 tracks 与 low 检测匹配
remaining_tracks = [active_tracks[i] for i in unmatched_tracks]
matches_low, still_unmatched, unmatched_low = self._match(
remaining_tracks, [d for _, d in low_dets], 0.5 # 低置信度用更宽松阈值
)
for track_idx, det_idx in matches_low:
track = remaining_tracks[track_idx]
orig_idx, det = low_dets[det_idx]
self._update_track(track, det)
det.track_id = track.track_id
# 6. 未匹配的 high 检测创建新 track
for det_idx in unmatched_high:
orig_idx, det = high_dets[det_idx]
new_track = self._create_track(det)
det.track_id = new_track.track_id
# 7. 未匹配的 tracks 增加 lost_frames
unmatched_original_tracks = [remaining_tracks[i] for i in still_unmatched]
for track in unmatched_original_tracks:
track.lost_frames += 1
track.state = TrackState.LOST
if track.lost_frames > self.max_lost_frames:
track.state = TrackState.REMOVED
# 8. 清理已移除的 tracks
self._tracks = {
tid: t for tid, t in self._tracks.items()
if t.state != TrackState.REMOVED
}
return detections
# ------------------------------------------------------------------
# 内部
# ------------------------------------------------------------------
def _match(
self,
tracks: List[Track],
detections: List[UnifiedDetection],
threshold: float,
) -> Tuple[List[Tuple[int, int]], List[int], List[int]]:
"""计算 tracks 与 detections 的 IOU 匹配。
Returns:
(matches, unmatched_track_indices, unmatched_det_indices)
"""
if not tracks or not detections:
return [], list(range(len(tracks))), list(range(len(detections)))
# 构建 IOU 矩阵
iou_matrix: List[List[float]] = []
for t in tracks:
row = [compute_iou(t.bbox, d.bbox) for d in detections]
iou_matrix.append(row)
matches = greedy_match(iou_matrix, threshold)
matched_tracks = {i for i, _ in matches}
matched_dets = {j for _, j in matches}
unmatched_tracks = [i for i in range(len(tracks)) if i not in matched_tracks]
unmatched_dets = [j for j in range(len(detections)) if j not in matched_dets]
return matches, unmatched_tracks, unmatched_dets
def _create_track(self, det: UnifiedDetection) -> Track:
track_id = self._next_id
self._next_id += 1
track = Track(
track_id=track_id,
bbox=det.bbox,
confidence=det.confidence,
class_name=det.class_name,
state=TrackState.TRACKED,
age=1,
lost_frames=0,
last_update_frame=self._frame_count,
)
self._tracks[track_id] = track
return track
def _update_track(self, track: Track, det: UnifiedDetection) -> None:
track.bbox = det.bbox
track.confidence = det.confidence
track.age += 1
track.lost_frames = 0
track.state = TrackState.TRACKED
track.last_update_frame = self._frame_count
# ------------------------------------------------------------------
# 状态
# ------------------------------------------------------------------
@property
def active_tracks(self) -> List[Track]:
return [t for t in self._tracks.values() if t.state == TrackState.TRACKED]
@property
def lost_tracks(self) -> List[Track]:
return [t for t in self._tracks.values() if t.state == TrackState.LOST]
def reset(self) -> None:
self._tracks.clear()
self._next_id = 1
self._frame_count = 0
# ---------------------------------------------------------------------------
# TrackingService: 多流跟踪器管理
# ---------------------------------------------------------------------------
class TrackingService:
"""多流目标跟踪服务。
为每个 source_id (摄像头/视频流) 维护独立的 ByteTracker 实例。
"""
def __init__(
self,
track_thresh: float = 0.5,
high_thresh: float = 0.6,
match_thresh: float = 0.8,
max_lost_frames: int = 30,
min_box_area: float = 10.0,
) -> None:
self.track_thresh = track_thresh
self.high_thresh = high_thresh
self.match_thresh = match_thresh
self.max_lost_frames = max_lost_frames
self.min_box_area = min_box_area
self._trackers: Dict[str, ByteTracker] = {}
def get_or_create(self, source_id: str) -> ByteTracker:
if source_id not in self._trackers:
self._trackers[source_id] = ByteTracker(
track_thresh=self.track_thresh,
high_thresh=self.high_thresh,
match_thresh=self.match_thresh,
max_lost_frames=self.max_lost_frames,
min_box_area=self.min_box_area,
)
return self._trackers[source_id]
def update(
self,
source_id: str,
detections: List[UnifiedDetection],
) -> List[UnifiedDetection]:
tracker = self.get_or_create(source_id)
return tracker.update(detections)
def reset(self, source_id: Optional[str] = None) -> None:
if source_id is None:
self._trackers.clear()
elif source_id in self._trackers:
self._trackers[source_id].reset()
def remove(self, source_id: str) -> None:
self._trackers.pop(source_id, None)
@property
def source_count(self) -> int:
return len(self._trackers)
__all__ = [
"ByteTracker",
"TrackingService",
"Track",
"TrackState",
"compute_iou",
"greedy_match",
]