diff --git a/app/tracking/__init__.py b/app/tracking/__init__.py index a0f797a..7fb7ed0 100644 --- a/app/tracking/__init__.py +++ b/app/tracking/__init__.py @@ -317,6 +317,15 @@ class OneEuroFilter(GenericVelocityFilter): self._x_filtered = keypoints # Position filter state self._dx_filtered = None # Initially no velocity estimate + @overload + def _smoothing_factor(self, cutoff: float, dt: float) -> float: ... + + @overload + def _smoothing_factor( + self, cutoff: Float[Array, "J"], dt: float + ) -> Float[Array, "J"]: ... + + @jaxtyped(typechecker=beartype) def _smoothing_factor( self, cutoff: Union[float, Float[Array, "J"]], dt: float ) -> Union[float, Float[Array, "J"]]: @@ -324,6 +333,7 @@ class OneEuroFilter(GenericVelocityFilter): r = 2 * jnp.pi * cutoff * dt return r / (r + 1) + @jaxtyped(typechecker=beartype) def _exponential_smoothing( self, a: Union[float, Float[Array, "J"]], @@ -367,12 +377,11 @@ class OneEuroFilter(GenericVelocityFilter): """ dt = (timestamp - self._last_timestamp).total_seconds() if dt <= 0: - raise ValueError("invalid timestamp") + raise ValueError( + f"new timestamp is not greater than the last timestamp; expecting: {timestamp} > {self._last_timestamp}" + ) - # Compute velocity from current input and filtered state - dx = ( - (keypoints - self._x_filtered) / dt if dt > 0 else jnp.zeros_like(keypoints) - ) + dx = (keypoints - self._x_filtered) / dt # Determine cutoff frequency based on movement speed cutoff = self._min_cutoff + self._beta * jnp.linalg.norm( @@ -392,9 +401,7 @@ class OneEuroFilter(GenericVelocityFilter): ) # Apply low-pass filter to position with adaptive cutoff - a_cutoff = self._smoothing_factor( - jnp.asarray(cutoff), dt - ) # Convert cutoff to scalar if needed + a_cutoff = self._smoothing_factor(jnp.asarray(cutoff), dt) self._x_filtered = self._exponential_smoothing( a_cutoff, keypoints, self._x_filtered )