refactor: Enhance tracking state management and velocity filter integration
- Introduced `TrackingState` to encapsulate the state of tracking, improving data organization and immutability. - Updated the `Tracking` class to utilize `TrackingState`, enhancing clarity in state management. - Refactored methods to access keypoints and timestamps through the new state structure, ensuring consistency across the codebase. - Added a `DummyVelocityFilter` for cases where no velocity estimation is needed, improving flexibility in tracking implementations. - Cleaned up imports and improved type hints for better code organization.
This commit is contained in:
@ -1,31 +1,33 @@
|
||||
import weakref
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from itertools import chain
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Generator,
|
||||
Optional,
|
||||
Protocol,
|
||||
Sequence,
|
||||
TypeAlias,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
cast,
|
||||
overload,
|
||||
Protocol,
|
||||
)
|
||||
from datetime import timedelta
|
||||
|
||||
import jax.numpy as jnp
|
||||
from beartype import beartype
|
||||
from beartype.typing import Mapping, Sequence
|
||||
from jax import Array
|
||||
from jaxtyping import Array, Float, Int, jaxtyped
|
||||
from pyrsistent import PVector, v
|
||||
from itertools import chain
|
||||
|
||||
from app.camera import Detection
|
||||
|
||||
|
||||
class TrackingPrediction(TypedDict):
|
||||
velocity: Float[Array, "J 3"]
|
||||
velocity: Optional[Float[Array, "J 3"]]
|
||||
keypoints: Float[Array, "J 3"]
|
||||
|
||||
|
||||
@ -68,6 +70,31 @@ class GenericVelocityFilter(Protocol):
|
||||
... # pylint: disable=unnecessary-ellipsis
|
||||
|
||||
|
||||
class DummyVelocityFilter(GenericVelocityFilter):
|
||||
"""
|
||||
a dummy velocity filter that does nothing
|
||||
"""
|
||||
|
||||
_keypoints_shape: tuple[int, ...]
|
||||
|
||||
def __init__(self, keypoints: Float[Array, "J 3"]):
|
||||
self._keypoints_shape = keypoints.shape
|
||||
|
||||
def predict(self, timestamp: datetime) -> TrackingPrediction:
|
||||
return TrackingPrediction(
|
||||
velocity=None,
|
||||
keypoints=jnp.zeros(self._keypoints_shape),
|
||||
)
|
||||
|
||||
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: ...
|
||||
|
||||
def get(self) -> TrackingPrediction:
|
||||
return TrackingPrediction(
|
||||
velocity=None,
|
||||
keypoints=jnp.zeros(self._keypoints_shape),
|
||||
)
|
||||
|
||||
|
||||
class LastDifferenceVelocityFilter(GenericVelocityFilter):
|
||||
"""
|
||||
a naive velocity filter that uses the last difference of keypoints
|
||||
@ -85,7 +112,7 @@ class LastDifferenceVelocityFilter(GenericVelocityFilter):
|
||||
delta_t_s = (timestamp - self._last_timestamp).total_seconds()
|
||||
if self._last_velocity is None:
|
||||
return TrackingPrediction(
|
||||
velocity=jnp.zeros_like(self._last_keypoints),
|
||||
velocity=None,
|
||||
keypoints=self._last_keypoints,
|
||||
)
|
||||
else:
|
||||
@ -103,7 +130,7 @@ class LastDifferenceVelocityFilter(GenericVelocityFilter):
|
||||
def get(self) -> TrackingPrediction:
|
||||
if self._last_velocity is None:
|
||||
return TrackingPrediction(
|
||||
velocity=jnp.zeros_like(self._last_keypoints),
|
||||
velocity=None,
|
||||
keypoints=self._last_keypoints,
|
||||
)
|
||||
else:
|
||||
@ -126,33 +153,42 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
|
||||
"""
|
||||
_velocity: Optional[Float[Array, "J 3"]] = None
|
||||
|
||||
@staticmethod
|
||||
def from_tracking(tracking: "Tracking") -> "LeastMeanSquareVelocityFilter":
|
||||
"""
|
||||
create a LeastMeanSquareVelocityFilter from a Tracking object
|
||||
"""
|
||||
velocity = tracking.velocity_filter.get()["velocity"]
|
||||
if jnp.all(velocity == jnp.zeros_like(velocity)):
|
||||
return LeastMeanSquareVelocityFilter(
|
||||
get_historical_detections=lambda: tracking.historical_detections
|
||||
)
|
||||
else:
|
||||
f = LeastMeanSquareVelocityFilter(
|
||||
get_historical_detections=lambda: tracking.historical_detections
|
||||
)
|
||||
# pylint: disable-next=protected-access
|
||||
f._velocity = velocity
|
||||
return f
|
||||
|
||||
def __init__(self, get_historical_detections: Callable[[], Sequence[Detection]]):
|
||||
self._get_historical_detections = get_historical_detections
|
||||
self._velocity = None
|
||||
|
||||
@property
|
||||
def velocity(self) -> Float[Array, "J 3"]:
|
||||
if self._velocity is None:
|
||||
raise ValueError("Velocity not initialized")
|
||||
return self._velocity
|
||||
@staticmethod
|
||||
def from_tracking(tracking: "Tracking") -> "LeastMeanSquareVelocityFilter":
|
||||
"""
|
||||
create a LeastMeanSquareVelocityFilter from a Tracking object
|
||||
|
||||
Note that this function is using a weak reference to the tracking object,
|
||||
so that the tracking object can be garbage collected if there are no other
|
||||
references to it.
|
||||
"""
|
||||
# Create a weak reference to avoid circular references
|
||||
# https://docs.python.org/3/library/weakref.html
|
||||
tracking_ref = weakref.ref(tracking)
|
||||
|
||||
# Create a getter function that uses the weak reference
|
||||
def get_historical_detections() -> Sequence[Detection]:
|
||||
tr = tracking_ref()
|
||||
if tr is None:
|
||||
return [] # Return empty list if tracking has been garbage collected
|
||||
return tr.state.historical_detections
|
||||
|
||||
velocity = tracking.velocity_filter.get()["velocity"]
|
||||
if velocity is None:
|
||||
return LeastMeanSquareVelocityFilter(
|
||||
get_historical_detections=get_historical_detections
|
||||
)
|
||||
else:
|
||||
f = LeastMeanSquareVelocityFilter(
|
||||
get_historical_detections=get_historical_detections
|
||||
)
|
||||
# pylint: disable-next=protected-access
|
||||
f._velocity = velocity
|
||||
return f
|
||||
|
||||
def predict(self, timestamp: datetime) -> TrackingPrediction:
|
||||
historical_detections = self._get_historical_detections()
|
||||
@ -168,7 +204,8 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
|
||||
|
||||
if self._velocity is None:
|
||||
return TrackingPrediction(
|
||||
velocity=jnp.zeros_like(latest_keypoints), keypoints=latest_keypoints
|
||||
velocity=None,
|
||||
keypoints=latest_keypoints,
|
||||
)
|
||||
else:
|
||||
# Linear motion model: ẋt = xt' + Vt' · (t - t')
|
||||
@ -252,9 +289,7 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
|
||||
latest_keypoints = latest_detection.keypoints
|
||||
|
||||
if self._velocity is None:
|
||||
return TrackingPrediction(
|
||||
velocity=jnp.zeros_like(latest_keypoints), keypoints=latest_keypoints
|
||||
)
|
||||
return TrackingPrediction(velocity=None, keypoints=latest_keypoints)
|
||||
else:
|
||||
return TrackingPrediction(
|
||||
velocity=self._velocity, keypoints=latest_keypoints
|
||||
@ -263,11 +298,11 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
@dataclass(frozen=True)
|
||||
class Tracking:
|
||||
id: int
|
||||
class TrackingState:
|
||||
"""
|
||||
The tracking id
|
||||
immutable state of a tracking
|
||||
"""
|
||||
|
||||
keypoints: Float[Array, "J 3"]
|
||||
"""
|
||||
The 3D keypoints of the tracking
|
||||
@ -286,13 +321,24 @@ class Tracking:
|
||||
Used for 3D re-triangulation
|
||||
"""
|
||||
|
||||
|
||||
class Tracking:
|
||||
id: int
|
||||
state: TrackingState
|
||||
velocity_filter: GenericVelocityFilter
|
||||
"""
|
||||
The velocity filter of the tracking
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: int,
|
||||
state: TrackingState,
|
||||
velocity_filter: Optional[GenericVelocityFilter] = None,
|
||||
):
|
||||
self.id = id
|
||||
self.state = state
|
||||
self.velocity_filter = velocity_filter or DummyVelocityFilter(state.keypoints)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Tracking({self.id}, {self.last_active_timestamp})"
|
||||
return f"Tracking({self.id}, {self.state.last_active_timestamp})"
|
||||
|
||||
@overload
|
||||
def predict(self, time: float) -> Float[Array, "J 3"]:
|
||||
@ -332,11 +378,11 @@ class Tracking:
|
||||
time: float | timedelta | datetime,
|
||||
) -> Float[Array, "J 3"]:
|
||||
if isinstance(time, timedelta):
|
||||
timestamp = self.last_active_timestamp + time
|
||||
timestamp = self.state.last_active_timestamp + time
|
||||
elif isinstance(time, datetime):
|
||||
timestamp = time
|
||||
else:
|
||||
timestamp = self.last_active_timestamp + timedelta(seconds=time)
|
||||
timestamp = self.state.last_active_timestamp + timedelta(seconds=time)
|
||||
# pylint: disable-next=unsubscriptable-object
|
||||
return self.velocity_filter.predict(timestamp)["keypoints"]
|
||||
|
||||
@ -346,7 +392,10 @@ class Tracking:
|
||||
The velocity of the tracking for each keypoint
|
||||
"""
|
||||
# pylint: disable-next=unsubscriptable-object
|
||||
return self.velocity_filter.get()["velocity"]
|
||||
if (vel := self.velocity_filter.get()["velocity"]) is None:
|
||||
raise ValueError("Velocity is not available")
|
||||
else:
|
||||
return vel
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
|
||||
@ -58,7 +58,12 @@ from app.camera import (
|
||||
classify_by_camera,
|
||||
)
|
||||
from app.solver._old import GLPKSolver
|
||||
from app.tracking import AffinityResult, LastDifferenceVelocityFilter, Tracking
|
||||
from app.tracking import (
|
||||
AffinityResult,
|
||||
LastDifferenceVelocityFilter,
|
||||
Tracking,
|
||||
TrackingState,
|
||||
)
|
||||
from app.visualize.whole_body import visualize_whole_body
|
||||
|
||||
NDArray: TypeAlias = np.ndarray
|
||||
@ -543,11 +548,14 @@ class GlobalTrackingState:
|
||||
def add_tracking(self, cluster: Sequence[Detection]) -> Tracking:
|
||||
kps_3d, latest_timestamp = triangle_from_cluster(cluster)
|
||||
next_id = self._last_id + 1
|
||||
tracking = Tracking(
|
||||
id=next_id,
|
||||
tracking_state = TrackingState(
|
||||
keypoints=kps_3d,
|
||||
last_active_timestamp=latest_timestamp,
|
||||
historical_detections=v(*cluster),
|
||||
)
|
||||
tracking = Tracking(
|
||||
id=next_id,
|
||||
state=tracking_state,
|
||||
velocity_filter=LastDifferenceVelocityFilter(kps_3d, latest_timestamp),
|
||||
)
|
||||
self._trackings[next_id] = tracking
|
||||
@ -753,12 +761,12 @@ def calculate_tracking_detection_affinity(
|
||||
Combined affinity score
|
||||
"""
|
||||
camera = detection.camera
|
||||
delta_t_raw = detection.timestamp - tracking.last_active_timestamp
|
||||
delta_t_raw = detection.timestamp - tracking.state.last_active_timestamp
|
||||
# Clamp delta_t to avoid division-by-zero / exploding affinity.
|
||||
delta_t = max(delta_t_raw, DELTA_T_MIN)
|
||||
|
||||
# Calculate 2D affinity
|
||||
tracking_2d_projection = camera.project(tracking.keypoints)
|
||||
tracking_2d_projection = camera.project(tracking.state.keypoints)
|
||||
w, h = camera.params.image_size
|
||||
distance_2d = calculate_distance_2d(
|
||||
tracking_2d_projection,
|
||||
@ -838,7 +846,7 @@ def calculate_camera_affinity_matrix_jax(
|
||||
|
||||
# === Tracking-side tensors ===
|
||||
kps3d_trk: Float[Array, "T J 3"] = jnp.stack(
|
||||
[trk.keypoints for trk in trackings]
|
||||
[trk.state.keypoints for trk in trackings]
|
||||
) # (T, J, 3)
|
||||
J = kps3d_trk.shape[1]
|
||||
# === Detection-side tensors ===
|
||||
@ -855,12 +863,12 @@ def calculate_camera_affinity_matrix_jax(
|
||||
# --- timestamps ----------
|
||||
t0 = min(
|
||||
chain(
|
||||
(trk.last_active_timestamp for trk in trackings),
|
||||
(trk.state.last_active_timestamp for trk in trackings),
|
||||
(det.timestamp for det in camera_detections),
|
||||
)
|
||||
).timestamp() # common origin (float)
|
||||
ts_trk = jnp.array(
|
||||
[trk.last_active_timestamp.timestamp() - t0 for trk in trackings],
|
||||
[trk.state.last_active_timestamp.timestamp() - t0 for trk in trackings],
|
||||
dtype=jnp.float32, # now small, ms-scale fits in fp32
|
||||
)
|
||||
ts_det = jnp.array(
|
||||
@ -1032,7 +1040,5 @@ display(affinities)
|
||||
|
||||
# %%
|
||||
def update_tracking(tracking: Tracking, detection: Detection):
|
||||
delta_t_ = detection.timestamp - tracking.last_active_timestamp
|
||||
delta_t = max(delta_t_, DELTA_T_MIN)
|
||||
|
||||
return tracking
|
||||
delta_t_ = detection.timestamp - tracking.state.last_active_timestamp
|
||||
raise NotImplementedError
|
||||
|
||||
Reference in New Issue
Block a user