# --- # jupyter: # jupytext: # text_representation: # extension: .py # format_name: percent # format_version: '1.3' # jupytext_version: 1.17.0 # kernelspec: # display_name: .venv # language: python # name: python3 # --- from collections import OrderedDict # %% from copy import copy as shallow_copy from copy import deepcopy from copy import deepcopy as deep_copy from dataclasses import dataclass from datetime import datetime, timedelta from pathlib import Path from typing import ( Any, Generator, Mapping, Optional, Sequence, TypeAlias, TypedDict, TypeVar, cast, overload, ) import awkward as ak import jax import jax.numpy as jnp import numpy as np import orjson from beartype import beartype from cv2 import undistortPoints from IPython.display import display from jaxtyping import Array, Float, Num, jaxtyped from matplotlib import pyplot as plt from numpy.typing import ArrayLike from scipy.optimize import linear_sum_assignment from scipy.spatial.transform import Rotation as R from app.camera import ( Camera, CameraID, CameraParams, Detection, calculate_affinity_matrix_by_epipolar_constraint, classify_by_camera, ) from app.solver._old import GLPKSolver from app.visualize.whole_body import visualize_whole_body NDArray: TypeAlias = np.ndarray # %% DATASET_PATH = Path("samples") / "04_02" AK_CAMERA_DATASET: ak.Array = ak.from_parquet(DATASET_PATH / "camera_params.parquet") display(AK_CAMERA_DATASET) # %% class Resolution(TypedDict): width: int height: int class Intrinsic(TypedDict): camera_matrix: Num[Array, "3 3"] """ K """ distortion_coefficients: Num[Array, "N"] """ distortion coefficients; usually 5 """ class Extrinsic(TypedDict): rvec: Num[NDArray, "3"] tvec: Num[NDArray, "3"] class ExternalCameraParams(TypedDict): name: str port: int intrinsic: Intrinsic extrinsic: Extrinsic resolution: Resolution # %% def read_dataset_by_port(port: int) -> ak.Array: P = DATASET_PATH / f"{port}.parquet" return ak.from_parquet(P) KEYPOINT_DATASET = { int(p): read_dataset_by_port(p) for p in ak.to_numpy(AK_CAMERA_DATASET["port"]) } # %% class KeypointDataset(TypedDict): frame_index: int boxes: Num[NDArray, "N 4"] kps: Num[NDArray, "N J 2"] kps_scores: Num[NDArray, "N J"] @jaxtyped(typechecker=beartype) def to_transformation_matrix( rvec: Num[NDArray, "3"], tvec: Num[NDArray, "3"] ) -> Num[NDArray, "4 4"]: res = np.eye(4) res[:3, :3] = R.from_rotvec(rvec).as_matrix() res[:3, 3] = tvec return res @jaxtyped(typechecker=beartype) def undistort_points( points: Num[NDArray, "M 2"], camera_matrix: Num[NDArray, "3 3"], dist_coeffs: Num[NDArray, "N"], ) -> Num[NDArray, "M 2"]: K = camera_matrix dist = dist_coeffs res = undistortPoints(points, K, dist, P=K) # type: ignore return res.reshape(-1, 2) def from_camera_params(camera: ExternalCameraParams) -> Camera: rt = jnp.array( to_transformation_matrix( ak.to_numpy(camera["extrinsic"]["rvec"]), ak.to_numpy(camera["extrinsic"]["tvec"]), ) ) K = jnp.array(camera["intrinsic"]["camera_matrix"]).reshape(3, 3) dist_coeffs = jnp.array(camera["intrinsic"]["distortion_coefficients"]) image_size = jnp.array( (camera["resolution"]["width"], camera["resolution"]["height"]) ) return Camera( id=camera["name"], params=CameraParams( K=K, Rt=rt, dist_coeffs=dist_coeffs, image_size=image_size, ), ) def preprocess_keypoint_dataset( dataset: Sequence[KeypointDataset], camera: Camera, fps: float, start_timestamp: datetime, ) -> Generator[Detection, None, None]: frame_interval_s = 1 / fps for el in dataset: frame_index = el["frame_index"] timestamp = start_timestamp + timedelta(seconds=frame_index * frame_interval_s) for kp, kp_score in zip(el["kps"], el["kps_scores"]): yield Detection( keypoints=jnp.array(kp), confidences=jnp.array(kp_score), camera=camera, timestamp=timestamp, ) # %% DetectionGenerator: TypeAlias = Generator[Detection, None, None] def sync_batch_gen(gens: Sequence[DetectionGenerator], diff: timedelta): """ given a list of detection generators, return a generator that yields a batch of detections Args: gens: list of detection generators diff: maximum timestamp difference between detections to consider them part of the same batch """ N = len(gens) last_batch_timestamp: Optional[datetime] = None next_batch_timestamp: Optional[datetime] = None current_batch: list[Detection] = [] next_batch: list[Detection] = [] paused: list[bool] = [False] * N finished: list[bool] = [False] * N def reset_paused(): """ reset paused list based on finished list """ for i in range(N): if not finished[i]: paused[i] = False else: paused[i] = True EPS = 1e-6 # a small epsilon to avoid floating point precision issues diff_esp = diff - timedelta(seconds=EPS) while True: for i, gen in enumerate(gens): try: if finished[i] or paused[i]: continue val = next(gen) if last_batch_timestamp is None: last_batch_timestamp = val.timestamp current_batch.append(val) else: if abs(val.timestamp - last_batch_timestamp) >= diff_esp: next_batch.append(val) if next_batch_timestamp is None: next_batch_timestamp = val.timestamp paused[i] = True if all(paused): yield current_batch current_batch = next_batch next_batch = [] last_batch_timestamp = next_batch_timestamp next_batch_timestamp = None reset_paused() else: current_batch.append(val) except StopIteration: finished[i] = True paused[i] = True if all(finished): if len(current_batch) > 0: # All generators exhausted, flush remaining batch and exit yield current_batch break # %% @overload def to_projection_matrix( transformation_matrix: Num[NDArray, "4 4"], camera_matrix: Num[NDArray, "3 3"] ) -> Num[NDArray, "3 4"]: ... @overload def to_projection_matrix( transformation_matrix: Num[Array, "4 4"], camera_matrix: Num[Array, "3 3"] ) -> Num[Array, "3 4"]: ... @jaxtyped(typechecker=beartype) def to_projection_matrix( transformation_matrix: Num[Any, "4 4"], camera_matrix: Num[Any, "3 3"], ) -> Num[Any, "3 4"]: return camera_matrix @ transformation_matrix[:3, :] to_projection_matrix_jit = jax.jit(to_projection_matrix) @jaxtyped(typechecker=beartype) def dlt( H1: Num[NDArray, "3 4"], H2: Num[NDArray, "3 4"], p1: Num[NDArray, "2"], p2: Num[NDArray, "2"], ) -> Num[NDArray, "3"]: """ Direct Linear Transformation """ A = [ p1[1] * H1[2, :] - H1[1, :], H1[0, :] - p1[0] * H1[2, :], p2[1] * H2[2, :] - H2[1, :], H2[0, :] - p2[0] * H2[2, :], ] A = np.array(A).reshape((4, 4)) B = A.transpose() @ A from scipy import linalg U, s, Vh = linalg.svd(B, full_matrices=False) return Vh[3, 0:3] / Vh[3, 3] @overload def homogeneous_to_euclidean(points: Num[NDArray, "N 4"]) -> Num[NDArray, "N 3"]: ... @overload def homogeneous_to_euclidean(points: Num[Array, "N 4"]) -> Num[Array, "N 3"]: ... @jaxtyped(typechecker=beartype) def homogeneous_to_euclidean( points: Num[Any, "N 4"], ) -> Num[Any, "N 3"]: """ 将齐次坐标转换为欧几里得坐标 Args: points: homogeneous coordinates (x, y, z, w) in numpy array or jax array Returns: euclidean coordinates (x, y, z) in numpy array or jax array """ return points[..., :-1] / points[..., -1:] # %% FPS = 24 image_gen_5600 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5600], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5600][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore image_gen_5601 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5601], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5601][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore image_gen_5602 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5602], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5602][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore display(1 / FPS) sync_gen = sync_batch_gen( [image_gen_5600, image_gen_5601, image_gen_5602], timedelta(seconds=1 / FPS) ) # %% sorted_detections, affinity_matrix = calculate_affinity_matrix_by_epipolar_constraint( next(sync_gen), alpha_2d=2000 ) display(sorted_detections) # %% display( list( map( lambda x: {"timestamp": str(x.timestamp), "camera": x.camera.id}, sorted_detections, ) ) ) with jnp.printoptions(precision=3, suppress=True): display(affinity_matrix) # %% def clusters_to_detections( clusters: Sequence[Sequence[int]], sorted_detections: Sequence[Detection] ) -> list[list[Detection]]: """ given a list of clusters (which is the indices of the detections in the sorted_detections list), extract the detections from the sorted_detections list Args: clusters: list of clusters, each cluster is a list of indices of the detections in the `sorted_detections` list sorted_detections: list of SORTED detections Returns: list of clusters, each cluster is a list of detections """ return [[sorted_detections[i] for i in cluster] for cluster in clusters] solver = GLPKSolver() aff_np = np.asarray(affinity_matrix).astype(np.float64) clusters, sol_matrix = solver.solve(aff_np) display(clusters) display(sol_matrix) # %% WIDTH = 2560 HEIGHT = 1440 clusters_detections = clusters_to_detections(clusters, sorted_detections) im = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8) for el in clusters_detections[0]: im = visualize_whole_body(np.asarray(el.keypoints), im) p = plt.imshow(im) display(p) # %% im_prime = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8) for el in clusters_detections[1]: im_prime = visualize_whole_body(np.asarray(el.keypoints), im_prime) p_prime = plt.imshow(im_prime) display(p_prime) # %% @jaxtyped(typechecker=beartype) def triangulate_one_point_from_multiple_views_linear( proj_matrices: Float[Array, "N 3 4"], points: Num[Array, "N 2"], confidences: Optional[Float[Array, "N"]] = None, ) -> Float[Array, "3"]: """ Args: proj_matrices: 形状为(N, 3, 4)的投影矩阵序列 points: 形状为(N, 2)的点坐标序列 confidences: 形状为(N,)的置信度序列,范围[0.0, 1.0] Returns: point_3d: 形状为(3,)的三角测量得到的3D点 """ assert len(proj_matrices) == len(points) N = len(proj_matrices) confi: Float[Array, "N"] if confidences is None: confi = jnp.ones(N, dtype=np.float32) else: # Use square root of confidences for weighting - more balanced approach confi = jnp.sqrt(jnp.clip(confidences, 0, 1)) A = jnp.zeros((N * 2, 4), dtype=np.float32) for i in range(N): x, y = points[i] A = A.at[2 * i].set(proj_matrices[i, 2] * x - proj_matrices[i, 0]) A = A.at[2 * i + 1].set(proj_matrices[i, 2] * y - proj_matrices[i, 1]) A = A.at[2 * i].mul(confi[i]) A = A.at[2 * i + 1].mul(confi[i]) # https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.svd.html _, _, vh = jnp.linalg.svd(A, full_matrices=False) point_3d_homo = vh[-1] # shape (4,) # replace the Python `if` with a jnp.where point_3d_homo = jnp.where( point_3d_homo[3] < 0, # predicate (scalar bool tracer) -point_3d_homo, # if True point_3d_homo, # if False ) point_3d = point_3d_homo[:3] / point_3d_homo[3] return point_3d @jaxtyped(typechecker=beartype) def triangulate_points_from_multiple_views_linear( proj_matrices: Float[Array, "N 3 4"], points: Num[Array, "N P 2"], confidences: Optional[Float[Array, "N P"]] = None, ) -> Float[Array, "P 3"]: """ Batch-triangulate P points observed by N cameras, linearly via SVD. Args: proj_matrices: (N, 3, 4) projection matrices points: (N, P, 2) image-coordinates per view confidences: (N, P, 1) optional per-view confidences in [0,1] Returns: (P, 3) 3D point for each of the P tracks """ N, P, _ = points.shape assert proj_matrices.shape[0] == N if confidences is None: conf = jnp.ones((N, P), dtype=jnp.float32) else: conf = jnp.sqrt(jnp.clip(confidences, 0.0, 1.0)) # vectorize your one-point routine over P vmap_triangulate = jax.vmap( triangulate_one_point_from_multiple_views_linear, in_axes=(None, 1, 1), # proj_matrices static, map over points[:,p,:], conf[:,p] out_axes=0, ) return vmap_triangulate(proj_matrices, points, conf) # %% @jaxtyped(typechecker=beartype) @dataclass(frozen=True) class Tracking: id: int """ The tracking id """ keypoints: Float[Array, "J 3"] """ The 3D keypoints of the tracking """ last_active_timestamp: datetime velocity: Optional[Float[Array, "3"]] = None """ Could be `None`. Like when the 3D pose is initialized. `velocity` should be updated when target association yields a new 3D pose. """ def __repr__(self) -> str: return f"Tracking({self.id}, {self.last_active_timestamp})" @jaxtyped(typechecker=beartype) def triangle_from_cluster( cluster: Sequence[Detection], ) -> tuple[Float[Array, "N 3"], datetime]: proj_matrices = jnp.array([el.camera.params.projection_matrix for el in cluster]) points = jnp.array([el.keypoints_undistorted for el in cluster]) confidences = jnp.array([el.confidences for el in cluster]) latest_timestamp = max(el.timestamp for el in cluster) return ( triangulate_points_from_multiple_views_linear( proj_matrices, points, confidences=confidences ), latest_timestamp, ) # %% class GlobalTrackingState: _last_id: int _trackings: dict[int, Tracking] def __init__(self): self._last_id = 0 self._trackings = {} def __repr__(self) -> str: return ( f"GlobalTrackingState(last_id={self._last_id}, trackings={self._trackings})" ) @property def trackings(self) -> dict[int, Tracking]: return shallow_copy(self._trackings) def add_tracking(self, cluster: Sequence[Detection]) -> Tracking: kps_3d, latest_timestamp = triangle_from_cluster(cluster) next_id = self._last_id + 1 tracking = Tracking( id=next_id, keypoints=kps_3d, last_active_timestamp=latest_timestamp ) self._trackings[next_id] = tracking self._last_id = next_id return tracking global_tracking_state = GlobalTrackingState() for cluster in clusters_detections: global_tracking_state.add_tracking(cluster) display(global_tracking_state) # %% next_group = next(sync_gen) display(next_group) # %% @jaxtyped(typechecker=beartype) def calculate_distance_2d( left: Num[Array, "J 2"], right: Num[Array, "J 2"], image_size: tuple[int, int] = (1, 1), ) -> Float[Array, "J"]: """ Calculate the *normalized* distance between two sets of keypoints. Args: left: The left keypoints right: The right keypoints image_size: The size of the image Returns: Array of normalized Euclidean distances between corresponding keypoints """ w, h = image_size if w == 1 and h == 1: # already normalized left_normalized = left right_normalized = right else: left_normalized = left / jnp.array([w, h]) right_normalized = right / jnp.array([w, h]) return jnp.linalg.norm(left_normalized - right_normalized, axis=-1) @jaxtyped(typechecker=beartype) def calculate_affinity_2d( distance_2d: Float[Array, "J"], delta_t: timedelta, w_2d: float, alpha_2d: float, lambda_a: float, ) -> Float[Array, "J"]: """ Calculate the affinity between two detections based on the distances between their keypoints. The affinity score is calculated by summing individual keypoint affinities: A_2D = sum(w_2D * (1 - distance_2D / (alpha_2D*delta_t)) * np.exp(-lambda_a * delta_t)) for each keypoint Args: distance_2d: The normalized distances between keypoints (array with one value per keypoint) w_2d: The weight for 2D affinity alpha_2d: The normalization factor for distance lambda_a: The decay rate for time difference delta_t: The time delta between the two detections, in seconds Returns: Sum of affinity scores across all keypoints """ delta_t_s = delta_t.total_seconds() affinity_per_keypoint = ( w_2d * (1 - distance_2d / (alpha_2d * delta_t_s)) * jnp.exp(-lambda_a * delta_t_s) ) return affinity_per_keypoint @jaxtyped(typechecker=beartype) def perpendicular_distance_point_to_line_two_points( point: Num[Array, "3"], line: tuple[Num[Array, "3"], Num[Array, "3"]] ) -> Float[Array, ""]: """ Calculate the perpendicular distance between a point and a line. where `line` is represented by two points: `(line_start, line_end)` Args: point: The point to calculate the distance to line: The line to calculate the distance to, represented by two points Returns: The perpendicular distance between the point and the line (should be a scalar in `float`) """ line_start, line_end = line distance = jnp.linalg.norm( jnp.cross(line_end - line_start, line_start - point) ) / jnp.linalg.norm(line_end - line_start) return distance @jaxtyped(typechecker=beartype) def perpendicular_distance_camera_2d_points_to_tracking_raycasting( detection: Detection, tracking: Tracking, delta_t: timedelta, ) -> Float[Array, "J"]: """ Calculate the perpendicular distances between predicted 3D tracking points and the rays cast from camera center through the 2D image points. Args: detection: The detection object containing 2D keypoints and camera parameters tracking: The tracking object containing 3D keypoints delta_t: Time delta between the tracking's last update and current observation Returns: Array of perpendicular distances for each keypoint """ camera = detection.camera delta_t_s = delta_t.total_seconds() predicted_pose = predict_pose_3d(tracking, delta_t_s) # Back-project the 2D points to 3D space # intersection with z=0 plane back_projected_points = detection.camera.unproject_points_to_z_plane( detection.keypoints, z=0.0 ) camera_center = camera.params.location def calc_distance(predicted_point, back_projected_point): return perpendicular_distance_point_to_line_two_points( predicted_point, (camera_center, back_projected_point) ) # Vectorize over all keypoints vmap_calc_distance = jax.vmap(calc_distance) distances: Float[Array, "J"] = vmap_calc_distance( predicted_pose, back_projected_points ) return distances @jaxtyped(typechecker=beartype) def calculate_affinity_3d( distances: Float[Array, "J"], delta_t: timedelta, w_3d: float, alpha_3d: float, lambda_a: float, ) -> Float[Array, "J"]: """ Calculate 3D affinity score between a tracking and detection. The affinity score is calculated by summing individual keypoint affinities: A_3D = sum(w_3D * (1 - dl / alpha_3D) * np.exp(-lambda_a * delta_t)) for each keypoint Args: distances: Array of perpendicular distances for each keypoint delta_t: Time difference between tracking and detection w_3d: Weight for 3D affinity alpha_3d: Normalization factor for distance lambda_a: Decay rate for time difference Returns: Sum of affinity scores across all keypoints """ delta_t_s = delta_t.total_seconds() affinity_per_keypoint = ( w_3d * (1 - distances / alpha_3d) * jnp.exp(-lambda_a * delta_t_s) ) return affinity_per_keypoint def predict_pose_3d( tracking: Tracking, delta_t_s: float, ) -> Float[Array, "J 3"]: """ Predict the 3D pose of a tracking based on its velocity. """ if tracking.velocity is None: return tracking.keypoints return tracking.keypoints + tracking.velocity * delta_t_s @beartype def calculate_tracking_detection_affinity( tracking: Tracking, detection: Detection, w_2d: float, alpha_2d: float, w_3d: float, alpha_3d: float, lambda_a: float, ) -> float: """ Calculate the affinity between a tracking and a detection. Args: tracking: The tracking object detection: The detection object w_2d: Weight for 2D affinity alpha_2d: Normalization factor for 2D distance w_3d: Weight for 3D affinity alpha_3d: Normalization factor for 3D distance lambda_a: Decay rate for time difference Returns: Combined affinity score """ camera = detection.camera delta_t = detection.timestamp - tracking.last_active_timestamp # Calculate 2D affinity tracking_2d_projection = camera.project(tracking.keypoints) w, h = camera.params.image_size distance_2d = calculate_distance_2d( tracking_2d_projection, detection.keypoints, image_size=(int(w), int(h)), ) affinity_2d = calculate_affinity_2d( distance_2d, delta_t, w_2d=w_2d, alpha_2d=alpha_2d, lambda_a=lambda_a, ) # Calculate 3D affinity distances = perpendicular_distance_camera_2d_points_to_tracking_raycasting( detection, tracking, delta_t ) affinity_3d = calculate_affinity_3d( distances, delta_t, w_3d=w_3d, alpha_3d=alpha_3d, lambda_a=lambda_a, ) # Combine affinities total_affinity = affinity_2d + affinity_3d return jnp.sum(total_affinity).item() # %% @beartype def calculate_affinity_matrix( trackings: Sequence[Tracking], detections: Sequence[Detection], w_2d: float, alpha_2d: float, w_3d: float, alpha_3d: float, lambda_a: float, ) -> tuple[Float[Array, "T D"], OrderedDict[CameraID, list[Detection]]]: """ Calculate the affinity matrix between a set of trackings and detections. Args: trackings: Sequence of tracking objects detections: Sequence of detection objects w_2d: Weight for 2D affinity alpha_2d: Normalization factor for 2D distance w_3d: Weight for 3D affinity alpha_3d: Normalization factor for 3D distance lambda_a: Decay rate for time difference Returns: - affinity matrix of shape (T, D) where T is number of trackings and D is number of detections - dictionary mapping camera IDs to lists of detections from that camera, since it's a `OrderDict` you could flat it out to get the indices of detections in the affinity matrix Matrix Layout: The affinity matrix has shape (T, D), where: - T = number of trackings (rows) - D = total number of detections across all cameras (columns) The matrix is organized as follows: ``` | Camera 1 | Camera 2 | Camera c | | d1 d2 ... | d1 d2 ... | d1 d2 ... | ---------+-------------+-------------+-------------+ Track 1 | a11 a12 ... | a11 a12 ... | a11 a12 ... | Track 2 | a21 a22 ... | a21 a22 ... | a21 a22 ... | ... | ... | ... | ... | Track t | at1 at2 ... | at1 at2 ... | at1 at2 ... | ``` Where: - Rows are ordered by tracking ID - Columns are ordered by camera, then by detection within each camera - Each cell aij represents the affinity between tracking i and detection j The detection ordering in columns follows the exact same order as iterating through the detection_by_camera dictionary, which is returned alongside the matrix to maintain this relationship. """ affinity = jnp.zeros((len(trackings), len(detections))) detection_by_camera = classify_by_camera(detections) for i, tracking in enumerate(trackings): j = 0 for _, camera_detections in detection_by_camera.items(): for det in camera_detections: affinity_value = calculate_tracking_detection_affinity( tracking, det, w_2d=w_2d, alpha_2d=alpha_2d, w_3d=w_3d, alpha_3d=alpha_3d, lambda_a=lambda_a, ) affinity = affinity.at[i, j].set(affinity_value) j += 1 return affinity, detection_by_camera @beartype def calculate_camera_affinity_matrix( trackings: Sequence[Tracking], camera_detections: Sequence[Detection], w_2d: float, alpha_2d: float, w_3d: float, alpha_3d: float, lambda_a: float, ) -> Float[Array, "T D"]: """ Calculate an affinity matrix between trackings and detections from a single camera. This follows the iterative camera-by-camera approach from the paper "Cross-View Tracking for Multi-Human 3D Pose Estimation at over 100 FPS". Instead of creating one large matrix for all cameras, this creates a separate matrix for each camera, which can be processed independently. Args: trackings: Sequence of tracking objects camera_detections: Sequence of detection objects, from the same camera w_2d: Weight for 2D affinity alpha_2d: Normalization factor for 2D distance w_3d: Weight for 3D affinity alpha_3d: Normalization factor for 3D distance lambda_a: Decay rate for time difference Returns: Affinity matrix of shape (T, D) where: - T = number of trackings (rows) - D = number of detections from this specific camera (columns) Matrix Layout: The affinity matrix for a single camera has shape (T, D), where: - T = number of trackings (rows) - D = number of detections from this camera (columns) The matrix is organized as follows: ``` | Detections from Camera c | | d1 d2 d3 ... | ---------+------------------------+ Track 1 | a11 a12 a13 ... | Track 2 | a21 a22 a23 ... | ... | ... ... ... ... | Track t | at1 at2 at3 ... | ``` Each cell aij represents the affinity between tracking i and detection j, computed using both 2D and 3D geometric correspondences. """ def verify_all_detection_from_same_camera(detections: Sequence[Detection]): if not detections: return True camera_id = next(iter(detections)).camera.id return all(map(lambda d: d.camera.id == camera_id, detections)) if not verify_all_detection_from_same_camera(camera_detections): raise ValueError("All detections must be from the same camera") affinity = jnp.zeros((len(trackings), len(camera_detections))) for i, tracking in enumerate(trackings): for j, det in enumerate(camera_detections): affinity_value = calculate_tracking_detection_affinity( tracking, det, w_2d=w_2d, alpha_2d=alpha_2d, w_3d=w_3d, alpha_3d=alpha_3d, lambda_a=lambda_a, ) affinity = affinity.at[i, j].set(affinity_value) return affinity @beartype def process_detections_iteratively( trackings: Sequence[Tracking], detections: Sequence[Detection], w_2d: float = 1.0, alpha_2d: float = 1.0, w_3d: float = 1.0, alpha_3d: float = 1.0, lambda_a: float = 0.1, ) -> list[tuple[int, Detection]]: """ Process detections iteratively camera by camera, matching them to trackings. This implements the paper's approach where each camera is processed independently, and the affinity matrix is calculated for one camera at a time. This approach has several advantages: 1. Computational cost scales linearly with number of cameras 2. Can handle non-synchronized camera frames 3. More efficient for large-scale camera systems Args: trackings: Sequence of tracking objects detections: Sequence of detection objects w_2d: Weight for 2D affinity alpha_2d: Normalization factor for 2D distance w_3d: Weight for 3D affinity alpha_3d: Normalization factor for 3D distance lambda_a: Decay rate for time difference Returns: List of (tracking_index, detection) pairs representing matches """ # Group detections by camera detection_by_camera = classify_by_camera(detections) # Store matches between trackings and detections matches = [] # Process each camera one by one for camera_id, camera_detections in detection_by_camera.items(): # Calculate affinity matrix for this camera only camera_affinity = calculate_camera_affinity_matrix( trackings, camera_detections, w_2d=w_2d, alpha_2d=alpha_2d, w_3d=w_3d, alpha_3d=alpha_3d, lambda_a=lambda_a, ) # Apply Hungarian algorithm for this camera only tracking_indices, detection_indices = linear_sum_assignment( camera_affinity, maximize=True ) tracking_indices = cast(Sequence[int], tracking_indices) detection_indices = cast(Sequence[int], detection_indices) # Add matches to result for t_idx, d_idx in zip(tracking_indices, detection_indices): # Skip matches with zero or negative affinity if camera_affinity[t_idx, d_idx] <= 0: continue matches.append((t_idx, camera_detections[d_idx])) return matches # %% # let's do cross-view association W_2D = 1.0 ALPHA_2D = 1.0 LAMBDA_A = 0.1 W_3D = 1.0 ALPHA_3D = 1.0 trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id) unmatched_detections = shallow_copy(next_group) affinity, detection_by_camera = calculate_affinity_matrix( trackings, unmatched_detections, w_2d=W_2D, alpha_2d=ALPHA_2D, w_3d=W_3D, alpha_3d=ALPHA_3D, lambda_a=LAMBDA_A, ) display(affinity) # %% T = TypeVar("T") def flatten_values( d: Mapping[Any, Sequence[T]], ) -> list[T]: """ Flatten a dictionary of sequences into a single list of values. """ return [v for vs in d.values() for v in vs] detections_sorted = flatten_values(detection_by_camera) display(detections_sorted) display(detection_by_camera) # %% # Perform Hungarian algorithm for assignment for each camera indices_T, indices_D = linear_sum_assignment(affinity, maximize=True) indices_T = cast(Sequence[int], indices_T) indices_D = cast(Sequence[int], indices_D) display(indices_T) display(indices_D) # %%