Files
jc-video-recognize/apps/server/services/tracking_service.py

386 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""目标跟踪服务 (MVP-2 / D19)
基于 ByteTrack 思想的简化跟踪器,纯 Python 实现,无需额外依赖。
核心算法 (ByteTrack 论文要点):
1. 按置信度将检测分为 high (>= high_thresh) / low (< high_thresh)
2. 先用 high 与现有 tracks 进行 IOU 匹配 (匈牙利算法 / 贪心)
3. 未匹配的 tracks 再与 low 检测匹配 (拯救低置信度真实目标)
4. 仍未匹配的 high 检测创建新 track
5. 未匹配的 tracks 进入 Lost 状态,超过 max_lost_frames 移除
简化点 (相对原版):
- 不使用 Kalman Filter只用上一帧 bbox 与当前检测 IOU 匹配
- 匹配算法用贪心 (按 IOU 降序) 替代匈牙利
- 不支持外观特征 (ReID)
集成方式::
tracker = ByteTracker()
for frame_detections in stream:
tracked = tracker.update(frame_detections)
# tracked 中每个 UnifiedDetection 会被填充 track_id
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional, Tuple
from models.event_schemas import BBox, UnifiedDetection
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Track 状态
# ---------------------------------------------------------------------------
class TrackState(str, Enum):
"""跟踪状态。"""
NEW = "new"
TRACKED = "tracked"
LOST = "lost"
REMOVED = "removed"
@dataclass
class Track:
"""单个目标跟踪。"""
track_id: int
bbox: BBox
confidence: float
class_name: str
state: TrackState = TrackState.NEW
age: int = 0 # 总存活帧数
lost_frames: int = 0 # 连续丢失帧数
last_update_frame: int = 0
# ---------------------------------------------------------------------------
# 工具函数: IOU
# ---------------------------------------------------------------------------
def compute_iou(b1: BBox, b2: BBox) -> float:
"""计算两个 BBox 的 IOU。"""
inter_x1 = max(b1.x1, b2.x1)
inter_y1 = max(b1.y1, b2.y1)
inter_x2 = min(b1.x2, b2.x2)
inter_y2 = min(b1.y2, b2.y2)
inter_w = max(0, inter_x2 - inter_x1)
inter_h = max(0, inter_y2 - inter_y1)
inter = inter_w * inter_h
if inter == 0:
return 0.0
area1 = b1.area
area2 = b2.area
union = area1 + area2 - inter
return inter / union if union > 0 else 0.0
def greedy_match(
cost_matrix: List[List[float]],
threshold: float,
) -> List[Tuple[int, int]]:
"""贪心匹配: 按 IOU 降序选择不冲突的最佳匹配。
Args:
cost_matrix: cost_matrix[i][j] = IOU 越大越好
threshold: 低于该阈值不匹配
Returns:
匹配对列表 [(row_idx, col_idx), ...]
"""
matches: List[Tuple[int, int]] = []
if not cost_matrix or not cost_matrix[0]:
return matches
candidates: List[Tuple[float, int, int]] = []
for i, row in enumerate(cost_matrix):
for j, score in enumerate(row):
if score >= threshold:
candidates.append((score, i, j))
candidates.sort(reverse=True)
used_rows: set = set()
used_cols: set = set()
for score, i, j in candidates:
if i in used_rows or j in used_cols:
continue
matches.append((i, j))
used_rows.add(i)
used_cols.add(j)
return matches
# ---------------------------------------------------------------------------
# ByteTracker
# ---------------------------------------------------------------------------
class ByteTracker:
"""简化版 ByteTrack 跟踪器。
Args:
track_thresh: 跟踪基础置信度阈值,低于此值不参与跟踪
high_thresh: 高置信度阈值,分桶用
match_thresh: IOU 匹配阈值
max_lost_frames: 跟踪丢失最大帧数,超过移除
min_box_area: 最小框面积,低于此值忽略
"""
def __init__(
self,
track_thresh: float = 0.5,
high_thresh: float = 0.6,
match_thresh: float = 0.8,
max_lost_frames: int = 30,
min_box_area: float = 10.0,
) -> None:
self.track_thresh = track_thresh
self.high_thresh = high_thresh
self.match_thresh = match_thresh
self.max_lost_frames = max_lost_frames
self.min_box_area = min_box_area
self._tracks: Dict[int, Track] = {}
self._next_id: int = 1
self._frame_count: int = 0
# ------------------------------------------------------------------
# 主入口
# ------------------------------------------------------------------
def update(
self,
detections: List[UnifiedDetection],
) -> List[UnifiedDetection]:
"""对一帧检测结果执行跟踪更新,填充 track_id。
Returns:
带 track_id 的检测结果 (顺序保持不变)
"""
self._frame_count += 1
# 1. 过滤无效检测 (面积太小)
valid: List[Tuple[int, UnifiedDetection]] = []
for idx, det in enumerate(detections):
if det.bbox.area < self.min_box_area:
continue
valid.append((idx, det))
# 2. 按置信度分桶
high_dets: List[Tuple[int, UnifiedDetection]] = []
low_dets: List[Tuple[int, UnifiedDetection]] = []
for idx, det in valid:
if det.confidence >= self.high_thresh:
high_dets.append((idx, det))
elif det.confidence >= self.track_thresh:
low_dets.append((idx, det))
# < track_thresh 直接忽略
# 3. 获取当前 tracked + lost 的 tracks 列表
active_tracks: List[Track] = [
t for t in self._tracks.values()
if t.state in (TrackState.TRACKED, TrackState.LOST, TrackState.NEW)
]
# 4. 第一轮: 与 high 检测匹配
matches_high, unmatched_tracks, unmatched_high = self._match(
active_tracks, [d for _, d in high_dets], self.match_thresh
)
for track_idx, det_idx in matches_high:
track = active_tracks[track_idx]
orig_idx, det = high_dets[det_idx]
self._update_track(track, det)
det.track_id = track.track_id
# 5. 第二轮: 未匹配 tracks 与 low 检测匹配
remaining_tracks = [active_tracks[i] for i in unmatched_tracks]
matches_low, still_unmatched, unmatched_low = self._match(
remaining_tracks, [d for _, d in low_dets], 0.5 # 低置信度用更宽松阈值
)
for track_idx, det_idx in matches_low:
track = remaining_tracks[track_idx]
orig_idx, det = low_dets[det_idx]
self._update_track(track, det)
det.track_id = track.track_id
# 6. 未匹配的 high 检测创建新 track
for det_idx in unmatched_high:
orig_idx, det = high_dets[det_idx]
new_track = self._create_track(det)
det.track_id = new_track.track_id
# 7. 未匹配的 tracks 增加 lost_frames
unmatched_original_tracks = [remaining_tracks[i] for i in still_unmatched]
for track in unmatched_original_tracks:
track.lost_frames += 1
track.state = TrackState.LOST
if track.lost_frames > self.max_lost_frames:
track.state = TrackState.REMOVED
# 8. 清理已移除的 tracks
self._tracks = {
tid: t for tid, t in self._tracks.items()
if t.state != TrackState.REMOVED
}
return detections
# ------------------------------------------------------------------
# 内部
# ------------------------------------------------------------------
def _match(
self,
tracks: List[Track],
detections: List[UnifiedDetection],
threshold: float,
) -> Tuple[List[Tuple[int, int]], List[int], List[int]]:
"""计算 tracks 与 detections 的 IOU 匹配。
Returns:
(matches, unmatched_track_indices, unmatched_det_indices)
"""
if not tracks or not detections:
return [], list(range(len(tracks))), list(range(len(detections)))
# 构建 IOU 矩阵
iou_matrix: List[List[float]] = []
for t in tracks:
row = [compute_iou(t.bbox, d.bbox) for d in detections]
iou_matrix.append(row)
matches = greedy_match(iou_matrix, threshold)
matched_tracks = {i for i, _ in matches}
matched_dets = {j for _, j in matches}
unmatched_tracks = [i for i in range(len(tracks)) if i not in matched_tracks]
unmatched_dets = [j for j in range(len(detections)) if j not in matched_dets]
return matches, unmatched_tracks, unmatched_dets
def _create_track(self, det: UnifiedDetection) -> Track:
track_id = self._next_id
self._next_id += 1
track = Track(
track_id=track_id,
bbox=det.bbox,
confidence=det.confidence,
class_name=det.class_name,
state=TrackState.TRACKED,
age=1,
lost_frames=0,
last_update_frame=self._frame_count,
)
self._tracks[track_id] = track
return track
def _update_track(self, track: Track, det: UnifiedDetection) -> None:
track.bbox = det.bbox
track.confidence = det.confidence
track.age += 1
track.lost_frames = 0
track.state = TrackState.TRACKED
track.last_update_frame = self._frame_count
# ------------------------------------------------------------------
# 状态
# ------------------------------------------------------------------
@property
def active_tracks(self) -> List[Track]:
return [t for t in self._tracks.values() if t.state == TrackState.TRACKED]
@property
def lost_tracks(self) -> List[Track]:
return [t for t in self._tracks.values() if t.state == TrackState.LOST]
def reset(self) -> None:
self._tracks.clear()
self._next_id = 1
self._frame_count = 0
# ---------------------------------------------------------------------------
# TrackingService: 多流跟踪器管理
# ---------------------------------------------------------------------------
class TrackingService:
"""多流目标跟踪服务。
为每个 source_id (摄像头/视频流) 维护独立的 ByteTracker 实例。
"""
def __init__(
self,
track_thresh: float = 0.5,
high_thresh: float = 0.6,
match_thresh: float = 0.8,
max_lost_frames: int = 30,
min_box_area: float = 10.0,
) -> None:
self.track_thresh = track_thresh
self.high_thresh = high_thresh
self.match_thresh = match_thresh
self.max_lost_frames = max_lost_frames
self.min_box_area = min_box_area
self._trackers: Dict[str, ByteTracker] = {}
def get_or_create(self, source_id: str) -> ByteTracker:
if source_id not in self._trackers:
self._trackers[source_id] = ByteTracker(
track_thresh=self.track_thresh,
high_thresh=self.high_thresh,
match_thresh=self.match_thresh,
max_lost_frames=self.max_lost_frames,
min_box_area=self.min_box_area,
)
return self._trackers[source_id]
def update(
self,
source_id: str,
detections: List[UnifiedDetection],
) -> List[UnifiedDetection]:
tracker = self.get_or_create(source_id)
return tracker.update(detections)
def reset(self, source_id: Optional[str] = None) -> None:
if source_id is None:
self._trackers.clear()
elif source_id in self._trackers:
self._trackers[source_id].reset()
def remove(self, source_id: str) -> None:
self._trackers.pop(source_id, None)
@property
def source_count(self) -> int:
return len(self._trackers)
__all__ = [
"ByteTracker",
"TrackingService",
"Track",
"TrackState",
"compute_iou",
"greedy_match",
]