diff --git a/app/tracking/__init__.py b/app/tracking/__init__.py index 63fc84b..41e0266 100644 --- a/app/tracking/__init__.py +++ b/app/tracking/__init__.py @@ -67,16 +67,6 @@ class GenericVelocityFilter(Protocol): """ ... # pylint: disable=unnecessary-ellipsis - def reset(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: - """ - reset the filter state with new keypoints - - Args: - keypoints: new keypoints - timestamp: timestamp of the reset - """ - ... # pylint: disable=unnecessary-ellipsis - class LastDifferenceVelocityFilter(GenericVelocityFilter): """ @@ -122,13 +112,12 @@ class LastDifferenceVelocityFilter(GenericVelocityFilter): keypoints=self._last_keypoints, ) - def reset(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: - self._last_keypoints = keypoints - self._last_timestamp = timestamp - self._last_velocity = None - class LeastMeanSquareVelocityFilter(GenericVelocityFilter): + """ + a velocity filter that uses the least mean square method to estimate the velocity + """ + _get_historical_detections: Callable[[], Sequence[Detection]] """ get the current historical detections, assuming the detections are sorted by @@ -137,11 +126,56 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter): """ _velocity: Optional[Float[Array, "J 3"]] = None + @staticmethod + def from_tracking(tracking: "Tracking") -> "LeastMeanSquareVelocityFilter": + """ + create a LeastMeanSquareVelocityFilter from a Tracking object + """ + velocity = tracking.velocity_filter.get()["velocity"] + if jnp.all(velocity == jnp.zeros_like(velocity)): + return LeastMeanSquareVelocityFilter( + get_historical_detections=lambda: tracking.historical_detections + ) + else: + f = LeastMeanSquareVelocityFilter( + get_historical_detections=lambda: tracking.historical_detections + ) + # pylint: disable-next=protected-access + f._velocity = velocity + return f + def __init__(self, get_historical_detections: Callable[[], Sequence[Detection]]): self._get_historical_detections = get_historical_detections + self._velocity = None + + @property + def velocity(self) -> Float[Array, "J 3"]: + if self._velocity is None: + raise ValueError("Velocity not initialized") + return self._velocity def predict(self, timestamp: datetime) -> TrackingPrediction: - raise NotImplementedError + historical_detections = self._get_historical_detections() + if not historical_detections: + raise ValueError("No historical detections available for prediction") + + # Use the latest historical detection + latest_detection = historical_detections[-1] + latest_keypoints = latest_detection.keypoints + latest_timestamp = latest_detection.timestamp + + delta_t_s = (timestamp - latest_timestamp).total_seconds() + + if self._velocity is None: + return TrackingPrediction( + velocity=jnp.zeros_like(latest_keypoints), keypoints=latest_keypoints + ) + else: + # Linear motion model: ẋt = xt' + Vt' · (t - t') + predicted_keypoints = latest_keypoints + self._velocity * delta_t_s + return TrackingPrediction( + velocity=self._velocity, keypoints=predicted_keypoints + ) @jaxtyped(typechecker=beartype) def _update( @@ -152,28 +186,79 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter): """ update measurements with least mean square method """ - raise NotImplementedError + if keypoints.shape[0] < 2: + raise ValueError("Not enough measurements to estimate velocity") + + # Using least squares to fit a linear model for each joint and dimension + # X = timestamps, y = keypoints + # For each joint and each dimension, we solve for velocity + + n_samples = timestamps.shape[0] + n_joints = keypoints.shape[1] + + # Create design matrix for linear regression + # [t, 1] for each timestamp + X = jnp.column_stack([timestamps, jnp.ones(n_samples)]) + + # Reshape keypoints to solve for all joints and dimensions at once + # From [N, J, 3] to [N, J*3] + keypoints_reshaped = keypoints.reshape(n_samples, -1) + + # Use JAX's lstsq to solve the least squares problem + # This is more numerically stable than manually computing pseudoinverse + coefficients, _, _, _ = jnp.linalg.lstsq(X, keypoints_reshaped, rcond=None) + + # Coefficients shape is [2, J*3] + # First row: velocities, Second row: intercepts + velocities = coefficients[0].reshape(n_joints, 3) + + # Update velocity + self._velocity = velocities def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: historical_detections = self._get_historical_detections() + + if not historical_detections: + self._velocity = jnp.zeros_like(keypoints) + return + t_0 = min(d.timestamp for d in historical_detections) - detections = jnp.array( - chain((d.keypoints for d in historical_detections), (keypoints,)) + + all_keypoints = jnp.array( + list(chain((d.keypoints for d in historical_detections), (keypoints,))) ) - # timestamps relative to t_0 (the oldest detection timestamp) - timestamps = jnp.array( - chain( - ((d.timestamp - t_0).total_seconds() for d in historical_detections), - ((timestamp - t_0).total_seconds(),), + + # Timestamps relative to t_0 (the oldest detection timestamp) + all_timestamps = jnp.array( + list( + chain( + ( + (d.timestamp - t_0).total_seconds() + for d in historical_detections + ), + ((timestamp - t_0).total_seconds(),), + ) ) ) - raise NotImplementedError + + self._update(all_keypoints, all_timestamps) def get(self) -> TrackingPrediction: - raise NotImplementedError + historical_detections = self._get_historical_detections() + if not historical_detections: + raise ValueError("No historical detections available") - def reset(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: - raise NotImplementedError + latest_detection = historical_detections[-1] + latest_keypoints = latest_detection.keypoints + + if self._velocity is None: + return TrackingPrediction( + velocity=jnp.zeros_like(latest_keypoints), keypoints=latest_keypoints + ) + else: + return TrackingPrediction( + velocity=self._velocity, keypoints=latest_keypoints + ) @jaxtyped(typechecker=beartype)