refactor: Enhance OneEuroFilter with type hints and error handling improvements
- Added overloads for the `_smoothing_factor` method to improve type hinting for different input types. - Enhanced error handling in the timestamp validation to provide clearer feedback when an invalid timestamp is encountered. - Streamlined the calculation of the filtered velocity by simplifying the logic in the `update` method. - Improved code organization with additional type annotations for better clarity and maintainability.
This commit is contained in:
@ -317,6 +317,15 @@ class OneEuroFilter(GenericVelocityFilter):
|
|||||||
self._x_filtered = keypoints # Position filter state
|
self._x_filtered = keypoints # Position filter state
|
||||||
self._dx_filtered = None # Initially no velocity estimate
|
self._dx_filtered = None # Initially no velocity estimate
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def _smoothing_factor(self, cutoff: float, dt: float) -> float: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def _smoothing_factor(
|
||||||
|
self, cutoff: Float[Array, "J"], dt: float
|
||||||
|
) -> Float[Array, "J"]: ...
|
||||||
|
|
||||||
|
@jaxtyped(typechecker=beartype)
|
||||||
def _smoothing_factor(
|
def _smoothing_factor(
|
||||||
self, cutoff: Union[float, Float[Array, "J"]], dt: float
|
self, cutoff: Union[float, Float[Array, "J"]], dt: float
|
||||||
) -> Union[float, Float[Array, "J"]]:
|
) -> Union[float, Float[Array, "J"]]:
|
||||||
@ -324,6 +333,7 @@ class OneEuroFilter(GenericVelocityFilter):
|
|||||||
r = 2 * jnp.pi * cutoff * dt
|
r = 2 * jnp.pi * cutoff * dt
|
||||||
return r / (r + 1)
|
return r / (r + 1)
|
||||||
|
|
||||||
|
@jaxtyped(typechecker=beartype)
|
||||||
def _exponential_smoothing(
|
def _exponential_smoothing(
|
||||||
self,
|
self,
|
||||||
a: Union[float, Float[Array, "J"]],
|
a: Union[float, Float[Array, "J"]],
|
||||||
@ -367,12 +377,11 @@ class OneEuroFilter(GenericVelocityFilter):
|
|||||||
"""
|
"""
|
||||||
dt = (timestamp - self._last_timestamp).total_seconds()
|
dt = (timestamp - self._last_timestamp).total_seconds()
|
||||||
if dt <= 0:
|
if dt <= 0:
|
||||||
raise ValueError("invalid timestamp")
|
raise ValueError(
|
||||||
|
f"new timestamp is not greater than the last timestamp; expecting: {timestamp} > {self._last_timestamp}"
|
||||||
|
)
|
||||||
|
|
||||||
# Compute velocity from current input and filtered state
|
dx = (keypoints - self._x_filtered) / dt
|
||||||
dx = (
|
|
||||||
(keypoints - self._x_filtered) / dt if dt > 0 else jnp.zeros_like(keypoints)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Determine cutoff frequency based on movement speed
|
# Determine cutoff frequency based on movement speed
|
||||||
cutoff = self._min_cutoff + self._beta * jnp.linalg.norm(
|
cutoff = self._min_cutoff + self._beta * jnp.linalg.norm(
|
||||||
@ -392,9 +401,7 @@ class OneEuroFilter(GenericVelocityFilter):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Apply low-pass filter to position with adaptive cutoff
|
# Apply low-pass filter to position with adaptive cutoff
|
||||||
a_cutoff = self._smoothing_factor(
|
a_cutoff = self._smoothing_factor(jnp.asarray(cutoff), dt)
|
||||||
jnp.asarray(cutoff), dt
|
|
||||||
) # Convert cutoff to scalar if needed
|
|
||||||
self._x_filtered = self._exponential_smoothing(
|
self._x_filtered = self._exponential_smoothing(
|
||||||
a_cutoff, keypoints, self._x_filtered
|
a_cutoff, keypoints, self._x_filtered
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user