feat: Add LeastMeanSquareVelocityFilter for advanced tracking velocity estimation
- Introduced a new `LeastMeanSquareVelocityFilter` class to enhance tracking velocity estimation using historical detections. - Implemented methods for updating measurements and predicting future states, laying the groundwork for advanced tracking capabilities. - Improved import organization and added necessary dependencies for the new filter functionality. - Updated class documentation to reflect the new filter's purpose and methods.
This commit is contained in:
@ -19,8 +19,8 @@ from beartype import beartype
|
||||
from beartype.typing import Mapping, Sequence
|
||||
from jax import Array
|
||||
from jaxtyping import Array, Float, Int, jaxtyped
|
||||
from pyrsistent import PVector
|
||||
|
||||
from pyrsistent import PVector, v
|
||||
from itertools import chain
|
||||
from app.camera import Detection
|
||||
|
||||
|
||||
@ -80,7 +80,7 @@ class GenericVelocityFilter(Protocol):
|
||||
|
||||
class LastDifferenceVelocityFilter(GenericVelocityFilter):
|
||||
"""
|
||||
a velocity filter that uses the last difference of keypoints
|
||||
a naive velocity filter that uses the last difference of keypoints
|
||||
"""
|
||||
|
||||
_last_timestamp: datetime
|
||||
@ -128,6 +128,54 @@ class LastDifferenceVelocityFilter(GenericVelocityFilter):
|
||||
self._last_velocity = None
|
||||
|
||||
|
||||
class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
|
||||
_get_historical_detections: Callable[[], Sequence[Detection]]
|
||||
"""
|
||||
get the current historical detections, assuming the detections are sorted by
|
||||
timestamp incrementally (i.e. index 0 is the oldest detection, index -1 is
|
||||
the newest detection)
|
||||
"""
|
||||
_velocity: Optional[Float[Array, "J 3"]] = None
|
||||
|
||||
def __init__(self, get_historical_detections: Callable[[], Sequence[Detection]]):
|
||||
self._get_historical_detections = get_historical_detections
|
||||
|
||||
def predict(self, timestamp: datetime) -> TrackingPrediction:
|
||||
raise NotImplementedError
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def _update(
|
||||
self,
|
||||
keypoints: Float[Array, "N J 3"],
|
||||
timestamps: Float[Array, "N"],
|
||||
) -> None:
|
||||
"""
|
||||
update measurements with least mean square method
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
|
||||
historical_detections = self._get_historical_detections()
|
||||
t_0 = min(d.timestamp for d in historical_detections)
|
||||
detections = jnp.array(
|
||||
chain((d.keypoints for d in historical_detections), (keypoints,))
|
||||
)
|
||||
# timestamps relative to t_0 (the oldest detection timestamp)
|
||||
timestamps = jnp.array(
|
||||
chain(
|
||||
((d.timestamp - t_0).total_seconds() for d in historical_detections),
|
||||
((timestamp - t_0).total_seconds(),),
|
||||
)
|
||||
)
|
||||
raise NotImplementedError
|
||||
|
||||
def get(self) -> TrackingPrediction:
|
||||
raise NotImplementedError
|
||||
|
||||
def reset(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
@dataclass(frozen=True)
|
||||
class Tracking:
|
||||
|
||||
Reference in New Issue
Block a user