1
0
forked from HQU-gxy/CVTH3PE
Files
CVTH3PE/app/tracking/__init__.py
crosstyan 1f8d70803f feat: Implement time-weighted triangulation for enhanced 3D point reconstruction
- Added two new functions: `triangulate_one_point_from_multiple_views_linear_time_weighted` and `triangulate_points_from_multiple_views_linear_time_weighted` to perform triangulation with time-based weighting, improving accuracy in 3D point estimation.
- Introduced a method to group detections by camera while preserving the latest detection, enhancing tracking state management.
- Updated the `update_tracking` function to incorporate time-weighted triangulation, allowing for more robust updates to tracking states based on new detections.
- Refactored the `TrackingState` to utilize a mapping of historical detections by camera, improving data organization and access.
2025-05-03 17:17:47 +08:00

563 lines
17 KiB
Python

import weakref
from collections import deque
from dataclasses import dataclass
from datetime import datetime, timedelta
from itertools import chain
from typing import (
Any,
Callable,
Generator,
Optional,
Protocol,
Sequence,
TypeAlias,
TypedDict,
TypeVar,
Union,
cast,
overload,
)
import jax.numpy as jnp
from beartype import beartype
from beartype.typing import Mapping, Sequence
from jax import Array
from jaxtyping import Array, Float, Int, jaxtyped
from pyrsistent import PVector, v, PRecord, PMap
from app.camera import Detection, CameraID
TrackingID: TypeAlias = int
class TrackingPrediction(TypedDict):
velocity: Optional[Float[Array, "J 3"]]
keypoints: Float[Array, "J 3"]
class GenericVelocityFilter(Protocol):
"""
a filter interface for tracking velocity estimation
"""
def predict(self, timestamp: datetime) -> TrackingPrediction:
"""
predict the velocity and the keypoints location
Args:
timestamp: timestamp of the prediction
Returns:
velocity: velocity of the tracking
keypoints: keypoints of the tracking
"""
... # pylint: disable=unnecessary-ellipsis
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
"""
update the filter state with new measurements
Args:
keypoints: new measurements
timestamp: timestamp of the update
"""
... # pylint: disable=unnecessary-ellipsis
def get(self) -> TrackingPrediction:
"""
get the current state of the filter state
Returns:
velocity: velocity of the tracking
keypoints: keypoints of the tracking
"""
... # pylint: disable=unnecessary-ellipsis
class DummyVelocityFilter(GenericVelocityFilter):
"""
a dummy velocity filter that does nothing
"""
_keypoints_shape: tuple[int, ...]
def __init__(self, keypoints: Float[Array, "J 3"]):
self._keypoints_shape = keypoints.shape
def predict(self, timestamp: datetime) -> TrackingPrediction:
return TrackingPrediction(
velocity=None,
keypoints=jnp.zeros(self._keypoints_shape),
)
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: ...
def get(self) -> TrackingPrediction:
return TrackingPrediction(
velocity=None,
keypoints=jnp.zeros(self._keypoints_shape),
)
class LastDifferenceVelocityFilter(GenericVelocityFilter):
"""
a naive velocity filter that uses the last difference of keypoints
"""
_last_timestamp: datetime
_last_keypoints: Float[Array, "J 3"]
_last_velocity: Optional[Float[Array, "J 3"]] = None
def __init__(self, keypoints: Float[Array, "J 3"], timestamp: datetime):
self._last_keypoints = keypoints
self._last_timestamp = timestamp
def predict(self, timestamp: datetime) -> TrackingPrediction:
delta_t_s = (timestamp - self._last_timestamp).total_seconds()
if self._last_velocity is None:
return TrackingPrediction(
velocity=None,
keypoints=self._last_keypoints,
)
else:
return TrackingPrediction(
velocity=self._last_velocity,
keypoints=self._last_keypoints + self._last_velocity * delta_t_s,
)
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
delta_t_s = (timestamp - self._last_timestamp).total_seconds()
self._last_velocity = (keypoints - self._last_keypoints) / delta_t_s
self._last_keypoints = keypoints
self._last_timestamp = timestamp
def get(self) -> TrackingPrediction:
if self._last_velocity is None:
return TrackingPrediction(
velocity=None,
keypoints=self._last_keypoints,
)
else:
return TrackingPrediction(
velocity=self._last_velocity,
keypoints=self._last_keypoints,
)
class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
"""
a velocity filter that uses the least mean square method to estimate the velocity
"""
_historical_3d_poses: deque[Float[Array, "J 3"]]
_historical_timestamps: deque[datetime]
_velocity: Optional[Float[Array, "J 3"]] = None
_max_samples: int
def __init__(
self,
historical_3d_poses: Sequence[Float[Array, "J 3"]],
historical_timestamps: Sequence[datetime],
max_samples: int = 10,
):
assert len(historical_3d_poses) == len(historical_timestamps)
temp = zip(historical_3d_poses, historical_timestamps)
temp_sorted = sorted(temp, key=lambda x: x[1])
self._historical_3d_poses = deque(
map(lambda x: x[0], temp_sorted), maxlen=max_samples
)
self._historical_timestamps = deque(
map(lambda x: x[1], temp_sorted), maxlen=max_samples
)
self._max_samples = max_samples
if len(self._historical_3d_poses) < 2:
self._velocity = None
else:
self._update(
jnp.array(self._historical_3d_poses),
jnp.array(self._historical_timestamps),
)
def predict(self, timestamp: datetime) -> TrackingPrediction:
if not self._historical_3d_poses:
raise ValueError("No historical 3D poses available for prediction")
# use the latest historical detection
latest_3d_pose = self._historical_3d_poses[-1]
latest_timestamp = self._historical_timestamps[-1]
delta_t_s = (timestamp - latest_timestamp).total_seconds()
if self._velocity is None:
return TrackingPrediction(
velocity=None,
keypoints=latest_3d_pose,
)
else:
# Linear motion model: ẋt = xt' + Vt' · (t - t')
predicted_3d_pose = latest_3d_pose + self._velocity * delta_t_s
return TrackingPrediction(
velocity=self._velocity, keypoints=predicted_3d_pose
)
@jaxtyped(typechecker=beartype)
def _update(
self,
keypoints: Float[Array, "N J 3"],
timestamps: Float[Array, "N"],
) -> None:
"""
update measurements with least mean square method
"""
if keypoints.shape[0] < 2:
raise ValueError("Not enough measurements to estimate velocity")
# Using least squares to fit a linear model for each joint and dimension
# X = timestamps, y = keypoints
# For each joint and each dimension, we solve for velocity
n_samples = timestamps.shape[0]
n_joints = keypoints.shape[1]
# Create design matrix for linear regression
# [t, 1] for each timestamp
X = jnp.column_stack([timestamps, jnp.ones(n_samples)])
# Reshape keypoints to solve for all joints and dimensions at once
# From [N, J, 3] to [N, J*3]
keypoints_reshaped = keypoints.reshape(n_samples, -1)
# Use JAX's lstsq to solve the least squares problem
# This is more numerically stable than manually computing pseudoinverse
coefficients, _, _, _ = jnp.linalg.lstsq(X, keypoints_reshaped, rcond=None)
# Coefficients shape is [2, J*3]
# First row: velocities, Second row: intercepts
velocities = coefficients[0].reshape(n_joints, 3)
# Update velocity
self._velocity = velocities
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
last_timestamp = self._historical_timestamps[-1]
assert last_timestamp <= timestamp
# deque would manage the maxlen automatically
self._historical_3d_poses.append(keypoints)
self._historical_timestamps.append(timestamp)
t_0 = self._historical_timestamps[0]
all_keypoints = jnp.array(self._historical_3d_poses)
def timestamp_to_seconds(timestamp: datetime) -> float:
assert t_0 <= timestamp
return (timestamp - t_0).total_seconds()
# timestamps relative to t_0 (the oldest detection timestamp)
all_timestamps = jnp.array(
map(timestamp_to_seconds, self._historical_timestamps)
)
self._update(all_keypoints, all_timestamps)
def get(self) -> TrackingPrediction:
if not self._historical_3d_poses:
raise ValueError("No historical 3D poses available")
latest_3d_pose = self._historical_3d_poses[-1]
if self._velocity is None:
return TrackingPrediction(velocity=None, keypoints=latest_3d_pose)
else:
return TrackingPrediction(velocity=self._velocity, keypoints=latest_3d_pose)
class OneEuroFilter(GenericVelocityFilter):
"""
Implementation of the 1€ filter (One Euro Filter) for smoothing keypoint data.
The 1€ filter is an adaptive low-pass filter that adjusts its cutoff frequency
based on movement speed to reduce jitter during slow movements while maintaining
responsiveness during fast movements.
Reference: https://cristal.univ-lille.fr/~casiez/1euro/
"""
_x_filtered: Float[Array, "J 3"]
_dx_filtered: Optional[Float[Array, "J 3"]] = None
_last_timestamp: datetime
_min_cutoff: float
_beta: float
_d_cutoff: float
def __init__(
self,
keypoints: Float[Array, "J 3"],
timestamp: datetime,
min_cutoff: float = 1.0,
beta: float = 0.0,
d_cutoff: float = 1.0,
):
"""
Initialize the One Euro Filter.
Args:
keypoints: Initial keypoints positions
timestamp: Initial timestamp
min_cutoff: Minimum cutoff frequency (lower = more smoothing)
beta: Speed coefficient (higher = less lag during fast movements)
d_cutoff: Cutoff frequency for the derivative filter
"""
self._last_timestamp = timestamp
# Filter parameters
self._min_cutoff = min_cutoff
self._beta = beta
self._d_cutoff = d_cutoff
# Filter state
self._x_filtered = keypoints # Position filter state
self._dx_filtered = None # Initially no velocity estimate
@overload
def _smoothing_factor(self, cutoff: float, dt: float) -> float: ...
@overload
def _smoothing_factor(
self, cutoff: Float[Array, "J"], dt: float
) -> Float[Array, "J"]: ...
@jaxtyped(typechecker=beartype)
def _smoothing_factor(
self, cutoff: Union[float, Float[Array, "J"]], dt: float
) -> Union[float, Float[Array, "J"]]:
"""Calculate the smoothing factor for the low-pass filter."""
r = 2 * jnp.pi * cutoff * dt
return r / (r + 1)
@jaxtyped(typechecker=beartype)
def _exponential_smoothing(
self,
a: Union[float, Float[Array, "J"]],
x: Float[Array, "J 3"],
x_prev: Float[Array, "J 3"],
) -> Float[Array, "J 3"]:
"""Apply exponential smoothing to the input."""
return a * x + (1 - a) * x_prev
def predict(self, timestamp: datetime) -> TrackingPrediction:
"""
Predict keypoints position at a given timestamp.
Args:
timestamp: Timestamp for prediction
Returns:
TrackingPrediction with velocity and keypoints
"""
dt = (timestamp - self._last_timestamp).total_seconds()
if self._dx_filtered is None:
return TrackingPrediction(
velocity=None,
keypoints=self._x_filtered,
)
else:
predicted_keypoints = self._x_filtered + self._dx_filtered * dt
return TrackingPrediction(
velocity=self._dx_filtered,
keypoints=predicted_keypoints,
)
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
"""
Update the filter with new measurements.
Args:
keypoints: New keypoint measurements
timestamp: Timestamp of the measurements
"""
dt = (timestamp - self._last_timestamp).total_seconds()
if dt <= 0:
raise ValueError(
f"new timestamp is not greater than the last timestamp; expecting: {timestamp} > {self._last_timestamp}"
)
dx = (keypoints - self._x_filtered) / dt
# Determine cutoff frequency based on movement speed
cutoff = self._min_cutoff + self._beta * jnp.linalg.norm(
dx, axis=-1, keepdims=True
)
# Apply low-pass filter to velocity
a_d = self._smoothing_factor(self._d_cutoff, dt)
self._dx_filtered = self._exponential_smoothing(
a_d,
dx,
(
jnp.zeros_like(keypoints)
if self._dx_filtered is None
else self._dx_filtered
),
)
# Apply low-pass filter to position with adaptive cutoff
a_cutoff = self._smoothing_factor(jnp.asarray(cutoff), dt)
self._x_filtered = self._exponential_smoothing(
a_cutoff, keypoints, self._x_filtered
)
# Update timestamp
self._last_timestamp = timestamp
def get(self) -> TrackingPrediction:
"""
Get the current state of the filter.
Returns:
TrackingPrediction with velocity and keypoints
"""
return TrackingPrediction(
velocity=self._dx_filtered,
keypoints=self._x_filtered,
)
@jaxtyped(typechecker=beartype)
@dataclass(frozen=True)
class TrackingState:
"""
immutable state of a tracking
"""
keypoints: Float[Array, "J 3"]
"""
The 3D keypoints of the tracking
Used for calculate affinity 3D
"""
last_active_timestamp: datetime
"""
The last active timestamp of the tracking
"""
historical_detections_by_camera: PMap[CameraID, Detection]
"""
Historical detections of the tracking.
Used for 3D re-triangulation
"""
class Tracking:
id: TrackingID
state: TrackingState
velocity_filter: GenericVelocityFilter
def __init__(
self,
id: TrackingID,
state: TrackingState,
velocity_filter: Optional[GenericVelocityFilter] = None,
):
self.id = id
self.state = state
self.velocity_filter = velocity_filter or DummyVelocityFilter(state.keypoints)
def __repr__(self) -> str:
return f"Tracking({self.id}, {self.state.last_active_timestamp})"
@overload
def predict(self, time: float) -> Float[Array, "J 3"]:
"""
predict the keypoints at a given time
Args:
time: the time in seconds to predict the keypoints
Returns:
the predicted keypoints
"""
... # pylint: disable=unnecessary-ellipsis
@overload
def predict(self, time: timedelta) -> Float[Array, "J 3"]:
"""
predict the keypoints at a given time
Args:
time: the time delta to predict the keypoints
"""
... # pylint: disable=unnecessary-ellipsis
@overload
def predict(self, time: datetime) -> Float[Array, "J 3"]:
"""
predict the keypoints at a given time
Args:
time: the timestamp to predict the keypoints
"""
... # pylint: disable=unnecessary-ellipsis
def predict(
self,
time: float | timedelta | datetime,
) -> Float[Array, "J 3"]:
if isinstance(time, timedelta):
timestamp = self.state.last_active_timestamp + time
elif isinstance(time, datetime):
timestamp = time
else:
timestamp = self.state.last_active_timestamp + timedelta(seconds=time)
# pylint: disable-next=unsubscriptable-object
return self.velocity_filter.predict(timestamp)["keypoints"]
def update(self, new_3d_pose: Float[Array, "J 3"], timestamp: datetime) -> None:
"""
update the tracking with a new 3D pose
Note:
equivalent to call `velocity_filter.update(new_3d_pose, timestamp)`
"""
self.velocity_filter.update(new_3d_pose, timestamp)
@property
def velocity(self) -> Float[Array, "J 3"]:
"""
The velocity of the tracking for each keypoint
"""
# pylint: disable-next=unsubscriptable-object
if (vel := self.velocity_filter.get()["velocity"]) is None:
return jnp.zeros_like(self.state.keypoints)
else:
return vel
@jaxtyped(typechecker=beartype)
@dataclass
class AffinityResult:
"""
Result of affinity computation between trackings and detections.
"""
matrix: Float[Array, "T D"]
trackings: Sequence[Tracking]
detections: Sequence[Detection]
indices_T: Int[Array, "T"] # pylint: disable=invalid-name
indices_D: Int[Array, "D"] # pylint: disable=invalid-name
def tracking_association(
self,
) -> Generator[tuple[float, Tracking, Detection], None, None]:
"""
iterate over the best matching trackings and detections
"""
for t, d in zip(self.indices_T, self.indices_D):
yield (
self.matrix[t, d].item(),
self.trackings[t],
self.detections[d],
)