forked from HQU-gxy/CVTH3PE
refactor: Clean up and enhance LeastMeanSquareVelocityFilter implementation
- Removed unused `reset` methods from `GenericVelocityFilter` and `LastDifferenceVelocityFilter` classes to streamline the code. - Added a static method `from_tracking` to `LeastMeanSquareVelocityFilter` for creating instances from a `Tracking` object. - Implemented robust error handling in the `predict` and `get` methods to ensure proper functioning with historical detections. - Enhanced the `update` method to utilize least squares for velocity estimation, improving accuracy in tracking predictions. - Updated class documentation to reflect changes and clarify method purposes.
This commit is contained in:
@ -67,16 +67,6 @@ class GenericVelocityFilter(Protocol):
|
||||
"""
|
||||
... # pylint: disable=unnecessary-ellipsis
|
||||
|
||||
def reset(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
|
||||
"""
|
||||
reset the filter state with new keypoints
|
||||
|
||||
Args:
|
||||
keypoints: new keypoints
|
||||
timestamp: timestamp of the reset
|
||||
"""
|
||||
... # pylint: disable=unnecessary-ellipsis
|
||||
|
||||
|
||||
class LastDifferenceVelocityFilter(GenericVelocityFilter):
|
||||
"""
|
||||
@ -122,13 +112,12 @@ class LastDifferenceVelocityFilter(GenericVelocityFilter):
|
||||
keypoints=self._last_keypoints,
|
||||
)
|
||||
|
||||
def reset(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
|
||||
self._last_keypoints = keypoints
|
||||
self._last_timestamp = timestamp
|
||||
self._last_velocity = None
|
||||
|
||||
|
||||
class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
|
||||
"""
|
||||
a velocity filter that uses the least mean square method to estimate the velocity
|
||||
"""
|
||||
|
||||
_get_historical_detections: Callable[[], Sequence[Detection]]
|
||||
"""
|
||||
get the current historical detections, assuming the detections are sorted by
|
||||
@ -137,11 +126,56 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
|
||||
"""
|
||||
_velocity: Optional[Float[Array, "J 3"]] = None
|
||||
|
||||
@staticmethod
|
||||
def from_tracking(tracking: "Tracking") -> "LeastMeanSquareVelocityFilter":
|
||||
"""
|
||||
create a LeastMeanSquareVelocityFilter from a Tracking object
|
||||
"""
|
||||
velocity = tracking.velocity_filter.get()["velocity"]
|
||||
if jnp.all(velocity == jnp.zeros_like(velocity)):
|
||||
return LeastMeanSquareVelocityFilter(
|
||||
get_historical_detections=lambda: tracking.historical_detections
|
||||
)
|
||||
else:
|
||||
f = LeastMeanSquareVelocityFilter(
|
||||
get_historical_detections=lambda: tracking.historical_detections
|
||||
)
|
||||
# pylint: disable-next=protected-access
|
||||
f._velocity = velocity
|
||||
return f
|
||||
|
||||
def __init__(self, get_historical_detections: Callable[[], Sequence[Detection]]):
|
||||
self._get_historical_detections = get_historical_detections
|
||||
self._velocity = None
|
||||
|
||||
@property
|
||||
def velocity(self) -> Float[Array, "J 3"]:
|
||||
if self._velocity is None:
|
||||
raise ValueError("Velocity not initialized")
|
||||
return self._velocity
|
||||
|
||||
def predict(self, timestamp: datetime) -> TrackingPrediction:
|
||||
raise NotImplementedError
|
||||
historical_detections = self._get_historical_detections()
|
||||
if not historical_detections:
|
||||
raise ValueError("No historical detections available for prediction")
|
||||
|
||||
# Use the latest historical detection
|
||||
latest_detection = historical_detections[-1]
|
||||
latest_keypoints = latest_detection.keypoints
|
||||
latest_timestamp = latest_detection.timestamp
|
||||
|
||||
delta_t_s = (timestamp - latest_timestamp).total_seconds()
|
||||
|
||||
if self._velocity is None:
|
||||
return TrackingPrediction(
|
||||
velocity=jnp.zeros_like(latest_keypoints), keypoints=latest_keypoints
|
||||
)
|
||||
else:
|
||||
# Linear motion model: ẋt = xt' + Vt' · (t - t')
|
||||
predicted_keypoints = latest_keypoints + self._velocity * delta_t_s
|
||||
return TrackingPrediction(
|
||||
velocity=self._velocity, keypoints=predicted_keypoints
|
||||
)
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def _update(
|
||||
@ -152,28 +186,79 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
|
||||
"""
|
||||
update measurements with least mean square method
|
||||
"""
|
||||
raise NotImplementedError
|
||||
if keypoints.shape[0] < 2:
|
||||
raise ValueError("Not enough measurements to estimate velocity")
|
||||
|
||||
# Using least squares to fit a linear model for each joint and dimension
|
||||
# X = timestamps, y = keypoints
|
||||
# For each joint and each dimension, we solve for velocity
|
||||
|
||||
n_samples = timestamps.shape[0]
|
||||
n_joints = keypoints.shape[1]
|
||||
|
||||
# Create design matrix for linear regression
|
||||
# [t, 1] for each timestamp
|
||||
X = jnp.column_stack([timestamps, jnp.ones(n_samples)])
|
||||
|
||||
# Reshape keypoints to solve for all joints and dimensions at once
|
||||
# From [N, J, 3] to [N, J*3]
|
||||
keypoints_reshaped = keypoints.reshape(n_samples, -1)
|
||||
|
||||
# Use JAX's lstsq to solve the least squares problem
|
||||
# This is more numerically stable than manually computing pseudoinverse
|
||||
coefficients, _, _, _ = jnp.linalg.lstsq(X, keypoints_reshaped, rcond=None)
|
||||
|
||||
# Coefficients shape is [2, J*3]
|
||||
# First row: velocities, Second row: intercepts
|
||||
velocities = coefficients[0].reshape(n_joints, 3)
|
||||
|
||||
# Update velocity
|
||||
self._velocity = velocities
|
||||
|
||||
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
|
||||
historical_detections = self._get_historical_detections()
|
||||
|
||||
if not historical_detections:
|
||||
self._velocity = jnp.zeros_like(keypoints)
|
||||
return
|
||||
|
||||
t_0 = min(d.timestamp for d in historical_detections)
|
||||
detections = jnp.array(
|
||||
chain((d.keypoints for d in historical_detections), (keypoints,))
|
||||
|
||||
all_keypoints = jnp.array(
|
||||
list(chain((d.keypoints for d in historical_detections), (keypoints,)))
|
||||
)
|
||||
# timestamps relative to t_0 (the oldest detection timestamp)
|
||||
timestamps = jnp.array(
|
||||
|
||||
# Timestamps relative to t_0 (the oldest detection timestamp)
|
||||
all_timestamps = jnp.array(
|
||||
list(
|
||||
chain(
|
||||
((d.timestamp - t_0).total_seconds() for d in historical_detections),
|
||||
(
|
||||
(d.timestamp - t_0).total_seconds()
|
||||
for d in historical_detections
|
||||
),
|
||||
((timestamp - t_0).total_seconds(),),
|
||||
)
|
||||
)
|
||||
raise NotImplementedError
|
||||
)
|
||||
|
||||
self._update(all_keypoints, all_timestamps)
|
||||
|
||||
def get(self) -> TrackingPrediction:
|
||||
raise NotImplementedError
|
||||
historical_detections = self._get_historical_detections()
|
||||
if not historical_detections:
|
||||
raise ValueError("No historical detections available")
|
||||
|
||||
def reset(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
|
||||
raise NotImplementedError
|
||||
latest_detection = historical_detections[-1]
|
||||
latest_keypoints = latest_detection.keypoints
|
||||
|
||||
if self._velocity is None:
|
||||
return TrackingPrediction(
|
||||
velocity=jnp.zeros_like(latest_keypoints), keypoints=latest_keypoints
|
||||
)
|
||||
else:
|
||||
return TrackingPrediction(
|
||||
velocity=self._velocity, keypoints=latest_keypoints
|
||||
)
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
|
||||
Reference in New Issue
Block a user