Files
CVTH3PE/app/tracking/__init__.py
crosstyan 4a5cfde245 feat: Add OneEuroFilter for adaptive keypoint smoothing
- Introduced the `OneEuroFilter` class to implement an adaptive low-pass filter for smoothing keypoint data, enhancing tracking stability during varying movement speeds.
- Implemented methods for initialization, prediction, and updating of keypoints, allowing for dynamic adjustment of smoothing based on movement.
- Added detailed documentation and type hints to clarify the filter's functionality and parameters.
- Improved the handling of timestamps and filtering logic to ensure accurate predictions and updates.
2025-05-03 14:58:51 +08:00

545 lines
17 KiB
Python

import weakref
from collections import deque
from dataclasses import dataclass
from datetime import datetime, timedelta
from itertools import chain
from typing import (
Any,
Callable,
Generator,
Optional,
Protocol,
Sequence,
TypeAlias,
TypedDict,
TypeVar,
cast,
overload,
Union,
)
import jax.numpy as jnp
from beartype import beartype
from beartype.typing import Mapping, Sequence
from jax import Array
from jaxtyping import Array, Float, Int, jaxtyped
from pyrsistent import PVector, v
from app.camera import Detection
class TrackingPrediction(TypedDict):
velocity: Optional[Float[Array, "J 3"]]
keypoints: Float[Array, "J 3"]
class GenericVelocityFilter(Protocol):
"""
a filter interface for tracking velocity estimation
"""
def predict(self, timestamp: datetime) -> TrackingPrediction:
"""
predict the velocity and the keypoints location
Args:
timestamp: timestamp of the prediction
Returns:
velocity: velocity of the tracking
keypoints: keypoints of the tracking
"""
... # pylint: disable=unnecessary-ellipsis
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
"""
update the filter state with new measurements
Args:
keypoints: new measurements
timestamp: timestamp of the update
"""
... # pylint: disable=unnecessary-ellipsis
def get(self) -> TrackingPrediction:
"""
get the current state of the filter state
Returns:
velocity: velocity of the tracking
keypoints: keypoints of the tracking
"""
... # pylint: disable=unnecessary-ellipsis
class DummyVelocityFilter(GenericVelocityFilter):
"""
a dummy velocity filter that does nothing
"""
_keypoints_shape: tuple[int, ...]
def __init__(self, keypoints: Float[Array, "J 3"]):
self._keypoints_shape = keypoints.shape
def predict(self, timestamp: datetime) -> TrackingPrediction:
return TrackingPrediction(
velocity=None,
keypoints=jnp.zeros(self._keypoints_shape),
)
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: ...
def get(self) -> TrackingPrediction:
return TrackingPrediction(
velocity=None,
keypoints=jnp.zeros(self._keypoints_shape),
)
class LastDifferenceVelocityFilter(GenericVelocityFilter):
"""
a naive velocity filter that uses the last difference of keypoints
"""
_last_timestamp: datetime
_last_keypoints: Float[Array, "J 3"]
_last_velocity: Optional[Float[Array, "J 3"]] = None
def __init__(self, keypoints: Float[Array, "J 3"], timestamp: datetime):
self._last_keypoints = keypoints
self._last_timestamp = timestamp
def predict(self, timestamp: datetime) -> TrackingPrediction:
delta_t_s = (timestamp - self._last_timestamp).total_seconds()
if self._last_velocity is None:
return TrackingPrediction(
velocity=None,
keypoints=self._last_keypoints,
)
else:
return TrackingPrediction(
velocity=self._last_velocity,
keypoints=self._last_keypoints + self._last_velocity * delta_t_s,
)
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
delta_t_s = (timestamp - self._last_timestamp).total_seconds()
self._last_velocity = (keypoints - self._last_keypoints) / delta_t_s
self._last_keypoints = keypoints
self._last_timestamp = timestamp
def get(self) -> TrackingPrediction:
if self._last_velocity is None:
return TrackingPrediction(
velocity=None,
keypoints=self._last_keypoints,
)
else:
return TrackingPrediction(
velocity=self._last_velocity,
keypoints=self._last_keypoints,
)
class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
"""
a velocity filter that uses the least mean square method to estimate the velocity
"""
_historical_3d_poses: deque[Float[Array, "J 3"]]
_historical_timestamps: deque[datetime]
_velocity: Optional[Float[Array, "J 3"]] = None
_max_samples: int
def __init__(
self,
historical_3d_poses: Sequence[Float[Array, "J 3"]],
historical_timestamps: Sequence[datetime],
max_samples: int = 10,
):
assert len(historical_3d_poses) == len(historical_timestamps)
temp = zip(historical_3d_poses, historical_timestamps)
temp_sorted = sorted(temp, key=lambda x: x[1])
self._historical_3d_poses = deque(
map(lambda x: x[0], temp_sorted), maxlen=max_samples
)
self._historical_timestamps = deque(
map(lambda x: x[1], temp_sorted), maxlen=max_samples
)
self._max_samples = max_samples
if len(self._historical_3d_poses) < 2:
self._velocity = None
else:
self._update(
jnp.array(self._historical_3d_poses),
jnp.array(self._historical_timestamps),
)
def predict(self, timestamp: datetime) -> TrackingPrediction:
if not self._historical_3d_poses:
raise ValueError("No historical 3D poses available for prediction")
# use the latest historical detection
latest_3d_pose = self._historical_3d_poses[-1]
latest_timestamp = self._historical_timestamps[-1]
delta_t_s = (timestamp - latest_timestamp).total_seconds()
if self._velocity is None:
return TrackingPrediction(
velocity=None,
keypoints=latest_3d_pose,
)
else:
# Linear motion model: ẋt = xt' + Vt' · (t - t')
predicted_3d_pose = latest_3d_pose + self._velocity * delta_t_s
return TrackingPrediction(
velocity=self._velocity, keypoints=predicted_3d_pose
)
@jaxtyped(typechecker=beartype)
def _update(
self,
keypoints: Float[Array, "N J 3"],
timestamps: Float[Array, "N"],
) -> None:
"""
update measurements with least mean square method
"""
if keypoints.shape[0] < 2:
raise ValueError("Not enough measurements to estimate velocity")
# Using least squares to fit a linear model for each joint and dimension
# X = timestamps, y = keypoints
# For each joint and each dimension, we solve for velocity
n_samples = timestamps.shape[0]
n_joints = keypoints.shape[1]
# Create design matrix for linear regression
# [t, 1] for each timestamp
X = jnp.column_stack([timestamps, jnp.ones(n_samples)])
# Reshape keypoints to solve for all joints and dimensions at once
# From [N, J, 3] to [N, J*3]
keypoints_reshaped = keypoints.reshape(n_samples, -1)
# Use JAX's lstsq to solve the least squares problem
# This is more numerically stable than manually computing pseudoinverse
coefficients, _, _, _ = jnp.linalg.lstsq(X, keypoints_reshaped, rcond=None)
# Coefficients shape is [2, J*3]
# First row: velocities, Second row: intercepts
velocities = coefficients[0].reshape(n_joints, 3)
# Update velocity
self._velocity = velocities
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
last_timestamp = self._historical_timestamps[-1]
assert last_timestamp <= timestamp
# deque would manage the maxlen automatically
self._historical_3d_poses.append(keypoints)
self._historical_timestamps.append(timestamp)
t_0 = self._historical_timestamps[0]
all_keypoints = jnp.array(self._historical_3d_poses)
def timestamp_to_seconds(timestamp: datetime) -> float:
assert t_0 <= timestamp
return (timestamp - t_0).total_seconds()
# timestamps relative to t_0 (the oldest detection timestamp)
all_timestamps = jnp.array(
map(timestamp_to_seconds, self._historical_timestamps)
)
self._update(all_keypoints, all_timestamps)
def get(self) -> TrackingPrediction:
if not self._historical_3d_poses:
raise ValueError("No historical 3D poses available")
latest_3d_pose = self._historical_3d_poses[-1]
if self._velocity is None:
return TrackingPrediction(velocity=None, keypoints=latest_3d_pose)
else:
return TrackingPrediction(velocity=self._velocity, keypoints=latest_3d_pose)
class OneEuroFilter(GenericVelocityFilter):
"""
Implementation of the 1€ filter (One Euro Filter) for smoothing keypoint data.
The 1€ filter is an adaptive low-pass filter that adjusts its cutoff frequency
based on movement speed to reduce jitter during slow movements while maintaining
responsiveness during fast movements.
Reference: https://cristal.univ-lille.fr/~casiez/1euro/
"""
_x_filtered: Float[Array, "J 3"]
_dx_filtered: Optional[Float[Array, "J 3"]] = None
_last_timestamp: datetime
_min_cutoff: float
_beta: float
_d_cutoff: float
def __init__(
self,
keypoints: Float[Array, "J 3"],
timestamp: datetime,
min_cutoff: float = 1.0,
beta: float = 0.0,
d_cutoff: float = 1.0,
):
"""
Initialize the One Euro Filter.
Args:
keypoints: Initial keypoints positions
timestamp: Initial timestamp
min_cutoff: Minimum cutoff frequency (lower = more smoothing)
beta: Speed coefficient (higher = less lag during fast movements)
d_cutoff: Cutoff frequency for the derivative filter
"""
self._last_timestamp = timestamp
# Filter parameters
self._min_cutoff = min_cutoff
self._beta = beta
self._d_cutoff = d_cutoff
# Filter state
self._x_filtered = keypoints # Position filter state
self._dx_filtered = None # Initially no velocity estimate
def _smoothing_factor(
self, cutoff: Union[float, Float[Array, "J"]], dt: float
) -> Union[float, Float[Array, "J"]]:
"""Calculate the smoothing factor for the low-pass filter."""
r = 2 * jnp.pi * cutoff * dt
return r / (r + 1)
def _exponential_smoothing(
self,
a: Union[float, Float[Array, "J"]],
x: Float[Array, "J 3"],
x_prev: Float[Array, "J 3"],
) -> Float[Array, "J 3"]:
"""Apply exponential smoothing to the input."""
return a * x + (1 - a) * x_prev
def predict(self, timestamp: datetime) -> TrackingPrediction:
"""
Predict keypoints position at a given timestamp.
Args:
timestamp: Timestamp for prediction
Returns:
TrackingPrediction with velocity and keypoints
"""
dt = (timestamp - self._last_timestamp).total_seconds()
if self._dx_filtered is None:
return TrackingPrediction(
velocity=None,
keypoints=self._x_filtered,
)
else:
predicted_keypoints = self._x_filtered + self._dx_filtered * dt
return TrackingPrediction(
velocity=self._dx_filtered,
keypoints=predicted_keypoints,
)
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
"""
Update the filter with new measurements.
Args:
keypoints: New keypoint measurements
timestamp: Timestamp of the measurements
"""
dt = (timestamp - self._last_timestamp).total_seconds()
if dt <= 0:
raise ValueError("invalid timestamp")
# Compute velocity from current input and filtered state
dx = (
(keypoints - self._x_filtered) / dt if dt > 0 else jnp.zeros_like(keypoints)
)
# Determine cutoff frequency based on movement speed
cutoff = self._min_cutoff + self._beta * jnp.linalg.norm(
dx, axis=-1, keepdims=True
)
# Apply low-pass filter to velocity
a_d = self._smoothing_factor(self._d_cutoff, dt)
self._dx_filtered = self._exponential_smoothing(
a_d,
dx,
(
jnp.zeros_like(keypoints)
if self._dx_filtered is None
else self._dx_filtered
),
)
# Apply low-pass filter to position with adaptive cutoff
a_cutoff = self._smoothing_factor(
jnp.asarray(cutoff), dt
) # Convert cutoff to scalar if needed
self._x_filtered = self._exponential_smoothing(
a_cutoff, keypoints, self._x_filtered
)
# Update timestamp
self._last_timestamp = timestamp
def get(self) -> TrackingPrediction:
"""
Get the current state of the filter.
Returns:
TrackingPrediction with velocity and keypoints
"""
return TrackingPrediction(
velocity=self._dx_filtered,
keypoints=self._x_filtered,
)
@jaxtyped(typechecker=beartype)
@dataclass(frozen=True)
class TrackingState:
"""
immutable state of a tracking
"""
keypoints: Float[Array, "J 3"]
"""
The 3D keypoints of the tracking
Used for calculate affinity 3D
"""
last_active_timestamp: datetime
"""
The last active timestamp of the tracking
"""
historical_detections: PVector[Detection]
"""
Historical detections of the tracking.
Used for 3D re-triangulation
"""
class Tracking:
id: int
state: TrackingState
velocity_filter: GenericVelocityFilter
def __init__(
self,
id: int,
state: TrackingState,
velocity_filter: Optional[GenericVelocityFilter] = None,
):
self.id = id
self.state = state
self.velocity_filter = velocity_filter or DummyVelocityFilter(state.keypoints)
def __repr__(self) -> str:
return f"Tracking({self.id}, {self.state.last_active_timestamp})"
@overload
def predict(self, time: float) -> Float[Array, "J 3"]:
"""
predict the keypoints at a given time
Args:
time: the time in seconds to predict the keypoints
Returns:
the predicted keypoints
"""
... # pylint: disable=unnecessary-ellipsis
@overload
def predict(self, time: timedelta) -> Float[Array, "J 3"]:
"""
predict the keypoints at a given time
Args:
time: the time delta to predict the keypoints
"""
... # pylint: disable=unnecessary-ellipsis
@overload
def predict(self, time: datetime) -> Float[Array, "J 3"]:
"""
predict the keypoints at a given time
Args:
time: the timestamp to predict the keypoints
"""
... # pylint: disable=unnecessary-ellipsis
def predict(
self,
time: float | timedelta | datetime,
) -> Float[Array, "J 3"]:
if isinstance(time, timedelta):
timestamp = self.state.last_active_timestamp + time
elif isinstance(time, datetime):
timestamp = time
else:
timestamp = self.state.last_active_timestamp + timedelta(seconds=time)
# pylint: disable-next=unsubscriptable-object
return self.velocity_filter.predict(timestamp)["keypoints"]
@property
def velocity(self) -> Float[Array, "J 3"]:
"""
The velocity of the tracking for each keypoint
"""
# pylint: disable-next=unsubscriptable-object
if (vel := self.velocity_filter.get()["velocity"]) is None:
return jnp.zeros_like(self.state.keypoints)
else:
return vel
@jaxtyped(typechecker=beartype)
@dataclass
class AffinityResult:
"""
Result of affinity computation between trackings and detections.
"""
matrix: Float[Array, "T D"]
trackings: Sequence[Tracking]
detections: Sequence[Detection]
indices_T: Int[Array, "T"] # pylint: disable=invalid-name
indices_D: Int[Array, "D"] # pylint: disable=invalid-name
def tracking_detections(
self,
) -> Generator[tuple[float, Tracking, Detection], None, None]:
"""
iterate over the best matching trackings and detections
"""
for t, d in zip(self.indices_T, self.indices_D):
yield (
self.matrix[t, d].item(),
self.trackings[t],
self.detections[d],
)