from enum import Enum from typing import Callable, Optional, Tuple import numpy as np from jaxtyping import Num, jaxtyped from loguru import logger from typeguard import typechecked from app.typing import BoundingBoxFormat, NDArray from .bboxes_tracker import BoxTracker, BoxTrackerConfig, BoxTracking @jaxtyped(typechecker=typechecked) def bounding_box_area(tracker: Num[NDArray, "N 4"], format: BoundingBoxFormat) -> float: if format == "xyxy": return float( np.mean((tracker[:, 2] - tracker[:, 0]) * (tracker[:, 3] - tracker[:, 1])) ) elif format == "xywh": return float(np.mean(tracker[:, 2] * tracker[:, 3])) else: raise ValueError(f"Unknown bounding box format: {format}") class TrackingIdType(Enum): Overridden = "overridden" Selected = "selected" General = "general" TrackingId = Tuple[int, TrackingIdType] def find_suitable_tracking_id( tracking: list[BoxTracking], format: BoundingBoxFormat ) -> Optional[int]: if len(tracking) == 0: return None elif len(tracking) == 1: return tracking[0].id else: i = np.argmax( [ bounding_box_area(tracker.last_n_bounding_boxes, format) for tracker in tracking ] ) return tracking[i].id class SingleObjectTracker: _tracker: BoxTracker _overridden_tracking_id: Optional[int] = None _selected_tracking_id: Optional[int] = None _bounding_box_format: BoundingBoxFormat _on_lost_tracking: Optional[Callable[[BoxTracking], None]] = None """ (tracking, tracking_id) -> None """ _on_tracking_acquired: Optional[Callable[[list[BoxTracking], TrackingId], None]] = ( None ) """ (trackings, tracking_id) -> None """ def __init__( self, tracker_param: BoxTrackerConfig, bounding_box_format: BoundingBoxFormat = "xyxy", ): self._selected_tracking_id = None self._tracker = BoxTracker(tracker_param, bounding_box_format) def reset(self): self._tracker.reset() @property def confirmed(self): return self._tracker.confirmed_trackings @property def confirmed_trackings(self): """ alias of `confirmed` """ return self.confirmed def get_by_id( self, tracking_id: int, trackings: list[BoxTracking] ) -> Optional[BoxTracking]: assert tracking_id is not None try: return next(filter(lambda x: x.id == tracking_id, trackings)) except StopIteration: return None def try_get_by_overridden_id(self) -> Optional[BoxTracking]: """ If successfully get the tracking, mutate self._selected_tracking_id. Otherwise, set self._overridden_tracking_id to None. """ overridden_id = self._overridden_tracking_id if overridden_id is None: return None sel: Optional[BoxTracking] = self.get_by_id(overridden_id, self.confirmed) if sel is None: self._overridden_tracking_id = None logger.trace( "Overridden tracking id {} not found in {}", overridden_id, self.confirmed, ) else: if ( self._selected_tracking_id is None or self._selected_tracking_id != overridden_id ): self._selected_tracking_id = overridden_id logger.info("Acquired tracking id {} by override", overridden_id) if self._on_tracking_acquired is not None: self._on_tracking_acquired( self.confirmed, (overridden_id, TrackingIdType.Overridden) ) return sel def try_get_by_selected_id(self) -> Optional[BoxTracking]: """ If no selected tracking, find the one with `find_suitable_tracking_id`. """ selected_id = self._selected_tracking_id if selected_id is None: selected_id = find_suitable_tracking_id( self.confirmed, self._bounding_box_format ) if selected_id is None: return None sel: Optional[BoxTracking] = self.get_by_id(selected_id, self.confirmed) if sel is None: self._selected_tracking_id = None logger.warning( "Selected tracking id {} not found in {}", selected_id, self.confirmed ) else: if ( self._selected_tracking_id is None or self._selected_tracking_id != selected_id ): self._selected_tracking_id = selected_id logger.info("Acquired tracking id {}", selected_id) if self._on_tracking_acquired is not None: self._on_tracking_acquired( self.confirmed, (selected_id, TrackingIdType.Selected) ) return sel @property def bounding_box_format(self) -> BoundingBoxFormat: return self._tracker.bounding_box_format def next_measurements(self, boxes: Num[NDArray, "N 4"]): self._tracker.next_measurements(boxes) return self.confirmed