refactor: Clean up and enhance LeastMeanSquareVelocityFilter implementation

- Removed unused `reset` methods from `GenericVelocityFilter` and `LastDifferenceVelocityFilter` classes to streamline the code.
- Added a static method `from_tracking` to `LeastMeanSquareVelocityFilter` for creating instances from a `Tracking` object.
- Implemented robust error handling in the `predict` and `get` methods to ensure proper functioning with historical detections.
- Enhanced the `update` method to utilize least squares for velocity estimation, improving accuracy in tracking predictions.
- Updated class documentation to reflect changes and clarify method purposes.
This commit is contained in:
2025-05-02 12:06:05 +08:00
parent 4e78165f12
commit 46b8518a10

View File

@ -67,16 +67,6 @@ class GenericVelocityFilter(Protocol):
""" """
... # pylint: disable=unnecessary-ellipsis ... # pylint: disable=unnecessary-ellipsis
def reset(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
"""
reset the filter state with new keypoints
Args:
keypoints: new keypoints
timestamp: timestamp of the reset
"""
... # pylint: disable=unnecessary-ellipsis
class LastDifferenceVelocityFilter(GenericVelocityFilter): class LastDifferenceVelocityFilter(GenericVelocityFilter):
""" """
@ -122,13 +112,12 @@ class LastDifferenceVelocityFilter(GenericVelocityFilter):
keypoints=self._last_keypoints, keypoints=self._last_keypoints,
) )
def reset(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
self._last_keypoints = keypoints
self._last_timestamp = timestamp
self._last_velocity = None
class LeastMeanSquareVelocityFilter(GenericVelocityFilter): class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
"""
a velocity filter that uses the least mean square method to estimate the velocity
"""
_get_historical_detections: Callable[[], Sequence[Detection]] _get_historical_detections: Callable[[], Sequence[Detection]]
""" """
get the current historical detections, assuming the detections are sorted by get the current historical detections, assuming the detections are sorted by
@ -137,11 +126,56 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
""" """
_velocity: Optional[Float[Array, "J 3"]] = None _velocity: Optional[Float[Array, "J 3"]] = None
@staticmethod
def from_tracking(tracking: "Tracking") -> "LeastMeanSquareVelocityFilter":
"""
create a LeastMeanSquareVelocityFilter from a Tracking object
"""
velocity = tracking.velocity_filter.get()["velocity"]
if jnp.all(velocity == jnp.zeros_like(velocity)):
return LeastMeanSquareVelocityFilter(
get_historical_detections=lambda: tracking.historical_detections
)
else:
f = LeastMeanSquareVelocityFilter(
get_historical_detections=lambda: tracking.historical_detections
)
# pylint: disable-next=protected-access
f._velocity = velocity
return f
def __init__(self, get_historical_detections: Callable[[], Sequence[Detection]]): def __init__(self, get_historical_detections: Callable[[], Sequence[Detection]]):
self._get_historical_detections = get_historical_detections self._get_historical_detections = get_historical_detections
self._velocity = None
@property
def velocity(self) -> Float[Array, "J 3"]:
if self._velocity is None:
raise ValueError("Velocity not initialized")
return self._velocity
def predict(self, timestamp: datetime) -> TrackingPrediction: def predict(self, timestamp: datetime) -> TrackingPrediction:
raise NotImplementedError historical_detections = self._get_historical_detections()
if not historical_detections:
raise ValueError("No historical detections available for prediction")
# Use the latest historical detection
latest_detection = historical_detections[-1]
latest_keypoints = latest_detection.keypoints
latest_timestamp = latest_detection.timestamp
delta_t_s = (timestamp - latest_timestamp).total_seconds()
if self._velocity is None:
return TrackingPrediction(
velocity=jnp.zeros_like(latest_keypoints), keypoints=latest_keypoints
)
else:
# Linear motion model: ẋt = xt' + Vt' · (t - t')
predicted_keypoints = latest_keypoints + self._velocity * delta_t_s
return TrackingPrediction(
velocity=self._velocity, keypoints=predicted_keypoints
)
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
def _update( def _update(
@ -152,28 +186,79 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
""" """
update measurements with least mean square method update measurements with least mean square method
""" """
raise NotImplementedError if keypoints.shape[0] < 2:
raise ValueError("Not enough measurements to estimate velocity")
# Using least squares to fit a linear model for each joint and dimension
# X = timestamps, y = keypoints
# For each joint and each dimension, we solve for velocity
n_samples = timestamps.shape[0]
n_joints = keypoints.shape[1]
# Create design matrix for linear regression
# [t, 1] for each timestamp
X = jnp.column_stack([timestamps, jnp.ones(n_samples)])
# Reshape keypoints to solve for all joints and dimensions at once
# From [N, J, 3] to [N, J*3]
keypoints_reshaped = keypoints.reshape(n_samples, -1)
# Use JAX's lstsq to solve the least squares problem
# This is more numerically stable than manually computing pseudoinverse
coefficients, _, _, _ = jnp.linalg.lstsq(X, keypoints_reshaped, rcond=None)
# Coefficients shape is [2, J*3]
# First row: velocities, Second row: intercepts
velocities = coefficients[0].reshape(n_joints, 3)
# Update velocity
self._velocity = velocities
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
historical_detections = self._get_historical_detections() historical_detections = self._get_historical_detections()
if not historical_detections:
self._velocity = jnp.zeros_like(keypoints)
return
t_0 = min(d.timestamp for d in historical_detections) t_0 = min(d.timestamp for d in historical_detections)
detections = jnp.array(
chain((d.keypoints for d in historical_detections), (keypoints,)) all_keypoints = jnp.array(
list(chain((d.keypoints for d in historical_detections), (keypoints,)))
) )
# timestamps relative to t_0 (the oldest detection timestamp)
timestamps = jnp.array( # Timestamps relative to t_0 (the oldest detection timestamp)
all_timestamps = jnp.array(
list(
chain( chain(
((d.timestamp - t_0).total_seconds() for d in historical_detections), (
(d.timestamp - t_0).total_seconds()
for d in historical_detections
),
((timestamp - t_0).total_seconds(),), ((timestamp - t_0).total_seconds(),),
) )
) )
raise NotImplementedError )
self._update(all_keypoints, all_timestamps)
def get(self) -> TrackingPrediction: def get(self) -> TrackingPrediction:
raise NotImplementedError historical_detections = self._get_historical_detections()
if not historical_detections:
raise ValueError("No historical detections available")
def reset(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: latest_detection = historical_detections[-1]
raise NotImplementedError latest_keypoints = latest_detection.keypoints
if self._velocity is None:
return TrackingPrediction(
velocity=jnp.zeros_like(latest_keypoints), keypoints=latest_keypoints
)
else:
return TrackingPrediction(
velocity=self._velocity, keypoints=latest_keypoints
)
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)