Files
SimpleTracker/app/tracker/single_object_tracker.py
2025-02-27 16:00:13 +08:00

165 lines
5.2 KiB
Python

from enum import Enum
from typing import Callable, Optional, Tuple
import numpy as np
from jaxtyping import Num, jaxtyped
from loguru import logger
from typeguard import typechecked
from app.typing import BoundingBoxFormat, NDArray
from .bboxes_tracker import BoxTracker, BoxTrackerConfig, BoxTracking
@jaxtyped(typechecker=typechecked)
def bounding_box_area(tracker: Num[NDArray, "N 4"], format: BoundingBoxFormat) -> float:
if format == "xyxy":
return float(
np.mean((tracker[:, 2] - tracker[:, 0]) * (tracker[:, 3] - tracker[:, 1]))
)
elif format == "xywh":
return float(np.mean(tracker[:, 2] * tracker[:, 3]))
else:
raise ValueError(f"Unknown bounding box format: {format}")
class TrackingIdType(Enum):
Overridden = "overridden"
Selected = "selected"
General = "general"
TrackingId = Tuple[int, TrackingIdType]
def find_suitable_tracking_id(
tracking: list[BoxTracking], format: BoundingBoxFormat
) -> Optional[int]:
if len(tracking) == 0:
return None
elif len(tracking) == 1:
return tracking[0].id
else:
i = np.argmax(
[
bounding_box_area(tracker.last_n_bounding_boxes, format)
for tracker in tracking
]
)
return tracking[i].id
class SingleObjectTracker:
_tracker: BoxTracker
_overridden_tracking_id: Optional[int] = None
_selected_tracking_id: Optional[int] = None
_bounding_box_format: BoundingBoxFormat
_on_lost_tracking: Optional[Callable[[BoxTracking], None]] = None
"""
(tracking, tracking_id) -> None
"""
_on_tracking_acquired: Optional[Callable[[list[BoxTracking], TrackingId], None]] = (
None
)
"""
(trackings, tracking_id) -> None
"""
def __init__(
self,
tracker_param: BoxTrackerConfig,
bounding_box_format: BoundingBoxFormat = "xyxy",
):
self._selected_tracking_id = None
self._tracker = BoxTracker(tracker_param, bounding_box_format)
def reset(self):
self._tracker.reset()
@property
def confirmed(self):
return self._tracker.confirmed_trackings
@property
def confirmed_trackings(self):
"""
alias of `confirmed`
"""
return self.confirmed
def get_by_id(
self, tracking_id: int, trackings: list[BoxTracking]
) -> Optional[BoxTracking]:
assert tracking_id is not None
try:
return next(filter(lambda x: x.id == tracking_id, trackings))
except StopIteration:
return None
def try_get_by_overridden_id(self) -> Optional[BoxTracking]:
"""
If successfully get the tracking, mutate self._selected_tracking_id.
Otherwise, set self._overridden_tracking_id to None.
"""
overridden_id = self._overridden_tracking_id
if overridden_id is None:
return None
sel: Optional[BoxTracking] = self.get_by_id(overridden_id, self.confirmed)
if sel is None:
self._overridden_tracking_id = None
logger.trace(
"Overridden tracking id {} not found in {}",
overridden_id,
self.confirmed,
)
else:
if (
self._selected_tracking_id is None
or self._selected_tracking_id != overridden_id
):
self._selected_tracking_id = overridden_id
logger.info("Acquired tracking id {} by override", overridden_id)
if self._on_tracking_acquired is not None:
self._on_tracking_acquired(
self.confirmed, (overridden_id, TrackingIdType.Overridden)
)
return sel
def try_get_by_selected_id(self) -> Optional[BoxTracking]:
"""
If no selected tracking, find the one with `find_suitable_tracking_id`.
"""
selected_id = self._selected_tracking_id
if selected_id is None:
selected_id = find_suitable_tracking_id(
self.confirmed, self._bounding_box_format
)
if selected_id is None:
return None
sel: Optional[BoxTracking] = self.get_by_id(selected_id, self.confirmed)
if sel is None:
self._selected_tracking_id = None
logger.warning(
"Selected tracking id {} not found in {}", selected_id, self.confirmed
)
else:
if (
self._selected_tracking_id is None
or self._selected_tracking_id != selected_id
):
self._selected_tracking_id = selected_id
logger.info("Acquired tracking id {}", selected_id)
if self._on_tracking_acquired is not None:
self._on_tracking_acquired(
self.confirmed, (selected_id, TrackingIdType.Selected)
)
return sel
@property
def bounding_box_format(self) -> BoundingBoxFormat:
return self._tracker.bounding_box_format
def next_measurements(self, boxes: Num[NDArray, "N 4"]):
self._tracker.next_measurements(boxes)
return self.confirmed