refactor: Revamp LeastMeanSquareVelocityFilter to utilize historical 3D poses
- Replaced the historical detections mechanism with deques for managing historical 3D poses and timestamps, enhancing performance and memory efficiency. - Updated the constructor to accept historical data directly, ensuring proper initialization and sorting of poses and timestamps. - Refined the `predict` and `update` methods to work with the new data structure, improving clarity and functionality. - Enhanced error handling to ensure robustness when no historical data is available for predictions.
This commit is contained in:
@ -1,4 +1,5 @@
|
||||
import weakref
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from itertools import chain
|
||||
@ -145,73 +146,55 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
|
||||
a velocity filter that uses the least mean square method to estimate the velocity
|
||||
"""
|
||||
|
||||
_get_historical_detections: Callable[[], Sequence[Detection]]
|
||||
"""
|
||||
get the current historical detections, assuming the detections are sorted by
|
||||
timestamp incrementally (i.e. index 0 is the oldest detection, index -1 is
|
||||
the newest detection)
|
||||
"""
|
||||
_historical_3d_poses: deque[Float[Array, "J 3"]]
|
||||
_historical_timestamps: deque[datetime]
|
||||
_velocity: Optional[Float[Array, "J 3"]] = None
|
||||
_max_samples: int
|
||||
|
||||
def __init__(self, get_historical_detections: Callable[[], Sequence[Detection]]):
|
||||
self._get_historical_detections = get_historical_detections
|
||||
def __init__(
|
||||
self,
|
||||
historical_3d_poses: Sequence[Float[Array, "J 3"]],
|
||||
historical_timestamps: Sequence[datetime],
|
||||
max_samples: int = 10,
|
||||
):
|
||||
assert len(historical_3d_poses) == len(historical_timestamps)
|
||||
temp = zip(historical_3d_poses, historical_timestamps)
|
||||
temp_sorted = sorted(temp, key=lambda x: x[1])
|
||||
self._historical_3d_poses = deque(
|
||||
map(lambda x: x[0], temp_sorted), maxlen=max_samples
|
||||
)
|
||||
self._historical_timestamps = deque(
|
||||
map(lambda x: x[1], temp_sorted), maxlen=max_samples
|
||||
)
|
||||
self._max_samples = max_samples
|
||||
if len(self._historical_3d_poses) < 2:
|
||||
self._velocity = None
|
||||
|
||||
@staticmethod
|
||||
def from_tracking(tracking: "Tracking") -> "LeastMeanSquareVelocityFilter":
|
||||
"""
|
||||
create a LeastMeanSquareVelocityFilter from a Tracking object
|
||||
|
||||
Note that this function is using a weak reference to the tracking object,
|
||||
so that the tracking object can be garbage collected if there are no other
|
||||
references to it.
|
||||
"""
|
||||
# Create a weak reference to avoid circular references
|
||||
# https://docs.python.org/3/library/weakref.html
|
||||
tracking_ref = weakref.ref(tracking)
|
||||
|
||||
# Create a getter function that uses the weak reference
|
||||
def get_historical_detections() -> Sequence[Detection]:
|
||||
tr = tracking_ref()
|
||||
if tr is None:
|
||||
return [] # Return empty list if tracking has been garbage collected
|
||||
return tr.state.historical_detections
|
||||
|
||||
velocity = tracking.velocity_filter.get()["velocity"]
|
||||
if velocity is None:
|
||||
return LeastMeanSquareVelocityFilter(
|
||||
get_historical_detections=get_historical_detections
|
||||
)
|
||||
else:
|
||||
f = LeastMeanSquareVelocityFilter(
|
||||
get_historical_detections=get_historical_detections
|
||||
self._update(
|
||||
jnp.array(self._historical_3d_poses),
|
||||
jnp.array(self._historical_timestamps),
|
||||
)
|
||||
# pylint: disable-next=protected-access
|
||||
f._velocity = velocity
|
||||
return f
|
||||
|
||||
def predict(self, timestamp: datetime) -> TrackingPrediction:
|
||||
historical_detections = self._get_historical_detections()
|
||||
if not historical_detections:
|
||||
raise ValueError("No historical detections available for prediction")
|
||||
if not self._historical_3d_poses:
|
||||
raise ValueError("No historical 3D poses available for prediction")
|
||||
|
||||
# Use the latest historical detection
|
||||
latest_detection = historical_detections[-1]
|
||||
latest_keypoints = latest_detection.keypoints
|
||||
latest_timestamp = latest_detection.timestamp
|
||||
# use the latest historical detection
|
||||
latest_3d_pose = self._historical_3d_poses[-1]
|
||||
latest_timestamp = self._historical_timestamps[-1]
|
||||
|
||||
delta_t_s = (timestamp - latest_timestamp).total_seconds()
|
||||
|
||||
if self._velocity is None:
|
||||
return TrackingPrediction(
|
||||
velocity=None,
|
||||
keypoints=latest_keypoints,
|
||||
keypoints=latest_3d_pose,
|
||||
)
|
||||
else:
|
||||
# Linear motion model: ẋt = xt' + Vt' · (t - t')
|
||||
predicted_keypoints = latest_keypoints + self._velocity * delta_t_s
|
||||
predicted_3d_pose = latest_3d_pose + self._velocity * delta_t_s
|
||||
return TrackingPrediction(
|
||||
velocity=self._velocity, keypoints=predicted_keypoints
|
||||
velocity=self._velocity, keypoints=predicted_3d_pose
|
||||
)
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
@ -253,47 +236,37 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
|
||||
self._velocity = velocities
|
||||
|
||||
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
|
||||
historical_detections = self._get_historical_detections()
|
||||
last_timestamp = self._historical_timestamps[-1]
|
||||
assert last_timestamp <= timestamp
|
||||
|
||||
if not historical_detections:
|
||||
self._velocity = jnp.zeros_like(keypoints)
|
||||
return
|
||||
# deque would manage the maxlen automatically
|
||||
self._historical_3d_poses.append(keypoints)
|
||||
self._historical_timestamps.append(timestamp)
|
||||
|
||||
t_0 = min(d.timestamp for d in historical_detections)
|
||||
t_0 = self._historical_timestamps[0]
|
||||
all_keypoints = jnp.array(self._historical_3d_poses)
|
||||
|
||||
all_keypoints = jnp.array(
|
||||
list(chain((d.keypoints for d in historical_detections), (keypoints,)))
|
||||
)
|
||||
def timestamp_to_seconds(timestamp: datetime) -> float:
|
||||
assert t_0 <= timestamp
|
||||
return (timestamp - t_0).total_seconds()
|
||||
|
||||
# Timestamps relative to t_0 (the oldest detection timestamp)
|
||||
# timestamps relative to t_0 (the oldest detection timestamp)
|
||||
all_timestamps = jnp.array(
|
||||
list(
|
||||
chain(
|
||||
(
|
||||
(d.timestamp - t_0).total_seconds()
|
||||
for d in historical_detections
|
||||
),
|
||||
((timestamp - t_0).total_seconds(),),
|
||||
)
|
||||
)
|
||||
map(timestamp_to_seconds, self._historical_timestamps)
|
||||
)
|
||||
|
||||
self._update(all_keypoints, all_timestamps)
|
||||
|
||||
def get(self) -> TrackingPrediction:
|
||||
historical_detections = self._get_historical_detections()
|
||||
if not historical_detections:
|
||||
raise ValueError("No historical detections available")
|
||||
if not self._historical_3d_poses:
|
||||
raise ValueError("No historical 3D poses available")
|
||||
|
||||
latest_detection = historical_detections[-1]
|
||||
latest_keypoints = latest_detection.keypoints
|
||||
latest_3d_pose = self._historical_3d_poses[-1]
|
||||
|
||||
if self._velocity is None:
|
||||
return TrackingPrediction(velocity=None, keypoints=latest_keypoints)
|
||||
return TrackingPrediction(velocity=None, keypoints=latest_3d_pose)
|
||||
else:
|
||||
return TrackingPrediction(
|
||||
velocity=self._velocity, keypoints=latest_keypoints
|
||||
)
|
||||
return TrackingPrediction(velocity=self._velocity, keypoints=latest_3d_pose)
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
@ -393,7 +366,7 @@ class Tracking:
|
||||
"""
|
||||
# pylint: disable-next=unsubscriptable-object
|
||||
if (vel := self.velocity_filter.get()["velocity"]) is None:
|
||||
raise ValueError("Velocity is not available")
|
||||
return jnp.zeros_like(self.state.keypoints)
|
||||
else:
|
||||
return vel
|
||||
|
||||
|
||||
Reference in New Issue
Block a user