1
0
forked from HQU-gxy/CVTH3PE

11 Commits

Author SHA1 Message Date
6cd13064f3 fix: various
- Added the `LeastMeanSquareVelocityFilter` to improve tracking velocity estimation using historical 3D poses.
- Updated the `triangulate_one_point_from_multiple_views_linear` and `triangulate_points_from_multiple_views_linear` functions to enhance documentation and ensure proper handling of input parameters.
- Refined the logic in triangulation functions to ensure correct handling of homogeneous coordinates.
- Improved error handling in the `LastDifferenceVelocityFilter` to assert non-negative time deltas, enhancing robustness.
2025-06-18 10:35:23 +08:00
5d816d92d5 feat: Add general rules configuration for cursor behavior
- Introduced a new `general.mdc` file containing default rules for cursor behavior, specifying guidelines for interaction and response.
- Established a structured format for rules, including a description and application conditions, to enhance user experience and clarity in cursor functionality.
2025-05-09 14:27:26 +08:00
4bc3fce0b1 feat: Add minimum affinity filter to affinity result grouping
- Introduced a `min_affinity` parameter to the `affinity_result_by_tracking` function, allowing users to specify a threshold for filtering affinity results.
- Updated the logic to skip results with affinities below the specified minimum, enhancing the relevance of grouped detections.
- Improved function documentation to include details about the new parameter and its purpose.
2025-05-03 17:29:50 +08:00
1f8d70803f feat: Implement time-weighted triangulation for enhanced 3D point reconstruction
- Added two new functions: `triangulate_one_point_from_multiple_views_linear_time_weighted` and `triangulate_points_from_multiple_views_linear_time_weighted` to perform triangulation with time-based weighting, improving accuracy in 3D point estimation.
- Introduced a method to group detections by camera while preserving the latest detection, enhancing tracking state management.
- Updated the `update_tracking` function to incorporate time-weighted triangulation, allowing for more robust updates to tracking states based on new detections.
- Refactored the `TrackingState` to utilize a mapping of historical detections by camera, improving data organization and access.
2025-05-03 17:17:47 +08:00
20b2cf59f2 refactor: Enhance OneEuroFilter with type hints and error handling improvements
- Added overloads for the `_smoothing_factor` method to improve type hinting for different input types.
- Enhanced error handling in the timestamp validation to provide clearer feedback when an invalid timestamp is encountered.
- Streamlined the calculation of the filtered velocity by simplifying the logic in the `update` method.
- Improved code organization with additional type annotations for better clarity and maintainability.
2025-05-03 15:12:06 +08:00
4a5cfde245 feat: Add OneEuroFilter for adaptive keypoint smoothing
- Introduced the `OneEuroFilter` class to implement an adaptive low-pass filter for smoothing keypoint data, enhancing tracking stability during varying movement speeds.
- Implemented methods for initialization, prediction, and updating of keypoints, allowing for dynamic adjustment of smoothing based on movement.
- Added detailed documentation and type hints to clarify the filter's functionality and parameters.
- Improved the handling of timestamps and filtering logic to ensure accurate predictions and updates.
2025-05-03 14:58:51 +08:00
d2c1c8d624 refactor: Revamp LeastMeanSquareVelocityFilter to utilize historical 3D poses
- Replaced the historical detections mechanism with deques for managing historical 3D poses and timestamps, enhancing performance and memory efficiency.
- Updated the constructor to accept historical data directly, ensuring proper initialization and sorting of poses and timestamps.
- Refined the `predict` and `update` methods to work with the new data structure, improving clarity and functionality.
- Enhanced error handling to ensure robustness when no historical data is available for predictions.
2025-05-03 14:31:59 +08:00
c31cc4e7bf refactor: Enhance tracking state management and velocity filter integration
- Introduced `TrackingState` to encapsulate the state of tracking, improving data organization and immutability.
- Updated the `Tracking` class to utilize `TrackingState`, enhancing clarity in state management.
- Refactored methods to access keypoints and timestamps through the new state structure, ensuring consistency across the codebase.
- Added a `DummyVelocityFilter` for cases where no velocity estimation is needed, improving flexibility in tracking implementations.
- Cleaned up imports and improved type hints for better code organization.
2025-05-02 12:44:58 +08:00
46b8518a10 refactor: Clean up and enhance LeastMeanSquareVelocityFilter implementation
- Removed unused `reset` methods from `GenericVelocityFilter` and `LastDifferenceVelocityFilter` classes to streamline the code.
- Added a static method `from_tracking` to `LeastMeanSquareVelocityFilter` for creating instances from a `Tracking` object.
- Implemented robust error handling in the `predict` and `get` methods to ensure proper functioning with historical detections.
- Enhanced the `update` method to utilize least squares for velocity estimation, improving accuracy in tracking predictions.
- Updated class documentation to reflect changes and clarify method purposes.
2025-05-02 12:06:05 +08:00
4e78165f12 feat: Add LeastMeanSquareVelocityFilter for advanced tracking velocity estimation
- Introduced a new `LeastMeanSquareVelocityFilter` class to enhance tracking velocity estimation using historical detections.
- Implemented methods for updating measurements and predicting future states, laying the groundwork for advanced tracking capabilities.
- Improved import organization and added necessary dependencies for the new filter functionality.
- Updated class documentation to reflect the new filter's purpose and methods.
2025-05-02 11:39:01 +08:00
c78850855c feat: Introduce LastDifferenceVelocityFilter for improved tracking velocity estimation
- Added a new `LastDifferenceVelocityFilter` class to estimate tracking velocities based on the last observed keypoints, enhancing the tracking capabilities.
- Updated the `Tracking` class to utilize the new velocity filter, allowing for more accurate predictions of keypoints over time.
- Refactored the `predict` method to support various input types (float, timedelta, datetime) for better flexibility in time handling.
- Improved timestamp handling in the `perpendicular_distance_camera_2d_points_to_tracking_raycasting` function to ensure adherence to minimum delta time constraints.
- Cleaned up imports and type hints for better organization and clarity across the codebase.
2025-05-02 11:11:32 +08:00
3 changed files with 770 additions and 65 deletions

