1
0
forked from HQU-gxy/CVTH3PE

feat: Add OneEuroFilter for adaptive keypoint smoothing

- Introduced the `OneEuroFilter` class to implement an adaptive low-pass filter for smoothing keypoint data, enhancing tracking stability during varying movement speeds.
- Implemented methods for initialization, prediction, and updating of keypoints, allowing for dynamic adjustment of smoothing based on movement.
- Added detailed documentation and type hints to clarify the filter's functionality and parameters.
- Improved the handling of timestamps and filtering logic to ensure accurate predictions and updates.
This commit is contained in:
2025-05-03 14:58:51 +08:00
parent d2c1c8d624
commit 4a5cfde245

View File

@ -15,6 +15,7 @@ from typing import (
TypeVar, TypeVar,
cast, cast,
overload, overload,
Union,
) )
import jax.numpy as jnp import jax.numpy as jnp
@ -269,6 +270,151 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
return TrackingPrediction(velocity=self._velocity, keypoints=latest_3d_pose) return TrackingPrediction(velocity=self._velocity, keypoints=latest_3d_pose)
class OneEuroFilter(GenericVelocityFilter):
"""
Implementation of the 1€ filter (One Euro Filter) for smoothing keypoint data.
The 1€ filter is an adaptive low-pass filter that adjusts its cutoff frequency
based on movement speed to reduce jitter during slow movements while maintaining
responsiveness during fast movements.
Reference: https://cristal.univ-lille.fr/~casiez/1euro/
"""
_x_filtered: Float[Array, "J 3"]
_dx_filtered: Optional[Float[Array, "J 3"]] = None
_last_timestamp: datetime
_min_cutoff: float
_beta: float
_d_cutoff: float
def __init__(
self,
keypoints: Float[Array, "J 3"],
timestamp: datetime,
min_cutoff: float = 1.0,
beta: float = 0.0,
d_cutoff: float = 1.0,
):
"""
Initialize the One Euro Filter.
Args:
keypoints: Initial keypoints positions
timestamp: Initial timestamp
min_cutoff: Minimum cutoff frequency (lower = more smoothing)
beta: Speed coefficient (higher = less lag during fast movements)
d_cutoff: Cutoff frequency for the derivative filter
"""
self._last_timestamp = timestamp
# Filter parameters
self._min_cutoff = min_cutoff
self._beta = beta
self._d_cutoff = d_cutoff
# Filter state
self._x_filtered = keypoints # Position filter state
self._dx_filtered = None # Initially no velocity estimate
def _smoothing_factor(
self, cutoff: Union[float, Float[Array, "J"]], dt: float
) -> Union[float, Float[Array, "J"]]:
"""Calculate the smoothing factor for the low-pass filter."""
r = 2 * jnp.pi * cutoff * dt
return r / (r + 1)
def _exponential_smoothing(
self,
a: Union[float, Float[Array, "J"]],
x: Float[Array, "J 3"],
x_prev: Float[Array, "J 3"],
) -> Float[Array, "J 3"]:
"""Apply exponential smoothing to the input."""
return a * x + (1 - a) * x_prev
def predict(self, timestamp: datetime) -> TrackingPrediction:
"""
Predict keypoints position at a given timestamp.
Args:
timestamp: Timestamp for prediction
Returns:
TrackingPrediction with velocity and keypoints
"""
dt = (timestamp - self._last_timestamp).total_seconds()
if self._dx_filtered is None:
return TrackingPrediction(
velocity=None,
keypoints=self._x_filtered,
)
else:
predicted_keypoints = self._x_filtered + self._dx_filtered * dt
return TrackingPrediction(
velocity=self._dx_filtered,
keypoints=predicted_keypoints,
)
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
"""
Update the filter with new measurements.
Args:
keypoints: New keypoint measurements
timestamp: Timestamp of the measurements
"""
dt = (timestamp - self._last_timestamp).total_seconds()
if dt <= 0:
raise ValueError("invalid timestamp")
# Compute velocity from current input and filtered state
dx = (
(keypoints - self._x_filtered) / dt if dt > 0 else jnp.zeros_like(keypoints)
)
# Determine cutoff frequency based on movement speed
cutoff = self._min_cutoff + self._beta * jnp.linalg.norm(
dx, axis=-1, keepdims=True
)
# Apply low-pass filter to velocity
a_d = self._smoothing_factor(self._d_cutoff, dt)
self._dx_filtered = self._exponential_smoothing(
a_d,
dx,
(
jnp.zeros_like(keypoints)
if self._dx_filtered is None
else self._dx_filtered
),
)
# Apply low-pass filter to position with adaptive cutoff
a_cutoff = self._smoothing_factor(
jnp.asarray(cutoff), dt
) # Convert cutoff to scalar if needed
self._x_filtered = self._exponential_smoothing(
a_cutoff, keypoints, self._x_filtered
)
# Update timestamp
self._last_timestamp = timestamp
def get(self) -> TrackingPrediction:
"""
Get the current state of the filter.
Returns:
TrackingPrediction with velocity and keypoints
"""
return TrackingPrediction(
velocity=self._dx_filtered,
keypoints=self._x_filtered,
)
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
@dataclass(frozen=True) @dataclass(frozen=True)
class TrackingState: class TrackingState: