forked from HQU-gxy/CVTH3PE
feat: Add OneEuroFilter for adaptive keypoint smoothing
- Introduced the `OneEuroFilter` class to implement an adaptive low-pass filter for smoothing keypoint data, enhancing tracking stability during varying movement speeds. - Implemented methods for initialization, prediction, and updating of keypoints, allowing for dynamic adjustment of smoothing based on movement. - Added detailed documentation and type hints to clarify the filter's functionality and parameters. - Improved the handling of timestamps and filtering logic to ensure accurate predictions and updates.
This commit is contained in:
@ -15,6 +15,7 @@ from typing import (
|
|||||||
TypeVar,
|
TypeVar,
|
||||||
cast,
|
cast,
|
||||||
overload,
|
overload,
|
||||||
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
@ -269,6 +270,151 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
|
|||||||
return TrackingPrediction(velocity=self._velocity, keypoints=latest_3d_pose)
|
return TrackingPrediction(velocity=self._velocity, keypoints=latest_3d_pose)
|
||||||
|
|
||||||
|
|
||||||
|
class OneEuroFilter(GenericVelocityFilter):
|
||||||
|
"""
|
||||||
|
Implementation of the 1€ filter (One Euro Filter) for smoothing keypoint data.
|
||||||
|
|
||||||
|
The 1€ filter is an adaptive low-pass filter that adjusts its cutoff frequency
|
||||||
|
based on movement speed to reduce jitter during slow movements while maintaining
|
||||||
|
responsiveness during fast movements.
|
||||||
|
|
||||||
|
Reference: https://cristal.univ-lille.fr/~casiez/1euro/
|
||||||
|
"""
|
||||||
|
|
||||||
|
_x_filtered: Float[Array, "J 3"]
|
||||||
|
_dx_filtered: Optional[Float[Array, "J 3"]] = None
|
||||||
|
_last_timestamp: datetime
|
||||||
|
_min_cutoff: float
|
||||||
|
_beta: float
|
||||||
|
_d_cutoff: float
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
keypoints: Float[Array, "J 3"],
|
||||||
|
timestamp: datetime,
|
||||||
|
min_cutoff: float = 1.0,
|
||||||
|
beta: float = 0.0,
|
||||||
|
d_cutoff: float = 1.0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the One Euro Filter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keypoints: Initial keypoints positions
|
||||||
|
timestamp: Initial timestamp
|
||||||
|
min_cutoff: Minimum cutoff frequency (lower = more smoothing)
|
||||||
|
beta: Speed coefficient (higher = less lag during fast movements)
|
||||||
|
d_cutoff: Cutoff frequency for the derivative filter
|
||||||
|
"""
|
||||||
|
self._last_timestamp = timestamp
|
||||||
|
|
||||||
|
# Filter parameters
|
||||||
|
self._min_cutoff = min_cutoff
|
||||||
|
self._beta = beta
|
||||||
|
self._d_cutoff = d_cutoff
|
||||||
|
|
||||||
|
# Filter state
|
||||||
|
self._x_filtered = keypoints # Position filter state
|
||||||
|
self._dx_filtered = None # Initially no velocity estimate
|
||||||
|
|
||||||
|
def _smoothing_factor(
|
||||||
|
self, cutoff: Union[float, Float[Array, "J"]], dt: float
|
||||||
|
) -> Union[float, Float[Array, "J"]]:
|
||||||
|
"""Calculate the smoothing factor for the low-pass filter."""
|
||||||
|
r = 2 * jnp.pi * cutoff * dt
|
||||||
|
return r / (r + 1)
|
||||||
|
|
||||||
|
def _exponential_smoothing(
|
||||||
|
self,
|
||||||
|
a: Union[float, Float[Array, "J"]],
|
||||||
|
x: Float[Array, "J 3"],
|
||||||
|
x_prev: Float[Array, "J 3"],
|
||||||
|
) -> Float[Array, "J 3"]:
|
||||||
|
"""Apply exponential smoothing to the input."""
|
||||||
|
return a * x + (1 - a) * x_prev
|
||||||
|
|
||||||
|
def predict(self, timestamp: datetime) -> TrackingPrediction:
|
||||||
|
"""
|
||||||
|
Predict keypoints position at a given timestamp.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timestamp: Timestamp for prediction
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TrackingPrediction with velocity and keypoints
|
||||||
|
"""
|
||||||
|
dt = (timestamp - self._last_timestamp).total_seconds()
|
||||||
|
|
||||||
|
if self._dx_filtered is None:
|
||||||
|
return TrackingPrediction(
|
||||||
|
velocity=None,
|
||||||
|
keypoints=self._x_filtered,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
predicted_keypoints = self._x_filtered + self._dx_filtered * dt
|
||||||
|
return TrackingPrediction(
|
||||||
|
velocity=self._dx_filtered,
|
||||||
|
keypoints=predicted_keypoints,
|
||||||
|
)
|
||||||
|
|
||||||
|
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
|
||||||
|
"""
|
||||||
|
Update the filter with new measurements.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keypoints: New keypoint measurements
|
||||||
|
timestamp: Timestamp of the measurements
|
||||||
|
"""
|
||||||
|
dt = (timestamp - self._last_timestamp).total_seconds()
|
||||||
|
if dt <= 0:
|
||||||
|
raise ValueError("invalid timestamp")
|
||||||
|
|
||||||
|
# Compute velocity from current input and filtered state
|
||||||
|
dx = (
|
||||||
|
(keypoints - self._x_filtered) / dt if dt > 0 else jnp.zeros_like(keypoints)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine cutoff frequency based on movement speed
|
||||||
|
cutoff = self._min_cutoff + self._beta * jnp.linalg.norm(
|
||||||
|
dx, axis=-1, keepdims=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply low-pass filter to velocity
|
||||||
|
a_d = self._smoothing_factor(self._d_cutoff, dt)
|
||||||
|
self._dx_filtered = self._exponential_smoothing(
|
||||||
|
a_d,
|
||||||
|
dx,
|
||||||
|
(
|
||||||
|
jnp.zeros_like(keypoints)
|
||||||
|
if self._dx_filtered is None
|
||||||
|
else self._dx_filtered
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply low-pass filter to position with adaptive cutoff
|
||||||
|
a_cutoff = self._smoothing_factor(
|
||||||
|
jnp.asarray(cutoff), dt
|
||||||
|
) # Convert cutoff to scalar if needed
|
||||||
|
self._x_filtered = self._exponential_smoothing(
|
||||||
|
a_cutoff, keypoints, self._x_filtered
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update timestamp
|
||||||
|
self._last_timestamp = timestamp
|
||||||
|
|
||||||
|
def get(self) -> TrackingPrediction:
|
||||||
|
"""
|
||||||
|
Get the current state of the filter.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TrackingPrediction with velocity and keypoints
|
||||||
|
"""
|
||||||
|
return TrackingPrediction(
|
||||||
|
velocity=self._dx_filtered,
|
||||||
|
keypoints=self._x_filtered,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@jaxtyped(typechecker=beartype)
|
@jaxtyped(typechecker=beartype)
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class TrackingState:
|
class TrackingState:
|
||||||
|
|||||||
Reference in New Issue
Block a user