View File

@ -1,14 +1,19 @@
import weakref
from collections import deque
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime, timedelta
from itertools import chain
from typing import ( from typing import (
Any, Any,
Callable, Callable,
Generator, Generator,
Optional, Optional,
Protocol,
Sequence, Sequence,
TypeAlias, TypeAlias,
TypedDict, TypedDict,
TypeVar, TypeVar,
Union,
cast, cast,
overload, overload,
) )
@ -18,18 +23,428 @@ from beartype import beartype
from beartype.typing import Mapping, Sequence from beartype.typing import Mapping, Sequence
from jax import Array from jax import Array
from jaxtyping import Array, Float, Int, jaxtyped from jaxtyping import Array, Float, Int, jaxtyped
from pyrsistent import PVector from pyrsistent import PVector, v, PRecord, PMap
from app.camera import Detection from app.camera import Detection, CameraID
TrackingID: TypeAlias = int
class TrackingPrediction(TypedDict):
velocity: Optional[Float[Array, "J 3"]]
keypoints: Float[Array, "J 3"]
class GenericVelocityFilter(Protocol):
"""
a filter interface for tracking velocity estimation
"""
def predict(self, timestamp: datetime) -> TrackingPrediction:
"""
predict the velocity and the keypoints location
Args:
timestamp: timestamp of the prediction
Returns:
velocity: velocity of the tracking
keypoints: keypoints of the tracking
"""
... # pylint: disable=unnecessary-ellipsis
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
"""
update the filter state with new measurements
Args:
keypoints: new measurements
timestamp: timestamp of the update
"""
... # pylint: disable=unnecessary-ellipsis
def get(self) -> TrackingPrediction:
"""
get the current state of the filter state
Returns:
velocity: velocity of the tracking
keypoints: keypoints of the tracking
"""
... # pylint: disable=unnecessary-ellipsis
class DummyVelocityFilter(GenericVelocityFilter):
"""
a dummy velocity filter that does nothing
"""
_keypoints_shape: tuple[int, ...]
def __init__(self, keypoints: Float[Array, "J 3"]):
self._keypoints_shape = keypoints.shape
def predict(self, timestamp: datetime) -> TrackingPrediction:
return TrackingPrediction(
velocity=None,
keypoints=jnp.zeros(self._keypoints_shape),
)
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None: ...
def get(self) -> TrackingPrediction:
return TrackingPrediction(
velocity=None,
keypoints=jnp.zeros(self._keypoints_shape),
)
class LastDifferenceVelocityFilter(GenericVelocityFilter):
"""
a naive velocity filter that uses the last difference of keypoints
"""
_last_timestamp: datetime
_last_keypoints: Float[Array, "J 3"]
_last_velocity: Optional[Float[Array, "J 3"]] = None
def __init__(self, keypoints: Float[Array, "J 3"], timestamp: datetime):
self._last_keypoints = keypoints
self._last_timestamp = timestamp
def predict(self, timestamp: datetime) -> TrackingPrediction:
delta_t_s = (timestamp - self._last_timestamp).total_seconds()
assert delta_t_s >= 0, f"delta_t_s is negative: {delta_t_s}"
if self._last_velocity is None:
return TrackingPrediction(
velocity=None,
keypoints=self._last_keypoints,
)
else:
return TrackingPrediction(
velocity=self._last_velocity,
keypoints=self._last_keypoints + self._last_velocity * delta_t_s,
)
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
delta_t_s = (timestamp - self._last_timestamp).total_seconds()
assert delta_t_s >= 0, f"delta_t_s is negative: {delta_t_s}"
self._last_velocity = (keypoints - self._last_keypoints) / delta_t_s
self._last_keypoints = keypoints
self._last_timestamp = timestamp
def get(self) -> TrackingPrediction:
if self._last_velocity is None:
return TrackingPrediction(
velocity=None,
keypoints=self._last_keypoints,
)
else:
return TrackingPrediction(
velocity=self._last_velocity,
keypoints=self._last_keypoints,
)
class LeastMeanSquareVelocityFilter(GenericVelocityFilter):
"""
a velocity filter that uses the least mean square method to estimate the velocity
"""
_historical_3d_poses: deque[Float[Array, "J 3"]]
_historical_timestamps: deque[datetime]
_velocity: Optional[Float[Array, "J 3"]] = None
_max_samples: int
def __init__(
self,
historical_3d_poses: Sequence[Float[Array, "J 3"]],
historical_timestamps: Sequence[datetime],
max_samples: int = 10,
):
"""
Args:
historical_3d_poses: sequence of 3D poses, at least one element is required
historical_timestamps: sequence of timestamps, whose length is the same as `historical_3d_poses`
max_samples: maximum number of samples to keep
"""
assert (N := len(historical_3d_poses)) == len(
historical_timestamps
), f"the length of `historical_3d_poses` and `historical_timestamps` must be the same; got {N} and {len(historical_timestamps)}"
if N < 1:
raise ValueError("at least one historical 3D pose is required")
temp = zip(historical_3d_poses, historical_timestamps)
# sorted by timestamp
temp_sorted = sorted(temp, key=lambda x: x[1])
self._historical_3d_poses = deque(
map(lambda x: x[0], temp_sorted), maxlen=max_samples
)
self._historical_timestamps = deque(
map(lambda x: x[1], temp_sorted), maxlen=max_samples
)
self._max_samples = max_samples
if len(self._historical_3d_poses) < 2:
self._velocity = None
else:
self._update(
jnp.array(self._historical_3d_poses),
jnp.array(self._historical_timestamps),
)
def predict(self, timestamp: datetime) -> TrackingPrediction:
if not self._historical_3d_poses:
raise ValueError("No historical 3D poses available for prediction")
# use the latest historical detection
latest_3d_pose = self._historical_3d_poses[-1]
latest_timestamp = self._historical_timestamps[-1]
delta_t_s = (timestamp - latest_timestamp).total_seconds()
assert delta_t_s >= 0, f"delta_t_s is negative: {delta_t_s}"
if self._velocity is None:
return TrackingPrediction(
velocity=None,
keypoints=latest_3d_pose,
)
else:
# Linear motion model: ẋt = xt' + Vt' · (t - t')
predicted_3d_pose = latest_3d_pose + self._velocity * delta_t_s
return TrackingPrediction(
velocity=self._velocity, keypoints=predicted_3d_pose
)
@jaxtyped(typechecker=beartype)
def _update(
self,
keypoints: Float[Array, "N J 3"],
timestamps: Float[Array, "N"],
) -> None:
"""
update measurements with least mean square method
"""
if keypoints.shape[0] < 2:
raise ValueError("Not enough measurements to estimate velocity")
# Using least squares to fit a linear model for each joint and dimension
# X = timestamps, y = keypoints
# For each joint and each dimension, we solve for velocity
n_samples = timestamps.shape[0]
n_joints = keypoints.shape[1]
# Create design matrix for linear regression
# [t, 1] for each timestamp
X = jnp.column_stack([timestamps, jnp.ones(n_samples)])
# Reshape keypoints to solve for all joints and dimensions at once
# From [N, J, 3] to [N, J*3]
keypoints_reshaped = keypoints.reshape(n_samples, -1)
# Use JAX's lstsq to solve the least squares problem
coefficients, _, _, _ = jnp.linalg.lstsq(X, keypoints_reshaped, rcond=None)
# Coefficients shape is [2, J*3]
# First row: velocities, Second row: intercepts
velocities = coefficients[0].reshape(n_joints, 3)
# Update velocity
self._velocity = velocities
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
last_timestamp = self._historical_timestamps[-1]
assert last_timestamp <= timestamp
# deque would manage the maxlen automatically
self._historical_3d_poses.append(keypoints)
self._historical_timestamps.append(timestamp)
t_0 = self._historical_timestamps[0]
all_keypoints = jnp.array(self._historical_3d_poses)
def timestamp_to_seconds(timestamp: datetime) -> float:
assert t_0 <= timestamp
return (timestamp - t_0).total_seconds()
# timestamps relative to t_0 (the oldest detection timestamp)
all_timestamps = jnp.array(
map(timestamp_to_seconds, self._historical_timestamps)
)
self._update(all_keypoints, all_timestamps)
def get(self) -> TrackingPrediction:
if not self._historical_3d_poses:
raise ValueError("No historical 3D poses available")
latest_3d_pose = self._historical_3d_poses[-1]
if self._velocity is None:
return TrackingPrediction(velocity=None, keypoints=latest_3d_pose)
else:
return TrackingPrediction(velocity=self._velocity, keypoints=latest_3d_pose)
class OneEuroFilter(GenericVelocityFilter):
"""
Implementation of the 1€ filter (One Euro Filter) for smoothing keypoint data.
The 1€ filter is an adaptive low-pass filter that adjusts its cutoff frequency
based on movement speed to reduce jitter during slow movements while maintaining
responsiveness during fast movements.
Reference: https://cristal.univ-lille.fr/~casiez/1euro/
"""
_x_filtered: Float[Array, "J 3"]
_dx_filtered: Optional[Float[Array, "J 3"]] = None
_last_timestamp: datetime
_min_cutoff: float
_beta: float
_d_cutoff: float
def __init__(
self,
keypoints: Float[Array, "J 3"],
timestamp: datetime,
min_cutoff: float = 1.0,
beta: float = 0.0,
d_cutoff: float = 1.0,
):
"""
Initialize the One Euro Filter.
Args:
keypoints: Initial keypoints positions
timestamp: Initial timestamp
min_cutoff: Minimum cutoff frequency (lower = more smoothing)
beta: Speed coefficient (higher = less lag during fast movements)
d_cutoff: Cutoff frequency for the derivative filter
"""
self._last_timestamp = timestamp
# Filter parameters
self._min_cutoff = min_cutoff
self._beta = beta
self._d_cutoff = d_cutoff
# Filter state
self._x_filtered = keypoints # Position filter state
self._dx_filtered = None # Initially no velocity estimate
@overload
def _smoothing_factor(self, cutoff: float, dt: float) -> float: ...
@overload
def _smoothing_factor(
self, cutoff: Float[Array, "J"], dt: float
) -> Float[Array, "J"]: ...
@jaxtyped(typechecker=beartype)
def _smoothing_factor(
self, cutoff: Union[float, Float[Array, "J"]], dt: float
) -> Union[float, Float[Array, "J"]]:
"""Calculate the smoothing factor for the low-pass filter."""
r = 2 * jnp.pi * cutoff * dt
return r / (r + 1)
@jaxtyped(typechecker=beartype)
def _exponential_smoothing(
self,
a: Union[float, Float[Array, "J"]],
x: Float[Array, "J 3"],
x_prev: Float[Array, "J 3"],
) -> Float[Array, "J 3"]:
"""Apply exponential smoothing to the input."""
return a * x + (1 - a) * x_prev
def predict(self, timestamp: datetime) -> TrackingPrediction:
"""
Predict keypoints position at a given timestamp.
Args:
timestamp: Timestamp for prediction
Returns:
TrackingPrediction with velocity and keypoints
"""
dt = (timestamp - self._last_timestamp).total_seconds()
if self._dx_filtered is None:
return TrackingPrediction(
velocity=None,
keypoints=self._x_filtered,
)
else:
predicted_keypoints = self._x_filtered + self._dx_filtered * dt
return TrackingPrediction(
velocity=self._dx_filtered,
keypoints=predicted_keypoints,
)
def update(self, keypoints: Float[Array, "J 3"], timestamp: datetime) -> None:
"""
Update the filter with new measurements.
Args:
keypoints: New keypoint measurements
timestamp: Timestamp of the measurements
"""
dt = (timestamp - self._last_timestamp).total_seconds()
if dt <= 0:
raise ValueError(
f"new timestamp is not greater than the last timestamp; expecting: {timestamp} > {self._last_timestamp}"
)
dx = (keypoints - self._x_filtered) / dt
# Determine cutoff frequency based on movement speed
cutoff = self._min_cutoff + self._beta * jnp.linalg.norm(
dx, axis=-1, keepdims=True
)
# Apply low-pass filter to velocity
a_d = self._smoothing_factor(self._d_cutoff, dt)
self._dx_filtered = self._exponential_smoothing(
a_d,
dx,
(
jnp.zeros_like(keypoints)
if self._dx_filtered is None
else self._dx_filtered
),
)
# Apply low-pass filter to position with adaptive cutoff
a_cutoff = self._smoothing_factor(jnp.asarray(cutoff), dt)
self._x_filtered = self._exponential_smoothing(
a_cutoff, keypoints, self._x_filtered
)
# Update timestamp
self._last_timestamp = timestamp
def get(self) -> TrackingPrediction:
"""
Get the current state of the filter.
Returns:
TrackingPrediction with velocity and keypoints
"""
return TrackingPrediction(
velocity=self._dx_filtered,
keypoints=self._x_filtered,
)
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
@dataclass(frozen=True) @dataclass(frozen=True)
class Tracking: class TrackingState:
id: int
""" """
The tracking id immutable state of a tracking
""" """
keypoints: Float[Array, "J 3"] keypoints: Float[Array, "J 3"]
""" """
The 3D keypoints of the tracking The 3D keypoints of the tracking
@ -41,50 +456,97 @@ class Tracking:
The last active timestamp of the tracking The last active timestamp of the tracking
""" """
historical_detections: PVector[Detection] historical_detections_by_camera: PMap[CameraID, Detection]
""" """
Historical detections of the tracking. Historical detections of the tracking.
Used for 3D re-triangulation Used for 3D re-triangulation
""" """
velocity: Optional[Float[Array, "3"]] = None
"""
Could be `None`. Like when the 3D pose is initialized.
`velocity` should be updated when target association yields a new class Tracking:
3D pose. id: TrackingID
""" state: TrackingState
velocity_filter: GenericVelocityFilter
def __init__(
self,
id: TrackingID,
state: TrackingState,
velocity_filter: Optional[GenericVelocityFilter] = None,
):
self.id = id
self.state = state
self.velocity_filter = velocity_filter or DummyVelocityFilter(state.keypoints)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"Tracking({self.id}, {self.last_active_timestamp})" return f"Tracking({self.id}, {self.state.last_active_timestamp})"
@overload
def predict(self, time: float) -> Float[Array, "J 3"]:
"""
predict the keypoints at a given time
Args:
time: the time in seconds to predict the keypoints
Returns:
the predicted keypoints
"""
... # pylint: disable=unnecessary-ellipsis
@overload
def predict(self, time: timedelta) -> Float[Array, "J 3"]:
"""
predict the keypoints at a given time
Args:
time: the time delta to predict the keypoints
"""
... # pylint: disable=unnecessary-ellipsis
@overload
def predict(self, time: datetime) -> Float[Array, "J 3"]:
"""
predict the keypoints at a given time
Args:
time: the timestamp to predict the keypoints
"""
... # pylint: disable=unnecessary-ellipsis
def predict( def predict(
self, self,
delta_t_s: float, time: float | timedelta | datetime,
) -> Float[Array, "J 3"]: ) -> Float[Array, "J 3"]:
""" if isinstance(time, timedelta):
Predict the 3D pose of a tracking based on its velocity. timestamp = self.state.last_active_timestamp + time
JAX-friendly implementation that avoids Python control flow. elif isinstance(time, datetime):
timestamp = time
Args:
delta_t_s: Time delta in seconds
Returns:
Predicted 3D pose keypoints
"""
# ------------------------------------------------------------------
# Step 1 decide velocity on the Python side
# ------------------------------------------------------------------
if self.velocity is None:
velocity = jnp.zeros_like(self.keypoints) # (J, 3)
else: else:
velocity = self.velocity # (J, 3) timestamp = self.state.last_active_timestamp + timedelta(seconds=time)
# pylint: disable-next=unsubscriptable-object
return self.velocity_filter.predict(timestamp)["keypoints"]
# ------------------------------------------------------------------ def update(self, new_3d_pose: Float[Array, "J 3"], timestamp: datetime) -> None:
# Step 2 pure JAX math """
# ------------------------------------------------------------------ update the tracking with a new 3D pose
return self.keypoints + velocity * delta_t_s
Note:
equivalent to call `velocity_filter.update(new_3d_pose, timestamp)`
"""
self.velocity_filter.update(new_3d_pose, timestamp)
@property
def velocity(self) -> Float[Array, "J 3"]:
"""
The velocity of the tracking for each keypoint
"""
# pylint: disable-next=unsubscriptable-object
if (vel := self.velocity_filter.get()["velocity"]) is None:
return jnp.zeros_like(self.state.keypoints)
else:
return vel
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
@ -100,7 +562,7 @@ class AffinityResult:
indices_T: Int[Array, "T"] # pylint: disable=invalid-name indices_T: Int[Array, "T"] # pylint: disable=invalid-name
indices_D: Int[Array, "D"] # pylint: disable=invalid-name indices_D: Int[Array, "D"] # pylint: disable=invalid-name
def tracking_detections( def tracking_association(
self, self,
) -> Generator[tuple[float, Tracking, Detection], None, None]: ) -> Generator[tuple[float, Tracking, Detection], None, None]:
""" """

View File

@ -31,13 +31,13 @@ from typing import (
TypeVar, TypeVar,
cast, cast,
overload, overload,
Iterable,
) )
import awkward as ak import awkward as ak
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np import numpy as np
import orjson
from beartype import beartype from beartype import beartype
from beartype.typing import Mapping, Sequence from beartype.typing import Mapping, Sequence
from cv2 import undistortPoints from cv2 import undistortPoints
@ -46,9 +46,10 @@ from jaxtyping import Array, Float, Num, jaxtyped
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from numpy.typing import ArrayLike from numpy.typing import ArrayLike
from optax.assignment import hungarian_algorithm as linear_sum_assignment from optax.assignment import hungarian_algorithm as linear_sum_assignment
from pyrsistent import v, pvector from pyrsistent import pvector, v, m, pmap, PMap, freeze, thaw
from scipy.spatial.transform import Rotation as R from scipy.spatial.transform import Rotation as R
from typing_extensions import deprecated from typing_extensions import deprecated
from collections import defaultdict
from app.camera import ( from app.camera import (
Camera, Camera,
@ -59,15 +60,22 @@ from app.camera import (
classify_by_camera, classify_by_camera,
) )
from app.solver._old import GLPKSolver from app.solver._old import GLPKSolver
from app.tracking import AffinityResult, Tracking from app.tracking import (
TrackingID,
AffinityResult,
LastDifferenceVelocityFilter,
LeastMeanSquareVelocityFilter,
Tracking,
TrackingState,
)
from app.visualize.whole_body import visualize_whole_body from app.visualize.whole_body import visualize_whole_body
NDArray: TypeAlias = np.ndarray NDArray: TypeAlias = np.ndarray
# %% # %%
DATASET_PATH = Path("samples") / "04_02" DATASET_PATH = Path("samples") / "04_02"
AK_CAMERA_DATASET: ak.Array = ak.from_parquet(DATASET_PATH / "camera_params.parquet") AK_CAMERA_DATASET: ak.Array = ak.from_parquet(DATASET_PATH / "camera_params.parquet") # type: ignore
DELTA_T_MIN = timedelta(milliseconds=10) DELTA_T_MIN = timedelta(milliseconds=1)
display(AK_CAMERA_DATASET) display(AK_CAMERA_DATASET)
@ -431,12 +439,12 @@ def triangulate_one_point_from_multiple_views_linear(
) -> Float[Array, "3"]: ) -> Float[Array, "3"]:
""" """
Args: Args:
proj_matrices: 形状为(N, 3, 4)的投影矩阵序列 proj_matrices: (N, 3, 4) projection matrices
points: 形状为(N, 2)的点坐标序列 points: (N, 2) image-coordinates per view
confidences: 形状为(N,)的置信度序列,范围[0.0, 1.0] confidences: (N,) optional per-view confidences in [0,1]
Returns: Returns:
point_3d: 形状为(3,)的三角测量得到的3D点 (3,) 3D point
""" """
assert len(proj_matrices) == len(points) assert len(proj_matrices) == len(points)
@ -462,7 +470,7 @@ def triangulate_one_point_from_multiple_views_linear(
# replace the Python `if` with a jnp.where # replace the Python `if` with a jnp.where
point_3d_homo = jnp.where( point_3d_homo = jnp.where(
point_3d_homo[3] < 0, # predicate (scalar bool tracer) point_3d_homo[3] <= 0, # predicate (scalar bool tracer)
-point_3d_homo, # if True -point_3d_homo, # if True
point_3d_homo, # if False point_3d_homo, # if False
) )
@ -486,14 +494,14 @@ def triangulate_points_from_multiple_views_linear(
confidences: (N, P, 1) optional per-view confidences in [0,1] confidences: (N, P, 1) optional per-view confidences in [0,1]
Returns: Returns:
(P, 3) 3D point for each of the P tracks (P, 3) 3D point for each of the P
""" """
N, P, _ = points.shape N, P, _ = points.shape
assert proj_matrices.shape[0] == N assert proj_matrices.shape[0] == N
if confidences is None: if confidences is None:
conf = jnp.ones((N, P), dtype=jnp.float32) conf = jnp.ones((N, P), dtype=jnp.float32)
else: else:
conf = jnp.sqrt(jnp.clip(confidences, 0.0, 1.0)) conf = confidences
# vectorize your one-point routine over P # vectorize your one-point routine over P
vmap_triangulate = jax.vmap( vmap_triangulate = jax.vmap(
@ -504,6 +512,142 @@ def triangulate_points_from_multiple_views_linear(
return vmap_triangulate(proj_matrices, points, conf) return vmap_triangulate(proj_matrices, points, conf)
# %%
@jaxtyped(typechecker=beartype)
def triangulate_one_point_from_multiple_views_linear_time_weighted(
proj_matrices: Float[Array, "N 3 4"],
points: Num[Array, "N 2"],
delta_t: Num[Array, "N"],
lambda_t: float = 10.0,
confidences: Optional[Float[Array, "N"]] = None,
) -> Float[Array, "3"]:
"""
Triangulate one point from multiple views with time-weighted linear least squares.
Implements the incremental reconstruction method from "Cross-View Tracking for Multi-Human 3D Pose"
with weighting formula: w_i = exp(-λ_t(t-t_i)) / ||c^i^T||_2
Args:
proj_matrices: Shape (N, 3, 4) projection matrices sequence
points: Shape (N, 2) point coordinates sequence
delta_t: Time differences between current time and each observation (in seconds)
lambda_t: Time penalty rate (higher values decrease influence of older observations)
confidences: Shape (N,) confidence values in range [0.0, 1.0]
Returns:
point_3d: Shape (3,) triangulated 3D point
"""
assert len(proj_matrices) == len(points)
assert len(delta_t) == len(points)
N = len(proj_matrices)
# Prepare confidence weights
confi: Float[Array, "N"]
if confidences is None:
confi = jnp.ones(N, dtype=np.float32)
else:
confi = jnp.sqrt(jnp.clip(confidences, 0, 1))
A = jnp.zeros((N * 2, 4), dtype=np.float32)
# First build the coefficient matrix without weights
for i in range(N):
x, y = points[i]
A = A.at[2 * i].set(proj_matrices[i, 2] * x - proj_matrices[i, 0])
A = A.at[2 * i + 1].set(proj_matrices[i, 2] * y - proj_matrices[i, 1])
# Then apply the time-based and confidence weights
for i in range(N):
# Calculate time-decay weight: e^(-λ_t * Δt)
time_weight = jnp.exp(-lambda_t * delta_t[i])
# Calculate normalization factor: ||c^i^T||_2
row_norm_1 = jnp.linalg.norm(A[2 * i])
row_norm_2 = jnp.linalg.norm(A[2 * i + 1])
# Apply combined weight: time_weight / row_norm * confidence
w1 = (time_weight / row_norm_1) * confi[i]
w2 = (time_weight / row_norm_2) * confi[i]
A = A.at[2 * i].mul(w1)
A = A.at[2 * i + 1].mul(w2)
# Solve using SVD
_, _, vh = jnp.linalg.svd(A, full_matrices=False)
point_3d_homo = vh[-1] # shape (4,)
# Ensure homogeneous coordinate is positive
point_3d_homo = jnp.where(
point_3d_homo[3] <= 0,
-point_3d_homo,
point_3d_homo,
)
# Convert from homogeneous to Euclidean coordinates
point_3d = point_3d_homo[:3] / point_3d_homo[3]
return point_3d
@jaxtyped(typechecker=beartype)
def triangulate_points_from_multiple_views_linear_time_weighted(
proj_matrices: Float[Array, "N 3 4"],
points: Num[Array, "N P 2"],
delta_t: Num[Array, "N"],
lambda_t: float = 10.0,
confidences: Optional[Float[Array, "N P"]] = None,
) -> Float[Array, "P 3"]:
"""
Vectorized version that triangulates P points from N camera views with time-weighting.
This function uses JAX's vmap to efficiently triangulate multiple points in parallel.
Args:
proj_matrices: Shape (N, 3, 4) projection matrices for N cameras
points: Shape (N, P, 2) 2D points for P keypoints across N cameras
delta_t: Shape (N,) time differences between current time and each camera's timestamp (seconds)
lambda_t: Time penalty rate (higher values decrease influence of older observations)
confidences: Shape (N, P) confidence values for each point in each camera
Returns:
points_3d: Shape (P, 3) triangulated 3D points
"""
N, P, _ = points.shape
assert (
proj_matrices.shape[0] == N
), "Number of projection matrices must match number of cameras"
assert delta_t.shape[0] == N, "Number of time deltas must match number of cameras"
if confidences is None:
# Create uniform confidences if none provided
conf = jnp.ones((N, P), dtype=jnp.float32)
else:
conf = confidences
# Define the vmapped version of the single-point function
# We map over the second dimension (P points) of the input arrays
vmap_triangulate = jax.vmap(
triangulate_one_point_from_multiple_views_linear_time_weighted,
in_axes=(
None,
1,
None,
None,
1,
), # proj_matrices and delta_t static, map over points
out_axes=0, # Output has first dimension corresponding to points
)
# For each point p, extract the 2D coordinates from all cameras and triangulate
return vmap_triangulate(
proj_matrices, # (N, 3, 4) - static across points
points, # (N, P, 2) - map over dim 1 (P)
delta_t, # (N,) - static across points
lambda_t, # scalar - static
conf, # (N, P) - map over dim 1 (P)
)
# %% # %%
@ -524,6 +668,23 @@ def triangle_from_cluster(
# %% # %%
def group_by_cluster_by_camera(
cluster: Sequence[Detection],
) -> PMap[CameraID, Detection]:
"""
group the detections by camera, and preserve the latest detection for each camera
"""
r: dict[CameraID, Detection] = {}
for el in cluster:
if el.camera.id in r:
eld = r[el.camera.id]
preserved = max([eld, el], key=lambda x: x.timestamp)
r[el.camera.id] = preserved
else:
r[el.camera.id] = el
return pmap(r)
class GlobalTrackingState: class GlobalTrackingState:
_last_id: int _last_id: int
_trackings: dict[int, Tracking] _trackings: dict[int, Tracking]
@ -542,13 +703,25 @@ class GlobalTrackingState:
return shallow_copy(self._trackings) return shallow_copy(self._trackings)
def add_tracking(self, cluster: Sequence[Detection]) -> Tracking: def add_tracking(self, cluster: Sequence[Detection]) -> Tracking:
if len(cluster) < 2:
raise ValueError(
"cluster must contain at least 2 detections to form a tracking"
)
kps_3d, latest_timestamp = triangle_from_cluster(cluster) kps_3d, latest_timestamp = triangle_from_cluster(cluster)
next_id = self._last_id + 1 next_id = self._last_id + 1
tracking = Tracking( tracking_state = TrackingState(
id=next_id,
keypoints=kps_3d, keypoints=kps_3d,
last_active_timestamp=latest_timestamp, last_active_timestamp=latest_timestamp,
historical_detections=v(*cluster), historical_detections_by_camera=group_by_cluster_by_camera(cluster),
)
tracking = Tracking(
id=next_id,
state=tracking_state,
# velocity_filter=LastDifferenceVelocityFilter(kps_3d, latest_timestamp),
velocity_filter=LeastMeanSquareVelocityFilter(
historical_3d_poses=[kps_3d],
historical_timestamps=[latest_timestamp],
),
) )
self._trackings[next_id] = tracking self._trackings[next_id] = tracking
self._last_id = next_id self._last_id = next_id
@ -671,11 +844,7 @@ def perpendicular_distance_camera_2d_points_to_tracking_raycasting(
Array of perpendicular distances for each keypoint Array of perpendicular distances for each keypoint
""" """
camera = detection.camera camera = detection.camera
# Use the delta_t supplied by the caller, but clamp to DELTA_T_MIN to predicted_pose = tracking.predict(delta_t)
# avoid division-by-zero / exploding affinities.
delta_t = max(delta_t, DELTA_T_MIN)
delta_t_s = delta_t.total_seconds()
predicted_pose = tracking.predict(delta_t_s)
# Back-project the 2D points to 3D space # Back-project the 2D points to 3D space
# intersection with z=0 plane # intersection with z=0 plane
@ -755,12 +924,12 @@ def calculate_tracking_detection_affinity(
Combined affinity score Combined affinity score
""" """
camera = detection.camera camera = detection.camera
delta_t_raw = detection.timestamp - tracking.last_active_timestamp delta_t_raw = detection.timestamp - tracking.state.last_active_timestamp
# Clamp delta_t to avoid division-by-zero / exploding affinity. # Clamp delta_t to avoid division-by-zero / exploding affinity.
delta_t = max(delta_t_raw, DELTA_T_MIN) delta_t = max(delta_t_raw, DELTA_T_MIN)
# Calculate 2D affinity # Calculate 2D affinity
tracking_2d_projection = camera.project(tracking.keypoints) tracking_2d_projection = camera.project(tracking.state.keypoints)
w, h = camera.params.image_size w, h = camera.params.image_size
distance_2d = calculate_distance_2d( distance_2d = calculate_distance_2d(
tracking_2d_projection, tracking_2d_projection,
@ -840,7 +1009,7 @@ def calculate_camera_affinity_matrix_jax(
# === Tracking-side tensors === # === Tracking-side tensors ===
kps3d_trk: Float[Array, "T J 3"] = jnp.stack( kps3d_trk: Float[Array, "T J 3"] = jnp.stack(
[trk.keypoints for trk in trackings] [trk.state.keypoints for trk in trackings]
) # (T, J, 3) ) # (T, J, 3)
J = kps3d_trk.shape[1] J = kps3d_trk.shape[1]
# === Detection-side tensors === # === Detection-side tensors ===
@ -857,12 +1026,12 @@ def calculate_camera_affinity_matrix_jax(
# --- timestamps ---------- # --- timestamps ----------
t0 = min( t0 = min(
chain( chain(
(trk.last_active_timestamp for trk in trackings), (trk.state.last_active_timestamp for trk in trackings),
(det.timestamp for det in camera_detections), (det.timestamp for det in camera_detections),
) )
).timestamp() # common origin (float) ).timestamp() # common origin (float)
ts_trk = jnp.array( ts_trk = jnp.array(
[trk.last_active_timestamp.timestamp() - t0 for trk in trackings], [trk.state.last_active_timestamp.timestamp() - t0 for trk in trackings],
dtype=jnp.float32, # now small, ms-scale fits in fp32 dtype=jnp.float32, # now small, ms-scale fits in fp32
) )
ts_det = jnp.array( ts_det = jnp.array(
@ -1033,8 +1202,82 @@ display(affinities)
# %% # %%
def update_tracking(tracking: Tracking, detection: Detection): def affinity_result_by_tracking(
delta_t_ = detection.timestamp - tracking.last_active_timestamp results: Iterable[AffinityResult],
delta_t = max(delta_t_, DELTA_T_MIN) min_affinity: float = 0.0,
) -> dict[TrackingID, list[Detection]]:
"""
Group affinity results by target ID.
return tracking Args:
results: the affinity results to group
min_affinity: the minimum affinity to consider
Returns:
a dictionary mapping tracking IDs to a list of detections
"""
res: dict[TrackingID, list[Detection]] = defaultdict(list)
for affinity_result in results:
for affinity, t, d in affinity_result.tracking_association():
if affinity < min_affinity:
continue
res[t.id].append(d)
return res
def update_tracking(
tracking: Tracking,
detections: Sequence[Detection],
max_delta_t: timedelta = timedelta(milliseconds=100),
lambda_t: float = 10.0,
) -> None:
"""
update the tracking with a new set of detections
Args:
tracking: the tracking to update
detections: the detections to update the tracking with
max_delta_t: the maximum time difference between the last active timestamp and the latest detection
lambda_t: the lambda value for the time difference
Note:
the function would mutate the tracking object
"""
last_active_timestamp = tracking.state.last_active_timestamp
latest_timestamp = max(d.timestamp for d in detections)
d = thaw(tracking.state.historical_detections_by_camera)
for detection in detections:
d[detection.camera.id] = detection
for camera_id, detection in d.items():
if detection.timestamp - latest_timestamp > max_delta_t:
del d[camera_id]
new_detections = freeze(d)
new_detections_list = list(new_detections.values())
project_matrices = jnp.stack(
[detection.camera.params.projection_matrix for detection in new_detections_list]
)
delta_t = jnp.array(
[
detection.timestamp.timestamp() - last_active_timestamp.timestamp()
for detection in new_detections_list
]
)
kps = jnp.stack([detection.keypoints for detection in new_detections_list])
conf = jnp.stack([detection.confidences for detection in new_detections_list])
kps_3d = triangulate_points_from_multiple_views_linear_time_weighted(
project_matrices, kps, delta_t, lambda_t, conf
)
new_state = TrackingState(
keypoints=kps_3d,
last_active_timestamp=latest_timestamp,
historical_detections_by_camera=new_detections,
)
tracking.update(kps_3d, latest_timestamp)
tracking.state = new_state
# %%
affinity_results_by_tracking = affinity_result_by_tracking(affinities.values())
for tracking_id, detections in affinity_results_by_tracking.items():
update_tracking(global_tracking_state.trackings[tracking_id], detections)
# %%