forked from HQU-gxy/CVTH3PE
refactor: Enhance tracking state management and velocity filter integration
- Introduced `TrackingState` to encapsulate the state of tracking, improving data organization and immutability. - Updated the `Tracking` class to utilize `TrackingState`, enhancing clarity in state management. - Refactored methods to access keypoints and timestamps through the new state structure, ensuring consistency across the codebase. - Added a `DummyVelocityFilter` for cases where no velocity estimation is needed, improving flexibility in tracking implementations. - Cleaned up imports and improved type hints for better code organization.
This commit is contained in:
@ -1,31 +1,33 @@
|
||||
import weakref
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from itertools import chain
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Generator,
|
||||
Optional,
|
||||
Protocol,
|
||||
Sequence,
|
||||
TypeAlias,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
cast,
|
||||
overload,
|
||||
Protocol,
|
||||
)
|
||||
from datetime import timedelta
|
||||
|
||||
import jax.numpy as jnp
|
||||
from beartype import beartype
|
||||
from beartype.typing import Mapping, Sequence
|
||||
from jax import Array
|
||||
from jaxtyping import Array, Float, Int, jaxtyped
|
||||
from pyrsistent import PVector, v
|
||||
from itertools import chain
|
||||
|
||||
from app.camera import Detection
|
||||
|
||||
|
||||
class TrackingPrediction(TypedDict):
|
||||
velocity: Float[Array, "J 3"]
|
||||
velocity: Optional[Float[Array, "J 3"]]
|
||||
keypoints: Float[Array, "J 3"]
|
||||
|
||||
|
||||
@ -68,6 +70,31 @@ class GenericVelocityFilter(Protocol):
|
||||
... # pylint: disable=unnecessary-ellipsis
|
||||
|
||||
|
||||
class DummyVelocityFilter(GenericVelocityFilter):
|
||||
"""
|
||||
a dummy velocity filter that does nothing
|
||||
"""
|
||||
|
||||
_keypoints_shape: tuple[int, ...]
|
||||
|
||||
def __init__(self, keypoints: Float[Array, "J 3"]):
|
||||
self._keypoints_shape = keypoints.shape
|
||||
|
||||
def predict(self, timestamp: datetime) -> TrackingPrediction:
|
||||
return TrackingPrediction(
|
||||
velocity=None,
|
||||
keypoints=jnp.zeros(self._keypoints_shape),
|
||||
)
|
||||
|
||||
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: ...
|
||||
|
||||
def get(self) -> TrackingPrediction:
|
||||
return TrackingPrediction(
|
||||
velocity=None,
|
||||
keypoints=jnp.zeros(self._keypoints_shape),
|
||||
)
|
||||
|
||||
|
||||
class LastDifferenceVelocityFilter(GenericVelocityFilter):
|
||||
"""
|
||||
a naive velocity filter that uses the last difference of keypoints
|
||||
@ -85,7 +112,7 @@ class LastDifferenceVelocityFilter(GenericVelocityFilter):
|
||||
delta_t_s = (timestamp - self._last_timestamp).total_seconds()
|
||||
if self._last_velocity is None:
|
||||
return TrackingPrediction(
|
||||
velocity=jnp.zeros_like(self._last_keypoints),
|
||||
velocity=None,
|
||||
keypoints=self._last_keypoints,
|
||||
)
|
||||
else:
|
||||
@ -103,7 +130,7 @@ class LastDifferenceVelocityFilter(GenericVelocityFilter):
|
||||
def get(self) -> TrackingPrediction:
|
||||
if self._last_velocity is None:
|
||||
return TrackingPrediction(
|
||||
velocity=jnp.zeros_like(self._last_keypoints),
|
||||
velocity=None,
|
||||
keypoints=self._last_keypoints,
|
||||
)
|
||||
else:
|
||||
@ -126,33 +153,42 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
|
||||
"""
|
||||
_velocity: Optional[Float[Array, "J 3"]] = None
|
||||
|
||||
@staticmethod
|
||||
def from_tracking(tracking: "Tracking") -> "LeastMeanSquareVelocityFilter":
|
||||
"""
|
||||
create a LeastMeanSquareVelocityFilter from a Tracking object
|
||||
"""
|
||||
velocity = tracking.velocity_filter.get()["velocity"]
|
||||
if jnp.all(velocity == jnp.zeros_like(velocity)):
|
||||
return LeastMeanSquareVelocityFilter(
|
||||
get_historical_detections=lambda: tracking.historical_detections
|
||||
)
|
||||
else:
|
||||
f = LeastMeanSquareVelocityFilter(
|
||||
get_historical_detections=lambda: tracking.historical_detections
|
||||
)
|
||||
# pylint: disable-next=protected-access
|
||||
f._velocity = velocity
|
||||
return f
|
||||
|
||||
def __init__(self, get_historical_detections: Callable[[], Sequence[Detection]]):
|
||||
self._get_historical_detections = get_historical_detections
|
||||
self._velocity = None
|
||||
|
||||
@property
|
||||
def velocity(self) -> Float[Array, "J 3"]:
|
||||
if self._velocity is None:
|
||||
raise ValueError("Velocity not initialized")
|
||||
return self._velocity
|
||||
@staticmethod
|
||||
def from_tracking(tracking: "Tracking") -> "LeastMeanSquareVelocityFilter":
|
||||
"""
|
||||
create a LeastMeanSquareVelocityFilter from a Tracking object
|
||||
|
||||
Note that this function is using a weak reference to the tracking object,
|
||||
so that the tracking object can be garbage collected if there are no other
|
||||
references to it.
|
||||
"""
|
||||
# Create a weak reference to avoid circular references
|
||||
# https://docs.python.org/3/library/weakref.html
|
||||
tracking_ref = weakref.ref(tracking)
|
||||
|
||||
# Create a getter function that uses the weak reference
|
||||
def get_historical_detections() -> Sequence[Detection]:
|
||||
tr = tracking_ref()
|
||||
if tr is None:
|
||||
return [] # Return empty list if tracking has been garbage collected
|
||||
return tr.state.historical_detections
|
||||
|
||||
velocity = tracking.velocity_filter.get()["velocity"]
|
||||
if velocity is None:
|
||||
return LeastMeanSquareVelocityFilter(
|
||||
get_historical_detections=get_historical_detections
|
||||
)
|
||||
else:
|
||||
f = LeastMeanSquareVelocityFilter(
|
||||
get_historical_detections=get_historical_detections
|
||||
)
|
||||
# pylint: disable-next=protected-access
|
||||
f._velocity = velocity
|
||||
return f
|
||||
|
||||
def predict(self, timestamp: datetime) -> TrackingPrediction:
|
||||
historical_detections = self._get_historical_detections()
|
||||
@ -168,7 +204,8 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
|
||||
|
||||
if self._velocity is None:
|
||||
return TrackingPrediction(
|
||||
velocity=jnp.zeros_like(latest_keypoints), keypoints=latest_keypoints
|
||||
velocity=None,
|
||||
keypoints=latest_keypoints,
|
||||
)
|
||||
else:
|
||||
# Linear motion model: ẋt = xt' + Vt' · (t - t')
|
||||
@ -252,9 +289,7 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
|
||||
latest_keypoints = latest_detection.keypoints
|
||||
|
||||
if self._velocity is None:
|
||||
return TrackingPrediction(
|
||||
velocity=jnp.zeros_like(latest_keypoints), keypoints=latest_keypoints
|
||||
)
|
||||
return TrackingPrediction(velocity=None, keypoints=latest_keypoints)
|
||||
else:
|
||||
return TrackingPrediction(
|
||||
velocity=self._velocity, keypoints=latest_keypoints
|
||||
@ -263,11 +298,11 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
@dataclass(frozen=True)
|
||||
class Tracking:
|
||||
id: int
|
||||
class TrackingState:
|
||||
"""
|
||||
The tracking id
|
||||
immutable state of a tracking
|
||||
"""
|
||||
|
||||
keypoints: Float[Array, "J 3"]
|
||||
"""
|
||||
The 3D keypoints of the tracking
|
||||
@ -286,13 +321,24 @@ class Tracking:
|
||||
Used for 3D re-triangulation
|
||||
"""
|
||||
|
||||
|
||||
class Tracking:
|
||||
id: int
|
||||
state: TrackingState
|
||||
velocity_filter: GenericVelocityFilter
|
||||
"""
|
||||
The velocity filter of the tracking
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: int,
|
||||
state: TrackingState,
|
||||
velocity_filter: Optional[GenericVelocityFilter] = None,
|
||||
):
|
||||
self.id = id
|
||||
self.state = state
|
||||
self.velocity_filter = velocity_filter or DummyVelocityFilter(state.keypoints)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Tracking({self.id}, {self.last_active_timestamp})"
|
||||
return f"Tracking({self.id}, {self.state.last_active_timestamp})"
|
||||
|
||||
@overload
|
||||
def predict(self, time: float) -> Float[Array, "J 3"]:
|
||||
@ -332,11 +378,11 @@ class Tracking:
|
||||
time: float | timedelta | datetime,
|
||||
) -> Float[Array, "J 3"]:
|
||||
if isinstance(time, timedelta):
|
||||
timestamp = self.last_active_timestamp + time
|
||||
timestamp = self.state.last_active_timestamp + time
|
||||
elif isinstance(time, datetime):
|
||||
timestamp = time
|
||||
else:
|
||||
timestamp = self.last_active_timestamp + timedelta(seconds=time)
|
||||
timestamp = self.state.last_active_timestamp + timedelta(seconds=time)
|
||||
# pylint: disable-next=unsubscriptable-object
|
||||
return self.velocity_filter.predict(timestamp)["keypoints"]
|
||||
|
||||
@ -346,7 +392,10 @@ class Tracking:
|
||||
The velocity of the tracking for each keypoint
|
||||
"""
|
||||
# pylint: disable-next=unsubscriptable-object
|
||||
return self.velocity_filter.get()["velocity"]
|
||||
if (vel := self.velocity_filter.get()["velocity"]) is None:
|
||||
raise ValueError("Velocity is not available")
|
||||
else:
|
||||
return vel
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
|
||||
Reference in New Issue
Block a user