refactor: Revamp LeastMeanSquareVelocityFilter to utilize historical 3D poses
- Replaced the historical detections mechanism with deques for managing historical 3D poses and timestamps, enhancing performance and memory efficiency. - Updated the constructor to accept historical data directly, ensuring proper initialization and sorting of poses and timestamps. - Refined the `predict` and `update` methods to work with the new data structure, improving clarity and functionality. - Enhanced error handling to ensure robustness when no historical data is available for predictions.
This commit is contained in:
@ -1,4 +1,5 @@
|
|||||||
import weakref
|
import weakref
|
||||||
|
from collections import deque
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
@ -145,73 +146,55 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
|
|||||||
a velocity filter that uses the least mean square method to estimate the velocity
|
a velocity filter that uses the least mean square method to estimate the velocity
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_get_historical_detections: Callable[[], Sequence[Detection]]
|
_historical_3d_poses: deque[Float[Array, "J 3"]]
|
||||||
"""
|
_historical_timestamps: deque[datetime]
|
||||||
get the current historical detections, assuming the detections are sorted by
|
|
||||||
timestamp incrementally (i.e. index 0 is the oldest detection, index -1 is
|
|
||||||
the newest detection)
|
|
||||||
"""
|
|
||||||
_velocity: Optional[Float[Array, "J 3"]] = None
|
_velocity: Optional[Float[Array, "J 3"]] = None
|
||||||
|
_max_samples: int
|
||||||
|
|
||||||
def __init__(self, get_historical_detections: Callable[[], Sequence[Detection]]):
|
def __init__(
|
||||||
self._get_historical_detections = get_historical_detections
|
self,
|
||||||
self._velocity = None
|
historical_3d_poses: Sequence[Float[Array, "J 3"]],
|
||||||
|
historical_timestamps: Sequence[datetime],
|
||||||
@staticmethod
|
max_samples: int = 10,
|
||||||
def from_tracking(tracking: "Tracking") -> "LeastMeanSquareVelocityFilter":
|
):
|
||||||
"""
|
assert len(historical_3d_poses) == len(historical_timestamps)
|
||||||
create a LeastMeanSquareVelocityFilter from a Tracking object
|
temp = zip(historical_3d_poses, historical_timestamps)
|
||||||
|
temp_sorted = sorted(temp, key=lambda x: x[1])
|
||||||
Note that this function is using a weak reference to the tracking object,
|
self._historical_3d_poses = deque(
|
||||||
so that the tracking object can be garbage collected if there are no other
|
map(lambda x: x[0], temp_sorted), maxlen=max_samples
|
||||||
references to it.
|
)
|
||||||
"""
|
self._historical_timestamps = deque(
|
||||||
# Create a weak reference to avoid circular references
|
map(lambda x: x[1], temp_sorted), maxlen=max_samples
|
||||||
# https://docs.python.org/3/library/weakref.html
|
)
|
||||||
tracking_ref = weakref.ref(tracking)
|
self._max_samples = max_samples
|
||||||
|
if len(self._historical_3d_poses) < 2:
|
||||||
# Create a getter function that uses the weak reference
|
self._velocity = None
|
||||||
def get_historical_detections() -> Sequence[Detection]:
|
|
||||||
tr = tracking_ref()
|
|
||||||
if tr is None:
|
|
||||||
return [] # Return empty list if tracking has been garbage collected
|
|
||||||
return tr.state.historical_detections
|
|
||||||
|
|
||||||
velocity = tracking.velocity_filter.get()["velocity"]
|
|
||||||
if velocity is None:
|
|
||||||
return LeastMeanSquareVelocityFilter(
|
|
||||||
get_historical_detections=get_historical_detections
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
f = LeastMeanSquareVelocityFilter(
|
self._update(
|
||||||
get_historical_detections=get_historical_detections
|
jnp.array(self._historical_3d_poses),
|
||||||
|
jnp.array(self._historical_timestamps),
|
||||||
)
|
)
|
||||||
# pylint: disable-next=protected-access
|
|
||||||
f._velocity = velocity
|
|
||||||
return f
|
|
||||||
|
|
||||||
def predict(self, timestamp: datetime) -> TrackingPrediction:
|
def predict(self, timestamp: datetime) -> TrackingPrediction:
|
||||||
historical_detections = self._get_historical_detections()
|
if not self._historical_3d_poses:
|
||||||
if not historical_detections:
|
raise ValueError("No historical 3D poses available for prediction")
|
||||||
raise ValueError("No historical detections available for prediction")
|
|
||||||
|
|
||||||
# Use the latest historical detection
|
# use the latest historical detection
|
||||||
latest_detection = historical_detections[-1]
|
latest_3d_pose = self._historical_3d_poses[-1]
|
||||||
latest_keypoints = latest_detection.keypoints
|
latest_timestamp = self._historical_timestamps[-1]
|
||||||
latest_timestamp = latest_detection.timestamp
|
|
||||||
|
|
||||||
delta_t_s = (timestamp - latest_timestamp).total_seconds()
|
delta_t_s = (timestamp - latest_timestamp).total_seconds()
|
||||||
|
|
||||||
if self._velocity is None:
|
if self._velocity is None:
|
||||||
return TrackingPrediction(
|
return TrackingPrediction(
|
||||||
velocity=None,
|
velocity=None,
|
||||||
keypoints=latest_keypoints,
|
keypoints=latest_3d_pose,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Linear motion model: ẋt = xt' + Vt' · (t - t')
|
# Linear motion model: ẋt = xt' + Vt' · (t - t')
|
||||||
predicted_keypoints = latest_keypoints + self._velocity * delta_t_s
|
predicted_3d_pose = latest_3d_pose + self._velocity * delta_t_s
|
||||||
return TrackingPrediction(
|
return TrackingPrediction(
|
||||||
velocity=self._velocity, keypoints=predicted_keypoints
|
velocity=self._velocity, keypoints=predicted_3d_pose
|
||||||
)
|
)
|
||||||
|
|
||||||
@jaxtyped(typechecker=beartype)
|
@jaxtyped(typechecker=beartype)
|
||||||
@ -253,47 +236,37 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
|
|||||||
self._velocity = velocities
|
self._velocity = velocities
|
||||||
|
|
||||||
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
|
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
|
||||||
historical_detections = self._get_historical_detections()
|
last_timestamp = self._historical_timestamps[-1]
|
||||||
|
assert last_timestamp <= timestamp
|
||||||
|
|
||||||
if not historical_detections:
|
# deque would manage the maxlen automatically
|
||||||
self._velocity = jnp.zeros_like(keypoints)
|
self._historical_3d_poses.append(keypoints)
|
||||||
return
|
self._historical_timestamps.append(timestamp)
|
||||||
|
|
||||||
t_0 = min(d.timestamp for d in historical_detections)
|
t_0 = self._historical_timestamps[0]
|
||||||
|
all_keypoints = jnp.array(self._historical_3d_poses)
|
||||||
|
|
||||||
all_keypoints = jnp.array(
|
def timestamp_to_seconds(timestamp: datetime) -> float:
|
||||||
list(chain((d.keypoints for d in historical_detections), (keypoints,)))
|
assert t_0 <= timestamp
|
||||||
)
|
return (timestamp - t_0).total_seconds()
|
||||||
|
|
||||||
# Timestamps relative to t_0 (the oldest detection timestamp)
|
# timestamps relative to t_0 (the oldest detection timestamp)
|
||||||
all_timestamps = jnp.array(
|
all_timestamps = jnp.array(
|
||||||
list(
|
map(timestamp_to_seconds, self._historical_timestamps)
|
||||||
chain(
|
|
||||||
(
|
|
||||||
(d.timestamp - t_0).total_seconds()
|
|
||||||
for d in historical_detections
|
|
||||||
),
|
|
||||||
((timestamp - t_0).total_seconds(),),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self._update(all_keypoints, all_timestamps)
|
self._update(all_keypoints, all_timestamps)
|
||||||
|
|
||||||
def get(self) -> TrackingPrediction:
|
def get(self) -> TrackingPrediction:
|
||||||
historical_detections = self._get_historical_detections()
|
if not self._historical_3d_poses:
|
||||||
if not historical_detections:
|
raise ValueError("No historical 3D poses available")
|
||||||
raise ValueError("No historical detections available")
|
|
||||||
|
|
||||||
latest_detection = historical_detections[-1]
|
latest_3d_pose = self._historical_3d_poses[-1]
|
||||||
latest_keypoints = latest_detection.keypoints
|
|
||||||
|
|
||||||
if self._velocity is None:
|
if self._velocity is None:
|
||||||
return TrackingPrediction(velocity=None, keypoints=latest_keypoints)
|
return TrackingPrediction(velocity=None, keypoints=latest_3d_pose)
|
||||||
else:
|
else:
|
||||||
return TrackingPrediction(
|
return TrackingPrediction(velocity=self._velocity, keypoints=latest_3d_pose)
|
||||||
velocity=self._velocity, keypoints=latest_keypoints
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@jaxtyped(typechecker=beartype)
|
@jaxtyped(typechecker=beartype)
|
||||||
@ -393,7 +366,7 @@ class Tracking:
|
|||||||
"""
|
"""
|
||||||
# pylint: disable-next=unsubscriptable-object
|
# pylint: disable-next=unsubscriptable-object
|
||||||
if (vel := self.velocity_filter.get()["velocity"]) is None:
|
if (vel := self.velocity_filter.get()["velocity"]) is None:
|
||||||
raise ValueError("Velocity is not available")
|
return jnp.zeros_like(self.state.keypoints)
|
||||||
else:
|
else:
|
||||||
return vel
|
return vel
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user