386 lines
12 KiB
Python
386 lines
12 KiB
Python
"""目标跟踪服务 (MVP-2 / D19)
|
||
|
||
基于 ByteTrack 思想的简化跟踪器,纯 Python 实现,无需额外依赖。
|
||
|
||
核心算法 (ByteTrack 论文要点):
|
||
|
||
1. 按置信度将检测分为 high (>= high_thresh) / low (< high_thresh)
|
||
2. 先用 high 与现有 tracks 进行 IOU 匹配 (匈牙利算法 / 贪心)
|
||
3. 未匹配的 tracks 再与 low 检测匹配 (拯救低置信度真实目标)
|
||
4. 仍未匹配的 high 检测创建新 track
|
||
5. 未匹配的 tracks 进入 Lost 状态,超过 max_lost_frames 移除
|
||
|
||
简化点 (相对原版):
|
||
|
||
- 不使用 Kalman Filter,只用上一帧 bbox 与当前检测 IOU 匹配
|
||
- 匹配算法用贪心 (按 IOU 降序) 替代匈牙利
|
||
- 不支持外观特征 (ReID)
|
||
|
||
集成方式::
|
||
|
||
tracker = ByteTracker()
|
||
for frame_detections in stream:
|
||
tracked = tracker.update(frame_detections)
|
||
# tracked 中每个 UnifiedDetection 会被填充 track_id
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
from dataclasses import dataclass, field
|
||
from enum import Enum
|
||
from typing import Dict, List, Optional, Tuple
|
||
|
||
from models.event_schemas import BBox, UnifiedDetection
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Track 状态
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TrackState(str, Enum):
|
||
"""跟踪状态。"""
|
||
|
||
NEW = "new"
|
||
TRACKED = "tracked"
|
||
LOST = "lost"
|
||
REMOVED = "removed"
|
||
|
||
|
||
@dataclass
|
||
class Track:
|
||
"""单个目标跟踪。"""
|
||
|
||
track_id: int
|
||
bbox: BBox
|
||
confidence: float
|
||
class_name: str
|
||
state: TrackState = TrackState.NEW
|
||
age: int = 0 # 总存活帧数
|
||
lost_frames: int = 0 # 连续丢失帧数
|
||
last_update_frame: int = 0
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 工具函数: IOU
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def compute_iou(b1: BBox, b2: BBox) -> float:
|
||
"""计算两个 BBox 的 IOU。"""
|
||
|
||
inter_x1 = max(b1.x1, b2.x1)
|
||
inter_y1 = max(b1.y1, b2.y1)
|
||
inter_x2 = min(b1.x2, b2.x2)
|
||
inter_y2 = min(b1.y2, b2.y2)
|
||
|
||
inter_w = max(0, inter_x2 - inter_x1)
|
||
inter_h = max(0, inter_y2 - inter_y1)
|
||
inter = inter_w * inter_h
|
||
|
||
if inter == 0:
|
||
return 0.0
|
||
|
||
area1 = b1.area
|
||
area2 = b2.area
|
||
union = area1 + area2 - inter
|
||
return inter / union if union > 0 else 0.0
|
||
|
||
|
||
def greedy_match(
|
||
cost_matrix: List[List[float]],
|
||
threshold: float,
|
||
) -> List[Tuple[int, int]]:
|
||
"""贪心匹配: 按 IOU 降序选择不冲突的最佳匹配。
|
||
|
||
Args:
|
||
cost_matrix: cost_matrix[i][j] = IOU 越大越好
|
||
threshold: 低于该阈值不匹配
|
||
|
||
Returns:
|
||
匹配对列表 [(row_idx, col_idx), ...]
|
||
"""
|
||
|
||
matches: List[Tuple[int, int]] = []
|
||
if not cost_matrix or not cost_matrix[0]:
|
||
return matches
|
||
|
||
candidates: List[Tuple[float, int, int]] = []
|
||
for i, row in enumerate(cost_matrix):
|
||
for j, score in enumerate(row):
|
||
if score >= threshold:
|
||
candidates.append((score, i, j))
|
||
|
||
candidates.sort(reverse=True)
|
||
used_rows: set = set()
|
||
used_cols: set = set()
|
||
for score, i, j in candidates:
|
||
if i in used_rows or j in used_cols:
|
||
continue
|
||
matches.append((i, j))
|
||
used_rows.add(i)
|
||
used_cols.add(j)
|
||
return matches
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# ByteTracker
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class ByteTracker:
|
||
"""简化版 ByteTrack 跟踪器。
|
||
|
||
Args:
|
||
track_thresh: 跟踪基础置信度阈值,低于此值不参与跟踪
|
||
high_thresh: 高置信度阈值,分桶用
|
||
match_thresh: IOU 匹配阈值
|
||
max_lost_frames: 跟踪丢失最大帧数,超过移除
|
||
min_box_area: 最小框面积,低于此值忽略
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
track_thresh: float = 0.5,
|
||
high_thresh: float = 0.6,
|
||
match_thresh: float = 0.8,
|
||
max_lost_frames: int = 30,
|
||
min_box_area: float = 10.0,
|
||
) -> None:
|
||
self.track_thresh = track_thresh
|
||
self.high_thresh = high_thresh
|
||
self.match_thresh = match_thresh
|
||
self.max_lost_frames = max_lost_frames
|
||
self.min_box_area = min_box_area
|
||
|
||
self._tracks: Dict[int, Track] = {}
|
||
self._next_id: int = 1
|
||
self._frame_count: int = 0
|
||
|
||
# ------------------------------------------------------------------
|
||
# 主入口
|
||
# ------------------------------------------------------------------
|
||
|
||
def update(
|
||
self,
|
||
detections: List[UnifiedDetection],
|
||
) -> List[UnifiedDetection]:
|
||
"""对一帧检测结果执行跟踪更新,填充 track_id。
|
||
|
||
Returns:
|
||
带 track_id 的检测结果 (顺序保持不变)
|
||
"""
|
||
|
||
self._frame_count += 1
|
||
|
||
# 1. 过滤无效检测 (面积太小)
|
||
valid: List[Tuple[int, UnifiedDetection]] = []
|
||
for idx, det in enumerate(detections):
|
||
if det.bbox.area < self.min_box_area:
|
||
continue
|
||
valid.append((idx, det))
|
||
|
||
# 2. 按置信度分桶
|
||
high_dets: List[Tuple[int, UnifiedDetection]] = []
|
||
low_dets: List[Tuple[int, UnifiedDetection]] = []
|
||
for idx, det in valid:
|
||
if det.confidence >= self.high_thresh:
|
||
high_dets.append((idx, det))
|
||
elif det.confidence >= self.track_thresh:
|
||
low_dets.append((idx, det))
|
||
# < track_thresh 直接忽略
|
||
|
||
# 3. 获取当前 tracked + lost 的 tracks 列表
|
||
active_tracks: List[Track] = [
|
||
t for t in self._tracks.values()
|
||
if t.state in (TrackState.TRACKED, TrackState.LOST, TrackState.NEW)
|
||
]
|
||
|
||
# 4. 第一轮: 与 high 检测匹配
|
||
matches_high, unmatched_tracks, unmatched_high = self._match(
|
||
active_tracks, [d for _, d in high_dets], self.match_thresh
|
||
)
|
||
for track_idx, det_idx in matches_high:
|
||
track = active_tracks[track_idx]
|
||
orig_idx, det = high_dets[det_idx]
|
||
self._update_track(track, det)
|
||
det.track_id = track.track_id
|
||
|
||
# 5. 第二轮: 未匹配 tracks 与 low 检测匹配
|
||
remaining_tracks = [active_tracks[i] for i in unmatched_tracks]
|
||
matches_low, still_unmatched, unmatched_low = self._match(
|
||
remaining_tracks, [d for _, d in low_dets], 0.5 # 低置信度用更宽松阈值
|
||
)
|
||
for track_idx, det_idx in matches_low:
|
||
track = remaining_tracks[track_idx]
|
||
orig_idx, det = low_dets[det_idx]
|
||
self._update_track(track, det)
|
||
det.track_id = track.track_id
|
||
|
||
# 6. 未匹配的 high 检测创建新 track
|
||
for det_idx in unmatched_high:
|
||
orig_idx, det = high_dets[det_idx]
|
||
new_track = self._create_track(det)
|
||
det.track_id = new_track.track_id
|
||
|
||
# 7. 未匹配的 tracks 增加 lost_frames
|
||
unmatched_original_tracks = [remaining_tracks[i] for i in still_unmatched]
|
||
for track in unmatched_original_tracks:
|
||
track.lost_frames += 1
|
||
track.state = TrackState.LOST
|
||
if track.lost_frames > self.max_lost_frames:
|
||
track.state = TrackState.REMOVED
|
||
|
||
# 8. 清理已移除的 tracks
|
||
self._tracks = {
|
||
tid: t for tid, t in self._tracks.items()
|
||
if t.state != TrackState.REMOVED
|
||
}
|
||
|
||
return detections
|
||
|
||
# ------------------------------------------------------------------
|
||
# 内部
|
||
# ------------------------------------------------------------------
|
||
|
||
def _match(
|
||
self,
|
||
tracks: List[Track],
|
||
detections: List[UnifiedDetection],
|
||
threshold: float,
|
||
) -> Tuple[List[Tuple[int, int]], List[int], List[int]]:
|
||
"""计算 tracks 与 detections 的 IOU 匹配。
|
||
|
||
Returns:
|
||
(matches, unmatched_track_indices, unmatched_det_indices)
|
||
"""
|
||
|
||
if not tracks or not detections:
|
||
return [], list(range(len(tracks))), list(range(len(detections)))
|
||
|
||
# 构建 IOU 矩阵
|
||
iou_matrix: List[List[float]] = []
|
||
for t in tracks:
|
||
row = [compute_iou(t.bbox, d.bbox) for d in detections]
|
||
iou_matrix.append(row)
|
||
|
||
matches = greedy_match(iou_matrix, threshold)
|
||
matched_tracks = {i for i, _ in matches}
|
||
matched_dets = {j for _, j in matches}
|
||
unmatched_tracks = [i for i in range(len(tracks)) if i not in matched_tracks]
|
||
unmatched_dets = [j for j in range(len(detections)) if j not in matched_dets]
|
||
return matches, unmatched_tracks, unmatched_dets
|
||
|
||
def _create_track(self, det: UnifiedDetection) -> Track:
|
||
track_id = self._next_id
|
||
self._next_id += 1
|
||
track = Track(
|
||
track_id=track_id,
|
||
bbox=det.bbox,
|
||
confidence=det.confidence,
|
||
class_name=det.class_name,
|
||
state=TrackState.TRACKED,
|
||
age=1,
|
||
lost_frames=0,
|
||
last_update_frame=self._frame_count,
|
||
)
|
||
self._tracks[track_id] = track
|
||
return track
|
||
|
||
def _update_track(self, track: Track, det: UnifiedDetection) -> None:
|
||
track.bbox = det.bbox
|
||
track.confidence = det.confidence
|
||
track.age += 1
|
||
track.lost_frames = 0
|
||
track.state = TrackState.TRACKED
|
||
track.last_update_frame = self._frame_count
|
||
|
||
# ------------------------------------------------------------------
|
||
# 状态
|
||
# ------------------------------------------------------------------
|
||
|
||
@property
|
||
def active_tracks(self) -> List[Track]:
|
||
return [t for t in self._tracks.values() if t.state == TrackState.TRACKED]
|
||
|
||
@property
|
||
def lost_tracks(self) -> List[Track]:
|
||
return [t for t in self._tracks.values() if t.state == TrackState.LOST]
|
||
|
||
def reset(self) -> None:
|
||
self._tracks.clear()
|
||
self._next_id = 1
|
||
self._frame_count = 0
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# TrackingService: 多流跟踪器管理
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TrackingService:
|
||
"""多流目标跟踪服务。
|
||
|
||
为每个 source_id (摄像头/视频流) 维护独立的 ByteTracker 实例。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
track_thresh: float = 0.5,
|
||
high_thresh: float = 0.6,
|
||
match_thresh: float = 0.8,
|
||
max_lost_frames: int = 30,
|
||
min_box_area: float = 10.0,
|
||
) -> None:
|
||
self.track_thresh = track_thresh
|
||
self.high_thresh = high_thresh
|
||
self.match_thresh = match_thresh
|
||
self.max_lost_frames = max_lost_frames
|
||
self.min_box_area = min_box_area
|
||
self._trackers: Dict[str, ByteTracker] = {}
|
||
|
||
def get_or_create(self, source_id: str) -> ByteTracker:
|
||
if source_id not in self._trackers:
|
||
self._trackers[source_id] = ByteTracker(
|
||
track_thresh=self.track_thresh,
|
||
high_thresh=self.high_thresh,
|
||
match_thresh=self.match_thresh,
|
||
max_lost_frames=self.max_lost_frames,
|
||
min_box_area=self.min_box_area,
|
||
)
|
||
return self._trackers[source_id]
|
||
|
||
def update(
|
||
self,
|
||
source_id: str,
|
||
detections: List[UnifiedDetection],
|
||
) -> List[UnifiedDetection]:
|
||
tracker = self.get_or_create(source_id)
|
||
return tracker.update(detections)
|
||
|
||
def reset(self, source_id: Optional[str] = None) -> None:
|
||
if source_id is None:
|
||
self._trackers.clear()
|
||
elif source_id in self._trackers:
|
||
self._trackers[source_id].reset()
|
||
|
||
def remove(self, source_id: str) -> None:
|
||
self._trackers.pop(source_id, None)
|
||
|
||
@property
|
||
def source_count(self) -> int:
|
||
return len(self._trackers)
|
||
|
||
|
||
__all__ = [
|
||
"ByteTracker",
|
||
"TrackingService",
|
||
"Track",
|
||
"TrackState",
|
||
"compute_iou",
|
||
"greedy_match",
|
||
]
|