From c78850855cf52dade13ed5ab0fa7f3416b115863 Mon Sep 17 00:00:00 2001 From: crosstyan Date: Fri, 2 May 2025 11:11:32 +0800 Subject: [PATCH] feat: Introduce LastDifferenceVelocityFilter for improved tracking velocity estimation - Added a new `LastDifferenceVelocityFilter` class to estimate tracking velocities based on the last observed keypoints, enhancing the tracking capabilities. - Updated the `Tracking` class to utilize the new velocity filter, allowing for more accurate predictions of keypoints over time. - Refactored the `predict` method to support various input types (float, timedelta, datetime) for better flexibility in time handling. - Improved timestamp handling in the `perpendicular_distance_camera_2d_points_to_tracking_raycasting` function to ensure adherence to minimum delta time constraints. - Cleaned up imports and type hints for better organization and clarity across the codebase. --- app/tracking/__init__.py | 179 +++++++++++++++++++++++++++++++++------ playground.py | 14 ++- 2 files changed, 160 insertions(+), 33 deletions(-) diff --git a/app/tracking/__init__.py b/app/tracking/__init__.py index ab7d52a..5ef8f9c 100644 --- a/app/tracking/__init__.py +++ b/app/tracking/__init__.py @@ -11,8 +11,9 @@ from typing import ( TypeVar, cast, overload, + Protocol, ) - +from datetime import timedelta import jax.numpy as jnp from beartype import beartype from beartype.typing import Mapping, Sequence @@ -23,6 +24,110 @@ from pyrsistent import PVector from app.camera import Detection +class TrackingPrediction(TypedDict): + velocity: Float[Array, "J 3"] + keypoints: Float[Array, "J 3"] + + +class GenericVelocityFilter(Protocol): + """ + a filter interface for tracking velocity estimation + """ + + def predict(self, timestamp: datetime) -> TrackingPrediction: + """ + predict the velocity and the keypoints location + + Args: + timestamp: timestamp of the prediction + + Returns: + velocity: velocity of the tracking + keypoints: keypoints of the tracking + """ + ... # pylint: disable=unnecessary-ellipsis + + def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: + """ + update the filter state with new measurements + + Args: + keypoints: new measurements + timestamp: timestamp of the update + """ + ... # pylint: disable=unnecessary-ellipsis + + def get(self) -> TrackingPrediction: + """ + get the current state of the filter state + + Returns: + velocity: velocity of the tracking + keypoints: keypoints of the tracking + """ + ... # pylint: disable=unnecessary-ellipsis + + def reset(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: + """ + reset the filter state with new keypoints + + Args: + keypoints: new keypoints + timestamp: timestamp of the reset + """ + ... # pylint: disable=unnecessary-ellipsis + + +class LastDifferenceVelocityFilter(GenericVelocityFilter): + """ + a velocity filter that uses the last difference of keypoints + """ + + _last_timestamp: datetime + _last_keypoints: Float[Array, "J 3"] + _last_velocity: Optional[Float[Array, "J 3"]] = None + + def __init__(self, keypoints: Float[Array, "J 3"], timestamp: datetime): + self._last_keypoints = keypoints + self._last_timestamp = timestamp + + def predict(self, timestamp: datetime) -> TrackingPrediction: + delta_t_s = (timestamp - self._last_timestamp).total_seconds() + if self._last_velocity is None: + return TrackingPrediction( + velocity=jnp.zeros_like(self._last_keypoints), + keypoints=self._last_keypoints, + ) + else: + return TrackingPrediction( + velocity=self._last_velocity, + keypoints=self._last_keypoints + self._last_velocity * delta_t_s, + ) + + def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: + delta_t_s = (timestamp - self._last_timestamp).total_seconds() + self._last_velocity = (keypoints - self._last_keypoints) / delta_t_s + self._last_keypoints = keypoints + self._last_timestamp = timestamp + + def get(self) -> TrackingPrediction: + if self._last_velocity is None: + return TrackingPrediction( + velocity=jnp.zeros_like(self._last_keypoints), + keypoints=self._last_keypoints, + ) + else: + return TrackingPrediction( + velocity=self._last_velocity, + keypoints=self._last_keypoints, + ) + + def reset(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: + self._last_keypoints = keypoints + self._last_timestamp = timestamp + self._last_velocity = None + + @jaxtyped(typechecker=beartype) @dataclass(frozen=True) class Tracking: @@ -48,43 +153,67 @@ class Tracking: Used for 3D re-triangulation """ - velocity: Optional[Float[Array, "3"]] = None + velocity_filter: GenericVelocityFilter """ - Could be `None`. Like when the 3D pose is initialized. - - `velocity` should be updated when target association yields a new - 3D pose. + The velocity filter of the tracking """ def __repr__(self) -> str: return f"Tracking({self.id}, {self.last_active_timestamp})" - def predict( - self, - delta_t_s: float, - ) -> Float[Array, "J 3"]: + @overload + def predict(self, time: float) -> Float[Array, "J 3"]: """ - Predict the 3D pose of a tracking based on its velocity. - JAX-friendly implementation that avoids Python control flow. + predict the keypoints at a given time Args: - delta_t_s: Time delta in seconds + time: the time in seconds to predict the keypoints Returns: - Predicted 3D pose keypoints + the predicted keypoints """ - # ------------------------------------------------------------------ - # Step 1 – decide velocity on the Python side - # ------------------------------------------------------------------ - if self.velocity is None: - velocity = jnp.zeros_like(self.keypoints) # (J, 3) - else: - velocity = self.velocity # (J, 3) + ... # pylint: disable=unnecessary-ellipsis - # ------------------------------------------------------------------ - # Step 2 – pure JAX math - # ------------------------------------------------------------------ - return self.keypoints + velocity * delta_t_s + @overload + def predict(self, time: timedelta) -> Float[Array, "J 3"]: + """ + predict the keypoints at a given time + + Args: + time: the time delta to predict the keypoints + """ + ... # pylint: disable=unnecessary-ellipsis + + @overload + def predict(self, time: datetime) -> Float[Array, "J 3"]: + """ + predict the keypoints at a given time + + Args: + time: the timestamp to predict the keypoints + """ + ... # pylint: disable=unnecessary-ellipsis + + def predict( + self, + time: float | timedelta | datetime, + ) -> Float[Array, "J 3"]: + if isinstance(time, timedelta): + timestamp = self.last_active_timestamp + time + elif isinstance(time, datetime): + timestamp = time + else: + timestamp = self.last_active_timestamp + timedelta(seconds=time) + # pylint: disable-next=unsubscriptable-object + return self.velocity_filter.predict(timestamp)["keypoints"] + + @property + def velocity(self) -> Float[Array, "J 3"]: + """ + The velocity of the tracking for each keypoint + """ + # pylint: disable-next=unsubscriptable-object + return self.velocity_filter.get()["velocity"] @jaxtyped(typechecker=beartype) diff --git a/playground.py b/playground.py index 2dc3cd4..fa19480 100644 --- a/playground.py +++ b/playground.py @@ -37,7 +37,6 @@ import awkward as ak import jax import jax.numpy as jnp import numpy as np -import orjson from beartype import beartype from beartype.typing import Mapping, Sequence from cv2 import undistortPoints @@ -46,7 +45,7 @@ from jaxtyping import Array, Float, Num, jaxtyped from matplotlib import pyplot as plt from numpy.typing import ArrayLike from optax.assignment import hungarian_algorithm as linear_sum_assignment -from pyrsistent import v, pvector +from pyrsistent import pvector, v from scipy.spatial.transform import Rotation as R from typing_extensions import deprecated @@ -59,15 +58,15 @@ from app.camera import ( classify_by_camera, ) from app.solver._old import GLPKSolver -from app.tracking import AffinityResult, Tracking +from app.tracking import AffinityResult, LastDifferenceVelocityFilter, Tracking from app.visualize.whole_body import visualize_whole_body NDArray: TypeAlias = np.ndarray # %% DATASET_PATH = Path("samples") / "04_02" -AK_CAMERA_DATASET: ak.Array = ak.from_parquet(DATASET_PATH / "camera_params.parquet") -DELTA_T_MIN = timedelta(milliseconds=10) +AK_CAMERA_DATASET: ak.Array = ak.from_parquet(DATASET_PATH / "camera_params.parquet") # type: ignore +DELTA_T_MIN = timedelta(milliseconds=1) display(AK_CAMERA_DATASET) @@ -549,6 +548,7 @@ class GlobalTrackingState: keypoints=kps_3d, last_active_timestamp=latest_timestamp, historical_detections=v(*cluster), + velocity_filter=LastDifferenceVelocityFilter(kps_3d, latest_timestamp), ) self._trackings[next_id] = tracking self._last_id = next_id @@ -673,9 +673,7 @@ def perpendicular_distance_camera_2d_points_to_tracking_raycasting( camera = detection.camera # Use the delta_t supplied by the caller, but clamp to DELTA_T_MIN to # avoid division-by-zero / exploding affinities. - delta_t = max(delta_t, DELTA_T_MIN) - delta_t_s = delta_t.total_seconds() - predicted_pose = tracking.predict(delta_t_s) + predicted_pose = tracking.predict(max(delta_t, DELTA_T_MIN)) # Back-project the 2D points to 3D space # intersection with z=0 plane