1
0
forked from HQU-gxy/CVTH3PE

refactor: Enhance OneEuroFilter with type hints and error handling improvements

- Added overloads for the `_smoothing_factor` method to improve type hinting for different input types.
- Enhanced error handling in the timestamp validation to provide clearer feedback when an invalid timestamp is encountered.
- Streamlined the calculation of the filtered velocity by simplifying the logic in the `update` method.
- Improved code organization with additional type annotations for better clarity and maintainability.
This commit is contained in:
2025-05-03 15:12:06 +08:00
parent 4a5cfde245
commit 20b2cf59f2

View File

@ -317,6 +317,15 @@ class OneEuroFilter(GenericVelocityFilter):
self._x_filtered = keypoints # Position filter state self._x_filtered = keypoints # Position filter state
self._dx_filtered = None # Initially no velocity estimate self._dx_filtered = None # Initially no velocity estimate
@overload
def _smoothing_factor(self, cutoff: float, dt: float) -> float: ...
@overload
def _smoothing_factor(
self, cutoff: Float[Array, "J"], dt: float
) -> Float[Array, "J"]: ...
@jaxtyped(typechecker=beartype)
def _smoothing_factor( def _smoothing_factor(
self, cutoff: Union[float, Float[Array, "J"]], dt: float self, cutoff: Union[float, Float[Array, "J"]], dt: float
) -> Union[float, Float[Array, "J"]]: ) -> Union[float, Float[Array, "J"]]:
@ -324,6 +333,7 @@ class OneEuroFilter(GenericVelocityFilter):
r = 2 * jnp.pi * cutoff * dt r = 2 * jnp.pi * cutoff * dt
return r / (r + 1) return r / (r + 1)
@jaxtyped(typechecker=beartype)
def _exponential_smoothing( def _exponential_smoothing(
self, self,
a: Union[float, Float[Array, "J"]], a: Union[float, Float[Array, "J"]],
@ -367,12 +377,11 @@ class OneEuroFilter(GenericVelocityFilter):
""" """
dt = (timestamp - self._last_timestamp).total_seconds() dt = (timestamp - self._last_timestamp).total_seconds()
if dt <= 0: if dt <= 0:
raise ValueError("invalid timestamp") raise ValueError(
f"new timestamp is not greater than the last timestamp; expecting: {timestamp} > {self._last_timestamp}"
)
# Compute velocity from current input and filtered state dx = (keypoints - self._x_filtered) / dt
dx = (
(keypoints - self._x_filtered) / dt if dt > 0 else jnp.zeros_like(keypoints)
)
# Determine cutoff frequency based on movement speed # Determine cutoff frequency based on movement speed
cutoff = self._min_cutoff + self._beta * jnp.linalg.norm( cutoff = self._min_cutoff + self._beta * jnp.linalg.norm(
@ -392,9 +401,7 @@ class OneEuroFilter(GenericVelocityFilter):
) )
# Apply low-pass filter to position with adaptive cutoff # Apply low-pass filter to position with adaptive cutoff
a_cutoff = self._smoothing_factor( a_cutoff = self._smoothing_factor(jnp.asarray(cutoff), dt)
jnp.asarray(cutoff), dt
) # Convert cutoff to scalar if needed
self._x_filtered = self._exponential_smoothing( self._x_filtered = self._exponential_smoothing(
a_cutoff, keypoints, self._x_filtered a_cutoff, keypoints, self._x_filtered
) )