diff --git a/app/tracking/__init__.py b/app/tracking/__init__.py index 2991aee..e071ab2 100644 --- a/app/tracking/__init__.py +++ b/app/tracking/__init__.py @@ -1,4 +1,5 @@ import weakref +from collections import deque from dataclasses import dataclass from datetime import datetime, timedelta from itertools import chain @@ -145,73 +146,55 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter): a velocity filter that uses the least mean square method to estimate the velocity """ - _get_historical_detections: Callable[[], Sequence[Detection]] - """ - get the current historical detections, assuming the detections are sorted by - timestamp incrementally (i.e. index 0 is the oldest detection, index -1 is - the newest detection) - """ + _historical_3d_poses: deque[Float[Array, "J 3"]] + _historical_timestamps: deque[datetime] _velocity: Optional[Float[Array, "J 3"]] = None + _max_samples: int - def __init__(self, get_historical_detections: Callable[[], Sequence[Detection]]): - self._get_historical_detections = get_historical_detections - self._velocity = None - - @staticmethod - def from_tracking(tracking: "Tracking") -> "LeastMeanSquareVelocityFilter": - """ - create a LeastMeanSquareVelocityFilter from a Tracking object - - Note that this function is using a weak reference to the tracking object, - so that the tracking object can be garbage collected if there are no other - references to it. - """ - # Create a weak reference to avoid circular references - # https://docs.python.org/3/library/weakref.html - tracking_ref = weakref.ref(tracking) - - # Create a getter function that uses the weak reference - def get_historical_detections() -> Sequence[Detection]: - tr = tracking_ref() - if tr is None: - return [] # Return empty list if tracking has been garbage collected - return tr.state.historical_detections - - velocity = tracking.velocity_filter.get()["velocity"] - if velocity is None: - return LeastMeanSquareVelocityFilter( - get_historical_detections=get_historical_detections - ) + def __init__( + self, + historical_3d_poses: Sequence[Float[Array, "J 3"]], + historical_timestamps: Sequence[datetime], + max_samples: int = 10, + ): + assert len(historical_3d_poses) == len(historical_timestamps) + temp = zip(historical_3d_poses, historical_timestamps) + temp_sorted = sorted(temp, key=lambda x: x[1]) + self._historical_3d_poses = deque( + map(lambda x: x[0], temp_sorted), maxlen=max_samples + ) + self._historical_timestamps = deque( + map(lambda x: x[1], temp_sorted), maxlen=max_samples + ) + self._max_samples = max_samples + if len(self._historical_3d_poses) < 2: + self._velocity = None else: - f = LeastMeanSquareVelocityFilter( - get_historical_detections=get_historical_detections + self._update( + jnp.array(self._historical_3d_poses), + jnp.array(self._historical_timestamps), ) - # pylint: disable-next=protected-access - f._velocity = velocity - return f def predict(self, timestamp: datetime) -> TrackingPrediction: - historical_detections = self._get_historical_detections() - if not historical_detections: - raise ValueError("No historical detections available for prediction") + if not self._historical_3d_poses: + raise ValueError("No historical 3D poses available for prediction") - # Use the latest historical detection - latest_detection = historical_detections[-1] - latest_keypoints = latest_detection.keypoints - latest_timestamp = latest_detection.timestamp + # use the latest historical detection + latest_3d_pose = self._historical_3d_poses[-1] + latest_timestamp = self._historical_timestamps[-1] delta_t_s = (timestamp - latest_timestamp).total_seconds() if self._velocity is None: return TrackingPrediction( velocity=None, - keypoints=latest_keypoints, + keypoints=latest_3d_pose, ) else: # Linear motion model: ẋt = xt' + Vt' · (t - t') - predicted_keypoints = latest_keypoints + self._velocity * delta_t_s + predicted_3d_pose = latest_3d_pose + self._velocity * delta_t_s return TrackingPrediction( - velocity=self._velocity, keypoints=predicted_keypoints + velocity=self._velocity, keypoints=predicted_3d_pose ) @jaxtyped(typechecker=beartype) @@ -253,47 +236,37 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter): self._velocity = velocities def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: - historical_detections = self._get_historical_detections() + last_timestamp = self._historical_timestamps[-1] + assert last_timestamp <= timestamp - if not historical_detections: - self._velocity = jnp.zeros_like(keypoints) - return + # deque would manage the maxlen automatically + self._historical_3d_poses.append(keypoints) + self._historical_timestamps.append(timestamp) - t_0 = min(d.timestamp for d in historical_detections) + t_0 = self._historical_timestamps[0] + all_keypoints = jnp.array(self._historical_3d_poses) - all_keypoints = jnp.array( - list(chain((d.keypoints for d in historical_detections), (keypoints,))) - ) + def timestamp_to_seconds(timestamp: datetime) -> float: + assert t_0 <= timestamp + return (timestamp - t_0).total_seconds() - # Timestamps relative to t_0 (the oldest detection timestamp) + # timestamps relative to t_0 (the oldest detection timestamp) all_timestamps = jnp.array( - list( - chain( - ( - (d.timestamp - t_0).total_seconds() - for d in historical_detections - ), - ((timestamp - t_0).total_seconds(),), - ) - ) + map(timestamp_to_seconds, self._historical_timestamps) ) self._update(all_keypoints, all_timestamps) def get(self) -> TrackingPrediction: - historical_detections = self._get_historical_detections() - if not historical_detections: - raise ValueError("No historical detections available") + if not self._historical_3d_poses: + raise ValueError("No historical 3D poses available") - latest_detection = historical_detections[-1] - latest_keypoints = latest_detection.keypoints + latest_3d_pose = self._historical_3d_poses[-1] if self._velocity is None: - return TrackingPrediction(velocity=None, keypoints=latest_keypoints) + return TrackingPrediction(velocity=None, keypoints=latest_3d_pose) else: - return TrackingPrediction( - velocity=self._velocity, keypoints=latest_keypoints - ) + return TrackingPrediction(velocity=self._velocity, keypoints=latest_3d_pose) @jaxtyped(typechecker=beartype) @@ -393,7 +366,7 @@ class Tracking: """ # pylint: disable-next=unsubscriptable-object if (vel := self.velocity_filter.get()["velocity"]) is None: - raise ValueError("Velocity is not available") + return jnp.zeros_like(self.state.keypoints) else: return vel