578 lines
18 KiB
Python
578 lines
18 KiB
Python
import warnings
|
|
import weakref
|
|
from collections import deque
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timedelta
|
|
from itertools import chain
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
Generator,
|
|
Optional,
|
|
Protocol,
|
|
Sequence,
|
|
TypeAlias,
|
|
TypedDict,
|
|
TypeVar,
|
|
Union,
|
|
cast,
|
|
overload,
|
|
)
|
|
|
|
import jax.numpy as jnp
|
|
from beartype import beartype
|
|
from beartype.typing import Mapping, Sequence
|
|
from jax import Array
|
|
from jaxtyping import Array, Float, Int, jaxtyped
|
|
from pyrsistent import PVector, v, PRecord, PMap
|
|
|
|
from app.camera import Detection, CameraID
|
|
|
|
TrackingID: TypeAlias = int
|
|
|
|
|
|
class TrackingPrediction(TypedDict):
|
|
velocity: Optional[Float[Array, "J 3"]]
|
|
keypoints: Float[Array, "J 3"]
|
|
|
|
|
|
class GenericVelocityFilter(Protocol):
|
|
"""
|
|
a filter interface for tracking velocity estimation
|
|
"""
|
|
|
|
def predict(self, timestamp: datetime) -> TrackingPrediction:
|
|
"""
|
|
predict the velocity and the keypoints location
|
|
|
|
Args:
|
|
timestamp: timestamp of the prediction
|
|
|
|
Returns:
|
|
velocity: velocity of the tracking
|
|
keypoints: keypoints of the tracking
|
|
"""
|
|
... # pylint: disable=unnecessary-ellipsis
|
|
|
|
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
|
|
"""
|
|
update the filter state with new measurements
|
|
|
|
Args:
|
|
keypoints: new measurements
|
|
timestamp: timestamp of the update
|
|
"""
|
|
... # pylint: disable=unnecessary-ellipsis
|
|
|
|
def get(self) -> TrackingPrediction:
|
|
"""
|
|
get the current state of the filter state
|
|
|
|
Returns:
|
|
velocity: velocity of the tracking
|
|
keypoints: keypoints of the tracking
|
|
"""
|
|
... # pylint: disable=unnecessary-ellipsis
|
|
|
|
|
|
class DummyVelocityFilter(GenericVelocityFilter):
|
|
"""
|
|
a dummy velocity filter that does nothing
|
|
"""
|
|
|
|
_keypoints_shape: tuple[int, ...]
|
|
|
|
def __init__(self, keypoints: Float[Array, "J 3"]):
|
|
self._keypoints_shape = keypoints.shape
|
|
|
|
def predict(self, timestamp: datetime) -> TrackingPrediction:
|
|
return TrackingPrediction(
|
|
velocity=None,
|
|
keypoints=jnp.zeros(self._keypoints_shape),
|
|
)
|
|
|
|
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: ...
|
|
|
|
def get(self) -> TrackingPrediction:
|
|
return TrackingPrediction(
|
|
velocity=None,
|
|
keypoints=jnp.zeros(self._keypoints_shape),
|
|
)
|
|
|
|
|
|
class LastDifferenceVelocityFilter(GenericVelocityFilter):
|
|
"""
|
|
a naive velocity filter that uses the last difference of keypoints
|
|
"""
|
|
|
|
_last_timestamp: datetime
|
|
_last_keypoints: Float[Array, "J 3"]
|
|
_last_velocity: Optional[Float[Array, "J 3"]] = None
|
|
|
|
def __init__(self, keypoints: Float[Array, "J 3"], timestamp: datetime):
|
|
self._last_keypoints = keypoints
|
|
self._last_timestamp = timestamp
|
|
|
|
def predict(self, timestamp: datetime) -> TrackingPrediction:
|
|
delta_t_s = (timestamp - self._last_timestamp).total_seconds()
|
|
if delta_t_s <= 0:
|
|
warnings.warn(
|
|
"delta_t={}; last={}; current={}".format(
|
|
delta_t_s, self._last_timestamp, timestamp
|
|
)
|
|
)
|
|
if self._last_velocity is None:
|
|
return TrackingPrediction(
|
|
velocity=None,
|
|
keypoints=self._last_keypoints,
|
|
)
|
|
else:
|
|
if delta_t_s <= 0:
|
|
return TrackingPrediction(
|
|
velocity=self._last_velocity,
|
|
keypoints=self._last_keypoints,
|
|
)
|
|
return TrackingPrediction(
|
|
velocity=self._last_velocity,
|
|
keypoints=self._last_keypoints + self._last_velocity * delta_t_s,
|
|
)
|
|
|
|
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
|
|
delta_t_s = (timestamp - self._last_timestamp).total_seconds()
|
|
if delta_t_s <= 0:
|
|
pass
|
|
else:
|
|
self._last_timestamp = timestamp
|
|
self._last_velocity = (keypoints - self._last_keypoints) / delta_t_s
|
|
self._last_keypoints = keypoints
|
|
|
|
def get(self) -> TrackingPrediction:
|
|
if self._last_velocity is None:
|
|
return TrackingPrediction(
|
|
velocity=None,
|
|
keypoints=self._last_keypoints,
|
|
)
|
|
else:
|
|
return TrackingPrediction(
|
|
velocity=self._last_velocity,
|
|
keypoints=self._last_keypoints,
|
|
)
|
|
|
|
|
|
class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
|
|
"""
|
|
a velocity filter that uses the least mean square method to estimate the velocity
|
|
"""
|
|
|
|
_historical_3d_poses: deque[Float[Array, "J 3"]]
|
|
_historical_timestamps: deque[datetime]
|
|
_velocity: Optional[Float[Array, "J 3"]] = None
|
|
_max_samples: int
|
|
|
|
def __init__(
|
|
self,
|
|
historical_3d_poses: Sequence[Float[Array, "J 3"]],
|
|
historical_timestamps: Sequence[datetime],
|
|
max_samples: int = 10,
|
|
):
|
|
assert len(historical_3d_poses) == len(historical_timestamps)
|
|
temp = zip(historical_3d_poses, historical_timestamps)
|
|
temp_sorted = sorted(temp, key=lambda x: x[1])
|
|
self._historical_3d_poses = deque(
|
|
map(lambda x: x[0], temp_sorted), maxlen=max_samples
|
|
)
|
|
self._historical_timestamps = deque(
|
|
map(lambda x: x[1], temp_sorted), maxlen=max_samples
|
|
)
|
|
self._max_samples = max_samples
|
|
if len(self._historical_3d_poses) < 2:
|
|
self._velocity = None
|
|
else:
|
|
self._update(
|
|
jnp.array(self._historical_3d_poses),
|
|
jnp.array(self._historical_timestamps),
|
|
)
|
|
|
|
def predict(self, timestamp: datetime) -> TrackingPrediction:
|
|
if not self._historical_3d_poses:
|
|
raise ValueError("No historical 3D poses available for prediction")
|
|
|
|
# use the latest historical detection
|
|
latest_3d_pose = self._historical_3d_poses[-1]
|
|
latest_timestamp = self._historical_timestamps[-1]
|
|
|
|
delta_t_s = (timestamp - latest_timestamp).total_seconds()
|
|
|
|
if self._velocity is None:
|
|
return TrackingPrediction(
|
|
velocity=None,
|
|
keypoints=latest_3d_pose,
|
|
)
|
|
else:
|
|
# Linear motion model: ẋt = xt' + Vt' · (t - t')
|
|
predicted_3d_pose = latest_3d_pose + self._velocity * delta_t_s
|
|
return TrackingPrediction(
|
|
velocity=self._velocity, keypoints=predicted_3d_pose
|
|
)
|
|
|
|
@jaxtyped(typechecker=beartype)
|
|
def _update(
|
|
self,
|
|
keypoints: Float[Array, "N J 3"],
|
|
timestamps: Float[Array, "N"],
|
|
) -> None:
|
|
"""
|
|
update measurements with least mean square method
|
|
"""
|
|
if keypoints.shape[0] < 2:
|
|
raise ValueError("Not enough measurements to estimate velocity")
|
|
|
|
# Using least squares to fit a linear model for each joint and dimension
|
|
# X = timestamps, y = keypoints
|
|
# For each joint and each dimension, we solve for velocity
|
|
|
|
n_samples = timestamps.shape[0]
|
|
n_joints = keypoints.shape[1]
|
|
|
|
# Create design matrix for linear regression
|
|
# [t, 1] for each timestamp
|
|
X = jnp.column_stack([timestamps, jnp.ones(n_samples)])
|
|
|
|
# Reshape keypoints to solve for all joints and dimensions at once
|
|
# From [N, J, 3] to [N, J*3]
|
|
keypoints_reshaped = keypoints.reshape(n_samples, -1)
|
|
|
|
# Use JAX's lstsq to solve the least squares problem
|
|
# This is more numerically stable than manually computing pseudoinverse
|
|
coefficients, _, _, _ = jnp.linalg.lstsq(X, keypoints_reshaped, rcond=None)
|
|
|
|
# Coefficients shape is [2, J*3]
|
|
# First row: velocities, Second row: intercepts
|
|
velocities = coefficients[0].reshape(n_joints, 3)
|
|
|
|
# Update velocity
|
|
self._velocity = velocities
|
|
|
|
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
|
|
last_timestamp = self._historical_timestamps[-1]
|
|
assert last_timestamp <= timestamp
|
|
|
|
# deque would manage the maxlen automatically
|
|
self._historical_3d_poses.append(keypoints)
|
|
self._historical_timestamps.append(timestamp)
|
|
|
|
t_0 = self._historical_timestamps[0]
|
|
all_keypoints = jnp.array(self._historical_3d_poses)
|
|
|
|
def timestamp_to_seconds(timestamp: datetime) -> float:
|
|
assert t_0 <= timestamp
|
|
return (timestamp - t_0).total_seconds()
|
|
|
|
# timestamps relative to t_0 (the oldest detection timestamp)
|
|
all_timestamps = jnp.array(
|
|
map(timestamp_to_seconds, self._historical_timestamps)
|
|
)
|
|
|
|
self._update(all_keypoints, all_timestamps)
|
|
|
|
def get(self) -> TrackingPrediction:
|
|
if not self._historical_3d_poses:
|
|
raise ValueError("No historical 3D poses available")
|
|
|
|
latest_3d_pose = self._historical_3d_poses[-1]
|
|
|
|
if self._velocity is None:
|
|
return TrackingPrediction(velocity=None, keypoints=latest_3d_pose)
|
|
else:
|
|
return TrackingPrediction(velocity=self._velocity, keypoints=latest_3d_pose)
|
|
|
|
|
|
class OneEuroFilter(GenericVelocityFilter):
|
|
"""
|
|
Implementation of the 1€ filter (One Euro Filter) for smoothing keypoint data.
|
|
|
|
The 1€ filter is an adaptive low-pass filter that adjusts its cutoff frequency
|
|
based on movement speed to reduce jitter during slow movements while maintaining
|
|
responsiveness during fast movements.
|
|
|
|
Reference: https://cristal.univ-lille.fr/~casiez/1euro/
|
|
"""
|
|
|
|
_x_filtered: Float[Array, "J 3"]
|
|
_dx_filtered: Optional[Float[Array, "J 3"]] = None
|
|
_last_timestamp: datetime
|
|
_min_cutoff: float
|
|
_beta: float
|
|
_d_cutoff: float
|
|
|
|
def __init__(
|
|
self,
|
|
keypoints: Float[Array, "J 3"],
|
|
timestamp: datetime,
|
|
min_cutoff: float = 1.0,
|
|
beta: float = 0.0,
|
|
d_cutoff: float = 1.0,
|
|
):
|
|
"""
|
|
Initialize the One Euro Filter.
|
|
|
|
Args:
|
|
keypoints: Initial keypoints positions
|
|
timestamp: Initial timestamp
|
|
min_cutoff: Minimum cutoff frequency (lower = more smoothing)
|
|
beta: Speed coefficient (higher = less lag during fast movements)
|
|
d_cutoff: Cutoff frequency for the derivative filter
|
|
"""
|
|
self._last_timestamp = timestamp
|
|
|
|
# Filter parameters
|
|
self._min_cutoff = min_cutoff
|
|
self._beta = beta
|
|
self._d_cutoff = d_cutoff
|
|
|
|
# Filter state
|
|
self._x_filtered = keypoints # Position filter state
|
|
self._dx_filtered = None # Initially no velocity estimate
|
|
|
|
@overload
|
|
def _smoothing_factor(self, cutoff: float, dt: float) -> float: ...
|
|
|
|
@overload
|
|
def _smoothing_factor(
|
|
self, cutoff: Float[Array, "J"], dt: float
|
|
) -> Float[Array, "J"]: ...
|
|
|
|
@jaxtyped(typechecker=beartype)
|
|
def _smoothing_factor(
|
|
self, cutoff: Union[float, Float[Array, "J"]], dt: float
|
|
) -> Union[float, Float[Array, "J"]]:
|
|
"""Calculate the smoothing factor for the low-pass filter."""
|
|
r = 2 * jnp.pi * cutoff * dt
|
|
return r / (r + 1)
|
|
|
|
@jaxtyped(typechecker=beartype)
|
|
def _exponential_smoothing(
|
|
self,
|
|
a: Union[float, Float[Array, "J"]],
|
|
x: Float[Array, "J 3"],
|
|
x_prev: Float[Array, "J 3"],
|
|
) -> Float[Array, "J 3"]:
|
|
"""Apply exponential smoothing to the input."""
|
|
return a * x + (1 - a) * x_prev
|
|
|
|
def predict(self, timestamp: datetime) -> TrackingPrediction:
|
|
"""
|
|
Predict keypoints position at a given timestamp.
|
|
|
|
Args:
|
|
timestamp: Timestamp for prediction
|
|
|
|
Returns:
|
|
TrackingPrediction with velocity and keypoints
|
|
"""
|
|
dt = (timestamp - self._last_timestamp).total_seconds()
|
|
|
|
if self._dx_filtered is None:
|
|
return TrackingPrediction(
|
|
velocity=None,
|
|
keypoints=self._x_filtered,
|
|
)
|
|
else:
|
|
predicted_keypoints = self._x_filtered + self._dx_filtered * dt
|
|
return TrackingPrediction(
|
|
velocity=self._dx_filtered,
|
|
keypoints=predicted_keypoints,
|
|
)
|
|
|
|
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
|
|
"""
|
|
Update the filter with new measurements.
|
|
|
|
Args:
|
|
keypoints: New keypoint measurements
|
|
timestamp: Timestamp of the measurements
|
|
"""
|
|
dt = (timestamp - self._last_timestamp).total_seconds()
|
|
if dt <= 0:
|
|
raise ValueError(
|
|
f"new timestamp is not greater than the last timestamp; expecting: {timestamp} > {self._last_timestamp}"
|
|
)
|
|
|
|
dx = (keypoints - self._x_filtered) / dt
|
|
|
|
# Determine cutoff frequency based on movement speed
|
|
cutoff = self._min_cutoff + self._beta * jnp.linalg.norm(
|
|
dx, axis=-1, keepdims=True
|
|
)
|
|
|
|
# Apply low-pass filter to velocity
|
|
a_d = self._smoothing_factor(self._d_cutoff, dt)
|
|
self._dx_filtered = self._exponential_smoothing(
|
|
a_d,
|
|
dx,
|
|
(
|
|
jnp.zeros_like(keypoints)
|
|
if self._dx_filtered is None
|
|
else self._dx_filtered
|
|
),
|
|
)
|
|
|
|
# Apply low-pass filter to position with adaptive cutoff
|
|
a_cutoff = self._smoothing_factor(jnp.asarray(cutoff), dt)
|
|
self._x_filtered = self._exponential_smoothing(
|
|
a_cutoff, keypoints, self._x_filtered
|
|
)
|
|
|
|
# Update timestamp
|
|
self._last_timestamp = timestamp
|
|
|
|
def get(self) -> TrackingPrediction:
|
|
"""
|
|
Get the current state of the filter.
|
|
|
|
Returns:
|
|
TrackingPrediction with velocity and keypoints
|
|
"""
|
|
return TrackingPrediction(
|
|
velocity=self._dx_filtered,
|
|
keypoints=self._x_filtered,
|
|
)
|
|
|
|
|
|
@jaxtyped(typechecker=beartype)
|
|
@dataclass(frozen=True)
|
|
class TrackingState:
|
|
"""
|
|
immutable state of a tracking
|
|
"""
|
|
|
|
keypoints: Float[Array, "J 3"]
|
|
"""
|
|
The 3D keypoints of the tracking
|
|
|
|
Used for calculate affinity 3D
|
|
"""
|
|
last_active_timestamp: datetime
|
|
"""
|
|
The last active timestamp of the tracking
|
|
"""
|
|
|
|
historical_detections_by_camera: PMap[CameraID, Detection]
|
|
"""
|
|
Historical detections of the tracking.
|
|
|
|
Used for 3D re-triangulation
|
|
"""
|
|
|
|
|
|
class Tracking:
|
|
id: TrackingID
|
|
state: TrackingState
|
|
velocity_filter: GenericVelocityFilter
|
|
|
|
def __init__(
|
|
self,
|
|
id: TrackingID,
|
|
state: TrackingState,
|
|
velocity_filter: Optional[GenericVelocityFilter] = None,
|
|
):
|
|
self.id = id
|
|
self.state = state
|
|
self.velocity_filter = velocity_filter or DummyVelocityFilter(state.keypoints)
|
|
|
|
def __repr__(self) -> str:
|
|
return f"Tracking({self.id}, {self.state.last_active_timestamp})"
|
|
|
|
@overload
|
|
def predict(self, time: float) -> Float[Array, "J 3"]:
|
|
"""
|
|
predict the keypoints at a given time
|
|
|
|
Args:
|
|
time: the time in seconds to predict the keypoints
|
|
|
|
Returns:
|
|
the predicted keypoints
|
|
"""
|
|
... # pylint: disable=unnecessary-ellipsis
|
|
|
|
@overload
|
|
def predict(self, time: timedelta) -> Float[Array, "J 3"]:
|
|
"""
|
|
predict the keypoints at a given time
|
|
|
|
Args:
|
|
time: the time delta to predict the keypoints
|
|
"""
|
|
... # pylint: disable=unnecessary-ellipsis
|
|
|
|
@overload
|
|
def predict(self, time: datetime) -> Float[Array, "J 3"]:
|
|
"""
|
|
predict the keypoints at a given time
|
|
|
|
Args:
|
|
time: the timestamp to predict the keypoints
|
|
"""
|
|
... # pylint: disable=unnecessary-ellipsis
|
|
|
|
def predict(
|
|
self,
|
|
time: float | timedelta | datetime,
|
|
) -> Float[Array, "J 3"]:
|
|
if isinstance(time, timedelta):
|
|
timestamp = self.state.last_active_timestamp + time
|
|
elif isinstance(time, datetime):
|
|
timestamp = time
|
|
else:
|
|
timestamp = self.state.last_active_timestamp + timedelta(seconds=time)
|
|
# pylint: disable-next=unsubscriptable-object
|
|
return self.velocity_filter.predict(timestamp)["keypoints"]
|
|
|
|
def update(self, new_3d_pose: Float[Array, "J 3"], timestamp: datetime) -> None:
|
|
"""
|
|
update the tracking with a new 3D pose
|
|
|
|
Note:
|
|
equivalent to call `velocity_filter.update(new_3d_pose, timestamp)`
|
|
"""
|
|
self.velocity_filter.update(new_3d_pose, timestamp)
|
|
|
|
@property
|
|
def velocity(self) -> Float[Array, "J 3"]:
|
|
"""
|
|
The velocity of the tracking for each keypoint
|
|
"""
|
|
# pylint: disable-next=unsubscriptable-object
|
|
if (vel := self.velocity_filter.get()["velocity"]) is None:
|
|
return jnp.zeros_like(self.state.keypoints)
|
|
else:
|
|
return vel
|
|
|
|
|
|
@jaxtyped(typechecker=beartype)
|
|
@dataclass
|
|
class AffinityResult:
|
|
"""
|
|
Result of affinity computation between trackings and detections.
|
|
"""
|
|
|
|
matrix: Float[Array, "T D"]
|
|
trackings: Sequence[Tracking]
|
|
detections: Sequence[Detection]
|
|
indices_T: Int[Array, "A"] # pylint: disable=invalid-name
|
|
indices_D: Int[Array, "A"] # pylint: disable=invalid-name
|
|
|
|
def tracking_association(
|
|
self,
|
|
) -> Generator[tuple[float, Tracking, Detection], None, None]:
|
|
"""
|
|
iterate over the best matching trackings and detections
|
|
"""
|
|
for t, d in zip(self.indices_T, self.indices_D):
|
|
yield (
|
|
self.matrix[t, d].item(),
|
|
self.trackings[t],
|
|
self.detections[d],
|
|
)
|