forked from HQU-gxy/CVTH3PE
feat: Add OneEuroFilter for adaptive keypoint smoothing
- Introduced the `OneEuroFilter` class to implement an adaptive low-pass filter for smoothing keypoint data, enhancing tracking stability during varying movement speeds. - Implemented methods for initialization, prediction, and updating of keypoints, allowing for dynamic adjustment of smoothing based on movement. - Added detailed documentation and type hints to clarify the filter's functionality and parameters. - Improved the handling of timestamps and filtering logic to ensure accurate predictions and updates.
This commit is contained in:
@ -15,6 +15,7 @@ from typing import (
|
||||
TypeVar,
|
||||
cast,
|
||||
overload,
|
||||
Union,
|
||||
)
|
||||
|
||||
import jax.numpy as jnp
|
||||
@ -269,6 +270,151 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
|
||||
return TrackingPrediction(velocity=self._velocity, keypoints=latest_3d_pose)
|
||||
|
||||
|
||||
class OneEuroFilter(GenericVelocityFilter):
|
||||
"""
|
||||
Implementation of the 1€ filter (One Euro Filter) for smoothing keypoint data.
|
||||
|
||||
The 1€ filter is an adaptive low-pass filter that adjusts its cutoff frequency
|
||||
based on movement speed to reduce jitter during slow movements while maintaining
|
||||
responsiveness during fast movements.
|
||||
|
||||
Reference: https://cristal.univ-lille.fr/~casiez/1euro/
|
||||
"""
|
||||
|
||||
_x_filtered: Float[Array, "J 3"]
|
||||
_dx_filtered: Optional[Float[Array, "J 3"]] = None
|
||||
_last_timestamp: datetime
|
||||
_min_cutoff: float
|
||||
_beta: float
|
||||
_d_cutoff: float
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
keypoints: Float[Array, "J 3"],
|
||||
timestamp: datetime,
|
||||
min_cutoff: float = 1.0,
|
||||
beta: float = 0.0,
|
||||
d_cutoff: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Initialize the One Euro Filter.
|
||||
|
||||
Args:
|
||||
keypoints: Initial keypoints positions
|
||||
timestamp: Initial timestamp
|
||||
min_cutoff: Minimum cutoff frequency (lower = more smoothing)
|
||||
beta: Speed coefficient (higher = less lag during fast movements)
|
||||
d_cutoff: Cutoff frequency for the derivative filter
|
||||
"""
|
||||
self._last_timestamp = timestamp
|
||||
|
||||
# Filter parameters
|
||||
self._min_cutoff = min_cutoff
|
||||
self._beta = beta
|
||||
self._d_cutoff = d_cutoff
|
||||
|
||||
# Filter state
|
||||
self._x_filtered = keypoints # Position filter state
|
||||
self._dx_filtered = None # Initially no velocity estimate
|
||||
|
||||
def _smoothing_factor(
|
||||
self, cutoff: Union[float, Float[Array, "J"]], dt: float
|
||||
) -> Union[float, Float[Array, "J"]]:
|
||||
"""Calculate the smoothing factor for the low-pass filter."""
|
||||
r = 2 * jnp.pi * cutoff * dt
|
||||
return r / (r + 1)
|
||||
|
||||
def _exponential_smoothing(
|
||||
self,
|
||||
a: Union[float, Float[Array, "J"]],
|
||||
x: Float[Array, "J 3"],
|
||||
x_prev: Float[Array, "J 3"],
|
||||
) -> Float[Array, "J 3"]:
|
||||
"""Apply exponential smoothing to the input."""
|
||||
return a * x + (1 - a) * x_prev
|
||||
|
||||
def predict(self, timestamp: datetime) -> TrackingPrediction:
|
||||
"""
|
||||
Predict keypoints position at a given timestamp.
|
||||
|
||||
Args:
|
||||
timestamp: Timestamp for prediction
|
||||
|
||||
Returns:
|
||||
TrackingPrediction with velocity and keypoints
|
||||
"""
|
||||
dt = (timestamp - self._last_timestamp).total_seconds()
|
||||
|
||||
if self._dx_filtered is None:
|
||||
return TrackingPrediction(
|
||||
velocity=None,
|
||||
keypoints=self._x_filtered,
|
||||
)
|
||||
else:
|
||||
predicted_keypoints = self._x_filtered + self._dx_filtered * dt
|
||||
return TrackingPrediction(
|
||||
velocity=self._dx_filtered,
|
||||
keypoints=predicted_keypoints,
|
||||
)
|
||||
|
||||
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
|
||||
"""
|
||||
Update the filter with new measurements.
|
||||
|
||||
Args:
|
||||
keypoints: New keypoint measurements
|
||||
timestamp: Timestamp of the measurements
|
||||
"""
|
||||
dt = (timestamp - self._last_timestamp).total_seconds()
|
||||
if dt <= 0:
|
||||
raise ValueError("invalid timestamp")
|
||||
|
||||
# Compute velocity from current input and filtered state
|
||||
dx = (
|
||||
(keypoints - self._x_filtered) / dt if dt > 0 else jnp.zeros_like(keypoints)
|
||||
)
|
||||
|
||||
# Determine cutoff frequency based on movement speed
|
||||
cutoff = self._min_cutoff + self._beta * jnp.linalg.norm(
|
||||
dx, axis=-1, keepdims=True
|
||||
)
|
||||
|
||||
# Apply low-pass filter to velocity
|
||||
a_d = self._smoothing_factor(self._d_cutoff, dt)
|
||||
self._dx_filtered = self._exponential_smoothing(
|
||||
a_d,
|
||||
dx,
|
||||
(
|
||||
jnp.zeros_like(keypoints)
|
||||
if self._dx_filtered is None
|
||||
else self._dx_filtered
|
||||
),
|
||||
)
|
||||
|
||||
# Apply low-pass filter to position with adaptive cutoff
|
||||
a_cutoff = self._smoothing_factor(
|
||||
jnp.asarray(cutoff), dt
|
||||
) # Convert cutoff to scalar if needed
|
||||
self._x_filtered = self._exponential_smoothing(
|
||||
a_cutoff, keypoints, self._x_filtered
|
||||
)
|
||||
|
||||
# Update timestamp
|
||||
self._last_timestamp = timestamp
|
||||
|
||||
def get(self) -> TrackingPrediction:
|
||||
"""
|
||||
Get the current state of the filter.
|
||||
|
||||
Returns:
|
||||
TrackingPrediction with velocity and keypoints
|
||||
"""
|
||||
return TrackingPrediction(
|
||||
velocity=self._dx_filtered,
|
||||
keypoints=self._x_filtered,
|
||||
)
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
@dataclass(frozen=True)
|
||||
class TrackingState:
|
||||
|
||||
Reference in New Issue
Block a user