from dataclasses import dataclass from enum import Enum, auto from typing import ( Callable, Generator, Optional, Tuple, TypedDict, Union, cast, ) from loguru import logger import numpy as np from jaxtyping import Float, Int, Num, jaxtyped from pydantic import BaseModel from scipy.optimize import linear_sum_assignment from typeguard import typechecked from app.typing import BoundingBoxFormat from app.typing.constant import AREA_FILTER_THRESHOLD class BoxTrackerConfig(BaseModel): dt: float = 1.0 cov_threshold: float = 4.0 tentative_mahalanobis_threshold: float = 10.0 confirm_mahalanobis_threshold: float = 10.0 forming_tracks_euclidean_threshold: float = 25.0 survival_steps_threshold: int = 3 max_preserved_history_bounding_boxes: int = 10 @staticmethod def default() -> "BoxTrackerConfig": return BoxTrackerConfig( dt=1.0, cov_threshold=4.0, tentative_mahalanobis_threshold=10.0, confirm_mahalanobis_threshold=10.0, forming_tracks_euclidean_threshold=25.0, survival_steps_threshold=3, max_preserved_history_bounding_boxes=10, ) from . import ( CvModelGaussianState, GaussianState, LinearMeasurementModel, LinearMotionNoInputModel, NDArray, PosterioriResult, outer_distance, predict, update, ) class TrackingState(Enum): Tentative = auto() Confirmed = auto() class BoxTrackingDict(TypedDict): id: int bounding_box: NDArray state_x: NDArray state_P: NDArray @dataclass class BoxTracking: id: int state: CvModelGaussianState survived_time_steps: int missed_time_steps: int last_n_bounding_boxes: Num[NDArray, "N 4"] """ History of bounding boxes in a sliding window, with the latest one at the end. The window size is determined by the `max_preserved_history_bounding_boxes` parameter. """ @property def last_bounding_box(self) -> Num[NDArray, "4"]: b = cast(NDArray, self.last_n_bounding_boxes[-1]) assert b.shape == (4,) return b def to_dict(self) -> BoxTrackingDict: return { "id": self.id, "bounding_box": self.last_bounding_box, "state_x": self.state.x, "state_P": self.state.P, } @dataclass class CreateTrackingEvent: group: TrackingState id: int tracking: BoxTracking @dataclass class RemoveTrackingEvent: group: TrackingState id: int tracking: BoxTracking @dataclass class MatchedTrackingEvent: group: TrackingState id: int matched_bounding_box: Num[NDArray, "4"] TrackingEvent = Union[CreateTrackingEvent, RemoveTrackingEvent, MatchedTrackingEvent] def bounding_boxes_to_center( bounding_boxes: Num[NDArray, "N 4"], format: BoundingBoxFormat ) -> Num[NDArray, "N 2"]: if format == "xyxy": return (bounding_boxes[:, :2] + bounding_boxes[:, 2:]) / 2 if format == "xywh": return bounding_boxes[:, :2] + (bounding_boxes[:, 2:] / 2) raise ValueError(f"Unsupported bounding box format: {format}") def bounding_box_to_center( bounding_box: Num[NDArray, "4"], format: BoundingBoxFormat ) -> Num[NDArray, "2"]: if format == "xyxy": return (bounding_box[:2] + bounding_box[2:]) / 2 if format == "xywh": return bounding_box[:2] + (bounding_box[2:] / 2) raise ValueError(f"Unsupported bounding box format: {format}") def bounding_boxes_area( bounding_boxes: Num[NDArray, "N 4"], format: BoundingBoxFormat ) -> Num[NDArray, "N"]: if format == "xyxy": return (bounding_boxes[:, 2] - bounding_boxes[:, 0]) * ( bounding_boxes[:, 3] - bounding_boxes[:, 1] ) if format == "xywh": return bounding_boxes[:, 2] * bounding_boxes[:, 3] raise ValueError(f"Unsupported bounding box format: {format}") class BoxTracker: """ A simple GNN tracker, but for tracking targets with bounding boxes TODO: use score to help data association """ _last_measurements: NDArray = np.empty((0, 2), dtype=np.float32) _tentative_tracks: list[BoxTracking] = [] _confirmed_tracks: list[BoxTracking] = [] _last_id: int = 0 _params: BoxTrackerConfig _bounding_boxes_format: BoundingBoxFormat def __init__( self, params: BoxTrackerConfig, bounding_boxes_format: BoundingBoxFormat, ): self._last_measurements = np.empty((0, 2), dtype=np.float32) self._tentative_tracks = [] self._confirmed_tracks = [] self._last_id = 0 self._params = params self._bounding_boxes_format = bounding_boxes_format def reset(self): self._last_id = 0 self._last_measurements = np.empty((0, 2), dtype=np.float32) self._tentative_tracks = [] self._confirmed_tracks = [] def _push_new_bounding_box( self, old_bbs: Num[NDArray, "N 4"], new_bb: Num[NDArray, "4"] ) -> Num[NDArray, "N 4"]: bbs = np.append(old_bbs, np.expand_dims(new_bb, axis=0), axis=0) if bbs.shape[0] > self._params.max_preserved_history_bounding_boxes: bbs = bbs[-self._params.max_preserved_history_bounding_boxes :] return bbs def _predict(self, tracks: list[BoxTracking], dt: float = 1.0): def _predict_one(track: BoxTracking): new_st = predict(track.state, BoxTracker.motion_model(dt=dt)) o_cx, o_cy = bounding_box_to_center( track.last_bounding_box, self._bounding_boxes_format ) n_cx, n_cy, _v_x, _v_y = new_st.x delta_x, delta_y = n_cx - o_cx, n_cy - o_cy if self._bounding_boxes_format == "xyxy": x_0, y_0, x_1, y_1 = track.last_bounding_box new_bb = np.array( [x_0 + delta_x, y_0 + delta_y, x_1 + delta_x, y_1 + delta_y] ) elif self._bounding_boxes_format == "xywh": x_0, y_0, w, h = track.last_bounding_box new_bb = np.array([x_0 + delta_x - w / 2, y_0 + delta_y - h / 2, w, h]) else: raise ValueError( f"Unsupported bounding box format: {self._bounding_boxes_format}" ) new_bbs = self._push_new_bounding_box(track.last_n_bounding_boxes, new_bb) return BoxTracking( id=track.id, state=CvModelGaussianState.from_gaussian(new_st), survived_time_steps=track.survived_time_steps, missed_time_steps=track.missed_time_steps, last_n_bounding_boxes=new_bbs, ) return [_predict_one(track) for track in tracks] @jaxtyped(typechecker=typechecked) def _data_associate_and_update( self, select_array: TrackingState, measurements: Num[NDArray, "N 2"], bounding_boxes: Num[NDArray, "N 4"], ) -> Tuple[list[MatchedTrackingEvent], Num[NDArray, "M 2"], Num[NDArray, "M 4"]]: """ Match tracks with measurements and update the tracks Parameters ---------- [in] measurements: Float["a 2"] [in,out] tracks: Tracking["b"] the tracking list (tentative or confirmed) to be updated (mutated in place) Returns ---------- return Float["... 2"] the unmatched measurements Effect ---------- find the best match by minimum Mahalanobis distance, please note that I assume the state has been predicted """ evs: list[MatchedTrackingEvent] = [] assert measurements.ndim == 2 assert measurements.shape[1] == 2 assert bounding_boxes.ndim == 2 assert bounding_boxes.shape[1] == 4 assert bounding_boxes.shape[0] == measurements.shape[0] if select_array == TrackingState.Tentative: tracks = self._tentative_tracks distance_threshold = self._params.tentative_mahalanobis_threshold elif select_array == TrackingState.Confirmed: tracks = self._confirmed_tracks distance_threshold = self._params.confirm_mahalanobis_threshold else: raise ValueError("Unexpected tracking state {}".format(select_array)) if len(tracks) == 0: return evs, measurements, bounding_boxes def _update(measurement: NDArray, tracking: BoxTracking): return update(measurement, tracking.state, BoxTracker.measurement_model()) def outer_posteriori( measurements: NDArray, tracks: list[BoxTracking] ) -> list[list[PosterioriResult]]: """ calculate the outer posteriori for each measurement and track Parameters ---------- [in] measurements: Float["a 2"] [in] tracks: Tracking["b"] Returns ---------- PosterioriResult["a b"] """ return [ [_update(measurement, tracking) for measurement in measurements] for tracking in tracks ] def posteriori_to_mahalanobis( posteriori: list[list[PosterioriResult]], ) -> NDArray: """ Parameters ---------- [in] posteriori: PosterioriResult["a b"] Returns ---------- Float["a b"] """ return np.array( [[r_m.mahalanobis_distance for r_m in p_t] for p_t in posteriori], dtype=np.float32, ) posteriors = outer_posteriori(measurements, tracks) distances = posteriori_to_mahalanobis(posteriors) row, col = linear_sum_assignment(np.array(distances)) row = np.array(row) col = np.array(col) def to_be_deleted() -> Generator[Tuple[int, int], None, None]: for i, j in zip(row, col): post: PosterioriResult = posteriors[i][j] if post.mahalanobis_distance > distance_threshold: yield i, j for i, j in to_be_deleted(): row = row[row != i] col = col[col != j] # update matched tracks for i, j in zip(row, col): track = cast(BoxTracking, tracks[i]) post: PosterioriResult = posteriors[i][j] track.state = CvModelGaussianState.from_gaussian(post.state) track.survived_time_steps += 1 track.last_n_bounding_boxes = self._push_new_bounding_box( track.last_n_bounding_boxes, bounding_boxes[j] ) tracks[i] = track evs.append( MatchedTrackingEvent( group=select_array, id=track.id, matched_bounding_box=bounding_boxes[j], ) ) # missed tracks # note that it just for statistical purpose # the tracking should be removed by the covariance threshold for i, track in enumerate(tracks): if i not in row: track.missed_time_steps += 1 tracks[i] = track # remove measurements that have been matched left_measurements = np.delete(measurements, col, axis=0) left_bounding_boxes = np.delete(bounding_boxes, col, axis=0) return evs, left_measurements, left_bounding_boxes @jaxtyped(typechecker=typechecked) def _tracks_from_past_measurements( self, measurements: Num[NDArray, "N 2"], bounding_boxes: Num[NDArray, "N 4"], dt: float = 1.0, distance_threshold: float = 3.0, ): """ consume the last measurements and create tentative tracks from them Note ---- mutate self._tentative_tracks and self._last_measurements """ evs: list[CreateTrackingEvent] = [] if self._last_measurements.shape[0] == 0: self._last_measurements = measurements return evs distances = outer_distance(self._last_measurements, measurements) row, col = linear_sum_assignment(distances) row = np.array(row) col = np.array(col) def to_be_deleted() -> Generator[Tuple[int, int], None, None]: for i, j in zip(row, col): euclidean_distance = distances[i, j] if euclidean_distance > distance_threshold: yield i, j for i, j in to_be_deleted(): row = row[row != i] col = col[col != j] for i, j in zip(row, col): coord = measurements[j] vel = (coord - self._last_measurements[i]) / dt s = np.concatenate([coord, vel]) state = GaussianState(x=s, P=np.eye(4)) track = BoxTracking( id=self._last_id, state=CvModelGaussianState.from_gaussian(state), survived_time_steps=0, missed_time_steps=0, last_n_bounding_boxes=np.expand_dims(bounding_boxes[j], axis=0), ) self._last_id += 1 self._tentative_tracks.append(track) evs.append( CreateTrackingEvent( group=TrackingState.Tentative, id=track.id, tracking=track ) ) # update the last measurements with the unmatched measurements self._last_measurements = np.delete(measurements, col, axis=0) return evs def _transfer_tentative_to_confirmed(self, survival_steps_threshold: int = 3): """ transfer tentative tracks to confirmed tracks Note ---- mutate self._tentative_tracks and self._confirmed_tracks in place """ evs: list[CreateTrackingEvent] = [] for i, track in enumerate(self._tentative_tracks): if track.survived_time_steps > survival_steps_threshold: self._confirmed_tracks.append(track) self._tentative_tracks.pop(i) evs.append( CreateTrackingEvent( group=TrackingState.Confirmed, id=track.id, tracking=track ) ) return evs def _track_cov_deleter( self, track_to_use: TrackingState, cov_threshold: float = 4.0 ): """ delete tracks with covariance trace greater than threshold Parameters ---------- [in,out] tracks: list[BoxTracking] cov_threshold: float the threshold of the covariance trace Returns ---------- list[BoxTracking] the deleted tracks Note ---- mutate tracks in place """ if track_to_use == TrackingState.Tentative: tracks = self._tentative_tracks elif track_to_use == TrackingState.Confirmed: tracks = self._confirmed_tracks else: raise ValueError("Unexpected tracking state {}".format(track_to_use)) ret: list[RemoveTrackingEvent] = [] for i, track in enumerate(tracks): # https://numpy.org/doc/stable/reference/generated/numpy.trace.html if np.trace(track.state.P) > cov_threshold: tracks.pop(i) ret.append( RemoveTrackingEvent(group=track_to_use, id=track.id, tracking=track) ) return ret def next_measurements( self, bounding_boxes: Num[NDArray, "N 4"], ): evs: list[TrackingEvent] areas = bounding_boxes_area(bounding_boxes, self._bounding_boxes_format) if any(areas <= AREA_FILTER_THRESHOLD): logger.trace( "too small bounding boxes; bboxes={}; areas={}", bounding_boxes, areas, ) bounding_boxes = np.delete( bounding_boxes, np.where(areas <= AREA_FILTER_THRESHOLD), axis=0 ) measurements = bounding_boxes_to_center( bounding_boxes, self._bounding_boxes_format ) self._confirmed_tracks = self._predict(self._confirmed_tracks, self._params.dt) self._tentative_tracks = self._predict(self._tentative_tracks, self._params.dt) c_evs, c_left_m, c_left_bb = self._data_associate_and_update( TrackingState.Confirmed, measurements, bounding_boxes ) t_evs, t_left_m, t_left_bb = self._data_associate_and_update( TrackingState.Tentative, c_left_m, c_left_bb ) create_c_evs = self._transfer_tentative_to_confirmed( self._params.survival_steps_threshold ) # target initialize create_t_evs = self._tracks_from_past_measurements( t_left_m, t_left_bb, self._params.dt, self._params.forming_tracks_euclidean_threshold, ) del_t_evs = self._track_cov_deleter( TrackingState.Tentative, self._params.cov_threshold ) del_c_evs = self._track_cov_deleter( TrackingState.Confirmed, self._params.cov_threshold ) evs = c_evs + t_evs + create_c_evs + create_t_evs + del_t_evs + del_c_evs return evs @property def confirmed_trackings(self): return self._confirmed_tracks @property def bounding_box_format(self): return self._bounding_boxes_format @staticmethod def motion_model(dt: float = 1, q: float = 0.05) -> LinearMotionNoInputModel: """ a constant velocity motion model """ # yapf: disable F = np.array([[1, 0, dt, 0], [0, 1, 0, dt], [0, 0, 1, 0], [0, 0, 0, 1]]) # yapf: enable Q = q * np.eye(4) return LinearMotionNoInputModel(F=F, Q=Q) @staticmethod def measurement_model(r: float = 0.75) -> LinearMeasurementModel: # yapf: disable H = np.array([[1, 0, 0, 0], [0, 1, 0, 0]]) # yapf: enable R = r * np.eye(2) return LinearMeasurementModel(H=H, R=R)