import weakref from collections import deque from dataclasses import dataclass from datetime import datetime, timedelta from itertools import chain from typing import ( Any, Callable, Generator, Optional, Protocol, Sequence, TypeAlias, TypedDict, TypeVar, Union, cast, overload, ) import jax.numpy as jnp from beartype import beartype from beartype.typing import Mapping, Sequence from jax import Array from jaxtyping import Array, Float, Int, jaxtyped from pyrsistent import PVector, v, PRecord, PMap from app.camera import Detection, CameraID TrackingID: TypeAlias = int class TrackingPrediction(TypedDict): velocity: Optional[Float[Array, "J 3"]] keypoints: Float[Array, "J 3"] class GenericVelocityFilter(Protocol): """ a filter interface for tracking velocity estimation """ def predict(self, timestamp: datetime) -> TrackingPrediction: """ predict the velocity and the keypoints location Args: timestamp: timestamp of the prediction Returns: velocity: velocity of the tracking keypoints: keypoints of the tracking """ ... # pylint: disable=unnecessary-ellipsis def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: """ update the filter state with new measurements Args: keypoints: new measurements timestamp: timestamp of the update """ ... # pylint: disable=unnecessary-ellipsis def get(self) -> TrackingPrediction: """ get the current state of the filter state Returns: velocity: velocity of the tracking keypoints: keypoints of the tracking """ ... # pylint: disable=unnecessary-ellipsis class DummyVelocityFilter(GenericVelocityFilter): """ a dummy velocity filter that does nothing """ _keypoints_shape: tuple[int, ...] def __init__(self, keypoints: Float[Array, "J 3"]): self._keypoints_shape = keypoints.shape def predict(self, timestamp: datetime) -> TrackingPrediction: return TrackingPrediction( velocity=None, keypoints=jnp.zeros(self._keypoints_shape), ) def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: ... def get(self) -> TrackingPrediction: return TrackingPrediction( velocity=None, keypoints=jnp.zeros(self._keypoints_shape), ) class LastDifferenceVelocityFilter(GenericVelocityFilter): """ a naive velocity filter that uses the last difference of keypoints """ _last_timestamp: datetime _last_keypoints: Float[Array, "J 3"] _last_velocity: Optional[Float[Array, "J 3"]] = None def __init__(self, keypoints: Float[Array, "J 3"], timestamp: datetime): self._last_keypoints = keypoints self._last_timestamp = timestamp def predict(self, timestamp: datetime) -> TrackingPrediction: delta_t_s = (timestamp - self._last_timestamp).total_seconds() if self._last_velocity is None: return TrackingPrediction( velocity=None, keypoints=self._last_keypoints, ) else: return TrackingPrediction( velocity=self._last_velocity, keypoints=self._last_keypoints + self._last_velocity * delta_t_s, ) def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: delta_t_s = (timestamp - self._last_timestamp).total_seconds() self._last_velocity = (keypoints - self._last_keypoints) / delta_t_s self._last_keypoints = keypoints self._last_timestamp = timestamp def get(self) -> TrackingPrediction: if self._last_velocity is None: return TrackingPrediction( velocity=None, keypoints=self._last_keypoints, ) else: return TrackingPrediction( velocity=self._last_velocity, keypoints=self._last_keypoints, ) class LeastMeanSquareVelocityFilter(GenericVelocityFilter): """ a velocity filter that uses the least mean square method to estimate the velocity """ _historical_3d_poses: deque[Float[Array, "J 3"]] _historical_timestamps: deque[datetime] _velocity: Optional[Float[Array, "J 3"]] = None _max_samples: int def __init__( self, historical_3d_poses: Sequence[Float[Array, "J 3"]], historical_timestamps: Sequence[datetime], max_samples: int = 10, ): assert len(historical_3d_poses) == len(historical_timestamps) temp = zip(historical_3d_poses, historical_timestamps) temp_sorted = sorted(temp, key=lambda x: x[1]) self._historical_3d_poses = deque( map(lambda x: x[0], temp_sorted), maxlen=max_samples ) self._historical_timestamps = deque( map(lambda x: x[1], temp_sorted), maxlen=max_samples ) self._max_samples = max_samples if len(self._historical_3d_poses) < 2: self._velocity = None else: self._update( jnp.array(self._historical_3d_poses), jnp.array(self._historical_timestamps), ) def predict(self, timestamp: datetime) -> TrackingPrediction: if not self._historical_3d_poses: raise ValueError("No historical 3D poses available for prediction") # use the latest historical detection latest_3d_pose = self._historical_3d_poses[-1] latest_timestamp = self._historical_timestamps[-1] delta_t_s = (timestamp - latest_timestamp).total_seconds() if self._velocity is None: return TrackingPrediction( velocity=None, keypoints=latest_3d_pose, ) else: # Linear motion model: ẋt = xt' + Vt' · (t - t') predicted_3d_pose = latest_3d_pose + self._velocity * delta_t_s return TrackingPrediction( velocity=self._velocity, keypoints=predicted_3d_pose ) @jaxtyped(typechecker=beartype) def _update( self, keypoints: Float[Array, "N J 3"], timestamps: Float[Array, "N"], ) -> None: """ update measurements with least mean square method """ if keypoints.shape[0] < 2: raise ValueError("Not enough measurements to estimate velocity") # Using least squares to fit a linear model for each joint and dimension # X = timestamps, y = keypoints # For each joint and each dimension, we solve for velocity n_samples = timestamps.shape[0] n_joints = keypoints.shape[1] # Create design matrix for linear regression # [t, 1] for each timestamp X = jnp.column_stack([timestamps, jnp.ones(n_samples)]) # Reshape keypoints to solve for all joints and dimensions at once # From [N, J, 3] to [N, J*3] keypoints_reshaped = keypoints.reshape(n_samples, -1) # Use JAX's lstsq to solve the least squares problem # This is more numerically stable than manually computing pseudoinverse coefficients, _, _, _ = jnp.linalg.lstsq(X, keypoints_reshaped, rcond=None) # Coefficients shape is [2, J*3] # First row: velocities, Second row: intercepts velocities = coefficients[0].reshape(n_joints, 3) # Update velocity self._velocity = velocities def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: last_timestamp = self._historical_timestamps[-1] assert last_timestamp <= timestamp # deque would manage the maxlen automatically self._historical_3d_poses.append(keypoints) self._historical_timestamps.append(timestamp) t_0 = self._historical_timestamps[0] all_keypoints = jnp.array(self._historical_3d_poses) def timestamp_to_seconds(timestamp: datetime) -> float: assert t_0 <= timestamp return (timestamp - t_0).total_seconds() # timestamps relative to t_0 (the oldest detection timestamp) all_timestamps = jnp.array( map(timestamp_to_seconds, self._historical_timestamps) ) self._update(all_keypoints, all_timestamps) def get(self) -> TrackingPrediction: if not self._historical_3d_poses: raise ValueError("No historical 3D poses available") latest_3d_pose = self._historical_3d_poses[-1] if self._velocity is None: return TrackingPrediction(velocity=None, keypoints=latest_3d_pose) else: return TrackingPrediction(velocity=self._velocity, keypoints=latest_3d_pose) class OneEuroFilter(GenericVelocityFilter): """ Implementation of the 1€ filter (One Euro Filter) for smoothing keypoint data. The 1€ filter is an adaptive low-pass filter that adjusts its cutoff frequency based on movement speed to reduce jitter during slow movements while maintaining responsiveness during fast movements. Reference: https://cristal.univ-lille.fr/~casiez/1euro/ """ _x_filtered: Float[Array, "J 3"] _dx_filtered: Optional[Float[Array, "J 3"]] = None _last_timestamp: datetime _min_cutoff: float _beta: float _d_cutoff: float def __init__( self, keypoints: Float[Array, "J 3"], timestamp: datetime, min_cutoff: float = 1.0, beta: float = 0.0, d_cutoff: float = 1.0, ): """ Initialize the One Euro Filter. Args: keypoints: Initial keypoints positions timestamp: Initial timestamp min_cutoff: Minimum cutoff frequency (lower = more smoothing) beta: Speed coefficient (higher = less lag during fast movements) d_cutoff: Cutoff frequency for the derivative filter """ self._last_timestamp = timestamp # Filter parameters self._min_cutoff = min_cutoff self._beta = beta self._d_cutoff = d_cutoff # Filter state self._x_filtered = keypoints # Position filter state self._dx_filtered = None # Initially no velocity estimate @overload def _smoothing_factor(self, cutoff: float, dt: float) -> float: ... @overload def _smoothing_factor( self, cutoff: Float[Array, "J"], dt: float ) -> Float[Array, "J"]: ... @jaxtyped(typechecker=beartype) def _smoothing_factor( self, cutoff: Union[float, Float[Array, "J"]], dt: float ) -> Union[float, Float[Array, "J"]]: """Calculate the smoothing factor for the low-pass filter.""" r = 2 * jnp.pi * cutoff * dt return r / (r + 1) @jaxtyped(typechecker=beartype) def _exponential_smoothing( self, a: Union[float, Float[Array, "J"]], x: Float[Array, "J 3"], x_prev: Float[Array, "J 3"], ) -> Float[Array, "J 3"]: """Apply exponential smoothing to the input.""" return a * x + (1 - a) * x_prev def predict(self, timestamp: datetime) -> TrackingPrediction: """ Predict keypoints position at a given timestamp. Args: timestamp: Timestamp for prediction Returns: TrackingPrediction with velocity and keypoints """ dt = (timestamp - self._last_timestamp).total_seconds() if self._dx_filtered is None: return TrackingPrediction( velocity=None, keypoints=self._x_filtered, ) else: predicted_keypoints = self._x_filtered + self._dx_filtered * dt return TrackingPrediction( velocity=self._dx_filtered, keypoints=predicted_keypoints, ) def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: """ Update the filter with new measurements. Args: keypoints: New keypoint measurements timestamp: Timestamp of the measurements """ dt = (timestamp - self._last_timestamp).total_seconds() if dt <= 0: raise ValueError( f"new timestamp is not greater than the last timestamp; expecting: {timestamp} > {self._last_timestamp}" ) dx = (keypoints - self._x_filtered) / dt # Determine cutoff frequency based on movement speed cutoff = self._min_cutoff + self._beta * jnp.linalg.norm( dx, axis=-1, keepdims=True ) # Apply low-pass filter to velocity a_d = self._smoothing_factor(self._d_cutoff, dt) self._dx_filtered = self._exponential_smoothing( a_d, dx, ( jnp.zeros_like(keypoints) if self._dx_filtered is None else self._dx_filtered ), ) # Apply low-pass filter to position with adaptive cutoff a_cutoff = self._smoothing_factor(jnp.asarray(cutoff), dt) self._x_filtered = self._exponential_smoothing( a_cutoff, keypoints, self._x_filtered ) # Update timestamp self._last_timestamp = timestamp def get(self) -> TrackingPrediction: """ Get the current state of the filter. Returns: TrackingPrediction with velocity and keypoints """ return TrackingPrediction( velocity=self._dx_filtered, keypoints=self._x_filtered, ) @jaxtyped(typechecker=beartype) @dataclass(frozen=True) class TrackingState: """ immutable state of a tracking """ keypoints: Float[Array, "J 3"] """ The 3D keypoints of the tracking Used for calculate affinity 3D """ last_active_timestamp: datetime """ The last active timestamp of the tracking """ historical_detections_by_camera: PMap[CameraID, Detection] """ Historical detections of the tracking. Used for 3D re-triangulation """ class Tracking: id: TrackingID state: TrackingState velocity_filter: GenericVelocityFilter def __init__( self, id: TrackingID, state: TrackingState, velocity_filter: Optional[GenericVelocityFilter] = None, ): self.id = id self.state = state self.velocity_filter = velocity_filter or DummyVelocityFilter(state.keypoints) def __repr__(self) -> str: return f"Tracking({self.id}, {self.state.last_active_timestamp})" @overload def predict(self, time: float) -> Float[Array, "J 3"]: """ predict the keypoints at a given time Args: time: the time in seconds to predict the keypoints Returns: the predicted keypoints """ ... # pylint: disable=unnecessary-ellipsis @overload def predict(self, time: timedelta) -> Float[Array, "J 3"]: """ predict the keypoints at a given time Args: time: the time delta to predict the keypoints """ ... # pylint: disable=unnecessary-ellipsis @overload def predict(self, time: datetime) -> Float[Array, "J 3"]: """ predict the keypoints at a given time Args: time: the timestamp to predict the keypoints """ ... # pylint: disable=unnecessary-ellipsis def predict( self, time: float | timedelta | datetime, ) -> Float[Array, "J 3"]: if isinstance(time, timedelta): timestamp = self.state.last_active_timestamp + time elif isinstance(time, datetime): timestamp = time else: timestamp = self.state.last_active_timestamp + timedelta(seconds=time) # pylint: disable-next=unsubscriptable-object return self.velocity_filter.predict(timestamp)["keypoints"] def update(self, new_3d_pose: Float[Array, "J 3"], timestamp: datetime) -> None: """ update the tracking with a new 3D pose Note: equivalent to call `velocity_filter.update(new_3d_pose, timestamp)` """ self.velocity_filter.update(new_3d_pose, timestamp) @property def velocity(self) -> Float[Array, "J 3"]: """ The velocity of the tracking for each keypoint """ # pylint: disable-next=unsubscriptable-object if (vel := self.velocity_filter.get()["velocity"]) is None: return jnp.zeros_like(self.state.keypoints) else: return vel @jaxtyped(typechecker=beartype) @dataclass class AffinityResult: """ Result of affinity computation between trackings and detections. """ matrix: Float[Array, "T D"] trackings: Sequence[Tracking] detections: Sequence[Detection] indices_T: Int[Array, "T"] # pylint: disable=invalid-name indices_D: Int[Array, "D"] # pylint: disable=invalid-name def tracking_association( self, ) -> Generator[tuple[float, Tracking, Detection], None, None]: """ iterate over the best matching trackings and detections """ for t, d in zip(self.indices_T, self.indices_D): yield ( self.matrix[t, d].item(), self.trackings[t], self.detections[d], )