from dataclasses import dataclass from datetime import datetime from typing import ( Any, Callable, Generator, Optional, Sequence, TypeAlias, TypedDict, TypeVar, cast, overload, Protocol, ) from datetime import timedelta import jax.numpy as jnp from beartype import beartype from beartype.typing import Mapping, Sequence from jax import Array from jaxtyping import Array, Float, Int, jaxtyped from pyrsistent import PVector, v from itertools import chain from app.camera import Detection class TrackingPrediction(TypedDict): velocity: Float[Array, "J 3"] keypoints: Float[Array, "J 3"] class GenericVelocityFilter(Protocol): """ a filter interface for tracking velocity estimation """ def predict(self, timestamp: datetime) -> TrackingPrediction: """ predict the velocity and the keypoints location Args: timestamp: timestamp of the prediction Returns: velocity: velocity of the tracking keypoints: keypoints of the tracking """ ... # pylint: disable=unnecessary-ellipsis def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: """ update the filter state with new measurements Args: keypoints: new measurements timestamp: timestamp of the update """ ... # pylint: disable=unnecessary-ellipsis def get(self) -> TrackingPrediction: """ get the current state of the filter state Returns: velocity: velocity of the tracking keypoints: keypoints of the tracking """ ... # pylint: disable=unnecessary-ellipsis class LastDifferenceVelocityFilter(GenericVelocityFilter): """ a naive velocity filter that uses the last difference of keypoints """ _last_timestamp: datetime _last_keypoints: Float[Array, "J 3"] _last_velocity: Optional[Float[Array, "J 3"]] = None def __init__(self, keypoints: Float[Array, "J 3"], timestamp: datetime): self._last_keypoints = keypoints self._last_timestamp = timestamp def predict(self, timestamp: datetime) -> TrackingPrediction: delta_t_s = (timestamp - self._last_timestamp).total_seconds() if self._last_velocity is None: return TrackingPrediction( velocity=jnp.zeros_like(self._last_keypoints), keypoints=self._last_keypoints, ) else: return TrackingPrediction( velocity=self._last_velocity, keypoints=self._last_keypoints + self._last_velocity * delta_t_s, ) def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: delta_t_s = (timestamp - self._last_timestamp).total_seconds() self._last_velocity = (keypoints - self._last_keypoints) / delta_t_s self._last_keypoints = keypoints self._last_timestamp = timestamp def get(self) -> TrackingPrediction: if self._last_velocity is None: return TrackingPrediction( velocity=jnp.zeros_like(self._last_keypoints), keypoints=self._last_keypoints, ) else: return TrackingPrediction( velocity=self._last_velocity, keypoints=self._last_keypoints, ) class LeastMeanSquareVelocityFilter(GenericVelocityFilter): """ a velocity filter that uses the least mean square method to estimate the velocity """ _get_historical_detections: Callable[[], Sequence[Detection]] """ get the current historical detections, assuming the detections are sorted by timestamp incrementally (i.e. index 0 is the oldest detection, index -1 is the newest detection) """ _velocity: Optional[Float[Array, "J 3"]] = None @staticmethod def from_tracking(tracking: "Tracking") -> "LeastMeanSquareVelocityFilter": """ create a LeastMeanSquareVelocityFilter from a Tracking object """ velocity = tracking.velocity_filter.get()["velocity"] if jnp.all(velocity == jnp.zeros_like(velocity)): return LeastMeanSquareVelocityFilter( get_historical_detections=lambda: tracking.historical_detections ) else: f = LeastMeanSquareVelocityFilter( get_historical_detections=lambda: tracking.historical_detections ) # pylint: disable-next=protected-access f._velocity = velocity return f def __init__(self, get_historical_detections: Callable[[], Sequence[Detection]]): self._get_historical_detections = get_historical_detections self._velocity = None @property def velocity(self) -> Float[Array, "J 3"]: if self._velocity is None: raise ValueError("Velocity not initialized") return self._velocity def predict(self, timestamp: datetime) -> TrackingPrediction: historical_detections = self._get_historical_detections() if not historical_detections: raise ValueError("No historical detections available for prediction") # Use the latest historical detection latest_detection = historical_detections[-1] latest_keypoints = latest_detection.keypoints latest_timestamp = latest_detection.timestamp delta_t_s = (timestamp - latest_timestamp).total_seconds() if self._velocity is None: return TrackingPrediction( velocity=jnp.zeros_like(latest_keypoints), keypoints=latest_keypoints ) else: # Linear motion model: ẋt = xt' + Vt' · (t - t') predicted_keypoints = latest_keypoints + self._velocity * delta_t_s return TrackingPrediction( velocity=self._velocity, keypoints=predicted_keypoints ) @jaxtyped(typechecker=beartype) def _update( self, keypoints: Float[Array, "N J 3"], timestamps: Float[Array, "N"], ) -> None: """ update measurements with least mean square method """ if keypoints.shape[0] < 2: raise ValueError("Not enough measurements to estimate velocity") # Using least squares to fit a linear model for each joint and dimension # X = timestamps, y = keypoints # For each joint and each dimension, we solve for velocity n_samples = timestamps.shape[0] n_joints = keypoints.shape[1] # Create design matrix for linear regression # [t, 1] for each timestamp X = jnp.column_stack([timestamps, jnp.ones(n_samples)]) # Reshape keypoints to solve for all joints and dimensions at once # From [N, J, 3] to [N, J*3] keypoints_reshaped = keypoints.reshape(n_samples, -1) # Use JAX's lstsq to solve the least squares problem # This is more numerically stable than manually computing pseudoinverse coefficients, _, _, _ = jnp.linalg.lstsq(X, keypoints_reshaped, rcond=None) # Coefficients shape is [2, J*3] # First row: velocities, Second row: intercepts velocities = coefficients[0].reshape(n_joints, 3) # Update velocity self._velocity = velocities def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: historical_detections = self._get_historical_detections() if not historical_detections: self._velocity = jnp.zeros_like(keypoints) return t_0 = min(d.timestamp for d in historical_detections) all_keypoints = jnp.array( list(chain((d.keypoints for d in historical_detections), (keypoints,))) ) # Timestamps relative to t_0 (the oldest detection timestamp) all_timestamps = jnp.array( list( chain( ( (d.timestamp - t_0).total_seconds() for d in historical_detections ), ((timestamp - t_0).total_seconds(),), ) ) ) self._update(all_keypoints, all_timestamps) def get(self) -> TrackingPrediction: historical_detections = self._get_historical_detections() if not historical_detections: raise ValueError("No historical detections available") latest_detection = historical_detections[-1] latest_keypoints = latest_detection.keypoints if self._velocity is None: return TrackingPrediction( velocity=jnp.zeros_like(latest_keypoints), keypoints=latest_keypoints ) else: return TrackingPrediction( velocity=self._velocity, keypoints=latest_keypoints ) @jaxtyped(typechecker=beartype) @dataclass(frozen=True) class Tracking: id: int """ The tracking id """ keypoints: Float[Array, "J 3"] """ The 3D keypoints of the tracking Used for calculate affinity 3D """ last_active_timestamp: datetime """ The last active timestamp of the tracking """ historical_detections: PVector[Detection] """ Historical detections of the tracking. Used for 3D re-triangulation """ velocity_filter: GenericVelocityFilter """ The velocity filter of the tracking """ def __repr__(self) -> str: return f"Tracking({self.id}, {self.last_active_timestamp})" @overload def predict(self, time: float) -> Float[Array, "J 3"]: """ predict the keypoints at a given time Args: time: the time in seconds to predict the keypoints Returns: the predicted keypoints """ ... # pylint: disable=unnecessary-ellipsis @overload def predict(self, time: timedelta) -> Float[Array, "J 3"]: """ predict the keypoints at a given time Args: time: the time delta to predict the keypoints """ ... # pylint: disable=unnecessary-ellipsis @overload def predict(self, time: datetime) -> Float[Array, "J 3"]: """ predict the keypoints at a given time Args: time: the timestamp to predict the keypoints """ ... # pylint: disable=unnecessary-ellipsis def predict( self, time: float | timedelta | datetime, ) -> Float[Array, "J 3"]: if isinstance(time, timedelta): timestamp = self.last_active_timestamp + time elif isinstance(time, datetime): timestamp = time else: timestamp = self.last_active_timestamp + timedelta(seconds=time) # pylint: disable-next=unsubscriptable-object return self.velocity_filter.predict(timestamp)["keypoints"] @property def velocity(self) -> Float[Array, "J 3"]: """ The velocity of the tracking for each keypoint """ # pylint: disable-next=unsubscriptable-object return self.velocity_filter.get()["velocity"] @jaxtyped(typechecker=beartype) @dataclass class AffinityResult: """ Result of affinity computation between trackings and detections. """ matrix: Float[Array, "T D"] trackings: Sequence[Tracking] detections: Sequence[Detection] indices_T: Int[Array, "T"] # pylint: disable=invalid-name indices_D: Int[Array, "D"] # pylint: disable=invalid-name def tracking_detections( self, ) -> Generator[tuple[float, Tracking, Detection], None, None]: """ iterate over the best matching trackings and detections """ for t, d in zip(self.indices_T, self.indices_D): yield ( self.matrix[t, d].item(), self.trackings[t], self.detections[d], )