forked from HQU-gxy/CVTH3PE
feat: Introduce LastDifferenceVelocityFilter for improved tracking velocity estimation
- Added a new `LastDifferenceVelocityFilter` class to estimate tracking velocities based on the last observed keypoints, enhancing the tracking capabilities. - Updated the `Tracking` class to utilize the new velocity filter, allowing for more accurate predictions of keypoints over time. - Refactored the `predict` method to support various input types (float, timedelta, datetime) for better flexibility in time handling. - Improved timestamp handling in the `perpendicular_distance_camera_2d_points_to_tracking_raycasting` function to ensure adherence to minimum delta time constraints. - Cleaned up imports and type hints for better organization and clarity across the codebase.
This commit is contained in:
@ -11,8 +11,9 @@ from typing import (
|
|||||||
TypeVar,
|
TypeVar,
|
||||||
cast,
|
cast,
|
||||||
overload,
|
overload,
|
||||||
|
Protocol,
|
||||||
)
|
)
|
||||||
|
from datetime import timedelta
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from beartype import beartype
|
from beartype import beartype
|
||||||
from beartype.typing import Mapping, Sequence
|
from beartype.typing import Mapping, Sequence
|
||||||
@ -23,6 +24,110 @@ from pyrsistent import PVector
|
|||||||
from app.camera import Detection
|
from app.camera import Detection
|
||||||
|
|
||||||
|
|
||||||
|
class TrackingPrediction(TypedDict):
|
||||||
|
velocity: Float[Array, "J 3"]
|
||||||
|
keypoints: Float[Array, "J 3"]
|
||||||
|
|
||||||
|
|
||||||
|
class GenericVelocityFilter(Protocol):
|
||||||
|
"""
|
||||||
|
a filter interface for tracking velocity estimation
|
||||||
|
"""
|
||||||
|
|
||||||
|
def predict(self, timestamp: datetime) -> TrackingPrediction:
|
||||||
|
"""
|
||||||
|
predict the velocity and the keypoints location
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timestamp: timestamp of the prediction
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
velocity: velocity of the tracking
|
||||||
|
keypoints: keypoints of the tracking
|
||||||
|
"""
|
||||||
|
... # pylint: disable=unnecessary-ellipsis
|
||||||
|
|
||||||
|
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
|
||||||
|
"""
|
||||||
|
update the filter state with new measurements
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keypoints: new measurements
|
||||||
|
timestamp: timestamp of the update
|
||||||
|
"""
|
||||||
|
... # pylint: disable=unnecessary-ellipsis
|
||||||
|
|
||||||
|
def get(self) -> TrackingPrediction:
|
||||||
|
"""
|
||||||
|
get the current state of the filter state
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
velocity: velocity of the tracking
|
||||||
|
keypoints: keypoints of the tracking
|
||||||
|
"""
|
||||||
|
... # pylint: disable=unnecessary-ellipsis
|
||||||
|
|
||||||
|
def reset(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
|
||||||
|
"""
|
||||||
|
reset the filter state with new keypoints
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keypoints: new keypoints
|
||||||
|
timestamp: timestamp of the reset
|
||||||
|
"""
|
||||||
|
... # pylint: disable=unnecessary-ellipsis
|
||||||
|
|
||||||
|
|
||||||
|
class LastDifferenceVelocityFilter(GenericVelocityFilter):
|
||||||
|
"""
|
||||||
|
a velocity filter that uses the last difference of keypoints
|
||||||
|
"""
|
||||||
|
|
||||||
|
_last_timestamp: datetime
|
||||||
|
_last_keypoints: Float[Array, "J 3"]
|
||||||
|
_last_velocity: Optional[Float[Array, "J 3"]] = None
|
||||||
|
|
||||||
|
def __init__(self, keypoints: Float[Array, "J 3"], timestamp: datetime):
|
||||||
|
self._last_keypoints = keypoints
|
||||||
|
self._last_timestamp = timestamp
|
||||||
|
|
||||||
|
def predict(self, timestamp: datetime) -> TrackingPrediction:
|
||||||
|
delta_t_s = (timestamp - self._last_timestamp).total_seconds()
|
||||||
|
if self._last_velocity is None:
|
||||||
|
return TrackingPrediction(
|
||||||
|
velocity=jnp.zeros_like(self._last_keypoints),
|
||||||
|
keypoints=self._last_keypoints,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return TrackingPrediction(
|
||||||
|
velocity=self._last_velocity,
|
||||||
|
keypoints=self._last_keypoints + self._last_velocity * delta_t_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
|
||||||
|
delta_t_s = (timestamp - self._last_timestamp).total_seconds()
|
||||||
|
self._last_velocity = (keypoints - self._last_keypoints) / delta_t_s
|
||||||
|
self._last_keypoints = keypoints
|
||||||
|
self._last_timestamp = timestamp
|
||||||
|
|
||||||
|
def get(self) -> TrackingPrediction:
|
||||||
|
if self._last_velocity is None:
|
||||||
|
return TrackingPrediction(
|
||||||
|
velocity=jnp.zeros_like(self._last_keypoints),
|
||||||
|
keypoints=self._last_keypoints,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return TrackingPrediction(
|
||||||
|
velocity=self._last_velocity,
|
||||||
|
keypoints=self._last_keypoints,
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
|
||||||
|
self._last_keypoints = keypoints
|
||||||
|
self._last_timestamp = timestamp
|
||||||
|
self._last_velocity = None
|
||||||
|
|
||||||
|
|
||||||
@jaxtyped(typechecker=beartype)
|
@jaxtyped(typechecker=beartype)
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class Tracking:
|
class Tracking:
|
||||||
@ -48,43 +153,67 @@ class Tracking:
|
|||||||
Used for 3D re-triangulation
|
Used for 3D re-triangulation
|
||||||
"""
|
"""
|
||||||
|
|
||||||
velocity: Optional[Float[Array, "3"]] = None
|
velocity_filter: GenericVelocityFilter
|
||||||
"""
|
"""
|
||||||
Could be `None`. Like when the 3D pose is initialized.
|
The velocity filter of the tracking
|
||||||
|
|
||||||
`velocity` should be updated when target association yields a new
|
|
||||||
3D pose.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"Tracking({self.id}, {self.last_active_timestamp})"
|
return f"Tracking({self.id}, {self.last_active_timestamp})"
|
||||||
|
|
||||||
def predict(
|
@overload
|
||||||
self,
|
def predict(self, time: float) -> Float[Array, "J 3"]:
|
||||||
delta_t_s: float,
|
|
||||||
) -> Float[Array, "J 3"]:
|
|
||||||
"""
|
"""
|
||||||
Predict the 3D pose of a tracking based on its velocity.
|
predict the keypoints at a given time
|
||||||
JAX-friendly implementation that avoids Python control flow.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
delta_t_s: Time delta in seconds
|
time: the time in seconds to predict the keypoints
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Predicted 3D pose keypoints
|
the predicted keypoints
|
||||||
"""
|
"""
|
||||||
# ------------------------------------------------------------------
|
... # pylint: disable=unnecessary-ellipsis
|
||||||
# Step 1 – decide velocity on the Python side
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
if self.velocity is None:
|
|
||||||
velocity = jnp.zeros_like(self.keypoints) # (J, 3)
|
|
||||||
else:
|
|
||||||
velocity = self.velocity # (J, 3)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
@overload
|
||||||
# Step 2 – pure JAX math
|
def predict(self, time: timedelta) -> Float[Array, "J 3"]:
|
||||||
# ------------------------------------------------------------------
|
"""
|
||||||
return self.keypoints + velocity * delta_t_s
|
predict the keypoints at a given time
|
||||||
|
|
||||||
|
Args:
|
||||||
|
time: the time delta to predict the keypoints
|
||||||
|
"""
|
||||||
|
... # pylint: disable=unnecessary-ellipsis
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def predict(self, time: datetime) -> Float[Array, "J 3"]:
|
||||||
|
"""
|
||||||
|
predict the keypoints at a given time
|
||||||
|
|
||||||
|
Args:
|
||||||
|
time: the timestamp to predict the keypoints
|
||||||
|
"""
|
||||||
|
... # pylint: disable=unnecessary-ellipsis
|
||||||
|
|
||||||
|
def predict(
|
||||||
|
self,
|
||||||
|
time: float | timedelta | datetime,
|
||||||
|
) -> Float[Array, "J 3"]:
|
||||||
|
if isinstance(time, timedelta):
|
||||||
|
timestamp = self.last_active_timestamp + time
|
||||||
|
elif isinstance(time, datetime):
|
||||||
|
timestamp = time
|
||||||
|
else:
|
||||||
|
timestamp = self.last_active_timestamp + timedelta(seconds=time)
|
||||||
|
# pylint: disable-next=unsubscriptable-object
|
||||||
|
return self.velocity_filter.predict(timestamp)["keypoints"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def velocity(self) -> Float[Array, "J 3"]:
|
||||||
|
"""
|
||||||
|
The velocity of the tracking for each keypoint
|
||||||
|
"""
|
||||||
|
# pylint: disable-next=unsubscriptable-object
|
||||||
|
return self.velocity_filter.get()["velocity"]
|
||||||
|
|
||||||
|
|
||||||
@jaxtyped(typechecker=beartype)
|
@jaxtyped(typechecker=beartype)
|
||||||
|
|||||||
@ -37,7 +37,6 @@ import awkward as ak
|
|||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import orjson
|
|
||||||
from beartype import beartype
|
from beartype import beartype
|
||||||
from beartype.typing import Mapping, Sequence
|
from beartype.typing import Mapping, Sequence
|
||||||
from cv2 import undistortPoints
|
from cv2 import undistortPoints
|
||||||
@ -46,7 +45,7 @@ from jaxtyping import Array, Float, Num, jaxtyped
|
|||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from numpy.typing import ArrayLike
|
from numpy.typing import ArrayLike
|
||||||
from optax.assignment import hungarian_algorithm as linear_sum_assignment
|
from optax.assignment import hungarian_algorithm as linear_sum_assignment
|
||||||
from pyrsistent import v, pvector
|
from pyrsistent import pvector, v
|
||||||
from scipy.spatial.transform import Rotation as R
|
from scipy.spatial.transform import Rotation as R
|
||||||
from typing_extensions import deprecated
|
from typing_extensions import deprecated
|
||||||
|
|
||||||
@ -59,15 +58,15 @@ from app.camera import (
|
|||||||
classify_by_camera,
|
classify_by_camera,
|
||||||
)
|
)
|
||||||
from app.solver._old import GLPKSolver
|
from app.solver._old import GLPKSolver
|
||||||
from app.tracking import AffinityResult, Tracking
|
from app.tracking import AffinityResult, LastDifferenceVelocityFilter, Tracking
|
||||||
from app.visualize.whole_body import visualize_whole_body
|
from app.visualize.whole_body import visualize_whole_body
|
||||||
|
|
||||||
NDArray: TypeAlias = np.ndarray
|
NDArray: TypeAlias = np.ndarray
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
DATASET_PATH = Path("samples") / "04_02"
|
DATASET_PATH = Path("samples") / "04_02"
|
||||||
AK_CAMERA_DATASET: ak.Array = ak.from_parquet(DATASET_PATH / "camera_params.parquet")
|
AK_CAMERA_DATASET: ak.Array = ak.from_parquet(DATASET_PATH / "camera_params.parquet") # type: ignore
|
||||||
DELTA_T_MIN = timedelta(milliseconds=10)
|
DELTA_T_MIN = timedelta(milliseconds=1)
|
||||||
display(AK_CAMERA_DATASET)
|
display(AK_CAMERA_DATASET)
|
||||||
|
|
||||||
|
|
||||||
@ -549,6 +548,7 @@ class GlobalTrackingState:
|
|||||||
keypoints=kps_3d,
|
keypoints=kps_3d,
|
||||||
last_active_timestamp=latest_timestamp,
|
last_active_timestamp=latest_timestamp,
|
||||||
historical_detections=v(*cluster),
|
historical_detections=v(*cluster),
|
||||||
|
velocity_filter=LastDifferenceVelocityFilter(kps_3d, latest_timestamp),
|
||||||
)
|
)
|
||||||
self._trackings[next_id] = tracking
|
self._trackings[next_id] = tracking
|
||||||
self._last_id = next_id
|
self._last_id = next_id
|
||||||
@ -673,9 +673,7 @@ def perpendicular_distance_camera_2d_points_to_tracking_raycasting(
|
|||||||
camera = detection.camera
|
camera = detection.camera
|
||||||
# Use the delta_t supplied by the caller, but clamp to DELTA_T_MIN to
|
# Use the delta_t supplied by the caller, but clamp to DELTA_T_MIN to
|
||||||
# avoid division-by-zero / exploding affinities.
|
# avoid division-by-zero / exploding affinities.
|
||||||
delta_t = max(delta_t, DELTA_T_MIN)
|
predicted_pose = tracking.predict(max(delta_t, DELTA_T_MIN))
|
||||||
delta_t_s = delta_t.total_seconds()
|
|
||||||
predicted_pose = tracking.predict(delta_t_s)
|
|
||||||
|
|
||||||
# Back-project the 2D points to 3D space
|
# Back-project the 2D points to 3D space
|
||||||
# intersection with z=0 plane
|
# intersection with z=0 plane
|
||||||
|
|||||||
Reference in New Issue
Block a user