diff --git a/app/tracking/__init__.py b/app/tracking/__init__.py index 41e0266..2991aee 100644 --- a/app/tracking/__init__.py +++ b/app/tracking/__init__.py @@ -1,31 +1,33 @@ +import weakref from dataclasses import dataclass -from datetime import datetime +from datetime import datetime, timedelta +from itertools import chain from typing import ( Any, Callable, Generator, Optional, + Protocol, Sequence, TypeAlias, TypedDict, TypeVar, cast, overload, - Protocol, ) -from datetime import timedelta + import jax.numpy as jnp from beartype import beartype from beartype.typing import Mapping, Sequence from jax import Array from jaxtyping import Array, Float, Int, jaxtyped from pyrsistent import PVector, v -from itertools import chain + from app.camera import Detection class TrackingPrediction(TypedDict): - velocity: Float[Array, "J 3"] + velocity: Optional[Float[Array, "J 3"]] keypoints: Float[Array, "J 3"] @@ -68,6 +70,31 @@ class GenericVelocityFilter(Protocol): ... # pylint: disable=unnecessary-ellipsis +class DummyVelocityFilter(GenericVelocityFilter): + """ + a dummy velocity filter that does nothing + """ + + _keypoints_shape: tuple[int, ...] + + def __init__(self, keypoints: Float[Array, "J 3"]): + self._keypoints_shape = keypoints.shape + + def predict(self, timestamp: datetime) -> TrackingPrediction: + return TrackingPrediction( + velocity=None, + keypoints=jnp.zeros(self._keypoints_shape), + ) + + def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: ... + + def get(self) -> TrackingPrediction: + return TrackingPrediction( + velocity=None, + keypoints=jnp.zeros(self._keypoints_shape), + ) + + class LastDifferenceVelocityFilter(GenericVelocityFilter): """ a naive velocity filter that uses the last difference of keypoints @@ -85,7 +112,7 @@ class LastDifferenceVelocityFilter(GenericVelocityFilter): delta_t_s = (timestamp - self._last_timestamp).total_seconds() if self._last_velocity is None: return TrackingPrediction( - velocity=jnp.zeros_like(self._last_keypoints), + velocity=None, keypoints=self._last_keypoints, ) else: @@ -103,7 +130,7 @@ class LastDifferenceVelocityFilter(GenericVelocityFilter): def get(self) -> TrackingPrediction: if self._last_velocity is None: return TrackingPrediction( - velocity=jnp.zeros_like(self._last_keypoints), + velocity=None, keypoints=self._last_keypoints, ) else: @@ -126,33 +153,42 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter): """ _velocity: Optional[Float[Array, "J 3"]] = None - @staticmethod - def from_tracking(tracking: "Tracking") -> "LeastMeanSquareVelocityFilter": - """ - create a LeastMeanSquareVelocityFilter from a Tracking object - """ - velocity = tracking.velocity_filter.get()["velocity"] - if jnp.all(velocity == jnp.zeros_like(velocity)): - return LeastMeanSquareVelocityFilter( - get_historical_detections=lambda: tracking.historical_detections - ) - else: - f = LeastMeanSquareVelocityFilter( - get_historical_detections=lambda: tracking.historical_detections - ) - # pylint: disable-next=protected-access - f._velocity = velocity - return f - def __init__(self, get_historical_detections: Callable[[], Sequence[Detection]]): self._get_historical_detections = get_historical_detections self._velocity = None - @property - def velocity(self) -> Float[Array, "J 3"]: - if self._velocity is None: - raise ValueError("Velocity not initialized") - return self._velocity + @staticmethod + def from_tracking(tracking: "Tracking") -> "LeastMeanSquareVelocityFilter": + """ + create a LeastMeanSquareVelocityFilter from a Tracking object + + Note that this function is using a weak reference to the tracking object, + so that the tracking object can be garbage collected if there are no other + references to it. + """ + # Create a weak reference to avoid circular references + # https://docs.python.org/3/library/weakref.html + tracking_ref = weakref.ref(tracking) + + # Create a getter function that uses the weak reference + def get_historical_detections() -> Sequence[Detection]: + tr = tracking_ref() + if tr is None: + return [] # Return empty list if tracking has been garbage collected + return tr.state.historical_detections + + velocity = tracking.velocity_filter.get()["velocity"] + if velocity is None: + return LeastMeanSquareVelocityFilter( + get_historical_detections=get_historical_detections + ) + else: + f = LeastMeanSquareVelocityFilter( + get_historical_detections=get_historical_detections + ) + # pylint: disable-next=protected-access + f._velocity = velocity + return f def predict(self, timestamp: datetime) -> TrackingPrediction: historical_detections = self._get_historical_detections() @@ -168,7 +204,8 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter): if self._velocity is None: return TrackingPrediction( - velocity=jnp.zeros_like(latest_keypoints), keypoints=latest_keypoints + velocity=None, + keypoints=latest_keypoints, ) else: # Linear motion model: ẋt = xt' + Vt' · (t - t') @@ -252,9 +289,7 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter): latest_keypoints = latest_detection.keypoints if self._velocity is None: - return TrackingPrediction( - velocity=jnp.zeros_like(latest_keypoints), keypoints=latest_keypoints - ) + return TrackingPrediction(velocity=None, keypoints=latest_keypoints) else: return TrackingPrediction( velocity=self._velocity, keypoints=latest_keypoints @@ -263,11 +298,11 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter): @jaxtyped(typechecker=beartype) @dataclass(frozen=True) -class Tracking: - id: int +class TrackingState: """ - The tracking id + immutable state of a tracking """ + keypoints: Float[Array, "J 3"] """ The 3D keypoints of the tracking @@ -286,13 +321,24 @@ class Tracking: Used for 3D re-triangulation """ + +class Tracking: + id: int + state: TrackingState velocity_filter: GenericVelocityFilter - """ - The velocity filter of the tracking - """ + + def __init__( + self, + id: int, + state: TrackingState, + velocity_filter: Optional[GenericVelocityFilter] = None, + ): + self.id = id + self.state = state + self.velocity_filter = velocity_filter or DummyVelocityFilter(state.keypoints) def __repr__(self) -> str: - return f"Tracking({self.id}, {self.last_active_timestamp})" + return f"Tracking({self.id}, {self.state.last_active_timestamp})" @overload def predict(self, time: float) -> Float[Array, "J 3"]: @@ -332,11 +378,11 @@ class Tracking: time: float | timedelta | datetime, ) -> Float[Array, "J 3"]: if isinstance(time, timedelta): - timestamp = self.last_active_timestamp + time + timestamp = self.state.last_active_timestamp + time elif isinstance(time, datetime): timestamp = time else: - timestamp = self.last_active_timestamp + timedelta(seconds=time) + timestamp = self.state.last_active_timestamp + timedelta(seconds=time) # pylint: disable-next=unsubscriptable-object return self.velocity_filter.predict(timestamp)["keypoints"] @@ -346,7 +392,10 @@ class Tracking: The velocity of the tracking for each keypoint """ # pylint: disable-next=unsubscriptable-object - return self.velocity_filter.get()["velocity"] + if (vel := self.velocity_filter.get()["velocity"]) is None: + raise ValueError("Velocity is not available") + else: + return vel @jaxtyped(typechecker=beartype) diff --git a/playground.py b/playground.py index fa19480..32eba46 100644 --- a/playground.py +++ b/playground.py @@ -58,7 +58,12 @@ from app.camera import ( classify_by_camera, ) from app.solver._old import GLPKSolver -from app.tracking import AffinityResult, LastDifferenceVelocityFilter, Tracking +from app.tracking import ( + AffinityResult, + LastDifferenceVelocityFilter, + Tracking, + TrackingState, +) from app.visualize.whole_body import visualize_whole_body NDArray: TypeAlias = np.ndarray @@ -543,11 +548,14 @@ class GlobalTrackingState: def add_tracking(self, cluster: Sequence[Detection]) -> Tracking: kps_3d, latest_timestamp = triangle_from_cluster(cluster) next_id = self._last_id + 1 - tracking = Tracking( - id=next_id, + tracking_state = TrackingState( keypoints=kps_3d, last_active_timestamp=latest_timestamp, historical_detections=v(*cluster), + ) + tracking = Tracking( + id=next_id, + state=tracking_state, velocity_filter=LastDifferenceVelocityFilter(kps_3d, latest_timestamp), ) self._trackings[next_id] = tracking @@ -753,12 +761,12 @@ def calculate_tracking_detection_affinity( Combined affinity score """ camera = detection.camera - delta_t_raw = detection.timestamp - tracking.last_active_timestamp + delta_t_raw = detection.timestamp - tracking.state.last_active_timestamp # Clamp delta_t to avoid division-by-zero / exploding affinity. delta_t = max(delta_t_raw, DELTA_T_MIN) # Calculate 2D affinity - tracking_2d_projection = camera.project(tracking.keypoints) + tracking_2d_projection = camera.project(tracking.state.keypoints) w, h = camera.params.image_size distance_2d = calculate_distance_2d( tracking_2d_projection, @@ -838,7 +846,7 @@ def calculate_camera_affinity_matrix_jax( # === Tracking-side tensors === kps3d_trk: Float[Array, "T J 3"] = jnp.stack( - [trk.keypoints for trk in trackings] + [trk.state.keypoints for trk in trackings] ) # (T, J, 3) J = kps3d_trk.shape[1] # === Detection-side tensors === @@ -855,12 +863,12 @@ def calculate_camera_affinity_matrix_jax( # --- timestamps ---------- t0 = min( chain( - (trk.last_active_timestamp for trk in trackings), + (trk.state.last_active_timestamp for trk in trackings), (det.timestamp for det in camera_detections), ) ).timestamp() # common origin (float) ts_trk = jnp.array( - [trk.last_active_timestamp.timestamp() - t0 for trk in trackings], + [trk.state.last_active_timestamp.timestamp() - t0 for trk in trackings], dtype=jnp.float32, # now small, ms-scale fits in fp32 ) ts_det = jnp.array( @@ -1032,7 +1040,5 @@ display(affinities) # %% def update_tracking(tracking: Tracking, detection: Detection): - delta_t_ = detection.timestamp - tracking.last_active_timestamp - delta_t = max(delta_t_, DELTA_T_MIN) - - return tracking + delta_t_ = detection.timestamp - tracking.state.last_active_timestamp + raise NotImplementedError