single peopele detect and tracking

This commit is contained in:
2025-06-13 16:02:15 +08:00
parent eb9738cb02
commit 492b4fba04
8 changed files with 6734 additions and 1874 deletions

View File

@ -1,14 +1,20 @@
import warnings
import weakref
from collections import deque
from dataclasses import dataclass
from datetime import datetime
from datetime import datetime, timedelta
from itertools import chain
from typing import (
Any,
Callable,
Generator,
Optional,
Protocol,
Sequence,
TypeAlias,
TypedDict,
TypeVar,
Union,
cast,
overload,
)
@ -18,18 +24,428 @@ from beartype import beartype
from beartype.typing import Mapping, Sequence
from jax import Array
from jaxtyping import Array, Float, Int, jaxtyped
from pyrsistent import PVector
from pyrsistent import PVector, v, PRecord, PMap
from app.camera import Detection
from app.camera import Detection, CameraID
TrackingID: TypeAlias = int
class TrackingPrediction(TypedDict):
velocity: Optional[Float[Array, "J 3"]]
keypoints: Float[Array, "J 3"]
class GenericVelocityFilter(Protocol):
"""
a filter interface for tracking velocity estimation
"""
def predict(self, timestamp: datetime) -> TrackingPrediction:
"""
predict the velocity and the keypoints location
Args:
timestamp: timestamp of the prediction
Returns:
velocity: velocity of the tracking
keypoints: keypoints of the tracking
"""
... # pylint: disable=unnecessary-ellipsis
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
"""
update the filter state with new measurements
Args:
keypoints: new measurements
timestamp: timestamp of the update
"""
... # pylint: disable=unnecessary-ellipsis
def get(self) -> TrackingPrediction:
"""
get the current state of the filter state
Returns:
velocity: velocity of the tracking
keypoints: keypoints of the tracking
"""
... # pylint: disable=unnecessary-ellipsis
class DummyVelocityFilter(GenericVelocityFilter):
"""
a dummy velocity filter that does nothing
"""
_keypoints_shape: tuple[int, ...]
def __init__(self, keypoints: Float[Array, "J 3"]):
self._keypoints_shape = keypoints.shape
def predict(self, timestamp: datetime) -> TrackingPrediction:
return TrackingPrediction(
velocity=None,
keypoints=jnp.zeros(self._keypoints_shape),
)
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: ...
def get(self) -> TrackingPrediction:
return TrackingPrediction(
velocity=None,
keypoints=jnp.zeros(self._keypoints_shape),
)
class LastDifferenceVelocityFilter(GenericVelocityFilter):
"""
a naive velocity filter that uses the last difference of keypoints
"""
_last_timestamp: datetime
_last_keypoints: Float[Array, "J 3"]
_last_velocity: Optional[Float[Array, "J 3"]] = None
def __init__(self, keypoints: Float[Array, "J 3"], timestamp: datetime):
self._last_keypoints = keypoints
self._last_timestamp = timestamp
def predict(self, timestamp: datetime) -> TrackingPrediction:
delta_t_s = (timestamp - self._last_timestamp).total_seconds()
if delta_t_s <= 0:
warnings.warn(
"delta_t={}; last={}; current={}".format(
delta_t_s, self._last_timestamp, timestamp
)
)
if self._last_velocity is None:
return TrackingPrediction(
velocity=None,
keypoints=self._last_keypoints,
)
else:
if delta_t_s <= 0:
return TrackingPrediction(
velocity=self._last_velocity,
keypoints=self._last_keypoints,
)
return TrackingPrediction(
velocity=self._last_velocity,
keypoints=self._last_keypoints + self._last_velocity * delta_t_s,
)
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
delta_t_s = (timestamp - self._last_timestamp).total_seconds()
if delta_t_s <= 0:
pass
else:
self._last_timestamp = timestamp
self._last_velocity = (keypoints - self._last_keypoints) / delta_t_s
self._last_keypoints = keypoints
def get(self) -> TrackingPrediction:
if self._last_velocity is None:
return TrackingPrediction(
velocity=None,
keypoints=self._last_keypoints,
)
else:
return TrackingPrediction(
velocity=self._last_velocity,
keypoints=self._last_keypoints,
)
class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
"""
a velocity filter that uses the least mean square method to estimate the velocity
"""
_historical_3d_poses: deque[Float[Array, "J 3"]]
_historical_timestamps: deque[datetime]
_velocity: Optional[Float[Array, "J 3"]] = None
_max_samples: int
def __init__(
self,
historical_3d_poses: Sequence[Float[Array, "J 3"]],
historical_timestamps: Sequence[datetime],
max_samples: int = 10,
):
assert len(historical_3d_poses) == len(historical_timestamps)
temp = zip(historical_3d_poses, historical_timestamps)
temp_sorted = sorted(temp, key=lambda x: x[1])
self._historical_3d_poses = deque(
map(lambda x: x[0], temp_sorted), maxlen=max_samples
)
self._historical_timestamps = deque(
map(lambda x: x[1], temp_sorted), maxlen=max_samples
)
self._max_samples = max_samples
if len(self._historical_3d_poses) < 2:
self._velocity = None
else:
self._update(
jnp.array(self._historical_3d_poses),
jnp.array(self._historical_timestamps),
)
def predict(self, timestamp: datetime) -> TrackingPrediction:
if not self._historical_3d_poses:
raise ValueError("No historical 3D poses available for prediction")
# use the latest historical detection
latest_3d_pose = self._historical_3d_poses[-1]
latest_timestamp = self._historical_timestamps[-1]
delta_t_s = (timestamp - latest_timestamp).total_seconds()
if self._velocity is None:
return TrackingPrediction(
velocity=None,
keypoints=latest_3d_pose,
)
else:
# Linear motion model: ẋt = xt' + Vt' · (t - t')
predicted_3d_pose = latest_3d_pose + self._velocity * delta_t_s
return TrackingPrediction(
velocity=self._velocity, keypoints=predicted_3d_pose
)
@jaxtyped(typechecker=beartype)
def _update(
self,
keypoints: Float[Array, "N J 3"],
timestamps: Float[Array, "N"],
) -> None:
"""
update measurements with least mean square method
"""
if keypoints.shape[0] < 2:
raise ValueError("Not enough measurements to estimate velocity")
# Using least squares to fit a linear model for each joint and dimension
# X = timestamps, y = keypoints
# For each joint and each dimension, we solve for velocity
n_samples = timestamps.shape[0]
n_joints = keypoints.shape[1]
# Create design matrix for linear regression
# [t, 1] for each timestamp
X = jnp.column_stack([timestamps, jnp.ones(n_samples)])
# Reshape keypoints to solve for all joints and dimensions at once
# From [N, J, 3] to [N, J*3]
keypoints_reshaped = keypoints.reshape(n_samples, -1)
# Use JAX's lstsq to solve the least squares problem
# This is more numerically stable than manually computing pseudoinverse
coefficients, _, _, _ = jnp.linalg.lstsq(X, keypoints_reshaped, rcond=None)
# Coefficients shape is [2, J*3]
# First row: velocities, Second row: intercepts
velocities = coefficients[0].reshape(n_joints, 3)
# Update velocity
self._velocity = velocities
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
last_timestamp = self._historical_timestamps[-1]
assert last_timestamp <= timestamp
# deque would manage the maxlen automatically
self._historical_3d_poses.append(keypoints)
self._historical_timestamps.append(timestamp)
t_0 = self._historical_timestamps[0]
all_keypoints = jnp.array(self._historical_3d_poses)
def timestamp_to_seconds(timestamp: datetime) -> float:
assert t_0 <= timestamp
return (timestamp - t_0).total_seconds()
# timestamps relative to t_0 (the oldest detection timestamp)
all_timestamps = jnp.array(
map(timestamp_to_seconds, self._historical_timestamps)
)
self._update(all_keypoints, all_timestamps)
def get(self) -> TrackingPrediction:
if not self._historical_3d_poses:
raise ValueError("No historical 3D poses available")
latest_3d_pose = self._historical_3d_poses[-1]
if self._velocity is None:
return TrackingPrediction(velocity=None, keypoints=latest_3d_pose)
else:
return TrackingPrediction(velocity=self._velocity, keypoints=latest_3d_pose)
class OneEuroFilter(GenericVelocityFilter):
"""
Implementation of the 1€ filter (One Euro Filter) for smoothing keypoint data.
The 1€ filter is an adaptive low-pass filter that adjusts its cutoff frequency
based on movement speed to reduce jitter during slow movements while maintaining
responsiveness during fast movements.
Reference: https://cristal.univ-lille.fr/~casiez/1euro/
"""
_x_filtered: Float[Array, "J 3"]
_dx_filtered: Optional[Float[Array, "J 3"]] = None
_last_timestamp: datetime
_min_cutoff: float
_beta: float
_d_cutoff: float
def __init__(
self,
keypoints: Float[Array, "J 3"],
timestamp: datetime,
min_cutoff: float = 1.0,
beta: float = 0.0,
d_cutoff: float = 1.0,
):
"""
Initialize the One Euro Filter.
Args:
keypoints: Initial keypoints positions
timestamp: Initial timestamp
min_cutoff: Minimum cutoff frequency (lower = more smoothing)
beta: Speed coefficient (higher = less lag during fast movements)
d_cutoff: Cutoff frequency for the derivative filter
"""
self._last_timestamp = timestamp
# Filter parameters
self._min_cutoff = min_cutoff
self._beta = beta
self._d_cutoff = d_cutoff
# Filter state
self._x_filtered = keypoints # Position filter state
self._dx_filtered = None # Initially no velocity estimate
@overload
def _smoothing_factor(self, cutoff: float, dt: float) -> float: ...
@overload
def _smoothing_factor(
self, cutoff: Float[Array, "J"], dt: float
) -> Float[Array, "J"]: ...
@jaxtyped(typechecker=beartype)
def _smoothing_factor(
self, cutoff: Union[float, Float[Array, "J"]], dt: float
) -> Union[float, Float[Array, "J"]]:
"""Calculate the smoothing factor for the low-pass filter."""
r = 2 * jnp.pi * cutoff * dt
return r / (r + 1)
@jaxtyped(typechecker=beartype)
def _exponential_smoothing(
self,
a: Union[float, Float[Array, "J"]],
x: Float[Array, "J 3"],
x_prev: Float[Array, "J 3"],
) -> Float[Array, "J 3"]:
"""Apply exponential smoothing to the input."""
return a * x + (1 - a) * x_prev
def predict(self, timestamp: datetime) -> TrackingPrediction:
"""
Predict keypoints position at a given timestamp.
Args:
timestamp: Timestamp for prediction
Returns:
TrackingPrediction with velocity and keypoints
"""
dt = (timestamp - self._last_timestamp).total_seconds()
if self._dx_filtered is None:
return TrackingPrediction(
velocity=None,
keypoints=self._x_filtered,
)
else:
predicted_keypoints = self._x_filtered + self._dx_filtered * dt
return TrackingPrediction(
velocity=self._dx_filtered,
keypoints=predicted_keypoints,
)
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
"""
Update the filter with new measurements.
Args:
keypoints: New keypoint measurements
timestamp: Timestamp of the measurements
"""
dt = (timestamp - self._last_timestamp).total_seconds()
if dt <= 0:
raise ValueError(
f"new timestamp is not greater than the last timestamp; expecting: {timestamp} > {self._last_timestamp}"
)
dx = (keypoints - self._x_filtered) / dt
# Determine cutoff frequency based on movement speed
cutoff = self._min_cutoff + self._beta * jnp.linalg.norm(
dx, axis=-1, keepdims=True
)
# Apply low-pass filter to velocity
a_d = self._smoothing_factor(self._d_cutoff, dt)
self._dx_filtered = self._exponential_smoothing(
a_d,
dx,
(
jnp.zeros_like(keypoints)
if self._dx_filtered is None
else self._dx_filtered
),
)
# Apply low-pass filter to position with adaptive cutoff
a_cutoff = self._smoothing_factor(jnp.asarray(cutoff), dt)
self._x_filtered = self._exponential_smoothing(
a_cutoff, keypoints, self._x_filtered
)
# Update timestamp
self._last_timestamp = timestamp
def get(self) -> TrackingPrediction:
"""
Get the current state of the filter.
Returns:
TrackingPrediction with velocity and keypoints
"""
return TrackingPrediction(
velocity=self._dx_filtered,
keypoints=self._x_filtered,
)
@jaxtyped(typechecker=beartype)
@dataclass(frozen=True)
class Tracking:
id: int
class TrackingState:
"""
The tracking id
immutable state of a tracking
"""
keypoints: Float[Array, "J 3"]
"""
The 3D keypoints of the tracking
@ -41,50 +457,97 @@ class Tracking:
The last active timestamp of the tracking
"""
historical_detections: PVector[Detection]
historical_detections_by_camera: PMap[CameraID, Detection]
"""
Historical detections of the tracking.
Used for 3D re-triangulation
"""
velocity: Optional[Float[Array, "3"]] = None
"""
Could be `None`. Like when the 3D pose is initialized.
`velocity` should be updated when target association yields a new
3D pose.
"""
class Tracking:
id: TrackingID
state: TrackingState
velocity_filter: GenericVelocityFilter
def __init__(
self,
id: TrackingID,
state: TrackingState,
velocity_filter: Optional[GenericVelocityFilter] = None,
):
self.id = id
self.state = state
self.velocity_filter = velocity_filter or DummyVelocityFilter(state.keypoints)
def __repr__(self) -> str:
return f"Tracking({self.id}, {self.last_active_timestamp})"
return f"Tracking({self.id}, {self.state.last_active_timestamp})"
@overload
def predict(self, time: float) -> Float[Array, "J 3"]:
"""
predict the keypoints at a given time
Args:
time: the time in seconds to predict the keypoints
Returns:
the predicted keypoints
"""
... # pylint: disable=unnecessary-ellipsis
@overload
def predict(self, time: timedelta) -> Float[Array, "J 3"]:
"""
predict the keypoints at a given time
Args:
time: the time delta to predict the keypoints
"""
... # pylint: disable=unnecessary-ellipsis
@overload
def predict(self, time: datetime) -> Float[Array, "J 3"]:
"""
predict the keypoints at a given time
Args:
time: the timestamp to predict the keypoints
"""
... # pylint: disable=unnecessary-ellipsis
def predict(
self,
delta_t_s: float,
time: float | timedelta | datetime,
) -> Float[Array, "J 3"]:
"""
Predict the 3D pose of a tracking based on its velocity.
JAX-friendly implementation that avoids Python control flow.
Args:
delta_t_s: Time delta in seconds
Returns:
Predicted 3D pose keypoints
"""
# ------------------------------------------------------------------
# Step 1 decide velocity on the Python side
# ------------------------------------------------------------------
if self.velocity is None:
velocity = jnp.zeros_like(self.keypoints) # (J, 3)
if isinstance(time, timedelta):
timestamp = self.state.last_active_timestamp + time
elif isinstance(time, datetime):
timestamp = time
else:
velocity = self.velocity # (J, 3)
timestamp = self.state.last_active_timestamp + timedelta(seconds=time)
# pylint: disable-next=unsubscriptable-object
return self.velocity_filter.predict(timestamp)["keypoints"]
# ------------------------------------------------------------------
# Step 2 pure JAX math
# ------------------------------------------------------------------
return self.keypoints + velocity * delta_t_s
def update(self, new_3d_pose: Float[Array, "J 3"], timestamp: datetime) -> None:
"""
update the tracking with a new 3D pose
Note:
equivalent to call `velocity_filter.update(new_3d_pose, timestamp)`
"""
self.velocity_filter.update(new_3d_pose, timestamp)
@property
def velocity(self) -> Float[Array, "J 3"]:
"""
The velocity of the tracking for each keypoint
"""
# pylint: disable-next=unsubscriptable-object
if (vel := self.velocity_filter.get()["velocity"]) is None:
return jnp.zeros_like(self.state.keypoints)
else:
return vel
@jaxtyped(typechecker=beartype)
@ -98,9 +561,9 @@ class AffinityResult:
trackings: Sequence[Tracking]
detections: Sequence[Detection]
indices_T: Int[Array, "T"] # pylint: disable=invalid-name
indices_D: Int[Array, "D"] # pylint: disable=invalid-name
indices_D: Int[Array, "T"] # pylint: disable=invalid-name
def tracking_detections(
def tracking_association(
self,
) -> Generator[tuple[float, Tracking, Detection], None, None]:
"""