forked from HQU-gxy/CVTH3PE
refactor: Enhance OneEuroFilter with type hints and error handling improvements
- Added overloads for the `_smoothing_factor` method to improve type hinting for different input types. - Enhanced error handling in the timestamp validation to provide clearer feedback when an invalid timestamp is encountered. - Streamlined the calculation of the filtered velocity by simplifying the logic in the `update` method. - Improved code organization with additional type annotations for better clarity and maintainability.
This commit is contained in:
@ -317,6 +317,15 @@ class OneEuroFilter(GenericVelocityFilter):
|
||||
self._x_filtered = keypoints # Position filter state
|
||||
self._dx_filtered = None # Initially no velocity estimate
|
||||
|
||||
@overload
|
||||
def _smoothing_factor(self, cutoff: float, dt: float) -> float: ...
|
||||
|
||||
@overload
|
||||
def _smoothing_factor(
|
||||
self, cutoff: Float[Array, "J"], dt: float
|
||||
) -> Float[Array, "J"]: ...
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def _smoothing_factor(
|
||||
self, cutoff: Union[float, Float[Array, "J"]], dt: float
|
||||
) -> Union[float, Float[Array, "J"]]:
|
||||
@ -324,6 +333,7 @@ class OneEuroFilter(GenericVelocityFilter):
|
||||
r = 2 * jnp.pi * cutoff * dt
|
||||
return r / (r + 1)
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def _exponential_smoothing(
|
||||
self,
|
||||
a: Union[float, Float[Array, "J"]],
|
||||
@ -367,13 +377,12 @@ class OneEuroFilter(GenericVelocityFilter):
|
||||
"""
|
||||
dt = (timestamp - self._last_timestamp).total_seconds()
|
||||
if dt <= 0:
|
||||
raise ValueError("invalid timestamp")
|
||||
|
||||
# Compute velocity from current input and filtered state
|
||||
dx = (
|
||||
(keypoints - self._x_filtered) / dt if dt > 0 else jnp.zeros_like(keypoints)
|
||||
raise ValueError(
|
||||
f"new timestamp is not greater than the last timestamp; expecting: {timestamp} > {self._last_timestamp}"
|
||||
)
|
||||
|
||||
dx = (keypoints - self._x_filtered) / dt
|
||||
|
||||
# Determine cutoff frequency based on movement speed
|
||||
cutoff = self._min_cutoff + self._beta * jnp.linalg.norm(
|
||||
dx, axis=-1, keepdims=True
|
||||
@ -392,9 +401,7 @@ class OneEuroFilter(GenericVelocityFilter):
|
||||
)
|
||||
|
||||
# Apply low-pass filter to position with adaptive cutoff
|
||||
a_cutoff = self._smoothing_factor(
|
||||
jnp.asarray(cutoff), dt
|
||||
) # Convert cutoff to scalar if needed
|
||||
a_cutoff = self._smoothing_factor(jnp.asarray(cutoff), dt)
|
||||
self._x_filtered = self._exponential_smoothing(
|
||||
a_cutoff, keypoints, self._x_filtered
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user