refactor: Revamp LeastMeanSquareVelocityFilter to utilize historical 3D poses

- Replaced the historical detections mechanism with deques for managing historical 3D poses and timestamps, enhancing performance and memory efficiency.
- Updated the constructor to accept historical data directly, ensuring proper initialization and sorting of poses and timestamps.
- Refined the `predict` and `update` methods to work with the new data structure, improving clarity and functionality.
- Enhanced error handling to ensure robustness when no historical data is available for predictions.
This commit is contained in:
2025-05-03 14:31:59 +08:00
parent c31cc4e7bf
commit d2c1c8d624

View File

@ -1,4 +1,5 @@
import weakref
from collections import deque
from dataclasses import dataclass
from datetime import datetime, timedelta
from itertools import chain
@ -145,73 +146,55 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
a velocity filter that uses the least mean square method to estimate the velocity
"""
_get_historical_detections: Callable[[], Sequence[Detection]]
"""
get the current historical detections, assuming the detections are sorted by
timestamp incrementally (i.e. index 0 is the oldest detection, index -1 is
the newest detection)
"""
_historical_3d_poses: deque[Float[Array, "J 3"]]
_historical_timestamps: deque[datetime]
_velocity: Optional[Float[Array, "J 3"]] = None
_max_samples: int
def __init__(self, get_historical_detections: Callable[[], Sequence[Detection]]):
self._get_historical_detections = get_historical_detections
self._velocity = None
@staticmethod
def from_tracking(tracking: "Tracking") -> "LeastMeanSquareVelocityFilter":
"""
create a LeastMeanSquareVelocityFilter from a Tracking object
Note that this function is using a weak reference to the tracking object,
so that the tracking object can be garbage collected if there are no other
references to it.
"""
# Create a weak reference to avoid circular references
# https://docs.python.org/3/library/weakref.html
tracking_ref = weakref.ref(tracking)
# Create a getter function that uses the weak reference
def get_historical_detections() -> Sequence[Detection]:
tr = tracking_ref()
if tr is None:
return [] # Return empty list if tracking has been garbage collected
return tr.state.historical_detections
velocity = tracking.velocity_filter.get()["velocity"]
if velocity is None:
return LeastMeanSquareVelocityFilter(
get_historical_detections=get_historical_detections
)
def __init__(
self,
historical_3d_poses: Sequence[Float[Array, "J 3"]],
historical_timestamps: Sequence[datetime],
max_samples: int = 10,
):
assert len(historical_3d_poses) == len(historical_timestamps)
temp = zip(historical_3d_poses, historical_timestamps)
temp_sorted = sorted(temp, key=lambda x: x[1])
self._historical_3d_poses = deque(
map(lambda x: x[0], temp_sorted), maxlen=max_samples
)
self._historical_timestamps = deque(
map(lambda x: x[1], temp_sorted), maxlen=max_samples
)
self._max_samples = max_samples
if len(self._historical_3d_poses) < 2:
self._velocity = None
else:
f = LeastMeanSquareVelocityFilter(
get_historical_detections=get_historical_detections
self._update(
jnp.array(self._historical_3d_poses),
jnp.array(self._historical_timestamps),
)
# pylint: disable-next=protected-access
f._velocity = velocity
return f
def predict(self, timestamp: datetime) -> TrackingPrediction:
historical_detections = self._get_historical_detections()
if not historical_detections:
raise ValueError("No historical detections available for prediction")
if not self._historical_3d_poses:
raise ValueError("No historical 3D poses available for prediction")
# Use the latest historical detection
latest_detection = historical_detections[-1]
latest_keypoints = latest_detection.keypoints
latest_timestamp = latest_detection.timestamp
# use the latest historical detection
latest_3d_pose = self._historical_3d_poses[-1]
latest_timestamp = self._historical_timestamps[-1]
delta_t_s = (timestamp - latest_timestamp).total_seconds()
if self._velocity is None:
return TrackingPrediction(
velocity=None,
keypoints=latest_keypoints,
keypoints=latest_3d_pose,
)
else:
# Linear motion model: ẋt = xt' + Vt' · (t - t')
predicted_keypoints = latest_keypoints + self._velocity * delta_t_s
predicted_3d_pose = latest_3d_pose + self._velocity * delta_t_s
return TrackingPrediction(
velocity=self._velocity, keypoints=predicted_keypoints
velocity=self._velocity, keypoints=predicted_3d_pose
)
@jaxtyped(typechecker=beartype)
@ -253,47 +236,37 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
self._velocity = velocities
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
historical_detections = self._get_historical_detections()
last_timestamp = self._historical_timestamps[-1]
assert last_timestamp <= timestamp
if not historical_detections:
self._velocity = jnp.zeros_like(keypoints)
return
# deque would manage the maxlen automatically
self._historical_3d_poses.append(keypoints)
self._historical_timestamps.append(timestamp)
t_0 = min(d.timestamp for d in historical_detections)
t_0 = self._historical_timestamps[0]
all_keypoints = jnp.array(self._historical_3d_poses)
all_keypoints = jnp.array(
list(chain((d.keypoints for d in historical_detections), (keypoints,)))
)
def timestamp_to_seconds(timestamp: datetime) -> float:
assert t_0 <= timestamp
return (timestamp - t_0).total_seconds()
# Timestamps relative to t_0 (the oldest detection timestamp)
# timestamps relative to t_0 (the oldest detection timestamp)
all_timestamps = jnp.array(
list(
chain(
(
(d.timestamp - t_0).total_seconds()
for d in historical_detections
),
((timestamp - t_0).total_seconds(),),
)
)
map(timestamp_to_seconds, self._historical_timestamps)
)
self._update(all_keypoints, all_timestamps)
def get(self) -> TrackingPrediction:
historical_detections = self._get_historical_detections()
if not historical_detections:
raise ValueError("No historical detections available")
if not self._historical_3d_poses:
raise ValueError("No historical 3D poses available")
latest_detection = historical_detections[-1]
latest_keypoints = latest_detection.keypoints
latest_3d_pose = self._historical_3d_poses[-1]
if self._velocity is None:
return TrackingPrediction(velocity=None, keypoints=latest_keypoints)
return TrackingPrediction(velocity=None, keypoints=latest_3d_pose)
else:
return TrackingPrediction(
velocity=self._velocity, keypoints=latest_keypoints
)
return TrackingPrediction(velocity=self._velocity, keypoints=latest_3d_pose)
@jaxtyped(typechecker=beartype)
@ -393,7 +366,7 @@ class Tracking:
"""
# pylint: disable-next=unsubscriptable-object
if (vel := self.velocity_filter.get()["velocity"]) is None:
raise ValueError("Velocity is not available")
return jnp.zeros_like(self.state.keypoints)
else:
return vel