1
0
forked from HQU-gxy/CVTH3PE

refactor: Enhance tracking state management and velocity filter integration

- Introduced `TrackingState` to encapsulate the state of tracking, improving data organization and immutability.
- Updated the `Tracking` class to utilize `TrackingState`, enhancing clarity in state management.
- Refactored methods to access keypoints and timestamps through the new state structure, ensuring consistency across the codebase.
- Added a `DummyVelocityFilter` for cases where no velocity estimation is needed, improving flexibility in tracking implementations.
- Cleaned up imports and improved type hints for better code organization.
This commit is contained in:
2025-05-02 12:44:58 +08:00
parent 46b8518a10
commit c31cc4e7bf
2 changed files with 111 additions and 56 deletions

View File

@ -1,31 +1,33 @@
import weakref
from dataclasses import dataclass
from datetime import datetime
from datetime import datetime, timedelta
from itertools import chain
from typing import (
Any,
Callable,
Generator,
Optional,
Protocol,
Sequence,
TypeAlias,
TypedDict,
TypeVar,
cast,
overload,
Protocol,
)
from datetime import timedelta
import jax.numpy as jnp
from beartype import beartype
from beartype.typing import Mapping, Sequence
from jax import Array
from jaxtyping import Array, Float, Int, jaxtyped
from pyrsistent import PVector, v
from itertools import chain
from app.camera import Detection
class TrackingPrediction(TypedDict):
velocity: Float[Array, "J 3"]
velocity: Optional[Float[Array, "J 3"]]
keypoints: Float[Array, "J 3"]
@ -68,6 +70,31 @@ class GenericVelocityFilter(Protocol):
... # pylint: disable=unnecessary-ellipsis
class DummyVelocityFilter(GenericVelocityFilter):
"""
a dummy velocity filter that does nothing
"""
_keypoints_shape: tuple[int, ...]
def __init__(self, keypoints: Float[Array, "J 3"]):
self._keypoints_shape = keypoints.shape
def predict(self, timestamp: datetime) -> TrackingPrediction:
return TrackingPrediction(
velocity=None,
keypoints=jnp.zeros(self._keypoints_shape),
)
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: ...
def get(self) -> TrackingPrediction:
return TrackingPrediction(
velocity=None,
keypoints=jnp.zeros(self._keypoints_shape),
)
class LastDifferenceVelocityFilter(GenericVelocityFilter):
"""
a naive velocity filter that uses the last difference of keypoints
@ -85,7 +112,7 @@ class LastDifferenceVelocityFilter(GenericVelocityFilter):
delta_t_s = (timestamp - self._last_timestamp).total_seconds()
if self._last_velocity is None:
return TrackingPrediction(
velocity=jnp.zeros_like(self._last_keypoints),
velocity=None,
keypoints=self._last_keypoints,
)
else:
@ -103,7 +130,7 @@ class LastDifferenceVelocityFilter(GenericVelocityFilter):
def get(self) -> TrackingPrediction:
if self._last_velocity is None:
return TrackingPrediction(
velocity=jnp.zeros_like(self._last_keypoints),
velocity=None,
keypoints=self._last_keypoints,
)
else:
@ -126,33 +153,42 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
"""
_velocity: Optional[Float[Array, "J 3"]] = None
@staticmethod
def from_tracking(tracking: "Tracking") -> "LeastMeanSquareVelocityFilter":
"""
create a LeastMeanSquareVelocityFilter from a Tracking object
"""
velocity = tracking.velocity_filter.get()["velocity"]
if jnp.all(velocity == jnp.zeros_like(velocity)):
return LeastMeanSquareVelocityFilter(
get_historical_detections=lambda: tracking.historical_detections
)
else:
f = LeastMeanSquareVelocityFilter(
get_historical_detections=lambda: tracking.historical_detections
)
# pylint: disable-next=protected-access
f._velocity = velocity
return f
def __init__(self, get_historical_detections: Callable[[], Sequence[Detection]]):
self._get_historical_detections = get_historical_detections
self._velocity = None
@property
def velocity(self) -> Float[Array, "J 3"]:
if self._velocity is None:
raise ValueError("Velocity not initialized")
return self._velocity
@staticmethod
def from_tracking(tracking: "Tracking") -> "LeastMeanSquareVelocityFilter":
"""
create a LeastMeanSquareVelocityFilter from a Tracking object
Note that this function is using a weak reference to the tracking object,
so that the tracking object can be garbage collected if there are no other
references to it.
"""
# Create a weak reference to avoid circular references
# https://docs.python.org/3/library/weakref.html
tracking_ref = weakref.ref(tracking)
# Create a getter function that uses the weak reference
def get_historical_detections() -> Sequence[Detection]:
tr = tracking_ref()
if tr is None:
return [] # Return empty list if tracking has been garbage collected
return tr.state.historical_detections
velocity = tracking.velocity_filter.get()["velocity"]
if velocity is None:
return LeastMeanSquareVelocityFilter(
get_historical_detections=get_historical_detections
)
else:
f = LeastMeanSquareVelocityFilter(
get_historical_detections=get_historical_detections
)
# pylint: disable-next=protected-access
f._velocity = velocity
return f
def predict(self, timestamp: datetime) -> TrackingPrediction:
historical_detections = self._get_historical_detections()
@ -168,7 +204,8 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
if self._velocity is None:
return TrackingPrediction(
velocity=jnp.zeros_like(latest_keypoints), keypoints=latest_keypoints
velocity=None,
keypoints=latest_keypoints,
)
else:
# Linear motion model: ẋt = xt' + Vt' · (t - t')
@ -252,9 +289,7 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
latest_keypoints = latest_detection.keypoints
if self._velocity is None:
return TrackingPrediction(
velocity=jnp.zeros_like(latest_keypoints), keypoints=latest_keypoints
)
return TrackingPrediction(velocity=None, keypoints=latest_keypoints)
else:
return TrackingPrediction(
velocity=self._velocity, keypoints=latest_keypoints
@ -263,11 +298,11 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
@jaxtyped(typechecker=beartype)
@dataclass(frozen=True)
class Tracking:
id: int
class TrackingState:
"""
The tracking id
immutable state of a tracking
"""
keypoints: Float[Array, "J 3"]
"""
The 3D keypoints of the tracking
@ -286,13 +321,24 @@ class Tracking:
Used for 3D re-triangulation
"""
class Tracking:
id: int
state: TrackingState
velocity_filter: GenericVelocityFilter
"""
The velocity filter of the tracking
"""
def __init__(
self,
id: int,
state: TrackingState,
velocity_filter: Optional[GenericVelocityFilter] = None,
):
self.id = id
self.state = state
self.velocity_filter = velocity_filter or DummyVelocityFilter(state.keypoints)
def __repr__(self) -> str:
return f"Tracking({self.id}, {self.last_active_timestamp})"
return f"Tracking({self.id}, {self.state.last_active_timestamp})"
@overload
def predict(self, time: float) -> Float[Array, "J 3"]:
@ -332,11 +378,11 @@ class Tracking:
time: float | timedelta | datetime,
) -> Float[Array, "J 3"]:
if isinstance(time, timedelta):
timestamp = self.last_active_timestamp + time
timestamp = self.state.last_active_timestamp + time
elif isinstance(time, datetime):
timestamp = time
else:
timestamp = self.last_active_timestamp + timedelta(seconds=time)
timestamp = self.state.last_active_timestamp + timedelta(seconds=time)
# pylint: disable-next=unsubscriptable-object
return self.velocity_filter.predict(timestamp)["keypoints"]
@ -346,7 +392,10 @@ class Tracking:
The velocity of the tracking for each keypoint
"""
# pylint: disable-next=unsubscriptable-object
return self.velocity_filter.get()["velocity"]
if (vel := self.velocity_filter.get()["velocity"]) is None:
raise ValueError("Velocity is not available")
else:
return vel
@jaxtyped(typechecker=beartype)

View File

@ -58,7 +58,12 @@ from app.camera import (
classify_by_camera,
)
from app.solver._old import GLPKSolver
from app.tracking import AffinityResult, LastDifferenceVelocityFilter, Tracking
from app.tracking import (
AffinityResult,
LastDifferenceVelocityFilter,
Tracking,
TrackingState,
)
from app.visualize.whole_body import visualize_whole_body
NDArray: TypeAlias = np.ndarray
@ -543,11 +548,14 @@ class GlobalTrackingState:
def add_tracking(self, cluster: Sequence[Detection]) -> Tracking:
kps_3d, latest_timestamp = triangle_from_cluster(cluster)
next_id = self._last_id + 1
tracking = Tracking(
id=next_id,
tracking_state = TrackingState(
keypoints=kps_3d,
last_active_timestamp=latest_timestamp,
historical_detections=v(*cluster),
)
tracking = Tracking(
id=next_id,
state=tracking_state,
velocity_filter=LastDifferenceVelocityFilter(kps_3d, latest_timestamp),
)
self._trackings[next_id] = tracking
@ -753,12 +761,12 @@ def calculate_tracking_detection_affinity(
Combined affinity score
"""
camera = detection.camera
delta_t_raw = detection.timestamp - tracking.last_active_timestamp
delta_t_raw = detection.timestamp - tracking.state.last_active_timestamp
# Clamp delta_t to avoid division-by-zero / exploding affinity.
delta_t = max(delta_t_raw, DELTA_T_MIN)
# Calculate 2D affinity
tracking_2d_projection = camera.project(tracking.keypoints)
tracking_2d_projection = camera.project(tracking.state.keypoints)
w, h = camera.params.image_size
distance_2d = calculate_distance_2d(
tracking_2d_projection,
@ -838,7 +846,7 @@ def calculate_camera_affinity_matrix_jax(
# === Tracking-side tensors ===
kps3d_trk: Float[Array, "T J 3"] = jnp.stack(
[trk.keypoints for trk in trackings]
[trk.state.keypoints for trk in trackings]
) # (T, J, 3)
J = kps3d_trk.shape[1]
# === Detection-side tensors ===
@ -855,12 +863,12 @@ def calculate_camera_affinity_matrix_jax(
# --- timestamps ----------
t0 = min(
chain(
(trk.last_active_timestamp for trk in trackings),
(trk.state.last_active_timestamp for trk in trackings),
(det.timestamp for det in camera_detections),
)
).timestamp() # common origin (float)
ts_trk = jnp.array(
[trk.last_active_timestamp.timestamp() - t0 for trk in trackings],
[trk.state.last_active_timestamp.timestamp() - t0 for trk in trackings],
dtype=jnp.float32, # now small, ms-scale fits in fp32
)
ts_det = jnp.array(
@ -1032,7 +1040,5 @@ display(affinities)
# %%
def update_tracking(tracking: Tracking, detection: Detection):
delta_t_ = detection.timestamp - tracking.last_active_timestamp
delta_t = max(delta_t_, DELTA_T_MIN)
return tracking
delta_t_ = detection.timestamp - tracking.state.last_active_timestamp
raise NotImplementedError