Files
CVTH3PE/app/tracking/__init__.py
crosstyan c31cc4e7bf refactor: Enhance tracking state management and velocity filter integration
- Introduced `TrackingState` to encapsulate the state of tracking, improving data organization and immutability.
- Updated the `Tracking` class to utilize `TrackingState`, enhancing clarity in state management.
- Refactored methods to access keypoints and timestamps through the new state structure, ensuring consistency across the codebase.
- Added a `DummyVelocityFilter` for cases where no velocity estimation is needed, improving flexibility in tracking implementations.
- Cleaned up imports and improved type hints for better code organization.
2025-05-02 12:44:58 +08:00

426 lines
13 KiB
Python

import weakref
from dataclasses import dataclass
from datetime import datetime, timedelta
from itertools import chain
from typing import (
Any,
Callable,
Generator,
Optional,
Protocol,
Sequence,
TypeAlias,
TypedDict,
TypeVar,
cast,
overload,
)
import jax.numpy as jnp
from beartype import beartype
from beartype.typing import Mapping, Sequence
from jax import Array
from jaxtyping import Array, Float, Int, jaxtyped
from pyrsistent import PVector, v
from app.camera import Detection
class TrackingPrediction(TypedDict):
velocity: Optional[Float[Array, "J 3"]]
keypoints: Float[Array, "J 3"]
class GenericVelocityFilter(Protocol):
"""
a filter interface for tracking velocity estimation
"""
def predict(self, timestamp: datetime) -> TrackingPrediction:
"""
predict the velocity and the keypoints location
Args:
timestamp: timestamp of the prediction
Returns:
velocity: velocity of the tracking
keypoints: keypoints of the tracking
"""
... # pylint: disable=unnecessary-ellipsis
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
"""
update the filter state with new measurements
Args:
keypoints: new measurements
timestamp: timestamp of the update
"""
... # pylint: disable=unnecessary-ellipsis
def get(self) -> TrackingPrediction:
"""
get the current state of the filter state
Returns:
velocity: velocity of the tracking
keypoints: keypoints of the tracking
"""
... # pylint: disable=unnecessary-ellipsis
class DummyVelocityFilter(GenericVelocityFilter):
"""
a dummy velocity filter that does nothing
"""
_keypoints_shape: tuple[int, ...]
def __init__(self, keypoints: Float[Array, "J 3"]):
self._keypoints_shape = keypoints.shape
def predict(self, timestamp: datetime) -> TrackingPrediction:
return TrackingPrediction(
velocity=None,
keypoints=jnp.zeros(self._keypoints_shape),
)
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: ...
def get(self) -> TrackingPrediction:
return TrackingPrediction(
velocity=None,
keypoints=jnp.zeros(self._keypoints_shape),
)
class LastDifferenceVelocityFilter(GenericVelocityFilter):
"""
a naive velocity filter that uses the last difference of keypoints
"""
_last_timestamp: datetime
_last_keypoints: Float[Array, "J 3"]
_last_velocity: Optional[Float[Array, "J 3"]] = None
def __init__(self, keypoints: Float[Array, "J 3"], timestamp: datetime):
self._last_keypoints = keypoints
self._last_timestamp = timestamp
def predict(self, timestamp: datetime) -> TrackingPrediction:
delta_t_s = (timestamp - self._last_timestamp).total_seconds()
if self._last_velocity is None:
return TrackingPrediction(
velocity=None,
keypoints=self._last_keypoints,
)
else:
return TrackingPrediction(
velocity=self._last_velocity,
keypoints=self._last_keypoints + self._last_velocity * delta_t_s,
)
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
delta_t_s = (timestamp - self._last_timestamp).total_seconds()
self._last_velocity = (keypoints - self._last_keypoints) / delta_t_s
self._last_keypoints = keypoints
self._last_timestamp = timestamp
def get(self) -> TrackingPrediction:
if self._last_velocity is None:
return TrackingPrediction(
velocity=None,
keypoints=self._last_keypoints,
)
else:
return TrackingPrediction(
velocity=self._last_velocity,
keypoints=self._last_keypoints,
)
class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
"""
a velocity filter that uses the least mean square method to estimate the velocity
"""
_get_historical_detections: Callable[[], Sequence[Detection]]
"""
get the current historical detections, assuming the detections are sorted by
timestamp incrementally (i.e. index 0 is the oldest detection, index -1 is
the newest detection)
"""
_velocity: Optional[Float[Array, "J 3"]] = None
def __init__(self, get_historical_detections: Callable[[], Sequence[Detection]]):
self._get_historical_detections = get_historical_detections
self._velocity = None
@staticmethod
def from_tracking(tracking: "Tracking") -> "LeastMeanSquareVelocityFilter":
"""
create a LeastMeanSquareVelocityFilter from a Tracking object
Note that this function is using a weak reference to the tracking object,
so that the tracking object can be garbage collected if there are no other
references to it.
"""
# Create a weak reference to avoid circular references
# https://docs.python.org/3/library/weakref.html
tracking_ref = weakref.ref(tracking)
# Create a getter function that uses the weak reference
def get_historical_detections() -> Sequence[Detection]:
tr = tracking_ref()
if tr is None:
return [] # Return empty list if tracking has been garbage collected
return tr.state.historical_detections
velocity = tracking.velocity_filter.get()["velocity"]
if velocity is None:
return LeastMeanSquareVelocityFilter(
get_historical_detections=get_historical_detections
)
else:
f = LeastMeanSquareVelocityFilter(
get_historical_detections=get_historical_detections
)
# pylint: disable-next=protected-access
f._velocity = velocity
return f
def predict(self, timestamp: datetime) -> TrackingPrediction:
historical_detections = self._get_historical_detections()
if not historical_detections:
raise ValueError("No historical detections available for prediction")
# Use the latest historical detection
latest_detection = historical_detections[-1]
latest_keypoints = latest_detection.keypoints
latest_timestamp = latest_detection.timestamp
delta_t_s = (timestamp - latest_timestamp).total_seconds()
if self._velocity is None:
return TrackingPrediction(
velocity=None,
keypoints=latest_keypoints,
)
else:
# Linear motion model: ẋt = xt' + Vt' · (t - t')
predicted_keypoints = latest_keypoints + self._velocity * delta_t_s
return TrackingPrediction(
velocity=self._velocity, keypoints=predicted_keypoints
)
@jaxtyped(typechecker=beartype)
def _update(
self,
keypoints: Float[Array, "N J 3"],
timestamps: Float[Array, "N"],
) -> None:
"""
update measurements with least mean square method
"""
if keypoints.shape[0] < 2:
raise ValueError("Not enough measurements to estimate velocity")
# Using least squares to fit a linear model for each joint and dimension
# X = timestamps, y = keypoints
# For each joint and each dimension, we solve for velocity
n_samples = timestamps.shape[0]
n_joints = keypoints.shape[1]
# Create design matrix for linear regression
# [t, 1] for each timestamp
X = jnp.column_stack([timestamps, jnp.ones(n_samples)])
# Reshape keypoints to solve for all joints and dimensions at once
# From [N, J, 3] to [N, J*3]
keypoints_reshaped = keypoints.reshape(n_samples, -1)
# Use JAX's lstsq to solve the least squares problem
# This is more numerically stable than manually computing pseudoinverse
coefficients, _, _, _ = jnp.linalg.lstsq(X, keypoints_reshaped, rcond=None)
# Coefficients shape is [2, J*3]
# First row: velocities, Second row: intercepts
velocities = coefficients[0].reshape(n_joints, 3)
# Update velocity
self._velocity = velocities
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
historical_detections = self._get_historical_detections()
if not historical_detections:
self._velocity = jnp.zeros_like(keypoints)
return
t_0 = min(d.timestamp for d in historical_detections)
all_keypoints = jnp.array(
list(chain((d.keypoints for d in historical_detections), (keypoints,)))
)
# Timestamps relative to t_0 (the oldest detection timestamp)
all_timestamps = jnp.array(
list(
chain(
(
(d.timestamp - t_0).total_seconds()
for d in historical_detections
),
((timestamp - t_0).total_seconds(),),
)
)
)
self._update(all_keypoints, all_timestamps)
def get(self) -> TrackingPrediction:
historical_detections = self._get_historical_detections()
if not historical_detections:
raise ValueError("No historical detections available")
latest_detection = historical_detections[-1]
latest_keypoints = latest_detection.keypoints
if self._velocity is None:
return TrackingPrediction(velocity=None, keypoints=latest_keypoints)
else:
return TrackingPrediction(
velocity=self._velocity, keypoints=latest_keypoints
)
@jaxtyped(typechecker=beartype)
@dataclass(frozen=True)
class TrackingState:
"""
immutable state of a tracking
"""
keypoints: Float[Array, "J 3"]
"""
The 3D keypoints of the tracking
Used for calculate affinity 3D
"""
last_active_timestamp: datetime
"""
The last active timestamp of the tracking
"""
historical_detections: PVector[Detection]
"""
Historical detections of the tracking.
Used for 3D re-triangulation
"""
class Tracking:
id: int
state: TrackingState
velocity_filter: GenericVelocityFilter
def __init__(
self,
id: int,
state: TrackingState,
velocity_filter: Optional[GenericVelocityFilter] = None,
):
self.id = id
self.state = state
self.velocity_filter = velocity_filter or DummyVelocityFilter(state.keypoints)
def __repr__(self) -> str:
return f"Tracking({self.id}, {self.state.last_active_timestamp})"
@overload
def predict(self, time: float) -> Float[Array, "J 3"]:
"""
predict the keypoints at a given time
Args:
time: the time in seconds to predict the keypoints
Returns:
the predicted keypoints
"""
... # pylint: disable=unnecessary-ellipsis
@overload
def predict(self, time: timedelta) -> Float[Array, "J 3"]:
"""
predict the keypoints at a given time
Args:
time: the time delta to predict the keypoints
"""
... # pylint: disable=unnecessary-ellipsis
@overload
def predict(self, time: datetime) -> Float[Array, "J 3"]:
"""
predict the keypoints at a given time
Args:
time: the timestamp to predict the keypoints
"""
... # pylint: disable=unnecessary-ellipsis
def predict(
self,
time: float | timedelta | datetime,
) -> Float[Array, "J 3"]:
if isinstance(time, timedelta):
timestamp = self.state.last_active_timestamp + time
elif isinstance(time, datetime):
timestamp = time
else:
timestamp = self.state.last_active_timestamp + timedelta(seconds=time)
# pylint: disable-next=unsubscriptable-object
return self.velocity_filter.predict(timestamp)["keypoints"]
@property
def velocity(self) -> Float[Array, "J 3"]:
"""
The velocity of the tracking for each keypoint
"""
# pylint: disable-next=unsubscriptable-object
if (vel := self.velocity_filter.get()["velocity"]) is None:
raise ValueError("Velocity is not available")
else:
return vel
@jaxtyped(typechecker=beartype)
@dataclass
class AffinityResult:
"""
Result of affinity computation between trackings and detections.
"""
matrix: Float[Array, "T D"]
trackings: Sequence[Tracking]
detections: Sequence[Detection]
indices_T: Int[Array, "T"] # pylint: disable=invalid-name
indices_D: Int[Array, "D"] # pylint: disable=invalid-name
def tracking_detections(
self,
) -> Generator[tuple[float, Tracking, Detection], None, None]:
"""
iterate over the best matching trackings and detections
"""
for t, d in zip(self.indices_T, self.indices_D):
yield (
self.matrix[t, d].item(),
self.trackings[t],
self.detections[d],
)