- Added a new `LastDifferenceVelocityFilter` class to estimate tracking velocities based on the last observed keypoints, enhancing the tracking capabilities. - Updated the `Tracking` class to utilize the new velocity filter, allowing for more accurate predictions of keypoints over time. - Refactored the `predict` method to support various input types (float, timedelta, datetime) for better flexibility in time handling. - Improved timestamp handling in the `perpendicular_distance_camera_2d_points_to_tracking_raycasting` function to ensure adherence to minimum delta time constraints. - Cleaned up imports and type hints for better organization and clarity across the codebase.
244 lines
6.8 KiB
Python
244 lines
6.8 KiB
Python
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
Generator,
|
|
Optional,
|
|
Sequence,
|
|
TypeAlias,
|
|
TypedDict,
|
|
TypeVar,
|
|
cast,
|
|
overload,
|
|
Protocol,
|
|
)
|
|
from datetime import timedelta
|
|
import jax.numpy as jnp
|
|
from beartype import beartype
|
|
from beartype.typing import Mapping, Sequence
|
|
from jax import Array
|
|
from jaxtyping import Array, Float, Int, jaxtyped
|
|
from pyrsistent import PVector
|
|
|
|
from app.camera import Detection
|
|
|
|
|
|
class TrackingPrediction(TypedDict):
|
|
velocity: Float[Array, "J 3"]
|
|
keypoints: Float[Array, "J 3"]
|
|
|
|
|
|
class GenericVelocityFilter(Protocol):
|
|
"""
|
|
a filter interface for tracking velocity estimation
|
|
"""
|
|
|
|
def predict(self, timestamp: datetime) -> TrackingPrediction:
|
|
"""
|
|
predict the velocity and the keypoints location
|
|
|
|
Args:
|
|
timestamp: timestamp of the prediction
|
|
|
|
Returns:
|
|
velocity: velocity of the tracking
|
|
keypoints: keypoints of the tracking
|
|
"""
|
|
... # pylint: disable=unnecessary-ellipsis
|
|
|
|
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
|
|
"""
|
|
update the filter state with new measurements
|
|
|
|
Args:
|
|
keypoints: new measurements
|
|
timestamp: timestamp of the update
|
|
"""
|
|
... # pylint: disable=unnecessary-ellipsis
|
|
|
|
def get(self) -> TrackingPrediction:
|
|
"""
|
|
get the current state of the filter state
|
|
|
|
Returns:
|
|
velocity: velocity of the tracking
|
|
keypoints: keypoints of the tracking
|
|
"""
|
|
... # pylint: disable=unnecessary-ellipsis
|
|
|
|
def reset(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
|
|
"""
|
|
reset the filter state with new keypoints
|
|
|
|
Args:
|
|
keypoints: new keypoints
|
|
timestamp: timestamp of the reset
|
|
"""
|
|
... # pylint: disable=unnecessary-ellipsis
|
|
|
|
|
|
class LastDifferenceVelocityFilter(GenericVelocityFilter):
|
|
"""
|
|
a velocity filter that uses the last difference of keypoints
|
|
"""
|
|
|
|
_last_timestamp: datetime
|
|
_last_keypoints: Float[Array, "J 3"]
|
|
_last_velocity: Optional[Float[Array, "J 3"]] = None
|
|
|
|
def __init__(self, keypoints: Float[Array, "J 3"], timestamp: datetime):
|
|
self._last_keypoints = keypoints
|
|
self._last_timestamp = timestamp
|
|
|
|
def predict(self, timestamp: datetime) -> TrackingPrediction:
|
|
delta_t_s = (timestamp - self._last_timestamp).total_seconds()
|
|
if self._last_velocity is None:
|
|
return TrackingPrediction(
|
|
velocity=jnp.zeros_like(self._last_keypoints),
|
|
keypoints=self._last_keypoints,
|
|
)
|
|
else:
|
|
return TrackingPrediction(
|
|
velocity=self._last_velocity,
|
|
keypoints=self._last_keypoints + self._last_velocity * delta_t_s,
|
|
)
|
|
|
|
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
|
|
delta_t_s = (timestamp - self._last_timestamp).total_seconds()
|
|
self._last_velocity = (keypoints - self._last_keypoints) / delta_t_s
|
|
self._last_keypoints = keypoints
|
|
self._last_timestamp = timestamp
|
|
|
|
def get(self) -> TrackingPrediction:
|
|
if self._last_velocity is None:
|
|
return TrackingPrediction(
|
|
velocity=jnp.zeros_like(self._last_keypoints),
|
|
keypoints=self._last_keypoints,
|
|
)
|
|
else:
|
|
return TrackingPrediction(
|
|
velocity=self._last_velocity,
|
|
keypoints=self._last_keypoints,
|
|
)
|
|
|
|
def reset(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
|
|
self._last_keypoints = keypoints
|
|
self._last_timestamp = timestamp
|
|
self._last_velocity = None
|
|
|
|
|
|
@jaxtyped(typechecker=beartype)
|
|
@dataclass(frozen=True)
|
|
class Tracking:
|
|
id: int
|
|
"""
|
|
The tracking id
|
|
"""
|
|
keypoints: Float[Array, "J 3"]
|
|
"""
|
|
The 3D keypoints of the tracking
|
|
|
|
Used for calculate affinity 3D
|
|
"""
|
|
last_active_timestamp: datetime
|
|
"""
|
|
The last active timestamp of the tracking
|
|
"""
|
|
|
|
historical_detections: PVector[Detection]
|
|
"""
|
|
Historical detections of the tracking.
|
|
|
|
Used for 3D re-triangulation
|
|
"""
|
|
|
|
velocity_filter: GenericVelocityFilter
|
|
"""
|
|
The velocity filter of the tracking
|
|
"""
|
|
|
|
def __repr__(self) -> str:
|
|
return f"Tracking({self.id}, {self.last_active_timestamp})"
|
|
|
|
@overload
|
|
def predict(self, time: float) -> Float[Array, "J 3"]:
|
|
"""
|
|
predict the keypoints at a given time
|
|
|
|
Args:
|
|
time: the time in seconds to predict the keypoints
|
|
|
|
Returns:
|
|
the predicted keypoints
|
|
"""
|
|
... # pylint: disable=unnecessary-ellipsis
|
|
|
|
@overload
|
|
def predict(self, time: timedelta) -> Float[Array, "J 3"]:
|
|
"""
|
|
predict the keypoints at a given time
|
|
|
|
Args:
|
|
time: the time delta to predict the keypoints
|
|
"""
|
|
... # pylint: disable=unnecessary-ellipsis
|
|
|
|
@overload
|
|
def predict(self, time: datetime) -> Float[Array, "J 3"]:
|
|
"""
|
|
predict the keypoints at a given time
|
|
|
|
Args:
|
|
time: the timestamp to predict the keypoints
|
|
"""
|
|
... # pylint: disable=unnecessary-ellipsis
|
|
|
|
def predict(
|
|
self,
|
|
time: float | timedelta | datetime,
|
|
) -> Float[Array, "J 3"]:
|
|
if isinstance(time, timedelta):
|
|
timestamp = self.last_active_timestamp + time
|
|
elif isinstance(time, datetime):
|
|
timestamp = time
|
|
else:
|
|
timestamp = self.last_active_timestamp + timedelta(seconds=time)
|
|
# pylint: disable-next=unsubscriptable-object
|
|
return self.velocity_filter.predict(timestamp)["keypoints"]
|
|
|
|
@property
|
|
def velocity(self) -> Float[Array, "J 3"]:
|
|
"""
|
|
The velocity of the tracking for each keypoint
|
|
"""
|
|
# pylint: disable-next=unsubscriptable-object
|
|
return self.velocity_filter.get()["velocity"]
|
|
|
|
|
|
@jaxtyped(typechecker=beartype)
|
|
@dataclass
|
|
class AffinityResult:
|
|
"""
|
|
Result of affinity computation between trackings and detections.
|
|
"""
|
|
|
|
matrix: Float[Array, "T D"]
|
|
trackings: Sequence[Tracking]
|
|
detections: Sequence[Detection]
|
|
indices_T: Int[Array, "T"] # pylint: disable=invalid-name
|
|
indices_D: Int[Array, "D"] # pylint: disable=invalid-name
|
|
|
|
def tracking_detections(
|
|
self,
|
|
) -> Generator[tuple[float, Tracking, Detection], None, None]:
|
|
"""
|
|
iterate over the best matching trackings and detections
|
|
"""
|
|
for t, d in zip(self.indices_T, self.indices_D):
|
|
yield (
|
|
self.matrix[t, d].item(),
|
|
self.trackings[t],
|
|
self.detections[d],
|
|
)
|