import weakref from collections import deque from dataclasses import dataclass from datetime import datetime, timedelta from itertools import chain from typing import ( Any, Callable, Generator, Optional, Protocol, Sequence, TypeAlias, TypedDict, TypeVar, cast, overload, ) import jax.numpy as jnp from beartype import beartype from beartype.typing import Mapping, Sequence from jax import Array from jaxtyping import Array, Float, Int, jaxtyped from pyrsistent import PVector, v from app.camera import Detection class TrackingPrediction(TypedDict): velocity: Optional[Float[Array, "J 3"]] keypoints: Float[Array, "J 3"] class GenericVelocityFilter(Protocol): """ a filter interface for tracking velocity estimation """ def predict(self, timestamp: datetime) -> TrackingPrediction: """ predict the velocity and the keypoints location Args: timestamp: timestamp of the prediction Returns: velocity: velocity of the tracking keypoints: keypoints of the tracking """ ... # pylint: disable=unnecessary-ellipsis def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: """ update the filter state with new measurements Args: keypoints: new measurements timestamp: timestamp of the update """ ... # pylint: disable=unnecessary-ellipsis def get(self) -> TrackingPrediction: """ get the current state of the filter state Returns: velocity: velocity of the tracking keypoints: keypoints of the tracking """ ... # pylint: disable=unnecessary-ellipsis class DummyVelocityFilter(GenericVelocityFilter): """ a dummy velocity filter that does nothing """ _keypoints_shape: tuple[int, ...] def __init__(self, keypoints: Float[Array, "J 3"]): self._keypoints_shape = keypoints.shape def predict(self, timestamp: datetime) -> TrackingPrediction: return TrackingPrediction( velocity=None, keypoints=jnp.zeros(self._keypoints_shape), ) def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: ... def get(self) -> TrackingPrediction: return TrackingPrediction( velocity=None, keypoints=jnp.zeros(self._keypoints_shape), ) class LastDifferenceVelocityFilter(GenericVelocityFilter): """ a naive velocity filter that uses the last difference of keypoints """ _last_timestamp: datetime _last_keypoints: Float[Array, "J 3"] _last_velocity: Optional[Float[Array, "J 3"]] = None def __init__(self, keypoints: Float[Array, "J 3"], timestamp: datetime): self._last_keypoints = keypoints self._last_timestamp = timestamp def predict(self, timestamp: datetime) -> TrackingPrediction: delta_t_s = (timestamp - self._last_timestamp).total_seconds() if self._last_velocity is None: return TrackingPrediction( velocity=None, keypoints=self._last_keypoints, ) else: return TrackingPrediction( velocity=self._last_velocity, keypoints=self._last_keypoints + self._last_velocity * delta_t_s, ) def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: delta_t_s = (timestamp - self._last_timestamp).total_seconds() self._last_velocity = (keypoints - self._last_keypoints) / delta_t_s self._last_keypoints = keypoints self._last_timestamp = timestamp def get(self) -> TrackingPrediction: if self._last_velocity is None: return TrackingPrediction( velocity=None, keypoints=self._last_keypoints, ) else: return TrackingPrediction( velocity=self._last_velocity, keypoints=self._last_keypoints, ) class LeastMeanSquareVelocityFilter(GenericVelocityFilter): """ a velocity filter that uses the least mean square method to estimate the velocity """ _historical_3d_poses: deque[Float[Array, "J 3"]] _historical_timestamps: deque[datetime] _velocity: Optional[Float[Array, "J 3"]] = None _max_samples: int def __init__( self, historical_3d_poses: Sequence[Float[Array, "J 3"]], historical_timestamps: Sequence[datetime], max_samples: int = 10, ): assert len(historical_3d_poses) == len(historical_timestamps) temp = zip(historical_3d_poses, historical_timestamps) temp_sorted = sorted(temp, key=lambda x: x[1]) self._historical_3d_poses = deque( map(lambda x: x[0], temp_sorted), maxlen=max_samples ) self._historical_timestamps = deque( map(lambda x: x[1], temp_sorted), maxlen=max_samples ) self._max_samples = max_samples if len(self._historical_3d_poses) < 2: self._velocity = None else: self._update( jnp.array(self._historical_3d_poses), jnp.array(self._historical_timestamps), ) def predict(self, timestamp: datetime) -> TrackingPrediction: if not self._historical_3d_poses: raise ValueError("No historical 3D poses available for prediction") # use the latest historical detection latest_3d_pose = self._historical_3d_poses[-1] latest_timestamp = self._historical_timestamps[-1] delta_t_s = (timestamp - latest_timestamp).total_seconds() if self._velocity is None: return TrackingPrediction( velocity=None, keypoints=latest_3d_pose, ) else: # Linear motion model: ẋt = xt' + Vt' · (t - t') predicted_3d_pose = latest_3d_pose + self._velocity * delta_t_s return TrackingPrediction( velocity=self._velocity, keypoints=predicted_3d_pose ) @jaxtyped(typechecker=beartype) def _update( self, keypoints: Float[Array, "N J 3"], timestamps: Float[Array, "N"], ) -> None: """ update measurements with least mean square method """ if keypoints.shape[0] < 2: raise ValueError("Not enough measurements to estimate velocity") # Using least squares to fit a linear model for each joint and dimension # X = timestamps, y = keypoints # For each joint and each dimension, we solve for velocity n_samples = timestamps.shape[0] n_joints = keypoints.shape[1] # Create design matrix for linear regression # [t, 1] for each timestamp X = jnp.column_stack([timestamps, jnp.ones(n_samples)]) # Reshape keypoints to solve for all joints and dimensions at once # From [N, J, 3] to [N, J*3] keypoints_reshaped = keypoints.reshape(n_samples, -1) # Use JAX's lstsq to solve the least squares problem # This is more numerically stable than manually computing pseudoinverse coefficients, _, _, _ = jnp.linalg.lstsq(X, keypoints_reshaped, rcond=None) # Coefficients shape is [2, J*3] # First row: velocities, Second row: intercepts velocities = coefficients[0].reshape(n_joints, 3) # Update velocity self._velocity = velocities def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: last_timestamp = self._historical_timestamps[-1] assert last_timestamp <= timestamp # deque would manage the maxlen automatically self._historical_3d_poses.append(keypoints) self._historical_timestamps.append(timestamp) t_0 = self._historical_timestamps[0] all_keypoints = jnp.array(self._historical_3d_poses) def timestamp_to_seconds(timestamp: datetime) -> float: assert t_0 <= timestamp return (timestamp - t_0).total_seconds() # timestamps relative to t_0 (the oldest detection timestamp) all_timestamps = jnp.array( map(timestamp_to_seconds, self._historical_timestamps) ) self._update(all_keypoints, all_timestamps) def get(self) -> TrackingPrediction: if not self._historical_3d_poses: raise ValueError("No historical 3D poses available") latest_3d_pose = self._historical_3d_poses[-1] if self._velocity is None: return TrackingPrediction(velocity=None, keypoints=latest_3d_pose) else: return TrackingPrediction(velocity=self._velocity, keypoints=latest_3d_pose) @jaxtyped(typechecker=beartype) @dataclass(frozen=True) class TrackingState: """ immutable state of a tracking """ keypoints: Float[Array, "J 3"] """ The 3D keypoints of the tracking Used for calculate affinity 3D """ last_active_timestamp: datetime """ The last active timestamp of the tracking """ historical_detections: PVector[Detection] """ Historical detections of the tracking. Used for 3D re-triangulation """ class Tracking: id: int state: TrackingState velocity_filter: GenericVelocityFilter def __init__( self, id: int, state: TrackingState, velocity_filter: Optional[GenericVelocityFilter] = None, ): self.id = id self.state = state self.velocity_filter = velocity_filter or DummyVelocityFilter(state.keypoints) def __repr__(self) -> str: return f"Tracking({self.id}, {self.state.last_active_timestamp})" @overload def predict(self, time: float) -> Float[Array, "J 3"]: """ predict the keypoints at a given time Args: time: the time in seconds to predict the keypoints Returns: the predicted keypoints """ ... # pylint: disable=unnecessary-ellipsis @overload def predict(self, time: timedelta) -> Float[Array, "J 3"]: """ predict the keypoints at a given time Args: time: the time delta to predict the keypoints """ ... # pylint: disable=unnecessary-ellipsis @overload def predict(self, time: datetime) -> Float[Array, "J 3"]: """ predict the keypoints at a given time Args: time: the timestamp to predict the keypoints """ ... # pylint: disable=unnecessary-ellipsis def predict( self, time: float | timedelta | datetime, ) -> Float[Array, "J 3"]: if isinstance(time, timedelta): timestamp = self.state.last_active_timestamp + time elif isinstance(time, datetime): timestamp = time else: timestamp = self.state.last_active_timestamp + timedelta(seconds=time) # pylint: disable-next=unsubscriptable-object return self.velocity_filter.predict(timestamp)["keypoints"] @property def velocity(self) -> Float[Array, "J 3"]: """ The velocity of the tracking for each keypoint """ # pylint: disable-next=unsubscriptable-object if (vel := self.velocity_filter.get()["velocity"]) is None: return jnp.zeros_like(self.state.keypoints) else: return vel @jaxtyped(typechecker=beartype) @dataclass class AffinityResult: """ Result of affinity computation between trackings and detections. """ matrix: Float[Array, "T D"] trackings: Sequence[Tracking] detections: Sequence[Detection] indices_T: Int[Array, "T"] # pylint: disable=invalid-name indices_D: Int[Array, "D"] # pylint: disable=invalid-name def tracking_detections( self, ) -> Generator[tuple[float, Tracking, Detection], None, None]: """ iterate over the best matching trackings and detections """ for t, d in zip(self.indices_T, self.indices_D): yield ( self.matrix[t, d].item(), self.trackings[t], self.detections[d], )