diff --git a/app/tracking/__init__.py b/app/tracking/__init__.py index 7fb7ed0..b56c2a3 100644 --- a/app/tracking/__init__.py +++ b/app/tracking/__init__.py @@ -13,9 +13,9 @@ from typing import ( TypeAlias, TypedDict, TypeVar, + Union, cast, overload, - Union, ) import jax.numpy as jnp @@ -23,9 +23,11 @@ from beartype import beartype from beartype.typing import Mapping, Sequence from jax import Array from jaxtyping import Array, Float, Int, jaxtyped -from pyrsistent import PVector, v +from pyrsistent import PVector, v, PRecord, PMap -from app.camera import Detection +from app.camera import Detection, CameraID + +TrackingID: TypeAlias = int class TrackingPrediction(TypedDict): @@ -440,7 +442,7 @@ class TrackingState: The last active timestamp of the tracking """ - historical_detections: PVector[Detection] + historical_detections_by_camera: PMap[CameraID, Detection] """ Historical detections of the tracking. @@ -449,13 +451,13 @@ class TrackingState: class Tracking: - id: int + id: TrackingID state: TrackingState velocity_filter: GenericVelocityFilter def __init__( self, - id: int, + id: TrackingID, state: TrackingState, velocity_filter: Optional[GenericVelocityFilter] = None, ): @@ -512,6 +514,15 @@ class Tracking: # pylint: disable-next=unsubscriptable-object return self.velocity_filter.predict(timestamp)["keypoints"] + def update(self, new_3d_pose: Float[Array, "J 3"], timestamp: datetime) -> None: + """ + update the tracking with a new 3D pose + + Note: + equivalent to call `velocity_filter.update(new_3d_pose, timestamp)` + """ + self.velocity_filter.update(new_3d_pose, timestamp) + @property def velocity(self) -> Float[Array, "J 3"]: """ @@ -537,7 +548,7 @@ class AffinityResult: indices_T: Int[Array, "T"] # pylint: disable=invalid-name indices_D: Int[Array, "D"] # pylint: disable=invalid-name - def tracking_detections( + def tracking_association( self, ) -> Generator[tuple[float, Tracking, Detection], None, None]: """ diff --git a/playground.py b/playground.py index 32eba46..ca6fb4b 100644 --- a/playground.py +++ b/playground.py @@ -31,6 +31,7 @@ from typing import ( TypeVar, cast, overload, + Iterable, ) import awkward as ak @@ -45,9 +46,10 @@ from jaxtyping import Array, Float, Num, jaxtyped from matplotlib import pyplot as plt from numpy.typing import ArrayLike from optax.assignment import hungarian_algorithm as linear_sum_assignment -from pyrsistent import pvector, v +from pyrsistent import pvector, v, m, pmap, PMap, freeze, thaw from scipy.spatial.transform import Rotation as R from typing_extensions import deprecated +from collections import defaultdict from app.camera import ( Camera, @@ -59,6 +61,7 @@ from app.camera import ( ) from app.solver._old import GLPKSolver from app.tracking import ( + TrackingID, AffinityResult, LastDifferenceVelocityFilter, Tracking, @@ -508,6 +511,142 @@ def triangulate_points_from_multiple_views_linear( return vmap_triangulate(proj_matrices, points, conf) +# %% +@jaxtyped(typechecker=beartype) +def triangulate_one_point_from_multiple_views_linear_time_weighted( + proj_matrices: Float[Array, "N 3 4"], + points: Num[Array, "N 2"], + delta_t: Num[Array, "N"], + lambda_t: float = 10.0, + confidences: Optional[Float[Array, "N"]] = None, +) -> Float[Array, "3"]: + """ + Triangulate one point from multiple views with time-weighted linear least squares. + + Implements the incremental reconstruction method from "Cross-View Tracking for Multi-Human 3D Pose" + with weighting formula: w_i = exp(-λ_t(t-t_i)) / ||c^i^T||_2 + + Args: + proj_matrices: Shape (N, 3, 4) projection matrices sequence + points: Shape (N, 2) point coordinates sequence + delta_t: Time differences between current time and each observation (in seconds) + lambda_t: Time penalty rate (higher values decrease influence of older observations) + confidences: Shape (N,) confidence values in range [0.0, 1.0] + + Returns: + point_3d: Shape (3,) triangulated 3D point + """ + assert len(proj_matrices) == len(points) + assert len(delta_t) == len(points) + + N = len(proj_matrices) + + # Prepare confidence weights + confi: Float[Array, "N"] + if confidences is None: + confi = jnp.ones(N, dtype=np.float32) + else: + confi = jnp.sqrt(jnp.clip(confidences, 0, 1)) + + A = jnp.zeros((N * 2, 4), dtype=np.float32) + + # First build the coefficient matrix without weights + for i in range(N): + x, y = points[i] + A = A.at[2 * i].set(proj_matrices[i, 2] * x - proj_matrices[i, 0]) + A = A.at[2 * i + 1].set(proj_matrices[i, 2] * y - proj_matrices[i, 1]) + + # Then apply the time-based and confidence weights + for i in range(N): + # Calculate time-decay weight: e^(-λ_t * Δt) + time_weight = jnp.exp(-lambda_t * delta_t[i]) + + # Calculate normalization factor: ||c^i^T||_2 + row_norm_1 = jnp.linalg.norm(A[2 * i]) + row_norm_2 = jnp.linalg.norm(A[2 * i + 1]) + + # Apply combined weight: time_weight / row_norm * confidence + w1 = (time_weight / row_norm_1) * confi[i] + w2 = (time_weight / row_norm_2) * confi[i] + + A = A.at[2 * i].mul(w1) + A = A.at[2 * i + 1].mul(w2) + + # Solve using SVD + _, _, vh = jnp.linalg.svd(A, full_matrices=False) + point_3d_homo = vh[-1] # shape (4,) + + # Ensure homogeneous coordinate is positive + point_3d_homo = jnp.where( + point_3d_homo[3] < 0, + -point_3d_homo, + point_3d_homo, + ) + + # Convert from homogeneous to Euclidean coordinates + point_3d = point_3d_homo[:3] / point_3d_homo[3] + return point_3d + + +@jaxtyped(typechecker=beartype) +def triangulate_points_from_multiple_views_linear_time_weighted( + proj_matrices: Float[Array, "N 3 4"], + points: Num[Array, "N P 2"], + delta_t: Num[Array, "N"], + lambda_t: float = 10.0, + confidences: Optional[Float[Array, "N P"]] = None, +) -> Float[Array, "P 3"]: + """ + Vectorized version that triangulates P points from N camera views with time-weighting. + + This function uses JAX's vmap to efficiently triangulate multiple points in parallel. + + Args: + proj_matrices: Shape (N, 3, 4) projection matrices for N cameras + points: Shape (N, P, 2) 2D points for P keypoints across N cameras + delta_t: Shape (N,) time differences between current time and each camera's timestamp (seconds) + lambda_t: Time penalty rate (higher values decrease influence of older observations) + confidences: Shape (N, P) confidence values for each point in each camera + + Returns: + points_3d: Shape (P, 3) triangulated 3D points + """ + N, P, _ = points.shape + assert ( + proj_matrices.shape[0] == N + ), "Number of projection matrices must match number of cameras" + assert delta_t.shape[0] == N, "Number of time deltas must match number of cameras" + + if confidences is None: + # Create uniform confidences if none provided + conf = jnp.ones((N, P), dtype=jnp.float32) + else: + conf = confidences + + # Define the vmapped version of the single-point function + # We map over the second dimension (P points) of the input arrays + vmap_triangulate = jax.vmap( + triangulate_one_point_from_multiple_views_linear_time_weighted, + in_axes=( + None, + 1, + None, + None, + 1, + ), # proj_matrices and delta_t static, map over points + out_axes=0, # Output has first dimension corresponding to points + ) + + # For each point p, extract the 2D coordinates from all cameras and triangulate + return vmap_triangulate( + proj_matrices, # (N, 3, 4) - static across points + points, # (N, P, 2) - map over dim 1 (P) + delta_t, # (N,) - static across points + lambda_t, # scalar - static + conf, # (N, P) - map over dim 1 (P) + ) + + # %% @@ -528,6 +667,21 @@ def triangle_from_cluster( # %% +def group_by_cluster_by_camera( + cluster: Sequence[Detection], +) -> PMap[CameraID, Detection]: + """ + group the detections by camera, and preserve the latest detection for each camera + """ + r: dict[CameraID, Detection] = {} + for el in cluster: + if el.camera.id in r: + eld = r[el.camera.id] + preserved = max([eld, el], key=lambda x: x.timestamp) + r[el.camera.id] = preserved + return pmap(r) + + class GlobalTrackingState: _last_id: int _trackings: dict[int, Tracking] @@ -546,12 +700,16 @@ class GlobalTrackingState: return shallow_copy(self._trackings) def add_tracking(self, cluster: Sequence[Detection]) -> Tracking: + if len(cluster) < 2: + raise ValueError( + "cluster must contain at least 2 detections to form a tracking" + ) kps_3d, latest_timestamp = triangle_from_cluster(cluster) next_id = self._last_id + 1 tracking_state = TrackingState( keypoints=kps_3d, last_active_timestamp=latest_timestamp, - historical_detections=v(*cluster), + historical_detections_by_camera=group_by_cluster_by_camera(cluster), ) tracking = Tracking( id=next_id, @@ -679,9 +837,7 @@ def perpendicular_distance_camera_2d_points_to_tracking_raycasting( Array of perpendicular distances for each keypoint """ camera = detection.camera - # Use the delta_t supplied by the caller, but clamp to DELTA_T_MIN to - # avoid division-by-zero / exploding affinities. - predicted_pose = tracking.predict(max(delta_t, DELTA_T_MIN)) + predicted_pose = tracking.predict(delta_t) # Back-project the 2D points to 3D space # intersection with z=0 plane @@ -1039,6 +1195,73 @@ display(affinities) # %% -def update_tracking(tracking: Tracking, detection: Detection): - delta_t_ = detection.timestamp - tracking.state.last_active_timestamp - raise NotImplementedError +def affinity_result_by_tracking( + results: Iterable[AffinityResult], +) -> dict[TrackingID, list[Detection]]: + """ + Group affinity results by target ID. + """ + res: dict[TrackingID, list[Detection]] = defaultdict(list) + for affinity_result in results: + for _affinity, t, d in affinity_result.tracking_association(): + res[t.id].append(d) + return res + + +def update_tracking( + tracking: Tracking, + detections: Sequence[Detection], + max_delta_t: timedelta = timedelta(milliseconds=100), + lambda_t: float = 10.0, +) -> None: + """ + update the tracking with a new set of detections + + Args: + tracking: the tracking to update + detections: the detections to update the tracking with + max_delta_t: the maximum time difference between the last active timestamp and the latest detection + lambda_t: the lambda value for the time difference + + Note: + the function would mutate the tracking object + """ + last_active_timestamp = tracking.state.last_active_timestamp + latest_timestamp = max(d.timestamp for d in detections) + d = thaw(tracking.state.historical_detections_by_camera) + for detection in detections: + d[detection.camera.id] = detection + for camera_id, detection in d.items(): + if detection.timestamp - latest_timestamp > max_delta_t: + del d[camera_id] + new_detections = freeze(d) + new_detections_list = list(new_detections.values()) + project_matrices = jnp.stack( + [detection.camera.params.projection_matrix for detection in new_detections_list] + ) + delta_t = jnp.array( + [ + detection.timestamp.timestamp() - last_active_timestamp.timestamp() + for detection in new_detections_list + ] + ) + kps = jnp.stack([detection.keypoints for detection in new_detections_list]) + conf = jnp.stack([detection.confidences for detection in new_detections_list]) + kps_3d = triangulate_points_from_multiple_views_linear_time_weighted( + project_matrices, kps, delta_t, lambda_t, conf + ) + new_state = TrackingState( + keypoints=kps_3d, + last_active_timestamp=latest_timestamp, + historical_detections_by_camera=new_detections, + ) + tracking.update(kps_3d, latest_timestamp) + tracking.state = new_state + + +# %% +affinity_results_by_tracking = affinity_result_by_tracking(affinities.values()) +for tracking_id, detections in affinity_results_by_tracking.items(): + update_tracking(global_tracking_state.trackings[tracking_id], detections) + +# %%