forked from HQU-gxy/CVTH3PE
single peopele detect and tracking
This commit is contained in:
@ -1,14 +1,20 @@
|
|||||||
|
import warnings
|
||||||
|
import weakref
|
||||||
|
from collections import deque
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime, timedelta
|
||||||
|
from itertools import chain
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
Generator,
|
Generator,
|
||||||
Optional,
|
Optional,
|
||||||
|
Protocol,
|
||||||
Sequence,
|
Sequence,
|
||||||
TypeAlias,
|
TypeAlias,
|
||||||
TypedDict,
|
TypedDict,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
|
Union,
|
||||||
cast,
|
cast,
|
||||||
overload,
|
overload,
|
||||||
)
|
)
|
||||||
@ -18,18 +24,428 @@ from beartype import beartype
|
|||||||
from beartype.typing import Mapping, Sequence
|
from beartype.typing import Mapping, Sequence
|
||||||
from jax import Array
|
from jax import Array
|
||||||
from jaxtyping import Array, Float, Int, jaxtyped
|
from jaxtyping import Array, Float, Int, jaxtyped
|
||||||
from pyrsistent import PVector
|
from pyrsistent import PVector, v, PRecord, PMap
|
||||||
|
|
||||||
from app.camera import Detection
|
from app.camera import Detection, CameraID
|
||||||
|
|
||||||
|
TrackingID: TypeAlias = int
|
||||||
|
|
||||||
|
|
||||||
|
class TrackingPrediction(TypedDict):
|
||||||
|
velocity: Optional[Float[Array, "J 3"]]
|
||||||
|
keypoints: Float[Array, "J 3"]
|
||||||
|
|
||||||
|
|
||||||
|
class GenericVelocityFilter(Protocol):
|
||||||
|
"""
|
||||||
|
a filter interface for tracking velocity estimation
|
||||||
|
"""
|
||||||
|
|
||||||
|
def predict(self, timestamp: datetime) -> TrackingPrediction:
|
||||||
|
"""
|
||||||
|
predict the velocity and the keypoints location
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timestamp: timestamp of the prediction
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
velocity: velocity of the tracking
|
||||||
|
keypoints: keypoints of the tracking
|
||||||
|
"""
|
||||||
|
... # pylint: disable=unnecessary-ellipsis
|
||||||
|
|
||||||
|
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
|
||||||
|
"""
|
||||||
|
update the filter state with new measurements
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keypoints: new measurements
|
||||||
|
timestamp: timestamp of the update
|
||||||
|
"""
|
||||||
|
... # pylint: disable=unnecessary-ellipsis
|
||||||
|
|
||||||
|
def get(self) -> TrackingPrediction:
|
||||||
|
"""
|
||||||
|
get the current state of the filter state
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
velocity: velocity of the tracking
|
||||||
|
keypoints: keypoints of the tracking
|
||||||
|
"""
|
||||||
|
... # pylint: disable=unnecessary-ellipsis
|
||||||
|
|
||||||
|
|
||||||
|
class DummyVelocityFilter(GenericVelocityFilter):
|
||||||
|
"""
|
||||||
|
a dummy velocity filter that does nothing
|
||||||
|
"""
|
||||||
|
|
||||||
|
_keypoints_shape: tuple[int, ...]
|
||||||
|
|
||||||
|
def __init__(self, keypoints: Float[Array, "J 3"]):
|
||||||
|
self._keypoints_shape = keypoints.shape
|
||||||
|
|
||||||
|
def predict(self, timestamp: datetime) -> TrackingPrediction:
|
||||||
|
return TrackingPrediction(
|
||||||
|
velocity=None,
|
||||||
|
keypoints=jnp.zeros(self._keypoints_shape),
|
||||||
|
)
|
||||||
|
|
||||||
|
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: ...
|
||||||
|
|
||||||
|
def get(self) -> TrackingPrediction:
|
||||||
|
return TrackingPrediction(
|
||||||
|
velocity=None,
|
||||||
|
keypoints=jnp.zeros(self._keypoints_shape),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LastDifferenceVelocityFilter(GenericVelocityFilter):
|
||||||
|
"""
|
||||||
|
a naive velocity filter that uses the last difference of keypoints
|
||||||
|
"""
|
||||||
|
|
||||||
|
_last_timestamp: datetime
|
||||||
|
_last_keypoints: Float[Array, "J 3"]
|
||||||
|
_last_velocity: Optional[Float[Array, "J 3"]] = None
|
||||||
|
|
||||||
|
def __init__(self, keypoints: Float[Array, "J 3"], timestamp: datetime):
|
||||||
|
self._last_keypoints = keypoints
|
||||||
|
self._last_timestamp = timestamp
|
||||||
|
|
||||||
|
def predict(self, timestamp: datetime) -> TrackingPrediction:
|
||||||
|
delta_t_s = (timestamp - self._last_timestamp).total_seconds()
|
||||||
|
if delta_t_s <= 0:
|
||||||
|
warnings.warn(
|
||||||
|
"delta_t={}; last={}; current={}".format(
|
||||||
|
delta_t_s, self._last_timestamp, timestamp
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if self._last_velocity is None:
|
||||||
|
return TrackingPrediction(
|
||||||
|
velocity=None,
|
||||||
|
keypoints=self._last_keypoints,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if delta_t_s <= 0:
|
||||||
|
return TrackingPrediction(
|
||||||
|
velocity=self._last_velocity,
|
||||||
|
keypoints=self._last_keypoints,
|
||||||
|
)
|
||||||
|
return TrackingPrediction(
|
||||||
|
velocity=self._last_velocity,
|
||||||
|
keypoints=self._last_keypoints + self._last_velocity * delta_t_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
|
||||||
|
delta_t_s = (timestamp - self._last_timestamp).total_seconds()
|
||||||
|
if delta_t_s <= 0:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
self._last_timestamp = timestamp
|
||||||
|
self._last_velocity = (keypoints - self._last_keypoints) / delta_t_s
|
||||||
|
self._last_keypoints = keypoints
|
||||||
|
|
||||||
|
def get(self) -> TrackingPrediction:
|
||||||
|
if self._last_velocity is None:
|
||||||
|
return TrackingPrediction(
|
||||||
|
velocity=None,
|
||||||
|
keypoints=self._last_keypoints,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return TrackingPrediction(
|
||||||
|
velocity=self._last_velocity,
|
||||||
|
keypoints=self._last_keypoints,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
|
||||||
|
"""
|
||||||
|
a velocity filter that uses the least mean square method to estimate the velocity
|
||||||
|
"""
|
||||||
|
|
||||||
|
_historical_3d_poses: deque[Float[Array, "J 3"]]
|
||||||
|
_historical_timestamps: deque[datetime]
|
||||||
|
_velocity: Optional[Float[Array, "J 3"]] = None
|
||||||
|
_max_samples: int
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
historical_3d_poses: Sequence[Float[Array, "J 3"]],
|
||||||
|
historical_timestamps: Sequence[datetime],
|
||||||
|
max_samples: int = 10,
|
||||||
|
):
|
||||||
|
assert len(historical_3d_poses) == len(historical_timestamps)
|
||||||
|
temp = zip(historical_3d_poses, historical_timestamps)
|
||||||
|
temp_sorted = sorted(temp, key=lambda x: x[1])
|
||||||
|
self._historical_3d_poses = deque(
|
||||||
|
map(lambda x: x[0], temp_sorted), maxlen=max_samples
|
||||||
|
)
|
||||||
|
self._historical_timestamps = deque(
|
||||||
|
map(lambda x: x[1], temp_sorted), maxlen=max_samples
|
||||||
|
)
|
||||||
|
self._max_samples = max_samples
|
||||||
|
if len(self._historical_3d_poses) < 2:
|
||||||
|
self._velocity = None
|
||||||
|
else:
|
||||||
|
self._update(
|
||||||
|
jnp.array(self._historical_3d_poses),
|
||||||
|
jnp.array(self._historical_timestamps),
|
||||||
|
)
|
||||||
|
|
||||||
|
def predict(self, timestamp: datetime) -> TrackingPrediction:
|
||||||
|
if not self._historical_3d_poses:
|
||||||
|
raise ValueError("No historical 3D poses available for prediction")
|
||||||
|
|
||||||
|
# use the latest historical detection
|
||||||
|
latest_3d_pose = self._historical_3d_poses[-1]
|
||||||
|
latest_timestamp = self._historical_timestamps[-1]
|
||||||
|
|
||||||
|
delta_t_s = (timestamp - latest_timestamp).total_seconds()
|
||||||
|
|
||||||
|
if self._velocity is None:
|
||||||
|
return TrackingPrediction(
|
||||||
|
velocity=None,
|
||||||
|
keypoints=latest_3d_pose,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Linear motion model: ẋt = xt' + Vt' · (t - t')
|
||||||
|
predicted_3d_pose = latest_3d_pose + self._velocity * delta_t_s
|
||||||
|
return TrackingPrediction(
|
||||||
|
velocity=self._velocity, keypoints=predicted_3d_pose
|
||||||
|
)
|
||||||
|
|
||||||
|
@jaxtyped(typechecker=beartype)
|
||||||
|
def _update(
|
||||||
|
self,
|
||||||
|
keypoints: Float[Array, "N J 3"],
|
||||||
|
timestamps: Float[Array, "N"],
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
update measurements with least mean square method
|
||||||
|
"""
|
||||||
|
if keypoints.shape[0] < 2:
|
||||||
|
raise ValueError("Not enough measurements to estimate velocity")
|
||||||
|
|
||||||
|
# Using least squares to fit a linear model for each joint and dimension
|
||||||
|
# X = timestamps, y = keypoints
|
||||||
|
# For each joint and each dimension, we solve for velocity
|
||||||
|
|
||||||
|
n_samples = timestamps.shape[0]
|
||||||
|
n_joints = keypoints.shape[1]
|
||||||
|
|
||||||
|
# Create design matrix for linear regression
|
||||||
|
# [t, 1] for each timestamp
|
||||||
|
X = jnp.column_stack([timestamps, jnp.ones(n_samples)])
|
||||||
|
|
||||||
|
# Reshape keypoints to solve for all joints and dimensions at once
|
||||||
|
# From [N, J, 3] to [N, J*3]
|
||||||
|
keypoints_reshaped = keypoints.reshape(n_samples, -1)
|
||||||
|
|
||||||
|
# Use JAX's lstsq to solve the least squares problem
|
||||||
|
# This is more numerically stable than manually computing pseudoinverse
|
||||||
|
coefficients, _, _, _ = jnp.linalg.lstsq(X, keypoints_reshaped, rcond=None)
|
||||||
|
|
||||||
|
# Coefficients shape is [2, J*3]
|
||||||
|
# First row: velocities, Second row: intercepts
|
||||||
|
velocities = coefficients[0].reshape(n_joints, 3)
|
||||||
|
|
||||||
|
# Update velocity
|
||||||
|
self._velocity = velocities
|
||||||
|
|
||||||
|
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
|
||||||
|
last_timestamp = self._historical_timestamps[-1]
|
||||||
|
assert last_timestamp <= timestamp
|
||||||
|
|
||||||
|
# deque would manage the maxlen automatically
|
||||||
|
self._historical_3d_poses.append(keypoints)
|
||||||
|
self._historical_timestamps.append(timestamp)
|
||||||
|
|
||||||
|
t_0 = self._historical_timestamps[0]
|
||||||
|
all_keypoints = jnp.array(self._historical_3d_poses)
|
||||||
|
|
||||||
|
def timestamp_to_seconds(timestamp: datetime) -> float:
|
||||||
|
assert t_0 <= timestamp
|
||||||
|
return (timestamp - t_0).total_seconds()
|
||||||
|
|
||||||
|
# timestamps relative to t_0 (the oldest detection timestamp)
|
||||||
|
all_timestamps = jnp.array(
|
||||||
|
map(timestamp_to_seconds, self._historical_timestamps)
|
||||||
|
)
|
||||||
|
|
||||||
|
self._update(all_keypoints, all_timestamps)
|
||||||
|
|
||||||
|
def get(self) -> TrackingPrediction:
|
||||||
|
if not self._historical_3d_poses:
|
||||||
|
raise ValueError("No historical 3D poses available")
|
||||||
|
|
||||||
|
latest_3d_pose = self._historical_3d_poses[-1]
|
||||||
|
|
||||||
|
if self._velocity is None:
|
||||||
|
return TrackingPrediction(velocity=None, keypoints=latest_3d_pose)
|
||||||
|
else:
|
||||||
|
return TrackingPrediction(velocity=self._velocity, keypoints=latest_3d_pose)
|
||||||
|
|
||||||
|
|
||||||
|
class OneEuroFilter(GenericVelocityFilter):
|
||||||
|
"""
|
||||||
|
Implementation of the 1€ filter (One Euro Filter) for smoothing keypoint data.
|
||||||
|
|
||||||
|
The 1€ filter is an adaptive low-pass filter that adjusts its cutoff frequency
|
||||||
|
based on movement speed to reduce jitter during slow movements while maintaining
|
||||||
|
responsiveness during fast movements.
|
||||||
|
|
||||||
|
Reference: https://cristal.univ-lille.fr/~casiez/1euro/
|
||||||
|
"""
|
||||||
|
|
||||||
|
_x_filtered: Float[Array, "J 3"]
|
||||||
|
_dx_filtered: Optional[Float[Array, "J 3"]] = None
|
||||||
|
_last_timestamp: datetime
|
||||||
|
_min_cutoff: float
|
||||||
|
_beta: float
|
||||||
|
_d_cutoff: float
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
keypoints: Float[Array, "J 3"],
|
||||||
|
timestamp: datetime,
|
||||||
|
min_cutoff: float = 1.0,
|
||||||
|
beta: float = 0.0,
|
||||||
|
d_cutoff: float = 1.0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the One Euro Filter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keypoints: Initial keypoints positions
|
||||||
|
timestamp: Initial timestamp
|
||||||
|
min_cutoff: Minimum cutoff frequency (lower = more smoothing)
|
||||||
|
beta: Speed coefficient (higher = less lag during fast movements)
|
||||||
|
d_cutoff: Cutoff frequency for the derivative filter
|
||||||
|
"""
|
||||||
|
self._last_timestamp = timestamp
|
||||||
|
|
||||||
|
# Filter parameters
|
||||||
|
self._min_cutoff = min_cutoff
|
||||||
|
self._beta = beta
|
||||||
|
self._d_cutoff = d_cutoff
|
||||||
|
|
||||||
|
# Filter state
|
||||||
|
self._x_filtered = keypoints # Position filter state
|
||||||
|
self._dx_filtered = None # Initially no velocity estimate
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def _smoothing_factor(self, cutoff: float, dt: float) -> float: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def _smoothing_factor(
|
||||||
|
self, cutoff: Float[Array, "J"], dt: float
|
||||||
|
) -> Float[Array, "J"]: ...
|
||||||
|
|
||||||
|
@jaxtyped(typechecker=beartype)
|
||||||
|
def _smoothing_factor(
|
||||||
|
self, cutoff: Union[float, Float[Array, "J"]], dt: float
|
||||||
|
) -> Union[float, Float[Array, "J"]]:
|
||||||
|
"""Calculate the smoothing factor for the low-pass filter."""
|
||||||
|
r = 2 * jnp.pi * cutoff * dt
|
||||||
|
return r / (r + 1)
|
||||||
|
|
||||||
|
@jaxtyped(typechecker=beartype)
|
||||||
|
def _exponential_smoothing(
|
||||||
|
self,
|
||||||
|
a: Union[float, Float[Array, "J"]],
|
||||||
|
x: Float[Array, "J 3"],
|
||||||
|
x_prev: Float[Array, "J 3"],
|
||||||
|
) -> Float[Array, "J 3"]:
|
||||||
|
"""Apply exponential smoothing to the input."""
|
||||||
|
return a * x + (1 - a) * x_prev
|
||||||
|
|
||||||
|
def predict(self, timestamp: datetime) -> TrackingPrediction:
|
||||||
|
"""
|
||||||
|
Predict keypoints position at a given timestamp.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timestamp: Timestamp for prediction
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TrackingPrediction with velocity and keypoints
|
||||||
|
"""
|
||||||
|
dt = (timestamp - self._last_timestamp).total_seconds()
|
||||||
|
|
||||||
|
if self._dx_filtered is None:
|
||||||
|
return TrackingPrediction(
|
||||||
|
velocity=None,
|
||||||
|
keypoints=self._x_filtered,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
predicted_keypoints = self._x_filtered + self._dx_filtered * dt
|
||||||
|
return TrackingPrediction(
|
||||||
|
velocity=self._dx_filtered,
|
||||||
|
keypoints=predicted_keypoints,
|
||||||
|
)
|
||||||
|
|
||||||
|
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
|
||||||
|
"""
|
||||||
|
Update the filter with new measurements.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keypoints: New keypoint measurements
|
||||||
|
timestamp: Timestamp of the measurements
|
||||||
|
"""
|
||||||
|
dt = (timestamp - self._last_timestamp).total_seconds()
|
||||||
|
if dt <= 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"new timestamp is not greater than the last timestamp; expecting: {timestamp} > {self._last_timestamp}"
|
||||||
|
)
|
||||||
|
|
||||||
|
dx = (keypoints - self._x_filtered) / dt
|
||||||
|
|
||||||
|
# Determine cutoff frequency based on movement speed
|
||||||
|
cutoff = self._min_cutoff + self._beta * jnp.linalg.norm(
|
||||||
|
dx, axis=-1, keepdims=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply low-pass filter to velocity
|
||||||
|
a_d = self._smoothing_factor(self._d_cutoff, dt)
|
||||||
|
self._dx_filtered = self._exponential_smoothing(
|
||||||
|
a_d,
|
||||||
|
dx,
|
||||||
|
(
|
||||||
|
jnp.zeros_like(keypoints)
|
||||||
|
if self._dx_filtered is None
|
||||||
|
else self._dx_filtered
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply low-pass filter to position with adaptive cutoff
|
||||||
|
a_cutoff = self._smoothing_factor(jnp.asarray(cutoff), dt)
|
||||||
|
self._x_filtered = self._exponential_smoothing(
|
||||||
|
a_cutoff, keypoints, self._x_filtered
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update timestamp
|
||||||
|
self._last_timestamp = timestamp
|
||||||
|
|
||||||
|
def get(self) -> TrackingPrediction:
|
||||||
|
"""
|
||||||
|
Get the current state of the filter.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TrackingPrediction with velocity and keypoints
|
||||||
|
"""
|
||||||
|
return TrackingPrediction(
|
||||||
|
velocity=self._dx_filtered,
|
||||||
|
keypoints=self._x_filtered,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@jaxtyped(typechecker=beartype)
|
@jaxtyped(typechecker=beartype)
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class Tracking:
|
class TrackingState:
|
||||||
id: int
|
|
||||||
"""
|
"""
|
||||||
The tracking id
|
immutable state of a tracking
|
||||||
"""
|
"""
|
||||||
|
|
||||||
keypoints: Float[Array, "J 3"]
|
keypoints: Float[Array, "J 3"]
|
||||||
"""
|
"""
|
||||||
The 3D keypoints of the tracking
|
The 3D keypoints of the tracking
|
||||||
@ -41,50 +457,97 @@ class Tracking:
|
|||||||
The last active timestamp of the tracking
|
The last active timestamp of the tracking
|
||||||
"""
|
"""
|
||||||
|
|
||||||
historical_detections: PVector[Detection]
|
historical_detections_by_camera: PMap[CameraID, Detection]
|
||||||
"""
|
"""
|
||||||
Historical detections of the tracking.
|
Historical detections of the tracking.
|
||||||
|
|
||||||
Used for 3D re-triangulation
|
Used for 3D re-triangulation
|
||||||
"""
|
"""
|
||||||
|
|
||||||
velocity: Optional[Float[Array, "3"]] = None
|
|
||||||
"""
|
|
||||||
Could be `None`. Like when the 3D pose is initialized.
|
|
||||||
|
|
||||||
`velocity` should be updated when target association yields a new
|
class Tracking:
|
||||||
3D pose.
|
id: TrackingID
|
||||||
"""
|
state: TrackingState
|
||||||
|
velocity_filter: GenericVelocityFilter
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
id: TrackingID,
|
||||||
|
state: TrackingState,
|
||||||
|
velocity_filter: Optional[GenericVelocityFilter] = None,
|
||||||
|
):
|
||||||
|
self.id = id
|
||||||
|
self.state = state
|
||||||
|
self.velocity_filter = velocity_filter or DummyVelocityFilter(state.keypoints)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"Tracking({self.id}, {self.last_active_timestamp})"
|
return f"Tracking({self.id}, {self.state.last_active_timestamp})"
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def predict(self, time: float) -> Float[Array, "J 3"]:
|
||||||
|
"""
|
||||||
|
predict the keypoints at a given time
|
||||||
|
|
||||||
|
Args:
|
||||||
|
time: the time in seconds to predict the keypoints
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the predicted keypoints
|
||||||
|
"""
|
||||||
|
... # pylint: disable=unnecessary-ellipsis
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def predict(self, time: timedelta) -> Float[Array, "J 3"]:
|
||||||
|
"""
|
||||||
|
predict the keypoints at a given time
|
||||||
|
|
||||||
|
Args:
|
||||||
|
time: the time delta to predict the keypoints
|
||||||
|
"""
|
||||||
|
... # pylint: disable=unnecessary-ellipsis
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def predict(self, time: datetime) -> Float[Array, "J 3"]:
|
||||||
|
"""
|
||||||
|
predict the keypoints at a given time
|
||||||
|
|
||||||
|
Args:
|
||||||
|
time: the timestamp to predict the keypoints
|
||||||
|
"""
|
||||||
|
... # pylint: disable=unnecessary-ellipsis
|
||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
self,
|
self,
|
||||||
delta_t_s: float,
|
time: float | timedelta | datetime,
|
||||||
) -> Float[Array, "J 3"]:
|
) -> Float[Array, "J 3"]:
|
||||||
"""
|
if isinstance(time, timedelta):
|
||||||
Predict the 3D pose of a tracking based on its velocity.
|
timestamp = self.state.last_active_timestamp + time
|
||||||
JAX-friendly implementation that avoids Python control flow.
|
elif isinstance(time, datetime):
|
||||||
|
timestamp = time
|
||||||
Args:
|
|
||||||
delta_t_s: Time delta in seconds
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Predicted 3D pose keypoints
|
|
||||||
"""
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Step 1 – decide velocity on the Python side
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
if self.velocity is None:
|
|
||||||
velocity = jnp.zeros_like(self.keypoints) # (J, 3)
|
|
||||||
else:
|
else:
|
||||||
velocity = self.velocity # (J, 3)
|
timestamp = self.state.last_active_timestamp + timedelta(seconds=time)
|
||||||
|
# pylint: disable-next=unsubscriptable-object
|
||||||
|
return self.velocity_filter.predict(timestamp)["keypoints"]
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
def update(self, new_3d_pose: Float[Array, "J 3"], timestamp: datetime) -> None:
|
||||||
# Step 2 – pure JAX math
|
"""
|
||||||
# ------------------------------------------------------------------
|
update the tracking with a new 3D pose
|
||||||
return self.keypoints + velocity * delta_t_s
|
|
||||||
|
Note:
|
||||||
|
equivalent to call `velocity_filter.update(new_3d_pose, timestamp)`
|
||||||
|
"""
|
||||||
|
self.velocity_filter.update(new_3d_pose, timestamp)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def velocity(self) -> Float[Array, "J 3"]:
|
||||||
|
"""
|
||||||
|
The velocity of the tracking for each keypoint
|
||||||
|
"""
|
||||||
|
# pylint: disable-next=unsubscriptable-object
|
||||||
|
if (vel := self.velocity_filter.get()["velocity"]) is None:
|
||||||
|
return jnp.zeros_like(self.state.keypoints)
|
||||||
|
else:
|
||||||
|
return vel
|
||||||
|
|
||||||
|
|
||||||
@jaxtyped(typechecker=beartype)
|
@jaxtyped(typechecker=beartype)
|
||||||
@ -98,9 +561,9 @@ class AffinityResult:
|
|||||||
trackings: Sequence[Tracking]
|
trackings: Sequence[Tracking]
|
||||||
detections: Sequence[Detection]
|
detections: Sequence[Detection]
|
||||||
indices_T: Int[Array, "T"] # pylint: disable=invalid-name
|
indices_T: Int[Array, "T"] # pylint: disable=invalid-name
|
||||||
indices_D: Int[Array, "D"] # pylint: disable=invalid-name
|
indices_D: Int[Array, "T"] # pylint: disable=invalid-name
|
||||||
|
|
||||||
def tracking_detections(
|
def tracking_association(
|
||||||
self,
|
self,
|
||||||
) -> Generator[tuple[float, Tracking, Detection], None, None]:
|
) -> Generator[tuple[float, Tracking, Detection], None, None]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,23 +1,11 @@
|
|||||||
import awkward as ak
|
from narwhals import Boolean
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pathlib import Path
|
|
||||||
from matplotlib import pyplot as plt
|
|
||||||
import cv2
|
import cv2
|
||||||
from typing import Optional, cast, Final, TypedDict
|
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
|
||||||
Generator,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
TypeAlias,
|
TypeAlias,
|
||||||
TypedDict,
|
TypedDict,
|
||||||
cast,
|
|
||||||
overload,
|
|
||||||
)
|
)
|
||||||
from jaxtyping import Array, Float, Num, jaxtyped
|
from jaxtyping import Array, Num
|
||||||
from shapely import box
|
|
||||||
from app.visualize.whole_body import visualize_whole_body
|
|
||||||
import pyproj
|
|
||||||
from shapely.geometry import Polygon
|
from shapely.geometry import Polygon
|
||||||
from sympy import false, true
|
from sympy import false, true
|
||||||
|
|
||||||
@ -94,9 +82,9 @@ def calculaterCubeVersices(position, dimensions):
|
|||||||
# 获得盒子三维坐标系
|
# 获得盒子三维坐标系
|
||||||
def calculater_box_3d_points():
|
def calculater_box_3d_points():
|
||||||
# 盒子原点位置,相对于六面体中心偏移
|
# 盒子原点位置,相对于六面体中心偏移
|
||||||
box_ori_potision = [0.205 + 0.2, 0.205 + 0.50, -0.205 - 0.2]
|
box_ori_potision = [0.205 + 0.2, 0.205 + 0.50, -0.205 - 0.45]
|
||||||
# 盒子边长,宽:1.5米,高:1.5米,深度:1.8米
|
# 盒子边长,宽:1.5米,高:1.5米,深度:1.8米
|
||||||
box_geometry = [0.65, 1.8, 1.5]
|
box_geometry = [0.65, 1.8, 1]
|
||||||
filter_box_points_3d = calculaterCubeVersices(box_ori_potision, box_geometry)
|
filter_box_points_3d = calculaterCubeVersices(box_ori_potision, box_geometry)
|
||||||
filter_box_points_3d = {
|
filter_box_points_3d = {
|
||||||
str(index): element for index, element in enumerate(filter_box_points_3d)
|
str(index): element for index, element in enumerate(filter_box_points_3d)
|
||||||
@ -202,22 +190,19 @@ def get_contours(union_polygon):
|
|||||||
|
|
||||||
|
|
||||||
# 筛选落在盒子二维重投影区域内的关键点信息
|
# 筛选落在盒子二维重投影区域内的关键点信息
|
||||||
def filter_kps_box(kps, contours):
|
def filter_kps_in_contours(kps, contours) -> Boolean:
|
||||||
# 存放筛选后的目标框
|
|
||||||
# new_boxes_data = []
|
|
||||||
# 存放筛选后的2d姿态点数据
|
|
||||||
# new_kps_data = []
|
|
||||||
# 存放筛选后的2d姿态置信度
|
|
||||||
# 遍历未筛选的目标框
|
|
||||||
|
|
||||||
x1, y1 = kps[0]
|
# 4 5 16 17
|
||||||
x2, y2 = kps[16]
|
keypoint_index: list[list[int]] = [[4, 5], [16, 17]]
|
||||||
# 保留目标框中心在范围内的坐标点
|
centers = []
|
||||||
x_center = (x1 + x2) / 2
|
for element_keypoint in keypoint_index:
|
||||||
y_centet = (y1 + y2) / 2
|
x1, y1 = kps[element_keypoint[0]]
|
||||||
if point_in_polygon([x1, y1], contours) and point_in_polygon([x2, y2], contours):
|
x2, y2 = kps[element_keypoint[1]]
|
||||||
# if point_in_polygon([x_center, y_centet], contours) :
|
centers.append([(x1 + x2) / 2, (y1 + y2) / 2])
|
||||||
|
|
||||||
|
if point_in_polygon(centers[0], contours) and point_in_polygon(
|
||||||
|
centers[1], contours
|
||||||
|
):
|
||||||
return true
|
return true
|
||||||
else:
|
else:
|
||||||
return false
|
return false
|
||||||
# return new_kps_data
|
|
||||||
|
|||||||
2215
play.ipynb
2215
play.ipynb
File diff suppressed because it is too large
Load Diff
319
playground.py
319
playground.py
@ -31,13 +31,13 @@ from typing import (
|
|||||||
TypeVar,
|
TypeVar,
|
||||||
cast,
|
cast,
|
||||||
overload,
|
overload,
|
||||||
|
Iterable,
|
||||||
)
|
)
|
||||||
|
|
||||||
import awkward as ak
|
import awkward as ak
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import orjson
|
|
||||||
from beartype import beartype
|
from beartype import beartype
|
||||||
from beartype.typing import Mapping, Sequence
|
from beartype.typing import Mapping, Sequence
|
||||||
from cv2 import undistortPoints
|
from cv2 import undistortPoints
|
||||||
@ -46,9 +46,10 @@ from jaxtyping import Array, Float, Num, jaxtyped
|
|||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from numpy.typing import ArrayLike
|
from numpy.typing import ArrayLike
|
||||||
from optax.assignment import hungarian_algorithm as linear_sum_assignment
|
from optax.assignment import hungarian_algorithm as linear_sum_assignment
|
||||||
from pyrsistent import v, pvector
|
from pyrsistent import pvector, v, m, pmap, PMap, freeze, thaw
|
||||||
from scipy.spatial.transform import Rotation as R
|
from scipy.spatial.transform import Rotation as R
|
||||||
from typing_extensions import deprecated
|
from typing_extensions import deprecated
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
from app.camera import (
|
from app.camera import (
|
||||||
Camera,
|
Camera,
|
||||||
@ -59,17 +60,21 @@ from app.camera import (
|
|||||||
classify_by_camera,
|
classify_by_camera,
|
||||||
)
|
)
|
||||||
from app.solver._old import GLPKSolver
|
from app.solver._old import GLPKSolver
|
||||||
from app.tracking import AffinityResult, Tracking
|
from app.tracking import (
|
||||||
|
TrackingID,
|
||||||
|
AffinityResult,
|
||||||
|
LastDifferenceVelocityFilter,
|
||||||
|
Tracking,
|
||||||
|
TrackingState,
|
||||||
|
)
|
||||||
from app.visualize.whole_body import visualize_whole_body
|
from app.visualize.whole_body import visualize_whole_body
|
||||||
|
|
||||||
NDArray: TypeAlias = np.ndarray
|
NDArray: TypeAlias = np.ndarray
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
CAMERA_PATH = Path(
|
DATASET_PATH = Path("samples") / "04_02"
|
||||||
"/home/admin/Documents/ActualTest_QuanCheng/camera_ex_params_1_2025_4_20/camera_params"
|
AK_CAMERA_DATASET: ak.Array = ak.from_parquet(DATASET_PATH / "camera_params.parquet") # type: ignore
|
||||||
)
|
DELTA_T_MIN = timedelta(milliseconds=1)
|
||||||
AK_CAMERA_DATASET: ak.Array = ak.from_parquet(CAMERA_PATH / "camera_params.parquet")
|
|
||||||
DELTA_T_MIN = timedelta(milliseconds=10)
|
|
||||||
display(AK_CAMERA_DATASET)
|
display(AK_CAMERA_DATASET)
|
||||||
|
|
||||||
|
|
||||||
@ -104,13 +109,6 @@ class ExternalCameraParams(TypedDict):
|
|||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
|
||||||
# %%
|
|
||||||
DATASET_PATH = Path(
|
|
||||||
"/home/admin/Documents/ActualTest_QuanCheng/camera_ex_params_1_2025_4_20/detect_result/segement_1"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def read_dataset_by_port(port: int) -> ak.Array:
|
def read_dataset_by_port(port: int) -> ak.Array:
|
||||||
P = DATASET_PATH / f"{port}.parquet"
|
P = DATASET_PATH / f"{port}.parquet"
|
||||||
return ak.from_parquet(P)
|
return ak.from_parquet(P)
|
||||||
@ -119,7 +117,6 @@ def read_dataset_by_port(port: int) -> ak.Array:
|
|||||||
KEYPOINT_DATASET = {
|
KEYPOINT_DATASET = {
|
||||||
int(p): read_dataset_by_port(p) for p in ak.to_numpy(AK_CAMERA_DATASET["port"])
|
int(p): read_dataset_by_port(p) for p in ak.to_numpy(AK_CAMERA_DATASET["port"])
|
||||||
}
|
}
|
||||||
display(KEYPOINT_DATASET)
|
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
@ -194,8 +191,6 @@ def preprocess_keypoint_dataset(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# %%
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
DetectionGenerator: TypeAlias = Generator[Detection, None, None]
|
DetectionGenerator: TypeAlias = Generator[Detection, None, None]
|
||||||
|
|
||||||
@ -338,31 +333,13 @@ def homogeneous_to_euclidean(
|
|||||||
|
|
||||||
# %%
|
# %%
|
||||||
FPS = 24
|
FPS = 24
|
||||||
|
image_gen_5600 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5600], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5600][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore
|
||||||
image_gen_5601 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5601], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5601][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore
|
image_gen_5601 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5601], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5601][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore
|
||||||
image_gen_5602 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5602], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5602][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore
|
image_gen_5602 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5602], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5602][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore
|
||||||
image_gen_5603 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5603], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5603][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore
|
|
||||||
image_gen_5604 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5604], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5604][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore
|
|
||||||
image_gen_5605 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5605], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5605][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore
|
|
||||||
image_gen_5606 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5606], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5606][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore
|
|
||||||
image_gen_5607 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5607], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5607][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore
|
|
||||||
image_gen_5608 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5608], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5608][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore
|
|
||||||
image_gen_5609 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5609], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5609][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
display(1 / FPS)
|
display(1 / FPS)
|
||||||
sync_gen = sync_batch_gen(
|
sync_gen = sync_batch_gen(
|
||||||
[
|
[image_gen_5600, image_gen_5601, image_gen_5602], timedelta(seconds=1 / FPS)
|
||||||
image_gen_5601,
|
|
||||||
# image_gen_5602,
|
|
||||||
# image_gen_5603,
|
|
||||||
image_gen_5604,
|
|
||||||
image_gen_5605,
|
|
||||||
image_gen_5606,
|
|
||||||
# image_gen_5607,
|
|
||||||
image_gen_5608,
|
|
||||||
image_gen_5609,
|
|
||||||
],
|
|
||||||
timedelta(seconds=1 / FPS),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
@ -375,7 +352,7 @@ display(sorted_detections)
|
|||||||
display(
|
display(
|
||||||
list(
|
list(
|
||||||
map(
|
map(
|
||||||
lambda x: {"timestamp": str(x.timestamp), "camera": x.camera.id, "keypoint":x.keypoints.shape},
|
lambda x: {"timestamp": str(x.timestamp), "camera": x.camera.id},
|
||||||
sorted_detections,
|
sorted_detections,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -443,7 +420,6 @@ for el in clusters_detections[0]:
|
|||||||
p = plt.imshow(im)
|
p = plt.imshow(im)
|
||||||
display(p)
|
display(p)
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
im_prime = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8)
|
im_prime = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8)
|
||||||
for el in clusters_detections[1]:
|
for el in clusters_detections[1]:
|
||||||
@ -535,6 +511,142 @@ def triangulate_points_from_multiple_views_linear(
|
|||||||
return vmap_triangulate(proj_matrices, points, conf)
|
return vmap_triangulate(proj_matrices, points, conf)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
@jaxtyped(typechecker=beartype)
|
||||||
|
def triangulate_one_point_from_multiple_views_linear_time_weighted(
|
||||||
|
proj_matrices: Float[Array, "N 3 4"],
|
||||||
|
points: Num[Array, "N 2"],
|
||||||
|
delta_t: Num[Array, "N"],
|
||||||
|
lambda_t: float = 10.0,
|
||||||
|
confidences: Optional[Float[Array, "N"]] = None,
|
||||||
|
) -> Float[Array, "3"]:
|
||||||
|
"""
|
||||||
|
Triangulate one point from multiple views with time-weighted linear least squares.
|
||||||
|
|
||||||
|
Implements the incremental reconstruction method from "Cross-View Tracking for Multi-Human 3D Pose"
|
||||||
|
with weighting formula: w_i = exp(-λ_t(t-t_i)) / ||c^i^T||_2
|
||||||
|
|
||||||
|
Args:
|
||||||
|
proj_matrices: Shape (N, 3, 4) projection matrices sequence
|
||||||
|
points: Shape (N, 2) point coordinates sequence
|
||||||
|
delta_t: Time differences between current time and each observation (in seconds)
|
||||||
|
lambda_t: Time penalty rate (higher values decrease influence of older observations)
|
||||||
|
confidences: Shape (N,) confidence values in range [0.0, 1.0]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
point_3d: Shape (3,) triangulated 3D point
|
||||||
|
"""
|
||||||
|
assert len(proj_matrices) == len(points)
|
||||||
|
assert len(delta_t) == len(points)
|
||||||
|
|
||||||
|
N = len(proj_matrices)
|
||||||
|
|
||||||
|
# Prepare confidence weights
|
||||||
|
confi: Float[Array, "N"]
|
||||||
|
if confidences is None:
|
||||||
|
confi = jnp.ones(N, dtype=np.float32)
|
||||||
|
else:
|
||||||
|
confi = jnp.sqrt(jnp.clip(confidences, 0, 1))
|
||||||
|
|
||||||
|
A = jnp.zeros((N * 2, 4), dtype=np.float32)
|
||||||
|
|
||||||
|
# First build the coefficient matrix without weights
|
||||||
|
for i in range(N):
|
||||||
|
x, y = points[i]
|
||||||
|
A = A.at[2 * i].set(proj_matrices[i, 2] * x - proj_matrices[i, 0])
|
||||||
|
A = A.at[2 * i + 1].set(proj_matrices[i, 2] * y - proj_matrices[i, 1])
|
||||||
|
|
||||||
|
# Then apply the time-based and confidence weights
|
||||||
|
for i in range(N):
|
||||||
|
# Calculate time-decay weight: e^(-λ_t * Δt)
|
||||||
|
time_weight = jnp.exp(-lambda_t * delta_t[i])
|
||||||
|
|
||||||
|
# Calculate normalization factor: ||c^i^T||_2
|
||||||
|
row_norm_1 = jnp.linalg.norm(A[2 * i])
|
||||||
|
row_norm_2 = jnp.linalg.norm(A[2 * i + 1])
|
||||||
|
|
||||||
|
# Apply combined weight: time_weight / row_norm * confidence
|
||||||
|
w1 = (time_weight / row_norm_1) * confi[i]
|
||||||
|
w2 = (time_weight / row_norm_2) * confi[i]
|
||||||
|
|
||||||
|
A = A.at[2 * i].mul(w1)
|
||||||
|
A = A.at[2 * i + 1].mul(w2)
|
||||||
|
|
||||||
|
# Solve using SVD
|
||||||
|
_, _, vh = jnp.linalg.svd(A, full_matrices=False)
|
||||||
|
point_3d_homo = vh[-1] # shape (4,)
|
||||||
|
|
||||||
|
# Ensure homogeneous coordinate is positive
|
||||||
|
point_3d_homo = jnp.where(
|
||||||
|
point_3d_homo[3] < 0,
|
||||||
|
-point_3d_homo,
|
||||||
|
point_3d_homo,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert from homogeneous to Euclidean coordinates
|
||||||
|
point_3d = point_3d_homo[:3] / point_3d_homo[3]
|
||||||
|
return point_3d
|
||||||
|
|
||||||
|
|
||||||
|
@jaxtyped(typechecker=beartype)
|
||||||
|
def triangulate_points_from_multiple_views_linear_time_weighted(
|
||||||
|
proj_matrices: Float[Array, "N 3 4"],
|
||||||
|
points: Num[Array, "N P 2"],
|
||||||
|
delta_t: Num[Array, "N"],
|
||||||
|
lambda_t: float = 10.0,
|
||||||
|
confidences: Optional[Float[Array, "N P"]] = None,
|
||||||
|
) -> Float[Array, "P 3"]:
|
||||||
|
"""
|
||||||
|
Vectorized version that triangulates P points from N camera views with time-weighting.
|
||||||
|
|
||||||
|
This function uses JAX's vmap to efficiently triangulate multiple points in parallel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
proj_matrices: Shape (N, 3, 4) projection matrices for N cameras
|
||||||
|
points: Shape (N, P, 2) 2D points for P keypoints across N cameras
|
||||||
|
delta_t: Shape (N,) time differences between current time and each camera's timestamp (seconds)
|
||||||
|
lambda_t: Time penalty rate (higher values decrease influence of older observations)
|
||||||
|
confidences: Shape (N, P) confidence values for each point in each camera
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
points_3d: Shape (P, 3) triangulated 3D points
|
||||||
|
"""
|
||||||
|
N, P, _ = points.shape
|
||||||
|
assert (
|
||||||
|
proj_matrices.shape[0] == N
|
||||||
|
), "Number of projection matrices must match number of cameras"
|
||||||
|
assert delta_t.shape[0] == N, "Number of time deltas must match number of cameras"
|
||||||
|
|
||||||
|
if confidences is None:
|
||||||
|
# Create uniform confidences if none provided
|
||||||
|
conf = jnp.ones((N, P), dtype=jnp.float32)
|
||||||
|
else:
|
||||||
|
conf = confidences
|
||||||
|
|
||||||
|
# Define the vmapped version of the single-point function
|
||||||
|
# We map over the second dimension (P points) of the input arrays
|
||||||
|
vmap_triangulate = jax.vmap(
|
||||||
|
triangulate_one_point_from_multiple_views_linear_time_weighted,
|
||||||
|
in_axes=(
|
||||||
|
None,
|
||||||
|
1,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
1,
|
||||||
|
), # proj_matrices and delta_t static, map over points
|
||||||
|
out_axes=0, # Output has first dimension corresponding to points
|
||||||
|
)
|
||||||
|
|
||||||
|
# For each point p, extract the 2D coordinates from all cameras and triangulate
|
||||||
|
return vmap_triangulate(
|
||||||
|
proj_matrices, # (N, 3, 4) - static across points
|
||||||
|
points, # (N, P, 2) - map over dim 1 (P)
|
||||||
|
delta_t, # (N,) - static across points
|
||||||
|
lambda_t, # scalar - static
|
||||||
|
conf, # (N, P) - map over dim 1 (P)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
|
||||||
|
|
||||||
@ -555,6 +667,21 @@ def triangle_from_cluster(
|
|||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
def group_by_cluster_by_camera(
|
||||||
|
cluster: Sequence[Detection],
|
||||||
|
) -> PMap[CameraID, Detection]:
|
||||||
|
"""
|
||||||
|
group the detections by camera, and preserve the latest detection for each camera
|
||||||
|
"""
|
||||||
|
r: dict[CameraID, Detection] = {}
|
||||||
|
for el in cluster:
|
||||||
|
if el.camera.id in r:
|
||||||
|
eld = r[el.camera.id]
|
||||||
|
preserved = max([eld, el], key=lambda x: x.timestamp)
|
||||||
|
r[el.camera.id] = preserved
|
||||||
|
return pmap(r)
|
||||||
|
|
||||||
|
|
||||||
class GlobalTrackingState:
|
class GlobalTrackingState:
|
||||||
_last_id: int
|
_last_id: int
|
||||||
_trackings: dict[int, Tracking]
|
_trackings: dict[int, Tracking]
|
||||||
@ -573,13 +700,21 @@ class GlobalTrackingState:
|
|||||||
return shallow_copy(self._trackings)
|
return shallow_copy(self._trackings)
|
||||||
|
|
||||||
def add_tracking(self, cluster: Sequence[Detection]) -> Tracking:
|
def add_tracking(self, cluster: Sequence[Detection]) -> Tracking:
|
||||||
|
if len(cluster) < 2:
|
||||||
|
raise ValueError(
|
||||||
|
"cluster must contain at least 2 detections to form a tracking"
|
||||||
|
)
|
||||||
kps_3d, latest_timestamp = triangle_from_cluster(cluster)
|
kps_3d, latest_timestamp = triangle_from_cluster(cluster)
|
||||||
next_id = self._last_id + 1
|
next_id = self._last_id + 1
|
||||||
tracking = Tracking(
|
tracking_state = TrackingState(
|
||||||
id=next_id,
|
|
||||||
keypoints=kps_3d,
|
keypoints=kps_3d,
|
||||||
last_active_timestamp=latest_timestamp,
|
last_active_timestamp=latest_timestamp,
|
||||||
historical_detections=v(*cluster),
|
historical_detections_by_camera=group_by_cluster_by_camera(cluster),
|
||||||
|
)
|
||||||
|
tracking = Tracking(
|
||||||
|
id=next_id,
|
||||||
|
state=tracking_state,
|
||||||
|
velocity_filter=LastDifferenceVelocityFilter(kps_3d, latest_timestamp),
|
||||||
)
|
)
|
||||||
self._trackings[next_id] = tracking
|
self._trackings[next_id] = tracking
|
||||||
self._last_id = next_id
|
self._last_id = next_id
|
||||||
@ -702,11 +837,7 @@ def perpendicular_distance_camera_2d_points_to_tracking_raycasting(
|
|||||||
Array of perpendicular distances for each keypoint
|
Array of perpendicular distances for each keypoint
|
||||||
"""
|
"""
|
||||||
camera = detection.camera
|
camera = detection.camera
|
||||||
# Use the delta_t supplied by the caller, but clamp to DELTA_T_MIN to
|
predicted_pose = tracking.predict(delta_t)
|
||||||
# avoid division-by-zero / exploding affinities.
|
|
||||||
delta_t = max(delta_t, DELTA_T_MIN)
|
|
||||||
delta_t_s = delta_t.total_seconds()
|
|
||||||
predicted_pose = tracking.predict(delta_t_s)
|
|
||||||
|
|
||||||
# Back-project the 2D points to 3D space
|
# Back-project the 2D points to 3D space
|
||||||
# intersection with z=0 plane
|
# intersection with z=0 plane
|
||||||
@ -786,12 +917,12 @@ def calculate_tracking_detection_affinity(
|
|||||||
Combined affinity score
|
Combined affinity score
|
||||||
"""
|
"""
|
||||||
camera = detection.camera
|
camera = detection.camera
|
||||||
delta_t_raw = detection.timestamp - tracking.last_active_timestamp
|
delta_t_raw = detection.timestamp - tracking.state.last_active_timestamp
|
||||||
# Clamp delta_t to avoid division-by-zero / exploding affinity.
|
# Clamp delta_t to avoid division-by-zero / exploding affinity.
|
||||||
delta_t = max(delta_t_raw, DELTA_T_MIN)
|
delta_t = max(delta_t_raw, DELTA_T_MIN)
|
||||||
|
|
||||||
# Calculate 2D affinity
|
# Calculate 2D affinity
|
||||||
tracking_2d_projection = camera.project(tracking.keypoints)
|
tracking_2d_projection = camera.project(tracking.state.keypoints)
|
||||||
w, h = camera.params.image_size
|
w, h = camera.params.image_size
|
||||||
distance_2d = calculate_distance_2d(
|
distance_2d = calculate_distance_2d(
|
||||||
tracking_2d_projection,
|
tracking_2d_projection,
|
||||||
@ -871,7 +1002,7 @@ def calculate_camera_affinity_matrix_jax(
|
|||||||
|
|
||||||
# === Tracking-side tensors ===
|
# === Tracking-side tensors ===
|
||||||
kps3d_trk: Float[Array, "T J 3"] = jnp.stack(
|
kps3d_trk: Float[Array, "T J 3"] = jnp.stack(
|
||||||
[trk.keypoints for trk in trackings]
|
[trk.state.keypoints for trk in trackings]
|
||||||
) # (T, J, 3)
|
) # (T, J, 3)
|
||||||
J = kps3d_trk.shape[1]
|
J = kps3d_trk.shape[1]
|
||||||
# === Detection-side tensors ===
|
# === Detection-side tensors ===
|
||||||
@ -888,12 +1019,12 @@ def calculate_camera_affinity_matrix_jax(
|
|||||||
# --- timestamps ----------
|
# --- timestamps ----------
|
||||||
t0 = min(
|
t0 = min(
|
||||||
chain(
|
chain(
|
||||||
(trk.last_active_timestamp for trk in trackings),
|
(trk.state.last_active_timestamp for trk in trackings),
|
||||||
(det.timestamp for det in camera_detections),
|
(det.timestamp for det in camera_detections),
|
||||||
)
|
)
|
||||||
).timestamp() # common origin (float)
|
).timestamp() # common origin (float)
|
||||||
ts_trk = jnp.array(
|
ts_trk = jnp.array(
|
||||||
[trk.last_active_timestamp.timestamp() - t0 for trk in trackings],
|
[trk.state.last_active_timestamp.timestamp() - t0 for trk in trackings],
|
||||||
dtype=jnp.float32, # now small, ms-scale fits in fp32
|
dtype=jnp.float32, # now small, ms-scale fits in fp32
|
||||||
)
|
)
|
||||||
ts_det = jnp.array(
|
ts_det = jnp.array(
|
||||||
@ -1064,8 +1195,82 @@ display(affinities)
|
|||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
def update_tracking(tracking: Tracking, detection: Detection):
|
def affinity_result_by_tracking(
|
||||||
delta_t_ = detection.timestamp - tracking.last_active_timestamp
|
results: Iterable[AffinityResult],
|
||||||
delta_t = max(delta_t_, DELTA_T_MIN)
|
min_affinity: float = 0.0,
|
||||||
|
) -> dict[TrackingID, list[Detection]]:
|
||||||
|
"""
|
||||||
|
Group affinity results by target ID.
|
||||||
|
|
||||||
return tracking
|
Args:
|
||||||
|
results: the affinity results to group
|
||||||
|
min_affinity: the minimum affinity to consider
|
||||||
|
Returns:
|
||||||
|
a dictionary mapping tracking IDs to a list of detections
|
||||||
|
"""
|
||||||
|
res: dict[TrackingID, list[Detection]] = defaultdict(list)
|
||||||
|
for affinity_result in results:
|
||||||
|
for affinity, t, d in affinity_result.tracking_association():
|
||||||
|
if affinity < min_affinity:
|
||||||
|
continue
|
||||||
|
res[t.id].append(d)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def update_tracking(
|
||||||
|
tracking: Tracking,
|
||||||
|
detections: Sequence[Detection],
|
||||||
|
max_delta_t: timedelta = timedelta(milliseconds=100),
|
||||||
|
lambda_t: float = 10.0,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
update the tracking with a new set of detections
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tracking: the tracking to update
|
||||||
|
detections: the detections to update the tracking with
|
||||||
|
max_delta_t: the maximum time difference between the last active timestamp and the latest detection
|
||||||
|
lambda_t: the lambda value for the time difference
|
||||||
|
|
||||||
|
Note:
|
||||||
|
the function would mutate the tracking object
|
||||||
|
"""
|
||||||
|
last_active_timestamp = tracking.state.last_active_timestamp
|
||||||
|
latest_timestamp = max(d.timestamp for d in detections)
|
||||||
|
d = thaw(tracking.state.historical_detections_by_camera)
|
||||||
|
for detection in detections:
|
||||||
|
d[detection.camera.id] = detection
|
||||||
|
for camera_id, detection in d.items():
|
||||||
|
if detection.timestamp - latest_timestamp > max_delta_t:
|
||||||
|
del d[camera_id]
|
||||||
|
new_detections = freeze(d)
|
||||||
|
new_detections_list = list(new_detections.values())
|
||||||
|
project_matrices = jnp.stack(
|
||||||
|
[detection.camera.params.projection_matrix for detection in new_detections_list]
|
||||||
|
)
|
||||||
|
delta_t = jnp.array(
|
||||||
|
[
|
||||||
|
detection.timestamp.timestamp() - last_active_timestamp.timestamp()
|
||||||
|
for detection in new_detections_list
|
||||||
|
]
|
||||||
|
)
|
||||||
|
kps = jnp.stack([detection.keypoints for detection in new_detections_list])
|
||||||
|
conf = jnp.stack([detection.confidences for detection in new_detections_list])
|
||||||
|
kps_3d = triangulate_points_from_multiple_views_linear_time_weighted(
|
||||||
|
project_matrices, kps, delta_t, lambda_t, conf
|
||||||
|
)
|
||||||
|
new_state = TrackingState(
|
||||||
|
keypoints=kps_3d,
|
||||||
|
last_active_timestamp=latest_timestamp,
|
||||||
|
historical_detections_by_camera=new_detections,
|
||||||
|
)
|
||||||
|
tracking.update(kps_3d, latest_timestamp)
|
||||||
|
tracking.state = new_state
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
affinity_results_by_tracking = affinity_result_by_tracking(affinities.values())
|
||||||
|
for tracking_id, detections in affinity_results_by_tracking.items():
|
||||||
|
update_tracking(global_tracking_state.trackings[tracking_id], detections)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|||||||
@ -14,6 +14,7 @@ dependencies = [
|
|||||||
"jaxtyping>=0.2.38",
|
"jaxtyping>=0.2.38",
|
||||||
"jupytext>=1.17.0",
|
"jupytext>=1.17.0",
|
||||||
"matplotlib>=3.10.1",
|
"matplotlib>=3.10.1",
|
||||||
|
"more-itertools>=10.7.0",
|
||||||
"opencv-python-headless>=4.11.0.86",
|
"opencv-python-headless>=4.11.0.86",
|
||||||
"optax>=0.2.4",
|
"optax>=0.2.4",
|
||||||
"orjson>=3.10.15",
|
"orjson>=3.10.15",
|
||||||
@ -23,6 +24,7 @@ dependencies = [
|
|||||||
"pyrsistent>=0.20.0",
|
"pyrsistent>=0.20.0",
|
||||||
"pytest>=8.3.5",
|
"pytest>=8.3.5",
|
||||||
"scipy>=1.15.2",
|
"scipy>=1.15.2",
|
||||||
|
"shapely>=2.1.1",
|
||||||
"torch>=2.6.0",
|
"torch>=2.6.0",
|
||||||
"torchvision>=0.21.0",
|
"torchvision>=0.21.0",
|
||||||
"typeguard>=4.4.2",
|
"typeguard>=4.4.2",
|
||||||
|
|||||||
1062
single_people_detect_track.py
Normal file
1062
single_people_detect_track.py
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user