diff --git a/app/tracking/__init__.py b/app/tracking/__init__.py index 5ef8f9c..63fc84b 100644 --- a/app/tracking/__init__.py +++ b/app/tracking/__init__.py @@ -19,8 +19,8 @@ from beartype import beartype from beartype.typing import Mapping, Sequence from jax import Array from jaxtyping import Array, Float, Int, jaxtyped -from pyrsistent import PVector - +from pyrsistent import PVector, v +from itertools import chain from app.camera import Detection @@ -80,7 +80,7 @@ class GenericVelocityFilter(Protocol): class LastDifferenceVelocityFilter(GenericVelocityFilter): """ - a velocity filter that uses the last difference of keypoints + a naive velocity filter that uses the last difference of keypoints """ _last_timestamp: datetime @@ -128,6 +128,54 @@ class LastDifferenceVelocityFilter(GenericVelocityFilter): self._last_velocity = None +class LeastMeanSquareVelocityFilter(GenericVelocityFilter): + _get_historical_detections: Callable[[], Sequence[Detection]] + """ + get the current historical detections, assuming the detections are sorted by + timestamp incrementally (i.e. index 0 is the oldest detection, index -1 is + the newest detection) + """ + _velocity: Optional[Float[Array, "J 3"]] = None + + def __init__(self, get_historical_detections: Callable[[], Sequence[Detection]]): + self._get_historical_detections = get_historical_detections + + def predict(self, timestamp: datetime) -> TrackingPrediction: + raise NotImplementedError + + @jaxtyped(typechecker=beartype) + def _update( + self, + keypoints: Float[Array, "N J 3"], + timestamps: Float[Array, "N"], + ) -> None: + """ + update measurements with least mean square method + """ + raise NotImplementedError + + def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: + historical_detections = self._get_historical_detections() + t_0 = min(d.timestamp for d in historical_detections) + detections = jnp.array( + chain((d.keypoints for d in historical_detections), (keypoints,)) + ) + # timestamps relative to t_0 (the oldest detection timestamp) + timestamps = jnp.array( + chain( + ((d.timestamp - t_0).total_seconds() for d in historical_detections), + ((timestamp - t_0).total_seconds(),), + ) + ) + raise NotImplementedError + + def get(self) -> TrackingPrediction: + raise NotImplementedError + + def reset(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: + raise NotImplementedError + + @jaxtyped(typechecker=beartype) @dataclass(frozen=True) class Tracking: