feat: Add LeastMeanSquareVelocityFilter for advanced tracking velocity estimation
- Introduced a new `LeastMeanSquareVelocityFilter` class to enhance tracking velocity estimation using historical detections. - Implemented methods for updating measurements and predicting future states, laying the groundwork for advanced tracking capabilities. - Improved import organization and added necessary dependencies for the new filter functionality. - Updated class documentation to reflect the new filter's purpose and methods.
This commit is contained in:
@ -19,8 +19,8 @@ from beartype import beartype
|
|||||||
from beartype.typing import Mapping, Sequence
|
from beartype.typing import Mapping, Sequence
|
||||||
from jax import Array
|
from jax import Array
|
||||||
from jaxtyping import Array, Float, Int, jaxtyped
|
from jaxtyping import Array, Float, Int, jaxtyped
|
||||||
from pyrsistent import PVector
|
from pyrsistent import PVector, v
|
||||||
|
from itertools import chain
|
||||||
from app.camera import Detection
|
from app.camera import Detection
|
||||||
|
|
||||||
|
|
||||||
@ -80,7 +80,7 @@ class GenericVelocityFilter(Protocol):
|
|||||||
|
|
||||||
class LastDifferenceVelocityFilter(GenericVelocityFilter):
|
class LastDifferenceVelocityFilter(GenericVelocityFilter):
|
||||||
"""
|
"""
|
||||||
a velocity filter that uses the last difference of keypoints
|
a naive velocity filter that uses the last difference of keypoints
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_last_timestamp: datetime
|
_last_timestamp: datetime
|
||||||
@ -128,6 +128,54 @@ class LastDifferenceVelocityFilter(GenericVelocityFilter):
|
|||||||
self._last_velocity = None
|
self._last_velocity = None
|
||||||
|
|
||||||
|
|
||||||
|
class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
|
||||||
|
_get_historical_detections: Callable[[], Sequence[Detection]]
|
||||||
|
"""
|
||||||
|
get the current historical detections, assuming the detections are sorted by
|
||||||
|
timestamp incrementally (i.e. index 0 is the oldest detection, index -1 is
|
||||||
|
the newest detection)
|
||||||
|
"""
|
||||||
|
_velocity: Optional[Float[Array, "J 3"]] = None
|
||||||
|
|
||||||
|
def __init__(self, get_historical_detections: Callable[[], Sequence[Detection]]):
|
||||||
|
self._get_historical_detections = get_historical_detections
|
||||||
|
|
||||||
|
def predict(self, timestamp: datetime) -> TrackingPrediction:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@jaxtyped(typechecker=beartype)
|
||||||
|
def _update(
|
||||||
|
self,
|
||||||
|
keypoints: Float[Array, "N J 3"],
|
||||||
|
timestamps: Float[Array, "N"],
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
update measurements with least mean square method
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
|
||||||
|
historical_detections = self._get_historical_detections()
|
||||||
|
t_0 = min(d.timestamp for d in historical_detections)
|
||||||
|
detections = jnp.array(
|
||||||
|
chain((d.keypoints for d in historical_detections), (keypoints,))
|
||||||
|
)
|
||||||
|
# timestamps relative to t_0 (the oldest detection timestamp)
|
||||||
|
timestamps = jnp.array(
|
||||||
|
chain(
|
||||||
|
((d.timestamp - t_0).total_seconds() for d in historical_detections),
|
||||||
|
((timestamp - t_0).total_seconds(),),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get(self) -> TrackingPrediction:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def reset(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@jaxtyped(typechecker=beartype)
|
@jaxtyped(typechecker=beartype)
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class Tracking:
|
class Tracking:
|
||||||
|
|||||||
Reference in New Issue
Block a user