import weakref from dataclasses import dataclass from datetime import datetime, timedelta from itertools import chain from typing import ( Any, Callable, Generator, Optional, Protocol, Sequence, TypeAlias, TypedDict, TypeVar, cast, overload, ) import jax.numpy as jnp from beartype import beartype from beartype.typing import Mapping, Sequence from jax import Array from jaxtyping import Array, Float, Int, jaxtyped from pyrsistent import PVector, v from app.camera import Detection class TrackingPrediction(TypedDict): velocity: Optional[Float[Array, "J 3"]] keypoints: Float[Array, "J 3"] class GenericVelocityFilter(Protocol): """ a filter interface for tracking velocity estimation """ def predict(self, timestamp: datetime) -> TrackingPrediction: """ predict the velocity and the keypoints location Args: timestamp: timestamp of the prediction Returns: velocity: velocity of the tracking keypoints: keypoints of the tracking """ ... # pylint: disable=unnecessary-ellipsis def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: """ update the filter state with new measurements Args: keypoints: new measurements timestamp: timestamp of the update """ ... # pylint: disable=unnecessary-ellipsis def get(self) -> TrackingPrediction: """ get the current state of the filter state Returns: velocity: velocity of the tracking keypoints: keypoints of the tracking """ ... # pylint: disable=unnecessary-ellipsis class DummyVelocityFilter(GenericVelocityFilter): """ a dummy velocity filter that does nothing """ _keypoints_shape: tuple[int, ...] def __init__(self, keypoints: Float[Array, "J 3"]): self._keypoints_shape = keypoints.shape def predict(self, timestamp: datetime) -> TrackingPrediction: return TrackingPrediction( velocity=None, keypoints=jnp.zeros(self._keypoints_shape), ) def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: ... def get(self) -> TrackingPrediction: return TrackingPrediction( velocity=None, keypoints=jnp.zeros(self._keypoints_shape), ) class LastDifferenceVelocityFilter(GenericVelocityFilter): """ a naive velocity filter that uses the last difference of keypoints """ _last_timestamp: datetime _last_keypoints: Float[Array, "J 3"] _last_velocity: Optional[Float[Array, "J 3"]] = None def __init__(self, keypoints: Float[Array, "J 3"], timestamp: datetime): self._last_keypoints = keypoints self._last_timestamp = timestamp def predict(self, timestamp: datetime) -> TrackingPrediction: delta_t_s = (timestamp - self._last_timestamp).total_seconds() if self._last_velocity is None: return TrackingPrediction( velocity=None, keypoints=self._last_keypoints, ) else: return TrackingPrediction( velocity=self._last_velocity, keypoints=self._last_keypoints + self._last_velocity * delta_t_s, ) def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: delta_t_s = (timestamp - self._last_timestamp).total_seconds() self._last_velocity = (keypoints - self._last_keypoints) / delta_t_s self._last_keypoints = keypoints self._last_timestamp = timestamp def get(self) -> TrackingPrediction: if self._last_velocity is None: return TrackingPrediction( velocity=None, keypoints=self._last_keypoints, ) else: return TrackingPrediction( velocity=self._last_velocity, keypoints=self._last_keypoints, ) class LeastMeanSquareVelocityFilter(GenericVelocityFilter): """ a velocity filter that uses the least mean square method to estimate the velocity """ _get_historical_detections: Callable[[], Sequence[Detection]] """ get the current historical detections, assuming the detections are sorted by timestamp incrementally (i.e. index 0 is the oldest detection, index -1 is the newest detection) """ _velocity: Optional[Float[Array, "J 3"]] = None def __init__(self, get_historical_detections: Callable[[], Sequence[Detection]]): self._get_historical_detections = get_historical_detections self._velocity = None @staticmethod def from_tracking(tracking: "Tracking") -> "LeastMeanSquareVelocityFilter": """ create a LeastMeanSquareVelocityFilter from a Tracking object Note that this function is using a weak reference to the tracking object, so that the tracking object can be garbage collected if there are no other references to it. """ # Create a weak reference to avoid circular references # https://docs.python.org/3/library/weakref.html tracking_ref = weakref.ref(tracking) # Create a getter function that uses the weak reference def get_historical_detections() -> Sequence[Detection]: tr = tracking_ref() if tr is None: return [] # Return empty list if tracking has been garbage collected return tr.state.historical_detections velocity = tracking.velocity_filter.get()["velocity"] if velocity is None: return LeastMeanSquareVelocityFilter( get_historical_detections=get_historical_detections ) else: f = LeastMeanSquareVelocityFilter( get_historical_detections=get_historical_detections ) # pylint: disable-next=protected-access f._velocity = velocity return f def predict(self, timestamp: datetime) -> TrackingPrediction: historical_detections = self._get_historical_detections() if not historical_detections: raise ValueError("No historical detections available for prediction") # Use the latest historical detection latest_detection = historical_detections[-1] latest_keypoints = latest_detection.keypoints latest_timestamp = latest_detection.timestamp delta_t_s = (timestamp - latest_timestamp).total_seconds() if self._velocity is None: return TrackingPrediction( velocity=None, keypoints=latest_keypoints, ) else: # Linear motion model: ẋt = xt' + Vt' · (t - t') predicted_keypoints = latest_keypoints + self._velocity * delta_t_s return TrackingPrediction( velocity=self._velocity, keypoints=predicted_keypoints ) @jaxtyped(typechecker=beartype) def _update( self, keypoints: Float[Array, "N J 3"], timestamps: Float[Array, "N"], ) -> None: """ update measurements with least mean square method """ if keypoints.shape[0] < 2: raise ValueError("Not enough measurements to estimate velocity") # Using least squares to fit a linear model for each joint and dimension # X = timestamps, y = keypoints # For each joint and each dimension, we solve for velocity n_samples = timestamps.shape[0] n_joints = keypoints.shape[1] # Create design matrix for linear regression # [t, 1] for each timestamp X = jnp.column_stack([timestamps, jnp.ones(n_samples)]) # Reshape keypoints to solve for all joints and dimensions at once # From [N, J, 3] to [N, J*3] keypoints_reshaped = keypoints.reshape(n_samples, -1) # Use JAX's lstsq to solve the least squares problem # This is more numerically stable than manually computing pseudoinverse coefficients, _, _, _ = jnp.linalg.lstsq(X, keypoints_reshaped, rcond=None) # Coefficients shape is [2, J*3] # First row: velocities, Second row: intercepts velocities = coefficients[0].reshape(n_joints, 3) # Update velocity self._velocity = velocities def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: historical_detections = self._get_historical_detections() if not historical_detections: self._velocity = jnp.zeros_like(keypoints) return t_0 = min(d.timestamp for d in historical_detections) all_keypoints = jnp.array( list(chain((d.keypoints for d in historical_detections), (keypoints,))) ) # Timestamps relative to t_0 (the oldest detection timestamp) all_timestamps = jnp.array( list( chain( ( (d.timestamp - t_0).total_seconds() for d in historical_detections ), ((timestamp - t_0).total_seconds(),), ) ) ) self._update(all_keypoints, all_timestamps) def get(self) -> TrackingPrediction: historical_detections = self._get_historical_detections() if not historical_detections: raise ValueError("No historical detections available") latest_detection = historical_detections[-1] latest_keypoints = latest_detection.keypoints if self._velocity is None: return TrackingPrediction(velocity=None, keypoints=latest_keypoints) else: return TrackingPrediction( velocity=self._velocity, keypoints=latest_keypoints ) @jaxtyped(typechecker=beartype) @dataclass(frozen=True) class TrackingState: """ immutable state of a tracking """ keypoints: Float[Array, "J 3"] """ The 3D keypoints of the tracking Used for calculate affinity 3D """ last_active_timestamp: datetime """ The last active timestamp of the tracking """ historical_detections: PVector[Detection] """ Historical detections of the tracking. Used for 3D re-triangulation """ class Tracking: id: int state: TrackingState velocity_filter: GenericVelocityFilter def __init__( self, id: int, state: TrackingState, velocity_filter: Optional[GenericVelocityFilter] = None, ): self.id = id self.state = state self.velocity_filter = velocity_filter or DummyVelocityFilter(state.keypoints) def __repr__(self) -> str: return f"Tracking({self.id}, {self.state.last_active_timestamp})" @overload def predict(self, time: float) -> Float[Array, "J 3"]: """ predict the keypoints at a given time Args: time: the time in seconds to predict the keypoints Returns: the predicted keypoints """ ... # pylint: disable=unnecessary-ellipsis @overload def predict(self, time: timedelta) -> Float[Array, "J 3"]: """ predict the keypoints at a given time Args: time: the time delta to predict the keypoints """ ... # pylint: disable=unnecessary-ellipsis @overload def predict(self, time: datetime) -> Float[Array, "J 3"]: """ predict the keypoints at a given time Args: time: the timestamp to predict the keypoints """ ... # pylint: disable=unnecessary-ellipsis def predict( self, time: float | timedelta | datetime, ) -> Float[Array, "J 3"]: if isinstance(time, timedelta): timestamp = self.state.last_active_timestamp + time elif isinstance(time, datetime): timestamp = time else: timestamp = self.state.last_active_timestamp + timedelta(seconds=time) # pylint: disable-next=unsubscriptable-object return self.velocity_filter.predict(timestamp)["keypoints"] @property def velocity(self) -> Float[Array, "J 3"]: """ The velocity of the tracking for each keypoint """ # pylint: disable-next=unsubscriptable-object if (vel := self.velocity_filter.get()["velocity"]) is None: raise ValueError("Velocity is not available") else: return vel @jaxtyped(typechecker=beartype) @dataclass class AffinityResult: """ Result of affinity computation between trackings and detections. """ matrix: Float[Array, "T D"] trackings: Sequence[Tracking] detections: Sequence[Detection] indices_T: Int[Array, "T"] # pylint: disable=invalid-name indices_D: Int[Array, "D"] # pylint: disable=invalid-name def tracking_detections( self, ) -> Generator[tuple[float, Tracking, Detection], None, None]: """ iterate over the best matching trackings and detections """ for t, d in zip(self.indices_T, self.indices_D): yield ( self.matrix[t, d].item(), self.trackings[t], self.detections[d], )