forked from HQU-gxy/CVTH3PE
refactor: Clean up and enhance LeastMeanSquareVelocityFilter implementation
- Removed unused `reset` methods from `GenericVelocityFilter` and `LastDifferenceVelocityFilter` classes to streamline the code. - Added a static method `from_tracking` to `LeastMeanSquareVelocityFilter` for creating instances from a `Tracking` object. - Implemented robust error handling in the `predict` and `get` methods to ensure proper functioning with historical detections. - Enhanced the `update` method to utilize least squares for velocity estimation, improving accuracy in tracking predictions. - Updated class documentation to reflect changes and clarify method purposes.
This commit is contained in:
@ -67,16 +67,6 @@ class GenericVelocityFilter(Protocol):
|
|||||||
"""
|
"""
|
||||||
... # pylint: disable=unnecessary-ellipsis
|
... # pylint: disable=unnecessary-ellipsis
|
||||||
|
|
||||||
def reset(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
|
|
||||||
"""
|
|
||||||
reset the filter state with new keypoints
|
|
||||||
|
|
||||||
Args:
|
|
||||||
keypoints: new keypoints
|
|
||||||
timestamp: timestamp of the reset
|
|
||||||
"""
|
|
||||||
... # pylint: disable=unnecessary-ellipsis
|
|
||||||
|
|
||||||
|
|
||||||
class LastDifferenceVelocityFilter(GenericVelocityFilter):
|
class LastDifferenceVelocityFilter(GenericVelocityFilter):
|
||||||
"""
|
"""
|
||||||
@ -122,13 +112,12 @@ class LastDifferenceVelocityFilter(GenericVelocityFilter):
|
|||||||
keypoints=self._last_keypoints,
|
keypoints=self._last_keypoints,
|
||||||
)
|
)
|
||||||
|
|
||||||
def reset(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
|
|
||||||
self._last_keypoints = keypoints
|
|
||||||
self._last_timestamp = timestamp
|
|
||||||
self._last_velocity = None
|
|
||||||
|
|
||||||
|
|
||||||
class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
|
class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
|
||||||
|
"""
|
||||||
|
a velocity filter that uses the least mean square method to estimate the velocity
|
||||||
|
"""
|
||||||
|
|
||||||
_get_historical_detections: Callable[[], Sequence[Detection]]
|
_get_historical_detections: Callable[[], Sequence[Detection]]
|
||||||
"""
|
"""
|
||||||
get the current historical detections, assuming the detections are sorted by
|
get the current historical detections, assuming the detections are sorted by
|
||||||
@ -137,11 +126,56 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
|
|||||||
"""
|
"""
|
||||||
_velocity: Optional[Float[Array, "J 3"]] = None
|
_velocity: Optional[Float[Array, "J 3"]] = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_tracking(tracking: "Tracking") -> "LeastMeanSquareVelocityFilter":
|
||||||
|
"""
|
||||||
|
create a LeastMeanSquareVelocityFilter from a Tracking object
|
||||||
|
"""
|
||||||
|
velocity = tracking.velocity_filter.get()["velocity"]
|
||||||
|
if jnp.all(velocity == jnp.zeros_like(velocity)):
|
||||||
|
return LeastMeanSquareVelocityFilter(
|
||||||
|
get_historical_detections=lambda: tracking.historical_detections
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
f = LeastMeanSquareVelocityFilter(
|
||||||
|
get_historical_detections=lambda: tracking.historical_detections
|
||||||
|
)
|
||||||
|
# pylint: disable-next=protected-access
|
||||||
|
f._velocity = velocity
|
||||||
|
return f
|
||||||
|
|
||||||
def __init__(self, get_historical_detections: Callable[[], Sequence[Detection]]):
|
def __init__(self, get_historical_detections: Callable[[], Sequence[Detection]]):
|
||||||
self._get_historical_detections = get_historical_detections
|
self._get_historical_detections = get_historical_detections
|
||||||
|
self._velocity = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def velocity(self) -> Float[Array, "J 3"]:
|
||||||
|
if self._velocity is None:
|
||||||
|
raise ValueError("Velocity not initialized")
|
||||||
|
return self._velocity
|
||||||
|
|
||||||
def predict(self, timestamp: datetime) -> TrackingPrediction:
|
def predict(self, timestamp: datetime) -> TrackingPrediction:
|
||||||
raise NotImplementedError
|
historical_detections = self._get_historical_detections()
|
||||||
|
if not historical_detections:
|
||||||
|
raise ValueError("No historical detections available for prediction")
|
||||||
|
|
||||||
|
# Use the latest historical detection
|
||||||
|
latest_detection = historical_detections[-1]
|
||||||
|
latest_keypoints = latest_detection.keypoints
|
||||||
|
latest_timestamp = latest_detection.timestamp
|
||||||
|
|
||||||
|
delta_t_s = (timestamp - latest_timestamp).total_seconds()
|
||||||
|
|
||||||
|
if self._velocity is None:
|
||||||
|
return TrackingPrediction(
|
||||||
|
velocity=jnp.zeros_like(latest_keypoints), keypoints=latest_keypoints
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Linear motion model: ẋt = xt' + Vt' · (t - t')
|
||||||
|
predicted_keypoints = latest_keypoints + self._velocity * delta_t_s
|
||||||
|
return TrackingPrediction(
|
||||||
|
velocity=self._velocity, keypoints=predicted_keypoints
|
||||||
|
)
|
||||||
|
|
||||||
@jaxtyped(typechecker=beartype)
|
@jaxtyped(typechecker=beartype)
|
||||||
def _update(
|
def _update(
|
||||||
@ -152,28 +186,79 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
|
|||||||
"""
|
"""
|
||||||
update measurements with least mean square method
|
update measurements with least mean square method
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
if keypoints.shape[0] < 2:
|
||||||
|
raise ValueError("Not enough measurements to estimate velocity")
|
||||||
|
|
||||||
|
# Using least squares to fit a linear model for each joint and dimension
|
||||||
|
# X = timestamps, y = keypoints
|
||||||
|
# For each joint and each dimension, we solve for velocity
|
||||||
|
|
||||||
|
n_samples = timestamps.shape[0]
|
||||||
|
n_joints = keypoints.shape[1]
|
||||||
|
|
||||||
|
# Create design matrix for linear regression
|
||||||
|
# [t, 1] for each timestamp
|
||||||
|
X = jnp.column_stack([timestamps, jnp.ones(n_samples)])
|
||||||
|
|
||||||
|
# Reshape keypoints to solve for all joints and dimensions at once
|
||||||
|
# From [N, J, 3] to [N, J*3]
|
||||||
|
keypoints_reshaped = keypoints.reshape(n_samples, -1)
|
||||||
|
|
||||||
|
# Use JAX's lstsq to solve the least squares problem
|
||||||
|
# This is more numerically stable than manually computing pseudoinverse
|
||||||
|
coefficients, _, _, _ = jnp.linalg.lstsq(X, keypoints_reshaped, rcond=None)
|
||||||
|
|
||||||
|
# Coefficients shape is [2, J*3]
|
||||||
|
# First row: velocities, Second row: intercepts
|
||||||
|
velocities = coefficients[0].reshape(n_joints, 3)
|
||||||
|
|
||||||
|
# Update velocity
|
||||||
|
self._velocity = velocities
|
||||||
|
|
||||||
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
|
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
|
||||||
historical_detections = self._get_historical_detections()
|
historical_detections = self._get_historical_detections()
|
||||||
|
|
||||||
|
if not historical_detections:
|
||||||
|
self._velocity = jnp.zeros_like(keypoints)
|
||||||
|
return
|
||||||
|
|
||||||
t_0 = min(d.timestamp for d in historical_detections)
|
t_0 = min(d.timestamp for d in historical_detections)
|
||||||
detections = jnp.array(
|
|
||||||
chain((d.keypoints for d in historical_detections), (keypoints,))
|
all_keypoints = jnp.array(
|
||||||
|
list(chain((d.keypoints for d in historical_detections), (keypoints,)))
|
||||||
)
|
)
|
||||||
# timestamps relative to t_0 (the oldest detection timestamp)
|
|
||||||
timestamps = jnp.array(
|
# Timestamps relative to t_0 (the oldest detection timestamp)
|
||||||
chain(
|
all_timestamps = jnp.array(
|
||||||
((d.timestamp - t_0).total_seconds() for d in historical_detections),
|
list(
|
||||||
((timestamp - t_0).total_seconds(),),
|
chain(
|
||||||
|
(
|
||||||
|
(d.timestamp - t_0).total_seconds()
|
||||||
|
for d in historical_detections
|
||||||
|
),
|
||||||
|
((timestamp - t_0).total_seconds(),),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
raise NotImplementedError
|
|
||||||
|
self._update(all_keypoints, all_timestamps)
|
||||||
|
|
||||||
def get(self) -> TrackingPrediction:
|
def get(self) -> TrackingPrediction:
|
||||||
raise NotImplementedError
|
historical_detections = self._get_historical_detections()
|
||||||
|
if not historical_detections:
|
||||||
|
raise ValueError("No historical detections available")
|
||||||
|
|
||||||
def reset(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
|
latest_detection = historical_detections[-1]
|
||||||
raise NotImplementedError
|
latest_keypoints = latest_detection.keypoints
|
||||||
|
|
||||||
|
if self._velocity is None:
|
||||||
|
return TrackingPrediction(
|
||||||
|
velocity=jnp.zeros_like(latest_keypoints), keypoints=latest_keypoints
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return TrackingPrediction(
|
||||||
|
velocity=self._velocity, keypoints=latest_keypoints
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@jaxtyped(typechecker=beartype)
|
@jaxtyped(typechecker=beartype)
|
||||||
|
|||||||
Reference in New Issue
Block a user