feat: Add LeastMeanSquareVelocityFilter for advanced tracking velocity estimation

- Introduced a new `LeastMeanSquareVelocityFilter` class to enhance tracking velocity estimation using historical detections.
- Implemented methods for updating measurements and predicting future states, laying the groundwork for advanced tracking capabilities.
- Improved import organization and added necessary dependencies for the new filter functionality.
- Updated class documentation to reflect the new filter's purpose and methods.
This commit is contained in:
2025-05-02 11:39:01 +08:00
parent c78850855c
commit 4e78165f12

View File

@ -19,8 +19,8 @@ from beartype import beartype
from beartype.typing import Mapping, Sequence
from jax import Array
from jaxtyping import Array, Float, Int, jaxtyped
from pyrsistent import PVector
from pyrsistent import PVector, v
from itertools import chain
from app.camera import Detection
@ -80,7 +80,7 @@ class GenericVelocityFilter(Protocol):
class LastDifferenceVelocityFilter(GenericVelocityFilter):
"""
a velocity filter that uses the last difference of keypoints
a naive velocity filter that uses the last difference of keypoints
"""
_last_timestamp: datetime
@ -128,6 +128,54 @@ class LastDifferenceVelocityFilter(GenericVelocityFilter):
self._last_velocity = None
class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
_get_historical_detections: Callable[[], Sequence[Detection]]
"""
get the current historical detections, assuming the detections are sorted by
timestamp incrementally (i.e. index 0 is the oldest detection, index -1 is
the newest detection)
"""
_velocity: Optional[Float[Array, "J 3"]] = None
def __init__(self, get_historical_detections: Callable[[], Sequence[Detection]]):
self._get_historical_detections = get_historical_detections
def predict(self, timestamp: datetime) -> TrackingPrediction:
raise NotImplementedError
@jaxtyped(typechecker=beartype)
def _update(
self,
keypoints: Float[Array, "N J 3"],
timestamps: Float[Array, "N"],
) -> None:
"""
update measurements with least mean square method
"""
raise NotImplementedError
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
historical_detections = self._get_historical_detections()
t_0 = min(d.timestamp for d in historical_detections)
detections = jnp.array(
chain((d.keypoints for d in historical_detections), (keypoints,))
)
# timestamps relative to t_0 (the oldest detection timestamp)
timestamps = jnp.array(
chain(
((d.timestamp - t_0).total_seconds() for d in historical_detections),
((timestamp - t_0).total_seconds(),),
)
)
raise NotImplementedError
def get(self) -> TrackingPrediction:
raise NotImplementedError
def reset(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
raise NotImplementedError
@jaxtyped(typechecker=beartype)
@dataclass(frozen=True)
class Tracking: