feat: Introduce LastDifferenceVelocityFilter for improved tracking velocity estimation

- Added a new `LastDifferenceVelocityFilter` class to estimate tracking velocities based on the last observed keypoints, enhancing the tracking capabilities.
- Updated the `Tracking` class to utilize the new velocity filter, allowing for more accurate predictions of keypoints over time.
- Refactored the `predict` method to support various input types (float, timedelta, datetime) for better flexibility in time handling.
- Improved timestamp handling in the `perpendicular_distance_camera_2d_points_to_tracking_raycasting` function to ensure adherence to minimum delta time constraints.
- Cleaned up imports and type hints for better organization and clarity across the codebase.
This commit is contained in:
2025-05-02 11:11:32 +08:00
parent 072bf1c46f
commit c78850855c
2 changed files with 160 additions and 33 deletions

View File

@ -11,8 +11,9 @@ from typing import (
TypeVar, TypeVar,
cast, cast,
overload, overload,
Protocol,
) )
from datetime import timedelta
import jax.numpy as jnp import jax.numpy as jnp
from beartype import beartype from beartype import beartype
from beartype.typing import Mapping, Sequence from beartype.typing import Mapping, Sequence
@ -23,6 +24,110 @@ from pyrsistent import PVector
from app.camera import Detection from app.camera import Detection
class TrackingPrediction(TypedDict):
velocity: Float[Array, "J 3"]
keypoints: Float[Array, "J 3"]
class GenericVelocityFilter(Protocol):
"""
a filter interface for tracking velocity estimation
"""
def predict(self, timestamp: datetime) -> TrackingPrediction:
"""
predict the velocity and the keypoints location
Args:
timestamp: timestamp of the prediction
Returns:
velocity: velocity of the tracking
keypoints: keypoints of the tracking
"""
... # pylint: disable=unnecessary-ellipsis
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
"""
update the filter state with new measurements
Args:
keypoints: new measurements
timestamp: timestamp of the update
"""
... # pylint: disable=unnecessary-ellipsis
def get(self) -> TrackingPrediction:
"""
get the current state of the filter state
Returns:
velocity: velocity of the tracking
keypoints: keypoints of the tracking
"""
... # pylint: disable=unnecessary-ellipsis
def reset(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
"""
reset the filter state with new keypoints
Args:
keypoints: new keypoints
timestamp: timestamp of the reset
"""
... # pylint: disable=unnecessary-ellipsis
class LastDifferenceVelocityFilter(GenericVelocityFilter):
"""
a velocity filter that uses the last difference of keypoints
"""
_last_timestamp: datetime
_last_keypoints: Float[Array, "J 3"]
_last_velocity: Optional[Float[Array, "J 3"]] = None
def __init__(self, keypoints: Float[Array, "J 3"], timestamp: datetime):
self._last_keypoints = keypoints
self._last_timestamp = timestamp
def predict(self, timestamp: datetime) -> TrackingPrediction:
delta_t_s = (timestamp - self._last_timestamp).total_seconds()
if self._last_velocity is None:
return TrackingPrediction(
velocity=jnp.zeros_like(self._last_keypoints),
keypoints=self._last_keypoints,
)
else:
return TrackingPrediction(
velocity=self._last_velocity,
keypoints=self._last_keypoints + self._last_velocity * delta_t_s,
)
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
delta_t_s = (timestamp - self._last_timestamp).total_seconds()
self._last_velocity = (keypoints - self._last_keypoints) / delta_t_s
self._last_keypoints = keypoints
self._last_timestamp = timestamp
def get(self) -> TrackingPrediction:
if self._last_velocity is None:
return TrackingPrediction(
velocity=jnp.zeros_like(self._last_keypoints),
keypoints=self._last_keypoints,
)
else:
return TrackingPrediction(
velocity=self._last_velocity,
keypoints=self._last_keypoints,
)
def reset(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
self._last_keypoints = keypoints
self._last_timestamp = timestamp
self._last_velocity = None
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
@dataclass(frozen=True) @dataclass(frozen=True)
class Tracking: class Tracking:
@ -48,43 +153,67 @@ class Tracking:
Used for 3D re-triangulation Used for 3D re-triangulation
""" """
velocity: Optional[Float[Array, "3"]] = None velocity_filter: GenericVelocityFilter
""" """
Could be `None`. Like when the 3D pose is initialized. The velocity filter of the tracking
`velocity` should be updated when target association yields a new
3D pose.
""" """
def __repr__(self) -> str: def __repr__(self) -> str:
return f"Tracking({self.id}, {self.last_active_timestamp})" return f"Tracking({self.id}, {self.last_active_timestamp})"
def predict( @overload
self, def predict(self, time: float) -> Float[Array, "J 3"]:
delta_t_s: float,
) -> Float[Array, "J 3"]:
""" """
Predict the 3D pose of a tracking based on its velocity. predict the keypoints at a given time
JAX-friendly implementation that avoids Python control flow.
Args: Args:
delta_t_s: Time delta in seconds time: the time in seconds to predict the keypoints
Returns: Returns:
Predicted 3D pose keypoints the predicted keypoints
""" """
# ------------------------------------------------------------------ ... # pylint: disable=unnecessary-ellipsis
# Step 1 decide velocity on the Python side
# ------------------------------------------------------------------
if self.velocity is None:
velocity = jnp.zeros_like(self.keypoints) # (J, 3)
else:
velocity = self.velocity # (J, 3)
# ------------------------------------------------------------------ @overload
# Step 2 pure JAX math def predict(self, time: timedelta) -> Float[Array, "J 3"]:
# ------------------------------------------------------------------ """
return self.keypoints + velocity * delta_t_s predict the keypoints at a given time
Args:
time: the time delta to predict the keypoints
"""
... # pylint: disable=unnecessary-ellipsis
@overload
def predict(self, time: datetime) -> Float[Array, "J 3"]:
"""
predict the keypoints at a given time
Args:
time: the timestamp to predict the keypoints
"""
... # pylint: disable=unnecessary-ellipsis
def predict(
self,
time: float | timedelta | datetime,
) -> Float[Array, "J 3"]:
if isinstance(time, timedelta):
timestamp = self.last_active_timestamp + time
elif isinstance(time, datetime):
timestamp = time
else:
timestamp = self.last_active_timestamp + timedelta(seconds=time)
# pylint: disable-next=unsubscriptable-object
return self.velocity_filter.predict(timestamp)["keypoints"]
@property
def velocity(self) -> Float[Array, "J 3"]:
"""
The velocity of the tracking for each keypoint
"""
# pylint: disable-next=unsubscriptable-object
return self.velocity_filter.get()["velocity"]
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)

View File

@ -37,7 +37,6 @@ import awkward as ak
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np import numpy as np
import orjson
from beartype import beartype from beartype import beartype
from beartype.typing import Mapping, Sequence from beartype.typing import Mapping, Sequence
from cv2 import undistortPoints from cv2 import undistortPoints
@ -46,7 +45,7 @@ from jaxtyping import Array, Float, Num, jaxtyped
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from numpy.typing import ArrayLike from numpy.typing import ArrayLike
from optax.assignment import hungarian_algorithm as linear_sum_assignment from optax.assignment import hungarian_algorithm as linear_sum_assignment
from pyrsistent import v, pvector from pyrsistent import pvector, v
from scipy.spatial.transform import Rotation as R from scipy.spatial.transform import Rotation as R
from typing_extensions import deprecated from typing_extensions import deprecated
@ -59,15 +58,15 @@ from app.camera import (
classify_by_camera, classify_by_camera,
) )
from app.solver._old import GLPKSolver from app.solver._old import GLPKSolver
from app.tracking import AffinityResult, Tracking from app.tracking import AffinityResult, LastDifferenceVelocityFilter, Tracking
from app.visualize.whole_body import visualize_whole_body from app.visualize.whole_body import visualize_whole_body
NDArray: TypeAlias = np.ndarray NDArray: TypeAlias = np.ndarray
# %% # %%
DATASET_PATH = Path("samples") / "04_02" DATASET_PATH = Path("samples") / "04_02"
AK_CAMERA_DATASET: ak.Array = ak.from_parquet(DATASET_PATH / "camera_params.parquet") AK_CAMERA_DATASET: ak.Array = ak.from_parquet(DATASET_PATH / "camera_params.parquet") # type: ignore
DELTA_T_MIN = timedelta(milliseconds=10) DELTA_T_MIN = timedelta(milliseconds=1)
display(AK_CAMERA_DATASET) display(AK_CAMERA_DATASET)
@ -549,6 +548,7 @@ class GlobalTrackingState:
keypoints=kps_3d, keypoints=kps_3d,
last_active_timestamp=latest_timestamp, last_active_timestamp=latest_timestamp,
historical_detections=v(*cluster), historical_detections=v(*cluster),
velocity_filter=LastDifferenceVelocityFilter(kps_3d, latest_timestamp),
) )
self._trackings[next_id] = tracking self._trackings[next_id] = tracking
self._last_id = next_id self._last_id = next_id
@ -673,9 +673,7 @@ def perpendicular_distance_camera_2d_points_to_tracking_raycasting(
camera = detection.camera camera = detection.camera
# Use the delta_t supplied by the caller, but clamp to DELTA_T_MIN to # Use the delta_t supplied by the caller, but clamp to DELTA_T_MIN to
# avoid division-by-zero / exploding affinities. # avoid division-by-zero / exploding affinities.
delta_t = max(delta_t, DELTA_T_MIN) predicted_pose = tracking.predict(max(delta_t, DELTA_T_MIN))
delta_t_s = delta_t.total_seconds()
predicted_pose = tracking.predict(delta_t_s)
# Back-project the 2D points to 3D space # Back-project the 2D points to 3D space
# intersection with z=0 plane # intersection with z=0 plane