feat: Add LeastMeanSquareVelocityFilter for advanced tracking velocity estimation

- Introduced a new `LeastMeanSquareVelocityFilter` class to enhance tracking velocity estimation using historical detections.
- Implemented methods for updating measurements and predicting future states, laying the groundwork for advanced tracking capabilities.
- Improved import organization and added necessary dependencies for the new filter functionality.
- Updated class documentation to reflect the new filter's purpose and methods.
This commit is contained in:
2025-05-02 11:39:01 +08:00
parent c78850855c
commit 4e78165f12

View File

@ -19,8 +19,8 @@ from beartype import beartype
from beartype.typing import Mapping, Sequence from beartype.typing import Mapping, Sequence
from jax import Array from jax import Array
from jaxtyping import Array, Float, Int, jaxtyped from jaxtyping import Array, Float, Int, jaxtyped
from pyrsistent import PVector from pyrsistent import PVector, v
from itertools import chain
from app.camera import Detection from app.camera import Detection
@ -80,7 +80,7 @@ class GenericVelocityFilter(Protocol):
class LastDifferenceVelocityFilter(GenericVelocityFilter): class LastDifferenceVelocityFilter(GenericVelocityFilter):
""" """
a velocity filter that uses the last difference of keypoints a naive velocity filter that uses the last difference of keypoints
""" """
_last_timestamp: datetime _last_timestamp: datetime
@ -128,6 +128,54 @@ class LastDifferenceVelocityFilter(GenericVelocityFilter):
self._last_velocity = None self._last_velocity = None
class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
_get_historical_detections: Callable[[], Sequence[Detection]]
"""
get the current historical detections, assuming the detections are sorted by
timestamp incrementally (i.e. index 0 is the oldest detection, index -1 is
the newest detection)
"""
_velocity: Optional[Float[Array, "J 3"]] = None
def __init__(self, get_historical_detections: Callable[[], Sequence[Detection]]):
self._get_historical_detections = get_historical_detections
def predict(self, timestamp: datetime) -> TrackingPrediction:
raise NotImplementedError
@jaxtyped(typechecker=beartype)
def _update(
self,
keypoints: Float[Array, "N J 3"],
timestamps: Float[Array, "N"],
) -> None:
"""
update measurements with least mean square method
"""
raise NotImplementedError
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
historical_detections = self._get_historical_detections()
t_0 = min(d.timestamp for d in historical_detections)
detections = jnp.array(
chain((d.keypoints for d in historical_detections), (keypoints,))
)
# timestamps relative to t_0 (the oldest detection timestamp)
timestamps = jnp.array(
chain(
((d.timestamp - t_0).total_seconds() for d in historical_detections),
((timestamp - t_0).total_seconds(),),
)
)
raise NotImplementedError
def get(self) -> TrackingPrediction:
raise NotImplementedError
def reset(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
raise NotImplementedError
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
@dataclass(frozen=True) @dataclass(frozen=True)
class Tracking: class Tracking: