# --- # jupyter: # jupytext: # text_representation: # extension: .py # format_name: percent # format_version: '1.3' # jupytext_version: 1.17.0 # kernelspec: # display_name: .venv # language: python # name: python3 # --- # %% from collections import OrderedDict from copy import copy as shallow_copy from copy import deepcopy as deep_copy from dataclasses import dataclass from datetime import datetime, timedelta from functools import partial, reduce from itertools import chain from pathlib import Path from typing import ( Any, Generator, Optional, TypeAlias, TypedDict, TypeVar, cast, overload, Iterable, ) import awkward as ak import jax import jax.numpy as jnp import numpy as np from beartype import beartype from beartype.typing import Mapping, Sequence from cv2 import undistortPoints from IPython.display import display from jaxtyping import Array, Float, Num, jaxtyped from matplotlib import pyplot as plt from numpy.typing import ArrayLike from optax.assignment import hungarian_algorithm as linear_sum_assignment from pyrsistent import pvector, v, m, pmap, PMap, freeze, thaw from scipy.spatial.transform import Rotation as R from typing_extensions import deprecated from collections import defaultdict from app.camera import ( Camera, CameraID, CameraParams, Detection, calculate_affinity_matrix_by_epipolar_constraint, classify_by_camera, ) from app.solver._old import GLPKSolver from app.tracking import ( TrackingID, AffinityResult, LastDifferenceVelocityFilter, Tracking, TrackingState, ) from app.visualize.whole_body import visualize_whole_body NDArray: TypeAlias = np.ndarray # %% DATASET_PATH = Path("samples") / "04_02" AK_CAMERA_DATASET: ak.Array = ak.from_parquet(DATASET_PATH / "camera_params.parquet") # type: ignore DELTA_T_MIN = timedelta(milliseconds=1) display(AK_CAMERA_DATASET) # %% class Resolution(TypedDict): width: int height: int class Intrinsic(TypedDict): camera_matrix: Num[Array, "3 3"] """ K """ distortion_coefficients: Num[Array, "N"] """ distortion coefficients; usually 5 """ class Extrinsic(TypedDict): rvec: Num[NDArray, "3"] tvec: Num[NDArray, "3"] class ExternalCameraParams(TypedDict): name: str port: int intrinsic: Intrinsic extrinsic: Extrinsic resolution: Resolution # %% def read_dataset_by_port(port: int) -> ak.Array: P = DATASET_PATH / f"{port}.parquet" return ak.from_parquet(P) KEYPOINT_DATASET = { int(p): read_dataset_by_port(p) for p in ak.to_numpy(AK_CAMERA_DATASET["port"]) } # %% class KeypointDataset(TypedDict): frame_index: int boxes: Num[NDArray, "N 4"] kps: Num[NDArray, "N J 2"] kps_scores: Num[NDArray, "N J"] @jaxtyped(typechecker=beartype) def to_transformation_matrix( rvec: Num[NDArray, "3"], tvec: Num[NDArray, "3"] ) -> Num[NDArray, "4 4"]: res = np.eye(4) res[:3, :3] = R.from_rotvec(rvec).as_matrix() res[:3, 3] = tvec return res @jaxtyped(typechecker=beartype) def undistort_points( points: Num[NDArray, "M 2"], camera_matrix: Num[NDArray, "3 3"], dist_coeffs: Num[NDArray, "N"], ) -> Num[NDArray, "M 2"]: K = camera_matrix dist = dist_coeffs res = undistortPoints(points, K, dist, P=K) # type: ignore return res.reshape(-1, 2) def from_camera_params(camera: ExternalCameraParams) -> Camera: rt = jnp.array( to_transformation_matrix( ak.to_numpy(camera["extrinsic"]["rvec"]), ak.to_numpy(camera["extrinsic"]["tvec"]), ) ) K = jnp.array(camera["intrinsic"]["camera_matrix"]).reshape(3, 3) dist_coeffs = jnp.array(camera["intrinsic"]["distortion_coefficients"]) image_size = jnp.array( (camera["resolution"]["width"], camera["resolution"]["height"]) ) return Camera( id=camera["name"], params=CameraParams( K=K, Rt=rt, dist_coeffs=dist_coeffs, image_size=image_size, ), ) def preprocess_keypoint_dataset( dataset: Sequence[KeypointDataset], camera: Camera, fps: float, start_timestamp: datetime, ) -> Generator[Detection, None, None]: frame_interval_s = 1 / fps for el in dataset: frame_index = el["frame_index"] timestamp = start_timestamp + timedelta(seconds=frame_index * frame_interval_s) for kp, kp_score in zip(el["kps"], el["kps_scores"]): yield Detection( keypoints=jnp.array(kp), confidences=jnp.array(kp_score), camera=camera, timestamp=timestamp, ) # %% DetectionGenerator: TypeAlias = Generator[Detection, None, None] def sync_batch_gen(gens: Sequence[DetectionGenerator], diff: timedelta): """ given a list of detection generators, return a generator that yields a batch of detections Args: gens: list of detection generators diff: maximum timestamp difference between detections to consider them part of the same batch """ N = len(gens) last_batch_timestamp: Optional[datetime] = None next_batch_timestamp: Optional[datetime] = None current_batch: list[Detection] = [] next_batch: list[Detection] = [] paused: list[bool] = [False] * N finished: list[bool] = [False] * N def reset_paused(): """ reset paused list based on finished list """ for i in range(N): if not finished[i]: paused[i] = False else: paused[i] = True EPS = 1e-6 # a small epsilon to avoid floating point precision issues diff_esp = diff - timedelta(seconds=EPS) while True: for i, gen in enumerate(gens): try: if finished[i] or paused[i]: continue val = next(gen) if last_batch_timestamp is None: last_batch_timestamp = val.timestamp current_batch.append(val) else: if abs(val.timestamp - last_batch_timestamp) >= diff_esp: next_batch.append(val) if next_batch_timestamp is None: next_batch_timestamp = val.timestamp paused[i] = True if all(paused): yield current_batch current_batch = next_batch next_batch = [] last_batch_timestamp = next_batch_timestamp next_batch_timestamp = None reset_paused() else: current_batch.append(val) except StopIteration: finished[i] = True paused[i] = True if all(finished): if len(current_batch) > 0: # All generators exhausted, flush remaining batch and exit yield current_batch break # %% @overload def to_projection_matrix( transformation_matrix: Num[NDArray, "4 4"], camera_matrix: Num[NDArray, "3 3"] ) -> Num[NDArray, "3 4"]: ... @overload def to_projection_matrix( transformation_matrix: Num[Array, "4 4"], camera_matrix: Num[Array, "3 3"] ) -> Num[Array, "3 4"]: ... @jaxtyped(typechecker=beartype) def to_projection_matrix( transformation_matrix: Num[Any, "4 4"], camera_matrix: Num[Any, "3 3"], ) -> Num[Any, "3 4"]: return camera_matrix @ transformation_matrix[:3, :] to_projection_matrix_jit = jax.jit(to_projection_matrix) @jaxtyped(typechecker=beartype) def dlt( H1: Num[NDArray, "3 4"], H2: Num[NDArray, "3 4"], p1: Num[NDArray, "2"], p2: Num[NDArray, "2"], ) -> Num[NDArray, "3"]: """ Direct Linear Transformation """ A = [ p1[1] * H1[2, :] - H1[1, :], H1[0, :] - p1[0] * H1[2, :], p2[1] * H2[2, :] - H2[1, :], H2[0, :] - p2[0] * H2[2, :], ] A = np.array(A).reshape((4, 4)) B = A.transpose() @ A from scipy import linalg U, s, Vh = linalg.svd(B, full_matrices=False) return Vh[3, 0:3] / Vh[3, 3] @overload def homogeneous_to_euclidean(points: Num[NDArray, "N 4"]) -> Num[NDArray, "N 3"]: ... @overload def homogeneous_to_euclidean(points: Num[Array, "N 4"]) -> Num[Array, "N 3"]: ... @jaxtyped(typechecker=beartype) def homogeneous_to_euclidean( points: Num[Any, "N 4"], ) -> Num[Any, "N 3"]: """ 将齐次坐标转换为欧几里得坐标 Args: points: homogeneous coordinates (x, y, z, w) in numpy array or jax array Returns: euclidean coordinates (x, y, z) in numpy array or jax array """ return points[..., :-1] / points[..., -1:] # %% FPS = 24 image_gen_5600 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5600], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5600][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore image_gen_5601 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5601], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5601][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore image_gen_5602 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5602], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5602][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore display(1 / FPS) sync_gen = sync_batch_gen( [image_gen_5600, image_gen_5601, image_gen_5602], timedelta(seconds=1 / FPS) ) # %% sorted_detections, affinity_matrix = calculate_affinity_matrix_by_epipolar_constraint( next(sync_gen), alpha_2d=2000 ) display(sorted_detections) # %% display( list( map( lambda x: {"timestamp": str(x.timestamp), "camera": x.camera.id}, sorted_detections, ) ) ) with jnp.printoptions(precision=3, suppress=True): display(affinity_matrix) # %% def clusters_to_detections( clusters: Sequence[Sequence[int]], sorted_detections: Sequence[Detection] ) -> list[list[Detection]]: """ given a list of clusters (which is the indices of the detections in the sorted_detections list), extract the detections from the sorted_detections list Args: clusters: list of clusters, each cluster is a list of indices of the detections in the `sorted_detections` list sorted_detections: list of SORTED detections Returns: list of clusters, each cluster is a list of detections """ return [[sorted_detections[i] for i in cluster] for cluster in clusters] solver = GLPKSolver() aff_np = np.asarray(affinity_matrix).astype(np.float64) clusters, sol_matrix = solver.solve(aff_np) display(clusters) display(sol_matrix) # %% T = TypeVar("T") def flatten_values( d: Mapping[Any, Sequence[T]], ) -> list[T]: """ Flatten a dictionary of sequences into a single list of values. """ return [v for vs in d.values() for v in vs] def flatten_values_len( d: Mapping[Any, Sequence[T]], ) -> int: """ Flatten a dictionary of sequences into a single list of values. """ val = reduce(lambda acc, xs: acc + len(xs), d.values(), 0) return val # %% WIDTH = 2560 HEIGHT = 1440 clusters_detections = clusters_to_detections(clusters, sorted_detections) im = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8) for el in clusters_detections[0]: im = visualize_whole_body(np.asarray(el.keypoints), im) p = plt.imshow(im) display(p) # %% im_prime = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8) for el in clusters_detections[1]: im_prime = visualize_whole_body(np.asarray(el.keypoints), im_prime) p_prime = plt.imshow(im_prime) display(p_prime) # %% @jaxtyped(typechecker=beartype) def triangulate_one_point_from_multiple_views_linear( proj_matrices: Float[Array, "N 3 4"], points: Num[Array, "N 2"], confidences: Optional[Float[Array, "N"]] = None, ) -> Float[Array, "3"]: """ Args: proj_matrices: 形状为(N, 3, 4)的投影矩阵序列 points: 形状为(N, 2)的点坐标序列 confidences: 形状为(N,)的置信度序列,范围[0.0, 1.0] Returns: point_3d: 形状为(3,)的三角测量得到的3D点 """ assert len(proj_matrices) == len(points) N = len(proj_matrices) confi: Float[Array, "N"] if confidences is None: confi = jnp.ones(N, dtype=np.float32) else: # Use square root of confidences for weighting - more balanced approach confi = jnp.sqrt(jnp.clip(confidences, 0, 1)) A = jnp.zeros((N * 2, 4), dtype=np.float32) for i in range(N): x, y = points[i] A = A.at[2 * i].set(proj_matrices[i, 2] * x - proj_matrices[i, 0]) A = A.at[2 * i + 1].set(proj_matrices[i, 2] * y - proj_matrices[i, 1]) A = A.at[2 * i].mul(confi[i]) A = A.at[2 * i + 1].mul(confi[i]) # https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.svd.html _, _, vh = jnp.linalg.svd(A, full_matrices=False) point_3d_homo = vh[-1] # shape (4,) # replace the Python `if` with a jnp.where point_3d_homo = jnp.where( point_3d_homo[3] < 0, # predicate (scalar bool tracer) -point_3d_homo, # if True point_3d_homo, # if False ) point_3d = point_3d_homo[:3] / point_3d_homo[3] return point_3d @jaxtyped(typechecker=beartype) def triangulate_points_from_multiple_views_linear( proj_matrices: Float[Array, "N 3 4"], points: Num[Array, "N P 2"], confidences: Optional[Float[Array, "N P"]] = None, ) -> Float[Array, "P 3"]: """ Batch-triangulate P points observed by N cameras, linearly via SVD. Args: proj_matrices: (N, 3, 4) projection matrices points: (N, P, 2) image-coordinates per view confidences: (N, P, 1) optional per-view confidences in [0,1] Returns: (P, 3) 3D point for each of the P tracks """ N, P, _ = points.shape assert proj_matrices.shape[0] == N if confidences is None: conf = jnp.ones((N, P), dtype=jnp.float32) else: conf = jnp.sqrt(jnp.clip(confidences, 0.0, 1.0)) # vectorize your one-point routine over P vmap_triangulate = jax.vmap( triangulate_one_point_from_multiple_views_linear, in_axes=(None, 1, 1), # proj_matrices static, map over points[:,p,:], conf[:,p] out_axes=0, ) return vmap_triangulate(proj_matrices, points, conf) # %% @jaxtyped(typechecker=beartype) def triangulate_one_point_from_multiple_views_linear_time_weighted( proj_matrices: Float[Array, "N 3 4"], points: Num[Array, "N 2"], delta_t: Num[Array, "N"], lambda_t: float = 10.0, confidences: Optional[Float[Array, "N"]] = None, ) -> Float[Array, "3"]: """ Triangulate one point from multiple views with time-weighted linear least squares. Implements the incremental reconstruction method from "Cross-View Tracking for Multi-Human 3D Pose" with weighting formula: w_i = exp(-λ_t(t-t_i)) / ||c^i^T||_2 Args: proj_matrices: Shape (N, 3, 4) projection matrices sequence points: Shape (N, 2) point coordinates sequence delta_t: Time differences between current time and each observation (in seconds) lambda_t: Time penalty rate (higher values decrease influence of older observations) confidences: Shape (N,) confidence values in range [0.0, 1.0] Returns: point_3d: Shape (3,) triangulated 3D point """ assert len(proj_matrices) == len(points) assert len(delta_t) == len(points) N = len(proj_matrices) # Prepare confidence weights confi: Float[Array, "N"] if confidences is None: confi = jnp.ones(N, dtype=np.float32) else: confi = jnp.sqrt(jnp.clip(confidences, 0, 1)) A = jnp.zeros((N * 2, 4), dtype=np.float32) # First build the coefficient matrix without weights for i in range(N): x, y = points[i] A = A.at[2 * i].set(proj_matrices[i, 2] * x - proj_matrices[i, 0]) A = A.at[2 * i + 1].set(proj_matrices[i, 2] * y - proj_matrices[i, 1]) # Then apply the time-based and confidence weights for i in range(N): # Calculate time-decay weight: e^(-λ_t * Δt) time_weight = jnp.exp(-lambda_t * delta_t[i]) # Calculate normalization factor: ||c^i^T||_2 row_norm_1 = jnp.linalg.norm(A[2 * i]) row_norm_2 = jnp.linalg.norm(A[2 * i + 1]) # Apply combined weight: time_weight / row_norm * confidence w1 = (time_weight / row_norm_1) * confi[i] w2 = (time_weight / row_norm_2) * confi[i] A = A.at[2 * i].mul(w1) A = A.at[2 * i + 1].mul(w2) # Solve using SVD _, _, vh = jnp.linalg.svd(A, full_matrices=False) point_3d_homo = vh[-1] # shape (4,) # Ensure homogeneous coordinate is positive point_3d_homo = jnp.where( point_3d_homo[3] < 0, -point_3d_homo, point_3d_homo, ) # Convert from homogeneous to Euclidean coordinates point_3d = point_3d_homo[:3] / point_3d_homo[3] return point_3d @jaxtyped(typechecker=beartype) def triangulate_points_from_multiple_views_linear_time_weighted( proj_matrices: Float[Array, "N 3 4"], points: Num[Array, "N P 2"], delta_t: Num[Array, "N"], lambda_t: float = 10.0, confidences: Optional[Float[Array, "N P"]] = None, ) -> Float[Array, "P 3"]: """ Vectorized version that triangulates P points from N camera views with time-weighting. This function uses JAX's vmap to efficiently triangulate multiple points in parallel. Args: proj_matrices: Shape (N, 3, 4) projection matrices for N cameras points: Shape (N, P, 2) 2D points for P keypoints across N cameras delta_t: Shape (N,) time differences between current time and each camera's timestamp (seconds) lambda_t: Time penalty rate (higher values decrease influence of older observations) confidences: Shape (N, P) confidence values for each point in each camera Returns: points_3d: Shape (P, 3) triangulated 3D points """ N, P, _ = points.shape assert ( proj_matrices.shape[0] == N ), "Number of projection matrices must match number of cameras" assert delta_t.shape[0] == N, "Number of time deltas must match number of cameras" if confidences is None: # Create uniform confidences if none provided conf = jnp.ones((N, P), dtype=jnp.float32) else: conf = confidences # Define the vmapped version of the single-point function # We map over the second dimension (P points) of the input arrays vmap_triangulate = jax.vmap( triangulate_one_point_from_multiple_views_linear_time_weighted, in_axes=( None, 1, None, None, 1, ), # proj_matrices and delta_t static, map over points out_axes=0, # Output has first dimension corresponding to points ) # For each point p, extract the 2D coordinates from all cameras and triangulate return vmap_triangulate( proj_matrices, # (N, 3, 4) - static across points points, # (N, P, 2) - map over dim 1 (P) delta_t, # (N,) - static across points lambda_t, # scalar - static conf, # (N, P) - map over dim 1 (P) ) # %% @jaxtyped(typechecker=beartype) def triangle_from_cluster( cluster: Sequence[Detection], ) -> tuple[Float[Array, "N 3"], datetime]: proj_matrices = jnp.array([el.camera.params.projection_matrix for el in cluster]) points = jnp.array([el.keypoints_undistorted for el in cluster]) confidences = jnp.array([el.confidences for el in cluster]) latest_timestamp = max(el.timestamp for el in cluster) return ( triangulate_points_from_multiple_views_linear( proj_matrices, points, confidences=confidences ), latest_timestamp, ) # %% def group_by_cluster_by_camera( cluster: Sequence[Detection], ) -> PMap[CameraID, Detection]: """ group the detections by camera, and preserve the latest detection for each camera """ r: dict[CameraID, Detection] = {} for el in cluster: if el.camera.id in r: eld = r[el.camera.id] preserved = max([eld, el], key=lambda x: x.timestamp) r[el.camera.id] = preserved return pmap(r) class GlobalTrackingState: _last_id: int _trackings: dict[int, Tracking] def __init__(self): self._last_id = 0 self._trackings = {} def __repr__(self) -> str: return ( f"GlobalTrackingState(last_id={self._last_id}, trackings={self._trackings})" ) @property def trackings(self) -> dict[int, Tracking]: return shallow_copy(self._trackings) def add_tracking(self, cluster: Sequence[Detection]) -> Tracking: if len(cluster) < 2: raise ValueError( "cluster must contain at least 2 detections to form a tracking" ) kps_3d, latest_timestamp = triangle_from_cluster(cluster) next_id = self._last_id + 1 tracking_state = TrackingState( keypoints=kps_3d, last_active_timestamp=latest_timestamp, historical_detections_by_camera=group_by_cluster_by_camera(cluster), ) tracking = Tracking( id=next_id, state=tracking_state, velocity_filter=LastDifferenceVelocityFilter(kps_3d, latest_timestamp), ) self._trackings[next_id] = tracking self._last_id = next_id return tracking global_tracking_state = GlobalTrackingState() for cluster in clusters_detections: global_tracking_state.add_tracking(cluster) display(global_tracking_state) # %% next_group = next(sync_gen) display(next_group) # %% @jaxtyped(typechecker=beartype) def calculate_distance_2d( left: Num[Array, "J 2"], right: Num[Array, "J 2"], image_size: tuple[int, int] = (1, 1), ) -> Float[Array, "J"]: """ Calculate the *normalized* distance between two sets of keypoints. Args: left: The left keypoints right: The right keypoints image_size: The size of the image Returns: Array of normalized Euclidean distances between corresponding keypoints """ w, h = image_size if w == 1 and h == 1: # already normalized left_normalized = left right_normalized = right else: left_normalized = left / jnp.array([w, h]) right_normalized = right / jnp.array([w, h]) dist = jnp.linalg.norm(left_normalized - right_normalized, axis=-1) return dist @jaxtyped(typechecker=beartype) def calculate_affinity_2d( distance_2d: Float[Array, "J"], delta_t: timedelta, w_2d: float, alpha_2d: float, lambda_a: float, ) -> Float[Array, "J"]: """ Calculate the affinity between two detections based on the distances between their keypoints. The affinity score is calculated by summing individual keypoint affinities: A_2D = sum(w_2D * (1 - distance_2D / (alpha_2D*delta_t)) * np.exp(-lambda_a * delta_t)) for each keypoint Args: distance_2d: The normalized distances between keypoints (array with one value per keypoint) w_2d: The weight for 2D affinity alpha_2d: The normalization factor for distance lambda_a: The decay rate for time difference delta_t: The time delta between the two detections, in seconds Returns: Sum of affinity scores across all keypoints """ delta_t_s = delta_t.total_seconds() affinity_per_keypoint = ( w_2d * (1 - distance_2d / (alpha_2d * delta_t_s)) * jnp.exp(-lambda_a * delta_t_s) ) return affinity_per_keypoint @jaxtyped(typechecker=beartype) def perpendicular_distance_point_to_line_two_points( point: Num[Array, "3"], line: tuple[Num[Array, "3"], Num[Array, "3"]] ) -> Float[Array, ""]: """ Calculate the perpendicular distance between a point and a line. where `line` is represented by two points: `(line_start, line_end)` Args: point: The point to calculate the distance to line: The line to calculate the distance to, represented by two points Returns: The perpendicular distance between the point and the line (should be a scalar in `float`) """ line_start, line_end = line distance = jnp.linalg.norm( jnp.cross(line_end - line_start, line_start - point) ) / jnp.linalg.norm(line_end - line_start) return distance @jaxtyped(typechecker=beartype) def perpendicular_distance_camera_2d_points_to_tracking_raycasting( detection: Detection, tracking: Tracking, delta_t: timedelta, ) -> Float[Array, "J"]: """ NOTE: `delta_t` is now taken from the caller and NOT recomputed internally. Calculate the perpendicular distances between predicted 3D tracking points and the rays cast from camera center through the 2D image points. Args: detection: The detection object containing 2D keypoints and camera parameters tracking: The tracking object containing 3D keypoints delta_t: Time delta between the tracking's last update and current observation Returns: Array of perpendicular distances for each keypoint """ camera = detection.camera predicted_pose = tracking.predict(delta_t) # Back-project the 2D points to 3D space # intersection with z=0 plane back_projected_points = detection.camera.unproject_points_to_z_plane( detection.keypoints, z=0.0 ) camera_center = camera.params.location def calc_distance(predicted_point, back_projected_point): return perpendicular_distance_point_to_line_two_points( predicted_point, (camera_center, back_projected_point) ) # Vectorize over all keypoints vmap_calc_distance = jax.vmap(calc_distance) distances: Float[Array, "J"] = vmap_calc_distance( predicted_pose, back_projected_points ) return distances @jaxtyped(typechecker=beartype) def calculate_affinity_3d( distances: Float[Array, "J"], delta_t: timedelta, w_3d: float, alpha_3d: float, lambda_a: float, ) -> Float[Array, "J"]: """ Calculate 3D affinity score between a tracking and detection. The affinity score is calculated by summing individual keypoint affinities: A_3D = sum(w_3D * (1 - dl / alpha_3D) * np.exp(-lambda_a * delta_t)) for each keypoint Args: distances: Array of perpendicular distances for each keypoint delta_t: Time difference between tracking and detection w_3d: Weight for 3D affinity alpha_3d: Normalization factor for distance lambda_a: Decay rate for time difference Returns: Sum of affinity scores across all keypoints """ delta_t_s = delta_t.total_seconds() affinity_per_keypoint = ( w_3d * (1 - distances / alpha_3d) * jnp.exp(-lambda_a * delta_t_s) ) return affinity_per_keypoint @beartype def calculate_tracking_detection_affinity( tracking: Tracking, detection: Detection, w_2d: float, alpha_2d: float, w_3d: float, alpha_3d: float, lambda_a: float, ) -> float: """ Calculate the affinity between a tracking and a detection. Args: tracking: The tracking object detection: The detection object w_2d: Weight for 2D affinity alpha_2d: Normalization factor for 2D distance w_3d: Weight for 3D affinity alpha_3d: Normalization factor for 3D distance lambda_a: Decay rate for time difference Returns: Combined affinity score """ camera = detection.camera delta_t_raw = detection.timestamp - tracking.state.last_active_timestamp # Clamp delta_t to avoid division-by-zero / exploding affinity. delta_t = max(delta_t_raw, DELTA_T_MIN) # Calculate 2D affinity tracking_2d_projection = camera.project(tracking.state.keypoints) w, h = camera.params.image_size distance_2d = calculate_distance_2d( tracking_2d_projection, detection.keypoints, image_size=(int(w), int(h)), ) affinity_2d = calculate_affinity_2d( distance_2d, delta_t, w_2d=w_2d, alpha_2d=alpha_2d, lambda_a=lambda_a, ) # Calculate 3D affinity distances = perpendicular_distance_camera_2d_points_to_tracking_raycasting( detection, tracking, delta_t ) affinity_3d = calculate_affinity_3d( distances, delta_t, w_3d=w_3d, alpha_3d=alpha_3d, lambda_a=lambda_a, ) # Combine affinities total_affinity = affinity_2d + affinity_3d return jnp.sum(total_affinity).item() # %% @beartype def calculate_camera_affinity_matrix_jax( trackings: Sequence[Tracking], camera_detections: Sequence[Detection], w_2d: float, alpha_2d: float, w_3d: float, alpha_3d: float, lambda_a: float, ) -> Float[Array, "T D"]: """ Vectorized implementation to compute an affinity matrix between *trackings* and *detections* coming from **one** camera. Compared with the simple double-for-loop version, this leverages `jax`'s broadcasting + `vmap` facilities and avoids Python loops over every (tracking, detection) pair. The mathematical definition of the affinity is **unchanged**, so the result remains bit-identical to the reference implementation used in the tests. """ # ------------------------------------------------------------------ # Quick validations / early-exit guards # ------------------------------------------------------------------ if len(trackings) == 0 or len(camera_detections) == 0: # Return an empty affinity matrix with appropriate shape. return jnp.zeros((len(trackings), len(camera_detections))) # type: ignore[return-value] cam = next(iter(camera_detections)).camera # Ensure every detection truly belongs to the same camera (guard clause) cam_id = cam.id if any(det.camera.id != cam_id for det in camera_detections): raise ValueError( "All detections passed to `calculate_camera_affinity_matrix` must come from one camera." ) # We will rely on a single `Camera` instance (all detections share it) w_img_, h_img_ = cam.params.image_size w_img, h_img = float(w_img_), float(h_img_) # ------------------------------------------------------------------ # Gather data into ndarray / DeviceArray batches so that we can compute # everything in a single (or a few) fused kernels. # ------------------------------------------------------------------ # === Tracking-side tensors === kps3d_trk: Float[Array, "T J 3"] = jnp.stack( [trk.state.keypoints for trk in trackings] ) # (T, J, 3) J = kps3d_trk.shape[1] # === Detection-side tensors === kps2d_det: Float[Array, "D J 2"] = jnp.stack( [det.keypoints for det in camera_detections] ) # (D, J, 2) # ------------------------------------------------------------------ # Compute Δt matrix – shape (T, D) # ------------------------------------------------------------------ # Epoch timestamps are ~1.7 × 10⁹; storing them in float32 wipes out # sub‑second detail (resolution ≈ 200 ms). Keep them in float64 until # after subtraction so we preserve Δt‑on‑the‑order‑of‑milliseconds. # --- timestamps ---------- t0 = min( chain( (trk.state.last_active_timestamp for trk in trackings), (det.timestamp for det in camera_detections), ) ).timestamp() # common origin (float) ts_trk = jnp.array( [trk.state.last_active_timestamp.timestamp() - t0 for trk in trackings], dtype=jnp.float32, # now small, ms-scale fits in fp32 ) ts_det = jnp.array( [det.timestamp.timestamp() - t0 for det in camera_detections], dtype=jnp.float32, ) # Δt in seconds, fp32 throughout delta_t = ts_det[None, :] - ts_trk[:, None] # (T,D) min_dt_s = float(DELTA_T_MIN.total_seconds()) delta_t = jnp.clip(delta_t, a_min=min_dt_s, a_max=None) # ------------------------------------------------------------------ # ---------- 2D affinity ------------------------------------------- # ------------------------------------------------------------------ # Project each tracking's 3D keypoints onto the image once. # `Camera.project` works per-sample, so we vmap over the first axis. proj_fn = jax.vmap(cam.project, in_axes=0) # maps over the keypoint sets kps2d_trk_proj: Float[Array, "T J 2"] = proj_fn(kps3d_trk) # (T, J, 2) # Normalise keypoints by image size so absolute units do not bias distance norm_trk = kps2d_trk_proj / jnp.array([w_img, h_img]) norm_det = kps2d_det / jnp.array([w_img, h_img]) # L2 distance for every (T, D, J) # reshape for broadcasting: (T,1,J,2) vs (1,D,J,2) diff2d = norm_trk[:, None, :, :] - norm_det[None, :, :, :] dist2d: Float[Array, "T D J"] = jnp.linalg.norm(diff2d, axis=-1) # Compute per-keypoint 2D affinity delta_t_broadcast = delta_t[:, :, None] # (T, D, 1) affinity_2d = ( w_2d * (1 - dist2d / (alpha_2d * delta_t_broadcast)) * jnp.exp(-lambda_a * delta_t_broadcast) ) # ------------------------------------------------------------------ # ---------- 3D affinity ------------------------------------------- # ------------------------------------------------------------------ # For each detection pre-compute back-projected 3D points lying on z=0 plane. backproj_points_list = [ det.camera.unproject_points_to_z_plane(det.keypoints, z=0.0) for det in camera_detections ] # each (J,3) backproj: Float[Array, "D J 3"] = jnp.stack(backproj_points_list) # (D, J, 3) zero_velocity = jnp.zeros((J, 3)) trk_velocities = jnp.stack( [ trk.velocity if trk.velocity is not None else zero_velocity for trk in trackings ] ) predicted_pose: Float[Array, "T D J 3"] = ( kps3d_trk[:, None, :, :] # (T,1,J,3) + trk_velocities[:, None, :, :] * delta_t[:, :, None, None] # (T,D,1,1) ) # Camera center – shape (3,) -> will broadcast cam_center = cam.params.location # Compute perpendicular distance using vectorized formula # p1 = cam_center (3,) # p2 = backproj (D, J, 3) # P = predicted_pose (T, D, J, 3) # Broadcast plan: v1 = P - p1 → (T, D, J, 3) # v2 = p2[None, ...]-p1 → (1, D, J, 3) # Shapes now line up; no stray singleton axis. p1 = cam_center p2 = backproj P = predicted_pose v1 = P - p1 v2 = p2[None, :, :, :] - p1 # (1, D, J, 3) cross = jnp.cross(v1, v2) # (T, D, J, 3) num = jnp.linalg.norm(cross, axis=-1) # (T, D, J) den = jnp.linalg.norm(v2, axis=-1) # (1, D, J) dist3d: Float[Array, "T D J"] = num / den affinity_3d = ( w_3d * (1 - dist3d / alpha_3d) * jnp.exp(-lambda_a * delta_t_broadcast) ) # ------------------------------------------------------------------ # Combine and reduce across keypoints → (T, D) # ------------------------------------------------------------------ total_affinity: Float[Array, "T D"] = jnp.sum(affinity_2d + affinity_3d, axis=-1) return total_affinity # type: ignore[return-value] @beartype def calculate_affinity_matrix( trackings: Sequence[Tracking], detections: Sequence[Detection] | Mapping[CameraID, list[Detection]], w_2d: float, alpha_2d: float, w_3d: float, alpha_3d: float, lambda_a: float, ) -> dict[CameraID, AffinityResult]: """ Calculate the affinity matrix between a set of trackings and detections. Args: trackings: Sequence of tracking objects detections: Sequence of detection objects or a group detections by ID w_2d: Weight for 2D affinity alpha_2d: Normalization factor for 2D distance w_3d: Weight for 3D affinity alpha_3d: Normalization factor for 3D distance lambda_a: Decay rate for time difference Returns: A dictionary mapping camera IDs to affinity results. """ if isinstance(detections, Mapping): detection_by_camera = detections else: detection_by_camera = classify_by_camera(detections) res: dict[CameraID, AffinityResult] = {} for camera_id, camera_detections in detection_by_camera.items(): affinity_matrix = calculate_camera_affinity_matrix_jax( trackings, camera_detections, w_2d, alpha_2d, w_3d, alpha_3d, lambda_a, ) # row, col indices_T, indices_D = linear_sum_assignment(affinity_matrix) affinity_result = AffinityResult( matrix=affinity_matrix, trackings=trackings, detections=camera_detections, indices_T=indices_T, indices_D=indices_D, ) res[camera_id] = affinity_result return res # %% # let's do cross-view association W_2D = 1.0 ALPHA_2D = 1.0 LAMBDA_A = 0.1 W_3D = 1.0 ALPHA_3D = 1.0 trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id) unmatched_detections = shallow_copy(next_group) camera_detections = classify_by_camera(unmatched_detections) affinities = calculate_affinity_matrix( trackings, unmatched_detections, w_2d=W_2D, alpha_2d=ALPHA_2D, w_3d=W_3D, alpha_3d=ALPHA_3D, lambda_a=LAMBDA_A, ) display(affinities) # %% def affinity_result_by_tracking( results: Iterable[AffinityResult], ) -> dict[TrackingID, list[Detection]]: """ Group affinity results by target ID. """ res: dict[TrackingID, list[Detection]] = defaultdict(list) for affinity_result in results: for _affinity, t, d in affinity_result.tracking_association(): res[t.id].append(d) return res def update_tracking( tracking: Tracking, detections: Sequence[Detection], max_delta_t: timedelta = timedelta(milliseconds=100), lambda_t: float = 10.0, ) -> None: """ update the tracking with a new set of detections Args: tracking: the tracking to update detections: the detections to update the tracking with max_delta_t: the maximum time difference between the last active timestamp and the latest detection lambda_t: the lambda value for the time difference Note: the function would mutate the tracking object """ last_active_timestamp = tracking.state.last_active_timestamp latest_timestamp = max(d.timestamp for d in detections) d = thaw(tracking.state.historical_detections_by_camera) for detection in detections: d[detection.camera.id] = detection for camera_id, detection in d.items(): if detection.timestamp - latest_timestamp > max_delta_t: del d[camera_id] new_detections = freeze(d) new_detections_list = list(new_detections.values()) project_matrices = jnp.stack( [detection.camera.params.projection_matrix for detection in new_detections_list] ) delta_t = jnp.array( [ detection.timestamp.timestamp() - last_active_timestamp.timestamp() for detection in new_detections_list ] ) kps = jnp.stack([detection.keypoints for detection in new_detections_list]) conf = jnp.stack([detection.confidences for detection in new_detections_list]) kps_3d = triangulate_points_from_multiple_views_linear_time_weighted( project_matrices, kps, delta_t, lambda_t, conf ) new_state = TrackingState( keypoints=kps_3d, last_active_timestamp=latest_timestamp, historical_detections_by_camera=new_detections, ) tracking.update(kps_3d, latest_timestamp) tracking.state = new_state # %% affinity_results_by_tracking = affinity_result_by_tracking(affinities.values()) for tracking_id, detections in affinity_results_by_tracking.items(): update_tracking(global_tracking_state.trackings[tracking_id], detections) # %%