forked from HQU-gxy/CVTH3PE
refactor: Enhance tracking state management and velocity filter integration
- Introduced `TrackingState` to encapsulate the state of tracking, improving data organization and immutability. - Updated the `Tracking` class to utilize `TrackingState`, enhancing clarity in state management. - Refactored methods to access keypoints and timestamps through the new state structure, ensuring consistency across the codebase. - Added a `DummyVelocityFilter` for cases where no velocity estimation is needed, improving flexibility in tracking implementations. - Cleaned up imports and improved type hints for better code organization.
This commit is contained in:
@ -1,31 +1,33 @@
|
|||||||
|
import weakref
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime, timedelta
|
||||||
|
from itertools import chain
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
Generator,
|
Generator,
|
||||||
Optional,
|
Optional,
|
||||||
|
Protocol,
|
||||||
Sequence,
|
Sequence,
|
||||||
TypeAlias,
|
TypeAlias,
|
||||||
TypedDict,
|
TypedDict,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
cast,
|
cast,
|
||||||
overload,
|
overload,
|
||||||
Protocol,
|
|
||||||
)
|
)
|
||||||
from datetime import timedelta
|
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from beartype import beartype
|
from beartype import beartype
|
||||||
from beartype.typing import Mapping, Sequence
|
from beartype.typing import Mapping, Sequence
|
||||||
from jax import Array
|
from jax import Array
|
||||||
from jaxtyping import Array, Float, Int, jaxtyped
|
from jaxtyping import Array, Float, Int, jaxtyped
|
||||||
from pyrsistent import PVector, v
|
from pyrsistent import PVector, v
|
||||||
from itertools import chain
|
|
||||||
from app.camera import Detection
|
from app.camera import Detection
|
||||||
|
|
||||||
|
|
||||||
class TrackingPrediction(TypedDict):
|
class TrackingPrediction(TypedDict):
|
||||||
velocity: Float[Array, "J 3"]
|
velocity: Optional[Float[Array, "J 3"]]
|
||||||
keypoints: Float[Array, "J 3"]
|
keypoints: Float[Array, "J 3"]
|
||||||
|
|
||||||
|
|
||||||
@ -68,6 +70,31 @@ class GenericVelocityFilter(Protocol):
|
|||||||
... # pylint: disable=unnecessary-ellipsis
|
... # pylint: disable=unnecessary-ellipsis
|
||||||
|
|
||||||
|
|
||||||
|
class DummyVelocityFilter(GenericVelocityFilter):
|
||||||
|
"""
|
||||||
|
a dummy velocity filter that does nothing
|
||||||
|
"""
|
||||||
|
|
||||||
|
_keypoints_shape: tuple[int, ...]
|
||||||
|
|
||||||
|
def __init__(self, keypoints: Float[Array, "J 3"]):
|
||||||
|
self._keypoints_shape = keypoints.shape
|
||||||
|
|
||||||
|
def predict(self, timestamp: datetime) -> TrackingPrediction:
|
||||||
|
return TrackingPrediction(
|
||||||
|
velocity=None,
|
||||||
|
keypoints=jnp.zeros(self._keypoints_shape),
|
||||||
|
)
|
||||||
|
|
||||||
|
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: ...
|
||||||
|
|
||||||
|
def get(self) -> TrackingPrediction:
|
||||||
|
return TrackingPrediction(
|
||||||
|
velocity=None,
|
||||||
|
keypoints=jnp.zeros(self._keypoints_shape),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LastDifferenceVelocityFilter(GenericVelocityFilter):
|
class LastDifferenceVelocityFilter(GenericVelocityFilter):
|
||||||
"""
|
"""
|
||||||
a naive velocity filter that uses the last difference of keypoints
|
a naive velocity filter that uses the last difference of keypoints
|
||||||
@ -85,7 +112,7 @@ class LastDifferenceVelocityFilter(GenericVelocityFilter):
|
|||||||
delta_t_s = (timestamp - self._last_timestamp).total_seconds()
|
delta_t_s = (timestamp - self._last_timestamp).total_seconds()
|
||||||
if self._last_velocity is None:
|
if self._last_velocity is None:
|
||||||
return TrackingPrediction(
|
return TrackingPrediction(
|
||||||
velocity=jnp.zeros_like(self._last_keypoints),
|
velocity=None,
|
||||||
keypoints=self._last_keypoints,
|
keypoints=self._last_keypoints,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -103,7 +130,7 @@ class LastDifferenceVelocityFilter(GenericVelocityFilter):
|
|||||||
def get(self) -> TrackingPrediction:
|
def get(self) -> TrackingPrediction:
|
||||||
if self._last_velocity is None:
|
if self._last_velocity is None:
|
||||||
return TrackingPrediction(
|
return TrackingPrediction(
|
||||||
velocity=jnp.zeros_like(self._last_keypoints),
|
velocity=None,
|
||||||
keypoints=self._last_keypoints,
|
keypoints=self._last_keypoints,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -126,33 +153,42 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
|
|||||||
"""
|
"""
|
||||||
_velocity: Optional[Float[Array, "J 3"]] = None
|
_velocity: Optional[Float[Array, "J 3"]] = None
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_tracking(tracking: "Tracking") -> "LeastMeanSquareVelocityFilter":
|
|
||||||
"""
|
|
||||||
create a LeastMeanSquareVelocityFilter from a Tracking object
|
|
||||||
"""
|
|
||||||
velocity = tracking.velocity_filter.get()["velocity"]
|
|
||||||
if jnp.all(velocity == jnp.zeros_like(velocity)):
|
|
||||||
return LeastMeanSquareVelocityFilter(
|
|
||||||
get_historical_detections=lambda: tracking.historical_detections
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
f = LeastMeanSquareVelocityFilter(
|
|
||||||
get_historical_detections=lambda: tracking.historical_detections
|
|
||||||
)
|
|
||||||
# pylint: disable-next=protected-access
|
|
||||||
f._velocity = velocity
|
|
||||||
return f
|
|
||||||
|
|
||||||
def __init__(self, get_historical_detections: Callable[[], Sequence[Detection]]):
|
def __init__(self, get_historical_detections: Callable[[], Sequence[Detection]]):
|
||||||
self._get_historical_detections = get_historical_detections
|
self._get_historical_detections = get_historical_detections
|
||||||
self._velocity = None
|
self._velocity = None
|
||||||
|
|
||||||
@property
|
@staticmethod
|
||||||
def velocity(self) -> Float[Array, "J 3"]:
|
def from_tracking(tracking: "Tracking") -> "LeastMeanSquareVelocityFilter":
|
||||||
if self._velocity is None:
|
"""
|
||||||
raise ValueError("Velocity not initialized")
|
create a LeastMeanSquareVelocityFilter from a Tracking object
|
||||||
return self._velocity
|
|
||||||
|
Note that this function is using a weak reference to the tracking object,
|
||||||
|
so that the tracking object can be garbage collected if there are no other
|
||||||
|
references to it.
|
||||||
|
"""
|
||||||
|
# Create a weak reference to avoid circular references
|
||||||
|
# https://docs.python.org/3/library/weakref.html
|
||||||
|
tracking_ref = weakref.ref(tracking)
|
||||||
|
|
||||||
|
# Create a getter function that uses the weak reference
|
||||||
|
def get_historical_detections() -> Sequence[Detection]:
|
||||||
|
tr = tracking_ref()
|
||||||
|
if tr is None:
|
||||||
|
return [] # Return empty list if tracking has been garbage collected
|
||||||
|
return tr.state.historical_detections
|
||||||
|
|
||||||
|
velocity = tracking.velocity_filter.get()["velocity"]
|
||||||
|
if velocity is None:
|
||||||
|
return LeastMeanSquareVelocityFilter(
|
||||||
|
get_historical_detections=get_historical_detections
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
f = LeastMeanSquareVelocityFilter(
|
||||||
|
get_historical_detections=get_historical_detections
|
||||||
|
)
|
||||||
|
# pylint: disable-next=protected-access
|
||||||
|
f._velocity = velocity
|
||||||
|
return f
|
||||||
|
|
||||||
def predict(self, timestamp: datetime) -> TrackingPrediction:
|
def predict(self, timestamp: datetime) -> TrackingPrediction:
|
||||||
historical_detections = self._get_historical_detections()
|
historical_detections = self._get_historical_detections()
|
||||||
@ -168,7 +204,8 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
|
|||||||
|
|
||||||
if self._velocity is None:
|
if self._velocity is None:
|
||||||
return TrackingPrediction(
|
return TrackingPrediction(
|
||||||
velocity=jnp.zeros_like(latest_keypoints), keypoints=latest_keypoints
|
velocity=None,
|
||||||
|
keypoints=latest_keypoints,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Linear motion model: ẋt = xt' + Vt' · (t - t')
|
# Linear motion model: ẋt = xt' + Vt' · (t - t')
|
||||||
@ -252,9 +289,7 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
|
|||||||
latest_keypoints = latest_detection.keypoints
|
latest_keypoints = latest_detection.keypoints
|
||||||
|
|
||||||
if self._velocity is None:
|
if self._velocity is None:
|
||||||
return TrackingPrediction(
|
return TrackingPrediction(velocity=None, keypoints=latest_keypoints)
|
||||||
velocity=jnp.zeros_like(latest_keypoints), keypoints=latest_keypoints
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return TrackingPrediction(
|
return TrackingPrediction(
|
||||||
velocity=self._velocity, keypoints=latest_keypoints
|
velocity=self._velocity, keypoints=latest_keypoints
|
||||||
@ -263,11 +298,11 @@ class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
|
|||||||
|
|
||||||
@jaxtyped(typechecker=beartype)
|
@jaxtyped(typechecker=beartype)
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class Tracking:
|
class TrackingState:
|
||||||
id: int
|
|
||||||
"""
|
"""
|
||||||
The tracking id
|
immutable state of a tracking
|
||||||
"""
|
"""
|
||||||
|
|
||||||
keypoints: Float[Array, "J 3"]
|
keypoints: Float[Array, "J 3"]
|
||||||
"""
|
"""
|
||||||
The 3D keypoints of the tracking
|
The 3D keypoints of the tracking
|
||||||
@ -286,13 +321,24 @@ class Tracking:
|
|||||||
Used for 3D re-triangulation
|
Used for 3D re-triangulation
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class Tracking:
|
||||||
|
id: int
|
||||||
|
state: TrackingState
|
||||||
velocity_filter: GenericVelocityFilter
|
velocity_filter: GenericVelocityFilter
|
||||||
"""
|
|
||||||
The velocity filter of the tracking
|
def __init__(
|
||||||
"""
|
self,
|
||||||
|
id: int,
|
||||||
|
state: TrackingState,
|
||||||
|
velocity_filter: Optional[GenericVelocityFilter] = None,
|
||||||
|
):
|
||||||
|
self.id = id
|
||||||
|
self.state = state
|
||||||
|
self.velocity_filter = velocity_filter or DummyVelocityFilter(state.keypoints)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"Tracking({self.id}, {self.last_active_timestamp})"
|
return f"Tracking({self.id}, {self.state.last_active_timestamp})"
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def predict(self, time: float) -> Float[Array, "J 3"]:
|
def predict(self, time: float) -> Float[Array, "J 3"]:
|
||||||
@ -332,11 +378,11 @@ class Tracking:
|
|||||||
time: float | timedelta | datetime,
|
time: float | timedelta | datetime,
|
||||||
) -> Float[Array, "J 3"]:
|
) -> Float[Array, "J 3"]:
|
||||||
if isinstance(time, timedelta):
|
if isinstance(time, timedelta):
|
||||||
timestamp = self.last_active_timestamp + time
|
timestamp = self.state.last_active_timestamp + time
|
||||||
elif isinstance(time, datetime):
|
elif isinstance(time, datetime):
|
||||||
timestamp = time
|
timestamp = time
|
||||||
else:
|
else:
|
||||||
timestamp = self.last_active_timestamp + timedelta(seconds=time)
|
timestamp = self.state.last_active_timestamp + timedelta(seconds=time)
|
||||||
# pylint: disable-next=unsubscriptable-object
|
# pylint: disable-next=unsubscriptable-object
|
||||||
return self.velocity_filter.predict(timestamp)["keypoints"]
|
return self.velocity_filter.predict(timestamp)["keypoints"]
|
||||||
|
|
||||||
@ -346,7 +392,10 @@ class Tracking:
|
|||||||
The velocity of the tracking for each keypoint
|
The velocity of the tracking for each keypoint
|
||||||
"""
|
"""
|
||||||
# pylint: disable-next=unsubscriptable-object
|
# pylint: disable-next=unsubscriptable-object
|
||||||
return self.velocity_filter.get()["velocity"]
|
if (vel := self.velocity_filter.get()["velocity"]) is None:
|
||||||
|
raise ValueError("Velocity is not available")
|
||||||
|
else:
|
||||||
|
return vel
|
||||||
|
|
||||||
|
|
||||||
@jaxtyped(typechecker=beartype)
|
@jaxtyped(typechecker=beartype)
|
||||||
|
|||||||
@ -58,7 +58,12 @@ from app.camera import (
|
|||||||
classify_by_camera,
|
classify_by_camera,
|
||||||
)
|
)
|
||||||
from app.solver._old import GLPKSolver
|
from app.solver._old import GLPKSolver
|
||||||
from app.tracking import AffinityResult, LastDifferenceVelocityFilter, Tracking
|
from app.tracking import (
|
||||||
|
AffinityResult,
|
||||||
|
LastDifferenceVelocityFilter,
|
||||||
|
Tracking,
|
||||||
|
TrackingState,
|
||||||
|
)
|
||||||
from app.visualize.whole_body import visualize_whole_body
|
from app.visualize.whole_body import visualize_whole_body
|
||||||
|
|
||||||
NDArray: TypeAlias = np.ndarray
|
NDArray: TypeAlias = np.ndarray
|
||||||
@ -543,11 +548,14 @@ class GlobalTrackingState:
|
|||||||
def add_tracking(self, cluster: Sequence[Detection]) -> Tracking:
|
def add_tracking(self, cluster: Sequence[Detection]) -> Tracking:
|
||||||
kps_3d, latest_timestamp = triangle_from_cluster(cluster)
|
kps_3d, latest_timestamp = triangle_from_cluster(cluster)
|
||||||
next_id = self._last_id + 1
|
next_id = self._last_id + 1
|
||||||
tracking = Tracking(
|
tracking_state = TrackingState(
|
||||||
id=next_id,
|
|
||||||
keypoints=kps_3d,
|
keypoints=kps_3d,
|
||||||
last_active_timestamp=latest_timestamp,
|
last_active_timestamp=latest_timestamp,
|
||||||
historical_detections=v(*cluster),
|
historical_detections=v(*cluster),
|
||||||
|
)
|
||||||
|
tracking = Tracking(
|
||||||
|
id=next_id,
|
||||||
|
state=tracking_state,
|
||||||
velocity_filter=LastDifferenceVelocityFilter(kps_3d, latest_timestamp),
|
velocity_filter=LastDifferenceVelocityFilter(kps_3d, latest_timestamp),
|
||||||
)
|
)
|
||||||
self._trackings[next_id] = tracking
|
self._trackings[next_id] = tracking
|
||||||
@ -753,12 +761,12 @@ def calculate_tracking_detection_affinity(
|
|||||||
Combined affinity score
|
Combined affinity score
|
||||||
"""
|
"""
|
||||||
camera = detection.camera
|
camera = detection.camera
|
||||||
delta_t_raw = detection.timestamp - tracking.last_active_timestamp
|
delta_t_raw = detection.timestamp - tracking.state.last_active_timestamp
|
||||||
# Clamp delta_t to avoid division-by-zero / exploding affinity.
|
# Clamp delta_t to avoid division-by-zero / exploding affinity.
|
||||||
delta_t = max(delta_t_raw, DELTA_T_MIN)
|
delta_t = max(delta_t_raw, DELTA_T_MIN)
|
||||||
|
|
||||||
# Calculate 2D affinity
|
# Calculate 2D affinity
|
||||||
tracking_2d_projection = camera.project(tracking.keypoints)
|
tracking_2d_projection = camera.project(tracking.state.keypoints)
|
||||||
w, h = camera.params.image_size
|
w, h = camera.params.image_size
|
||||||
distance_2d = calculate_distance_2d(
|
distance_2d = calculate_distance_2d(
|
||||||
tracking_2d_projection,
|
tracking_2d_projection,
|
||||||
@ -838,7 +846,7 @@ def calculate_camera_affinity_matrix_jax(
|
|||||||
|
|
||||||
# === Tracking-side tensors ===
|
# === Tracking-side tensors ===
|
||||||
kps3d_trk: Float[Array, "T J 3"] = jnp.stack(
|
kps3d_trk: Float[Array, "T J 3"] = jnp.stack(
|
||||||
[trk.keypoints for trk in trackings]
|
[trk.state.keypoints for trk in trackings]
|
||||||
) # (T, J, 3)
|
) # (T, J, 3)
|
||||||
J = kps3d_trk.shape[1]
|
J = kps3d_trk.shape[1]
|
||||||
# === Detection-side tensors ===
|
# === Detection-side tensors ===
|
||||||
@ -855,12 +863,12 @@ def calculate_camera_affinity_matrix_jax(
|
|||||||
# --- timestamps ----------
|
# --- timestamps ----------
|
||||||
t0 = min(
|
t0 = min(
|
||||||
chain(
|
chain(
|
||||||
(trk.last_active_timestamp for trk in trackings),
|
(trk.state.last_active_timestamp for trk in trackings),
|
||||||
(det.timestamp for det in camera_detections),
|
(det.timestamp for det in camera_detections),
|
||||||
)
|
)
|
||||||
).timestamp() # common origin (float)
|
).timestamp() # common origin (float)
|
||||||
ts_trk = jnp.array(
|
ts_trk = jnp.array(
|
||||||
[trk.last_active_timestamp.timestamp() - t0 for trk in trackings],
|
[trk.state.last_active_timestamp.timestamp() - t0 for trk in trackings],
|
||||||
dtype=jnp.float32, # now small, ms-scale fits in fp32
|
dtype=jnp.float32, # now small, ms-scale fits in fp32
|
||||||
)
|
)
|
||||||
ts_det = jnp.array(
|
ts_det = jnp.array(
|
||||||
@ -1032,7 +1040,5 @@ display(affinities)
|
|||||||
|
|
||||||
# %%
|
# %%
|
||||||
def update_tracking(tracking: Tracking, detection: Detection):
|
def update_tracking(tracking: Tracking, detection: Detection):
|
||||||
delta_t_ = detection.timestamp - tracking.last_active_timestamp
|
delta_t_ = detection.timestamp - tracking.state.last_active_timestamp
|
||||||
delta_t = max(delta_t_, DELTA_T_MIN)
|
raise NotImplementedError
|
||||||
|
|
||||||
return tracking
|
|
||||||
|
|||||||
Reference in New Issue
Block a user