diff --git a/app/tracking/__init__.py b/app/tracking/__init__.py index e071ab2..a0f797a 100644 --- a/app/tracking/__init__.py +++ b/app/tracking/__init__.py @@ -15,6 +15,7 @@ from typing import ( TypeVar, cast, overload, + Union, ) import jax.numpy as jnp @@ -269,6 +270,151 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter): return TrackingPrediction(velocity=self._velocity, keypoints=latest_3d_pose) +class OneEuroFilter(GenericVelocityFilter): + """ + Implementation of the 1€ filter (One Euro Filter) for smoothing keypoint data. + + The 1€ filter is an adaptive low-pass filter that adjusts its cutoff frequency + based on movement speed to reduce jitter during slow movements while maintaining + responsiveness during fast movements. + + Reference: https://cristal.univ-lille.fr/~casiez/1euro/ + """ + + _x_filtered: Float[Array, "J 3"] + _dx_filtered: Optional[Float[Array, "J 3"]] = None + _last_timestamp: datetime + _min_cutoff: float + _beta: float + _d_cutoff: float + + def __init__( + self, + keypoints: Float[Array, "J 3"], + timestamp: datetime, + min_cutoff: float = 1.0, + beta: float = 0.0, + d_cutoff: float = 1.0, + ): + """ + Initialize the One Euro Filter. + + Args: + keypoints: Initial keypoints positions + timestamp: Initial timestamp + min_cutoff: Minimum cutoff frequency (lower = more smoothing) + beta: Speed coefficient (higher = less lag during fast movements) + d_cutoff: Cutoff frequency for the derivative filter + """ + self._last_timestamp = timestamp + + # Filter parameters + self._min_cutoff = min_cutoff + self._beta = beta + self._d_cutoff = d_cutoff + + # Filter state + self._x_filtered = keypoints # Position filter state + self._dx_filtered = None # Initially no velocity estimate + + def _smoothing_factor( + self, cutoff: Union[float, Float[Array, "J"]], dt: float + ) -> Union[float, Float[Array, "J"]]: + """Calculate the smoothing factor for the low-pass filter.""" + r = 2 * jnp.pi * cutoff * dt + return r / (r + 1) + + def _exponential_smoothing( + self, + a: Union[float, Float[Array, "J"]], + x: Float[Array, "J 3"], + x_prev: Float[Array, "J 3"], + ) -> Float[Array, "J 3"]: + """Apply exponential smoothing to the input.""" + return a * x + (1 - a) * x_prev + + def predict(self, timestamp: datetime) -> TrackingPrediction: + """ + Predict keypoints position at a given timestamp. + + Args: + timestamp: Timestamp for prediction + + Returns: + TrackingPrediction with velocity and keypoints + """ + dt = (timestamp - self._last_timestamp).total_seconds() + + if self._dx_filtered is None: + return TrackingPrediction( + velocity=None, + keypoints=self._x_filtered, + ) + else: + predicted_keypoints = self._x_filtered + self._dx_filtered * dt + return TrackingPrediction( + velocity=self._dx_filtered, + keypoints=predicted_keypoints, + ) + + def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: + """ + Update the filter with new measurements. + + Args: + keypoints: New keypoint measurements + timestamp: Timestamp of the measurements + """ + dt = (timestamp - self._last_timestamp).total_seconds() + if dt <= 0: + raise ValueError("invalid timestamp") + + # Compute velocity from current input and filtered state + dx = ( + (keypoints - self._x_filtered) / dt if dt > 0 else jnp.zeros_like(keypoints) + ) + + # Determine cutoff frequency based on movement speed + cutoff = self._min_cutoff + self._beta * jnp.linalg.norm( + dx, axis=-1, keepdims=True + ) + + # Apply low-pass filter to velocity + a_d = self._smoothing_factor(self._d_cutoff, dt) + self._dx_filtered = self._exponential_smoothing( + a_d, + dx, + ( + jnp.zeros_like(keypoints) + if self._dx_filtered is None + else self._dx_filtered + ), + ) + + # Apply low-pass filter to position with adaptive cutoff + a_cutoff = self._smoothing_factor( + jnp.asarray(cutoff), dt + ) # Convert cutoff to scalar if needed + self._x_filtered = self._exponential_smoothing( + a_cutoff, keypoints, self._x_filtered + ) + + # Update timestamp + self._last_timestamp = timestamp + + def get(self) -> TrackingPrediction: + """ + Get the current state of the filter. + + Returns: + TrackingPrediction with velocity and keypoints + """ + return TrackingPrediction( + velocity=self._dx_filtered, + keypoints=self._x_filtered, + ) + + @jaxtyped(typechecker=beartype) @dataclass(frozen=True) class TrackingState: