Files
CVTH3PE/app/tracking/__init__.py
crosstyan 46b8518a10 refactor: Clean up and enhance LeastMeanSquareVelocityFilter implementation
- Removed unused `reset` methods from `GenericVelocityFilter` and `LastDifferenceVelocityFilter` classes to streamline the code.
- Added a static method `from_tracking` to `LeastMeanSquareVelocityFilter` for creating instances from a `Tracking` object.
- Implemented robust error handling in the `predict` and `get` methods to ensure proper functioning with historical detections.
- Enhanced the `update` method to utilize least squares for velocity estimation, improving accuracy in tracking predictions.
- Updated class documentation to reflect changes and clarify method purposes.
2025-05-02 12:06:05 +08:00

377 lines
12 KiB
Python

from dataclasses import dataclass
from datetime import datetime
from typing import (
Any,
Callable,
Generator,
Optional,
Sequence,
TypeAlias,
TypedDict,
TypeVar,
cast,
overload,
Protocol,
)
from datetime import timedelta
import jax.numpy as jnp
from beartype import beartype
from beartype.typing import Mapping, Sequence
from jax import Array
from jaxtyping import Array, Float, Int, jaxtyped
from pyrsistent import PVector, v
from itertools import chain
from app.camera import Detection
class TrackingPrediction(TypedDict):
velocity: Float[Array, "J 3"]
keypoints: Float[Array, "J 3"]
class GenericVelocityFilter(Protocol):
"""
a filter interface for tracking velocity estimation
"""
def predict(self, timestamp: datetime) -> TrackingPrediction:
"""
predict the velocity and the keypoints location
Args:
timestamp: timestamp of the prediction
Returns:
velocity: velocity of the tracking
keypoints: keypoints of the tracking
"""
... # pylint: disable=unnecessary-ellipsis
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
"""
update the filter state with new measurements
Args:
keypoints: new measurements
timestamp: timestamp of the update
"""
... # pylint: disable=unnecessary-ellipsis
def get(self) -> TrackingPrediction:
"""
get the current state of the filter state
Returns:
velocity: velocity of the tracking
keypoints: keypoints of the tracking
"""
... # pylint: disable=unnecessary-ellipsis
class LastDifferenceVelocityFilter(GenericVelocityFilter):
"""
a naive velocity filter that uses the last difference of keypoints
"""
_last_timestamp: datetime
_last_keypoints: Float[Array, "J 3"]
_last_velocity: Optional[Float[Array, "J 3"]] = None
def __init__(self, keypoints: Float[Array, "J 3"], timestamp: datetime):
self._last_keypoints = keypoints
self._last_timestamp = timestamp
def predict(self, timestamp: datetime) -> TrackingPrediction:
delta_t_s = (timestamp - self._last_timestamp).total_seconds()
if self._last_velocity is None:
return TrackingPrediction(
velocity=jnp.zeros_like(self._last_keypoints),
keypoints=self._last_keypoints,
)
else:
return TrackingPrediction(
velocity=self._last_velocity,
keypoints=self._last_keypoints + self._last_velocity * delta_t_s,
)
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
delta_t_s = (timestamp - self._last_timestamp).total_seconds()
self._last_velocity = (keypoints - self._last_keypoints) / delta_t_s
self._last_keypoints = keypoints
self._last_timestamp = timestamp
def get(self) -> TrackingPrediction:
if self._last_velocity is None:
return TrackingPrediction(
velocity=jnp.zeros_like(self._last_keypoints),
keypoints=self._last_keypoints,
)
else:
return TrackingPrediction(
velocity=self._last_velocity,
keypoints=self._last_keypoints,
)
class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
"""
a velocity filter that uses the least mean square method to estimate the velocity
"""
_get_historical_detections: Callable[[], Sequence[Detection]]
"""
get the current historical detections, assuming the detections are sorted by
timestamp incrementally (i.e. index 0 is the oldest detection, index -1 is
the newest detection)
"""
_velocity: Optional[Float[Array, "J 3"]] = None
@staticmethod
def from_tracking(tracking: "Tracking") -> "LeastMeanSquareVelocityFilter":
"""
create a LeastMeanSquareVelocityFilter from a Tracking object
"""
velocity = tracking.velocity_filter.get()["velocity"]
if jnp.all(velocity == jnp.zeros_like(velocity)):
return LeastMeanSquareVelocityFilter(
get_historical_detections=lambda: tracking.historical_detections
)
else:
f = LeastMeanSquareVelocityFilter(
get_historical_detections=lambda: tracking.historical_detections
)
# pylint: disable-next=protected-access
f._velocity = velocity
return f
def __init__(self, get_historical_detections: Callable[[], Sequence[Detection]]):
self._get_historical_detections = get_historical_detections
self._velocity = None
@property
def velocity(self) -> Float[Array, "J 3"]:
if self._velocity is None:
raise ValueError("Velocity not initialized")
return self._velocity
def predict(self, timestamp: datetime) -> TrackingPrediction:
historical_detections = self._get_historical_detections()
if not historical_detections:
raise ValueError("No historical detections available for prediction")
# Use the latest historical detection
latest_detection = historical_detections[-1]
latest_keypoints = latest_detection.keypoints
latest_timestamp = latest_detection.timestamp
delta_t_s = (timestamp - latest_timestamp).total_seconds()
if self._velocity is None:
return TrackingPrediction(
velocity=jnp.zeros_like(latest_keypoints), keypoints=latest_keypoints
)
else:
# Linear motion model: ẋt = xt' + Vt' · (t - t')
predicted_keypoints = latest_keypoints + self._velocity * delta_t_s
return TrackingPrediction(
velocity=self._velocity, keypoints=predicted_keypoints
)
@jaxtyped(typechecker=beartype)
def _update(
self,
keypoints: Float[Array, "N J 3"],
timestamps: Float[Array, "N"],
) -> None:
"""
update measurements with least mean square method
"""
if keypoints.shape[0] < 2:
raise ValueError("Not enough measurements to estimate velocity")
# Using least squares to fit a linear model for each joint and dimension
# X = timestamps, y = keypoints
# For each joint and each dimension, we solve for velocity
n_samples = timestamps.shape[0]
n_joints = keypoints.shape[1]
# Create design matrix for linear regression
# [t, 1] for each timestamp
X = jnp.column_stack([timestamps, jnp.ones(n_samples)])
# Reshape keypoints to solve for all joints and dimensions at once
# From [N, J, 3] to [N, J*3]
keypoints_reshaped = keypoints.reshape(n_samples, -1)
# Use JAX's lstsq to solve the least squares problem
# This is more numerically stable than manually computing pseudoinverse
coefficients, _, _, _ = jnp.linalg.lstsq(X, keypoints_reshaped, rcond=None)
# Coefficients shape is [2, J*3]
# First row: velocities, Second row: intercepts
velocities = coefficients[0].reshape(n_joints, 3)
# Update velocity
self._velocity = velocities
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
historical_detections = self._get_historical_detections()
if not historical_detections:
self._velocity = jnp.zeros_like(keypoints)
return
t_0 = min(d.timestamp for d in historical_detections)
all_keypoints = jnp.array(
list(chain((d.keypoints for d in historical_detections), (keypoints,)))
)
# Timestamps relative to t_0 (the oldest detection timestamp)
all_timestamps = jnp.array(
list(
chain(
(
(d.timestamp - t_0).total_seconds()
for d in historical_detections
),
((timestamp - t_0).total_seconds(),),
)
)
)
self._update(all_keypoints, all_timestamps)
def get(self) -> TrackingPrediction:
historical_detections = self._get_historical_detections()
if not historical_detections:
raise ValueError("No historical detections available")
latest_detection = historical_detections[-1]
latest_keypoints = latest_detection.keypoints
if self._velocity is None:
return TrackingPrediction(
velocity=jnp.zeros_like(latest_keypoints), keypoints=latest_keypoints
)
else:
return TrackingPrediction(
velocity=self._velocity, keypoints=latest_keypoints
)
@jaxtyped(typechecker=beartype)
@dataclass(frozen=True)
class Tracking:
id: int
"""
The tracking id
"""
keypoints: Float[Array, "J 3"]
"""
The 3D keypoints of the tracking
Used for calculate affinity 3D
"""
last_active_timestamp: datetime
"""
The last active timestamp of the tracking
"""
historical_detections: PVector[Detection]
"""
Historical detections of the tracking.
Used for 3D re-triangulation
"""
velocity_filter: GenericVelocityFilter
"""
The velocity filter of the tracking
"""
def __repr__(self) -> str:
return f"Tracking({self.id}, {self.last_active_timestamp})"
@overload
def predict(self, time: float) -> Float[Array, "J 3"]:
"""
predict the keypoints at a given time
Args:
time: the time in seconds to predict the keypoints
Returns:
the predicted keypoints
"""
... # pylint: disable=unnecessary-ellipsis
@overload
def predict(self, time: timedelta) -> Float[Array, "J 3"]:
"""
predict the keypoints at a given time
Args:
time: the time delta to predict the keypoints
"""
... # pylint: disable=unnecessary-ellipsis
@overload
def predict(self, time: datetime) -> Float[Array, "J 3"]:
"""
predict the keypoints at a given time
Args:
time: the timestamp to predict the keypoints
"""
... # pylint: disable=unnecessary-ellipsis
def predict(
self,
time: float | timedelta | datetime,
) -> Float[Array, "J 3"]:
if isinstance(time, timedelta):
timestamp = self.last_active_timestamp + time
elif isinstance(time, datetime):
timestamp = time
else:
timestamp = self.last_active_timestamp + timedelta(seconds=time)
# pylint: disable-next=unsubscriptable-object
return self.velocity_filter.predict(timestamp)["keypoints"]
@property
def velocity(self) -> Float[Array, "J 3"]:
"""
The velocity of the tracking for each keypoint
"""
# pylint: disable-next=unsubscriptable-object
return self.velocity_filter.get()["velocity"]
@jaxtyped(typechecker=beartype)
@dataclass
class AffinityResult:
"""
Result of affinity computation between trackings and detections.
"""
matrix: Float[Array, "T D"]
trackings: Sequence[Tracking]
detections: Sequence[Detection]
indices_T: Int[Array, "T"] # pylint: disable=invalid-name
indices_D: Int[Array, "D"] # pylint: disable=invalid-name
def tracking_detections(
self,
) -> Generator[tuple[float, Tracking, Detection], None, None]:
"""
iterate over the best matching trackings and detections
"""
for t, d in zip(self.indices_T, self.indices_D):
yield (
self.matrix[t, d].item(),
self.trackings[t],
self.detections[d],
)