from math import isnan from pathlib import Path from re import L import awkward as ak from typing import ( Any, Generator, Optional, Sequence, TypeAlias, TypedDict, cast, TypeVar, ) from datetime import datetime, timedelta from jaxtyping import Array, Float, Num, jaxtyped import numpy as np from cv2 import undistortPoints from sympy import true from app import camera from app.camera import Camera, CameraParams, Detection import jax.numpy as jnp from beartype import beartype from scipy.spatial.transform import Rotation as R from app.camera import ( Camera, CameraID, CameraParams, Detection, calculate_affinity_matrix_by_epipolar_constraint, classify_by_camera, ) import jax from beartype.typing import Mapping, Sequence import orjson from app.visualize.whole_body import visualize_whole_body from matplotlib import pyplot as plt from app.solver._old import GLPKSolver from app.tracking import ( TrackingID, AffinityResult, LastDifferenceVelocityFilter, Tracking, TrackingState, ) from copy import copy as shallow_copy from pyrsistent import pvector, v, m, pmap, PMap, freeze, thaw from optax.assignment import hungarian_algorithm as linear_sum_assignment from itertools import chain NDArray: TypeAlias = np.ndarray DetectionGenerator: TypeAlias = Generator[Detection, None, None] DELTA_T_MIN = timedelta(milliseconds=1) """所有类型""" T = TypeVar("T") def unwrap(val: Optional[T]) -> T: if val is None: raise ValueError("None") return val class KeypointDataset(TypedDict): frame_index: int boxes: Num[NDArray, "N 4"] kps: Num[NDArray, "N J 2"] kps_scores: Num[NDArray, "N J"] class Resolution(TypedDict): width: int height: int class Intrinsic(TypedDict): camera_matrix: Num[Array, "3 3"] """ K """ distortion_coefficients: Num[Array, "N"] """ distortion coefficients; usually 5 """ class Extrinsic(TypedDict): rvec: Num[NDArray, "3"] tvec: Num[NDArray, "3"] class ExternalCameraParams(TypedDict): name: str port: int intrinsic: Intrinsic extrinsic: Extrinsic resolution: Resolution """获得所有机位的相机内外参""" def get_camera_params(camera_path: Path) -> ak.Array: camera_dataset: ak.Array = ak.from_parquet(camera_path / "camera_params.parquet") return camera_dataset """获取所有机位的2d检测数据""" def get_camera_detect( detect_path: Path, camera_port: list[int], camera_dataset: ak.Array ) -> dict[int, ak.Array]: keypoint_data: dict[int, ak.Array] = {} for element_port in ak.to_numpy(camera_dataset["port"]): if element_port in camera_port: keypoint_data[int(element_port)] = ak.from_parquet( detect_path / f"{element_port}.parquet" ) return keypoint_data """获得指定帧的2d检测数据(一段完整的跳跃片段)""" def get_segment( camera_port: list[int], frame_index: list[int], keypoint_data: dict[int, ak.Array] ) -> dict[int, ak.Array]: for port in camera_port: segement_data = [] camera_data = keypoint_data[port] for index, element_frame in enumerate(ak.to_numpy(camera_data["frame_index"])): if element_frame in frame_index: segement_data.append(camera_data[index]) keypoint_data[port] = ak.Array(segement_data) return keypoint_data """将所有2d检测数据打包""" @jaxtyped(typechecker=beartype) def undistort_points( points: Num[NDArray, "M 2"], camera_matrix: Num[NDArray, "3 3"], dist_coeffs: Num[NDArray, "N"], ) -> Num[NDArray, "M 2"]: K = camera_matrix dist = dist_coeffs res = undistortPoints(points, K, dist, P=K) # type: ignore return res.reshape(-1, 2) @jaxtyped(typechecker=beartype) def to_transformation_matrix( rvec: Num[NDArray, "3"], tvec: Num[NDArray, "3"] ) -> Num[NDArray, "4 4"]: res = np.eye(4) res[:3, :3] = R.from_rotvec(rvec).as_matrix() res[:3, 3] = tvec return res def from_camera_params(camera: ExternalCameraParams) -> Camera: rt = jnp.array( to_transformation_matrix( ak.to_numpy(camera["extrinsic"]["rvec"]), ak.to_numpy(camera["extrinsic"]["tvec"]), ) ) K = jnp.array(camera["intrinsic"]["camera_matrix"]).reshape(3, 3) dist_coeffs = jnp.array(camera["intrinsic"]["distortion_coefficients"]) image_size = jnp.array( (camera["resolution"]["width"], camera["resolution"]["height"]) ) return Camera( id=camera["name"], params=CameraParams( K=K, Rt=rt, dist_coeffs=dist_coeffs, image_size=image_size, ), ) def preprocess_keypoint_dataset( dataset: Sequence[KeypointDataset], camera: Camera, fps: float, start_timestamp: datetime, ) -> Generator[Detection, None, None]: frame_interval_s = 1 / fps for el in dataset: frame_index = el["frame_index"] timestamp = start_timestamp + timedelta(seconds=frame_index * frame_interval_s) for kp, kp_score, boxes in zip(el["kps"], el["kps_scores"], el["boxes"]): kp = undistort_points( np.asarray(kp), np.asarray(camera.params.K), np.asarray(camera.params.dist_coeffs), ) yield Detection( keypoints=jnp.array(kp), confidences=jnp.array(kp_score), camera=camera, timestamp=timestamp, ) def sync_batch_gen( gens: list[DetectionGenerator], diff: timedelta ) -> Generator[list[Detection], Any, None]: from more_itertools import partition """ given a list of detection generators, return a generator that yields a batch of detections Args: gens: list of detection generators diff: maximum timestamp difference between detections to consider them part of the same batch """ N = len(gens) last_batch_timestamp: Optional[datetime] = None current_batch: list[Detection] = [] paused: list[bool] = [False] * N finished: list[bool] = [False] * N unmached_detections: list[Detection] = [] def reset_paused(): """ reset paused list based on finished list """ for i in range(N): if not finished[i]: paused[i] = False else: paused[i] = True EPS = 1e-6 # a small epsilon to avoid floating point precision issues diff_esp = diff - timedelta(seconds=EPS) while True: for i, gen in enumerate(gens): try: if finished[i] or paused[i]: if all(finished): if len(current_batch) > 0: # All generators exhausted, flush remaining batch and exit yield current_batch return else: continue val = next(gen) if last_batch_timestamp is None: last_batch_timestamp = val.timestamp current_batch.append(val) else: if abs(val.timestamp - last_batch_timestamp) >= diff_esp: unmached_detections.append(val) paused[i] = True if all(paused): yield current_batch reset_paused() last_batch_timestamp = last_batch_timestamp + diff bad, good = partition( lambda x: x.timestamp < unwrap(last_batch_timestamp), unmached_detections, ) current_batch = list(good) unmached_detections = list(bad) else: current_batch.append(val) except StopIteration: return def get_batch_detect( keypoint_dataset, camera_dataset, camera_port: list[int], FPS: int = 24, batch_fps: int = 10, ) -> Generator[list[Detection], Any, None]: gen_data = [ preprocess_keypoint_dataset( keypoint_dataset[port], from_camera_params(camera_dataset[camera_dataset["port"] == port][0]), FPS, datetime(2024, 4, 2, 12, 0, 0), ) for port in camera_port ] sync_gen: Generator[list[Detection], Any, None] = sync_batch_gen( gen_data, timedelta(seconds=1 / batch_fps), ) return sync_gen @jaxtyped(typechecker=beartype) def triangulate_one_point_from_multiple_views_linear( proj_matrices: Float[Array, "N 3 4"], points: Num[Array, "N 2"], confidences: Optional[Float[Array, "N"]] = None, conf_threshold: float = 0.4, # 0.2 ) -> Float[Array, "3"]: """ Args: proj_matrices: 形状为(N, 3, 4)的投影矩阵序列 points: 形状为(N, 2)的点坐标序列 confidences: 形状为(N,)的置信度序列,范围[0.0, 1.0] conf_threshold: 置信度阈值,低于该值的观测不参与DLT Returns: point_3d: 形状为(3,)的三角测量得到的3D点 """ assert len(proj_matrices) == len(points) N = len(proj_matrices) # 置信度加权DLT if confidences is None: weights = jnp.ones(N, dtype=jnp.float32) else: valid_mask = confidences >= conf_threshold weights = jnp.where(valid_mask, confidences, 0.0) sum_weights = jnp.sum(weights) weights = jnp.where(sum_weights > 0, weights / sum_weights, weights) A = jnp.zeros((N * 2, 4), dtype=jnp.float32) for i in range(N): x, y = points[i] row1 = proj_matrices[i, 2] * x - proj_matrices[i, 0] row2 = proj_matrices[i, 2] * y - proj_matrices[i, 1] A = A.at[2 * i].set(row1 * weights[i]) A = A.at[2 * i + 1].set(row2 * weights[i]) _, _, vh = jnp.linalg.svd(A, full_matrices=False) point_3d_homo = vh[-1] point_3d_homo = jnp.where(point_3d_homo[3] < 0, -point_3d_homo, point_3d_homo) is_zero_weight = jnp.sum(weights) == 0 point_3d = jnp.where( is_zero_weight, jnp.full((3,), jnp.nan, dtype=jnp.float32), jnp.where( jnp.abs(point_3d_homo[3]) > 1e-8, point_3d_homo[:3] / point_3d_homo[3], jnp.full((3,), jnp.nan, dtype=jnp.float32), ), ) return point_3d jaxtyped(typechecker=beartype) def triangulate_points_from_multiple_views_linear( proj_matrices: Float[Array, "N 3 4"], points: Num[Array, "N P 2"], confidences: Optional[Float[Array, "N P"]] = None, ) -> Float[Array, "P 3"]: """ Batch‐triangulate P points observed by N cameras, linearly via SVD. Args: proj_matrices: (N, 3, 4) projection matrices points: (N, P, 2) image-coordinates per view confidences: (N, P, 1) optional per-view confidences in [0,1] Returns: (P, 3) 3D point for each of the P tracks """ N, P, _ = points.shape assert proj_matrices.shape[0] == N if confidences is None: conf = jnp.ones((N, P), dtype=jnp.float32) else: conf = jnp.array(confidences) vmap_triangulate = jax.vmap( triangulate_one_point_from_multiple_views_linear, in_axes=(None, 1, 1), out_axes=0, ) return vmap_triangulate(proj_matrices, points, conf) def clusters_to_detections( clusters: list[list[int]], sorted_detections: list[Detection] ) -> list[list[Detection]]: """ given a list of clusters (which is the indices of the detections in the sorted_detections list), extract the detections from the sorted_detections list Args: clusters: list of clusters, each cluster is a list of indices of the detections in the `sorted_detections` list sorted_detections: list of SORTED detections Returns: list of clusters, each cluster is a list of detections """ return [[sorted_detections[i] for i in cluster] for cluster in clusters] # def triangle_from_cluster(cluster: list[Detection]) -> Float[Array, "3"]: # proj_matrices = jnp.array([el.camera.params.projection_matrix for el in cluster]) # points = jnp.array([el.keypoints for el in cluster]) # confidences = jnp.array([el.confidences for el in cluster]) # return triangulate_points_from_multiple_views_linear( # proj_matrices, points, confidences=confidences # ) def group_by_cluster_by_camera( cluster: Sequence[Detection], ) -> PMap[CameraID, Detection]: """ group the detections by camera, and preserve the latest detection for each camera """ r: dict[CameraID, Detection] = {} for el in cluster: if el.camera.id in r: eld = r[el.camera.id] preserved = max([eld, el], key=lambda x: x.timestamp) r[el.camera.id] = preserved return pmap(r) @jaxtyped(typechecker=beartype) def triangle_from_cluster( cluster: Sequence[Detection], ) -> tuple[Float[Array, "N 3"], datetime]: proj_matrices = jnp.array([el.camera.params.projection_matrix for el in cluster]) points = jnp.array([el.keypoints_undistorted for el in cluster]) confidences = jnp.array([el.confidences for el in cluster]) latest_timestamp = max(el.timestamp for el in cluster) return ( triangulate_points_from_multiple_views_linear( proj_matrices, points, confidences=confidences ), latest_timestamp, ) class GlobalTrackingState: _last_id: int _trackings: dict[int, Tracking] def __init__(self): self._last_id = 0 self._trackings = {} def __repr__(self) -> str: return ( f"GlobalTrackingState(last_id={self._last_id}, trackings={self._trackings})" ) @property def trackings(self) -> dict[int, Tracking]: return shallow_copy(self._trackings) def add_tracking(self, cluster: Sequence[Detection]) -> Tracking: if len(cluster) < 2: raise ValueError( "cluster must contain at least 2 detections to form a tracking" ) kps_3d, latest_timestamp = triangle_from_cluster(cluster) next_id = self._last_id + 1 tracking_state = TrackingState( keypoints=kps_3d, last_active_timestamp=latest_timestamp, historical_detections_by_camera=group_by_cluster_by_camera(cluster), ) tracking = Tracking( id=next_id, state=tracking_state, velocity_filter=LastDifferenceVelocityFilter(kps_3d, latest_timestamp), ) self._trackings[next_id] = tracking self._last_id = next_id return tracking @jaxtyped(typechecker=beartype) def triangulate_one_point_from_multiple_views_linear_time_weighted( proj_matrices: Float[Array, "N 3 4"], points: Num[Array, "N 2"], delta_t: Num[Array, "N"], lambda_t: float = 10.0, confidences: Optional[Float[Array, "N"]] = None, ) -> Float[Array, "3"]: """ Triangulate one point from multiple views with time-weighted linear least squares. Implements the incremental reconstruction method from "Cross-View Tracking for Multi-Human 3D Pose" with weighting formula: w_i = exp(-λ_t(t-t_i)) / ||c^i^T||_2 Args: proj_matrices: Shape (N, 3, 4) projection matrices sequence points: Shape (N, 2) point coordinates sequence delta_t: Time differences between current time and each observation (in seconds) lambda_t: Time penalty rate (higher values decrease influence of older observations) confidences: Shape (N,) confidence values in range [0.0, 1.0] Returns: point_3d: Shape (3,) triangulated 3D point """ assert len(proj_matrices) == len(points) assert len(delta_t) == len(points) N = len(proj_matrices) # Prepare confidence weights confi: Float[Array, "N"] if confidences is None: confi = jnp.ones(N, dtype=np.float32) else: confi = jnp.sqrt(jnp.clip(confidences, 0, 1)) time_weights = jnp.exp(-lambda_t * delta_t) weights = time_weights * confi sum_weights = jnp.sum(weights) weights = jnp.where(sum_weights > 0, weights / sum_weights, weights) A = jnp.zeros((N * 2, 4), dtype=np.float32) for i in range(N): x, y = points[i] row1 = proj_matrices[i, 2] * x - proj_matrices[i, 0] row2 = proj_matrices[i, 2] * y - proj_matrices[i, 1] A = A.at[2 * i].set(row1 * weights[i]) A = A.at[2 * i + 1].set(row2 * weights[i]) _, _, vh = jnp.linalg.svd(A, full_matrices=False) point_3d_homo = vh[-1] point_3d_homo = jnp.where(point_3d_homo[3] < 0, -point_3d_homo, point_3d_homo) is_zero_weight = jnp.sum(weights) == 0 point_3d = jnp.where( is_zero_weight, jnp.full((3,), jnp.nan, dtype=jnp.float32), jnp.where( jnp.abs(point_3d_homo[3]) > 1e-8, point_3d_homo[:3] / point_3d_homo[3], jnp.full((3,), jnp.nan, dtype=jnp.float32), ), ) return point_3d @jaxtyped(typechecker=beartype) def triangulate_points_from_multiple_views_linear_time_weighted( proj_matrices: Float[Array, "N 3 4"], points: Num[Array, "N P 2"], delta_t: Num[Array, "N"], lambda_t: float = 10.0, confidences: Optional[Float[Array, "N P"]] = None, ) -> Float[Array, "P 3"]: """ Vectorized version that triangulates P points from N camera views with time-weighting. This function uses JAX's vmap to efficiently triangulate multiple points in parallel. Args: proj_matrices: Shape (N, 3, 4) projection matrices for N cameras points: Shape (N, P, 2) 2D points for P keypoints across N cameras delta_t: Shape (N,) time differences between current time and each camera's timestamp (seconds) lambda_t: Time penalty rate (higher values decrease influence of older observations) confidences: Shape (N, P) confidence values for each point in each camera Returns: points_3d: Shape (P, 3) triangulated 3D points """ N, P, _ = points.shape assert ( proj_matrices.shape[0] == N ), "Number of projection matrices must match number of cameras" assert delta_t.shape[0] == N, "Number of time deltas must match number of cameras" if confidences is None: # Create uniform confidences if none provided conf = jnp.ones((N, P), dtype=jnp.float32) else: conf = confidences # Define the vmapped version of the single-point function # We map over the second dimension (P points) of the input arrays vmap_triangulate = jax.vmap( triangulate_one_point_from_multiple_views_linear_time_weighted, in_axes=( None, 1, None, None, 1, ), # proj_matrices and delta_t static, map over points out_axes=0, # Output has first dimension corresponding to points ) # For each point p, extract the 2D coordinates from all cameras and triangulate return vmap_triangulate( proj_matrices, # (N, 3, 4) - static across points points, # (N, P, 2) - map over dim 1 (P) delta_t, # (N,) - static across points lambda_t, # scalar - static conf, # (N, P) - map over dim 1 (P) ) DetectionMap: TypeAlias = PMap[CameraID, Detection] def update_tracking( tracking: Tracking, detections: Sequence[Detection], max_delta_t: timedelta = timedelta(milliseconds=100), lambda_t: float = 10.0, ) -> None: """ update the tracking with a new set of detections Args: tracking: the tracking to update detections: the detections to update the tracking with max_delta_t: the maximum time difference between the last active timestamp and the latest detection lambda_t: the lambda value for the time difference Note: the function would mutate the tracking object """ last_active_timestamp = tracking.state.last_active_timestamp latest_timestamp = max(d.timestamp for d in detections) d = tracking.state.historical_detections_by_camera for detection in detections: d = cast(DetectionMap, d.update({detection.camera.id: detection})) for camera_id, detection in d.items(): if detection.timestamp - latest_timestamp > max_delta_t: d = d.remove(camera_id) new_detections = d new_detections_list = list(new_detections.values()) project_matrices = jnp.stack( [detection.camera.params.projection_matrix for detection in new_detections_list] ) delta_t = jnp.array( [ detection.timestamp.timestamp() - last_active_timestamp.timestamp() for detection in new_detections_list ] ) kps = jnp.stack([detection.keypoints for detection in new_detections_list]) conf = jnp.stack([detection.confidences for detection in new_detections_list]) kps_3d = triangulate_points_from_multiple_views_linear_time_weighted( project_matrices, kps, delta_t, lambda_t, conf ) new_state = TrackingState( keypoints=kps_3d, last_active_timestamp=latest_timestamp, historical_detections_by_camera=new_detections, ) tracking.update(kps_3d, latest_timestamp) tracking.state = new_state def filter_camera_port(detections: list[Detection]): camera_port = set() for detection in detections: camera_port.add(detection.camera.id) return list(camera_port) @beartype def calculate_camera_affinity_matrix_jax( trackings: Sequence[Tracking], camera_detections: Sequence[Detection], w_2d: float, alpha_2d: float, w_3d: float, alpha_3d: float, lambda_a: float, ) -> Float[Array, "T D"]: """ Vectorized implementation to compute an affinity matrix between *trackings* and *detections* coming from **one** camera. Compared with the simple double-for-loop version, this leverages `jax`'s broadcasting + `vmap` facilities and avoids Python loops over every (tracking, detection) pair. The mathematical definition of the affinity is **unchanged**, so the result remains bit-identical to the reference implementation used in the tests. """ # ------------------------------------------------------------------ # Quick validations / early-exit guards # ------------------------------------------------------------------ if len(trackings) == 0 or len(camera_detections) == 0: # Return an empty affinity matrix with appropriate shape. return jnp.zeros((len(trackings), len(camera_detections))) # type: ignore[return-value] cam = next(iter(camera_detections)).camera # Ensure every detection truly belongs to the same camera (guard clause) cam_id = cam.id if any(det.camera.id != cam_id for det in camera_detections): raise ValueError( "All detections passed to `calculate_camera_affinity_matrix` must come from one camera." ) # We will rely on a single `Camera` instance (all detections share it) w_img_, h_img_ = cam.params.image_size w_img, h_img = float(w_img_), float(h_img_) # ------------------------------------------------------------------ # Gather data into ndarray / DeviceArray batches so that we can compute # everything in a single (or a few) fused kernels. # ------------------------------------------------------------------ # === Tracking-side tensors === kps3d_trk: Float[Array, "T J 3"] = jnp.stack( [trk.state.keypoints for trk in trackings] ) # (T, J, 3) J = kps3d_trk.shape[1] # === Detection-side tensors === kps2d_det: Float[Array, "D J 2"] = jnp.stack( [det.keypoints for det in camera_detections] ) # (D, J, 2) # ------------------------------------------------------------------ # Compute Δt matrix – shape (T, D) # ------------------------------------------------------------------ # Epoch timestamps are ~1.7 × 10⁹; storing them in float32 wipes out # sub‑second detail (resolution ≈ 200 ms). Keep them in float64 until # after subtraction so we preserve Δt‑on‑the‑order‑of‑milliseconds. # --- timestamps ---------- t0 = min( chain( (trk.state.last_active_timestamp for trk in trackings), (det.timestamp for det in camera_detections), ) ).timestamp() # common origin (float) ts_trk = jnp.array( [trk.state.last_active_timestamp.timestamp() - t0 for trk in trackings], dtype=jnp.float32, # now small, ms-scale fits in fp32 ) ts_det = jnp.array( [det.timestamp.timestamp() - t0 for det in camera_detections], dtype=jnp.float32, ) # Δt in seconds, fp32 throughout delta_t = ts_det[None, :] - ts_trk[:, None] # (T,D) min_dt_s = float(DELTA_T_MIN.total_seconds()) delta_t = jnp.clip(delta_t, a_min=min_dt_s, a_max=None) # ------------------------------------------------------------------ # ---------- 2D affinity ------------------------------------------- # ------------------------------------------------------------------ # Project each tracking's 3D keypoints onto the image once. # `Camera.project` works per-sample, so we vmap over the first axis. proj_fn = jax.vmap(cam.project, in_axes=0) # maps over the keypoint sets kps2d_trk_proj: Float[Array, "T J 2"] = proj_fn(kps3d_trk) # (T, J, 2) # Normalise keypoints by image size so absolute units do not bias distance norm_trk = kps2d_trk_proj / jnp.array([w_img, h_img]) norm_det = kps2d_det / jnp.array([w_img, h_img]) # L2 distance for every (T, D, J) # reshape for broadcasting: (T,1,J,2) vs (1,D,J,2) diff2d = norm_trk[:, None, :, :] - norm_det[None, :, :, :] dist2d: Float[Array, "T D J"] = jnp.linalg.norm(diff2d, axis=-1) # Compute per-keypoint 2D affinity delta_t_broadcast = delta_t[:, :, None] # (T, D, 1) affinity_2d = ( w_2d * (1 - dist2d / (alpha_2d * delta_t_broadcast)) * jnp.exp(-lambda_a * delta_t_broadcast) ) # ------------------------------------------------------------------ # ---------- 3D affinity ------------------------------------------- # ------------------------------------------------------------------ # For each detection pre-compute back-projected 3D points lying on z=0 plane. backproj_points_list = [ det.camera.unproject_points_to_z_plane(det.keypoints, z=0.0) for det in camera_detections ] # each (J,3) backproj: Float[Array, "D J 3"] = jnp.stack(backproj_points_list) # (D, J, 3) zero_velocity = jnp.zeros((J, 3)) trk_velocities = jnp.stack( [ trk.velocity if trk.velocity is not None else zero_velocity for trk in trackings ] ) predicted_pose: Float[Array, "T D J 3"] = ( kps3d_trk[:, None, :, :] # (T,1,J,3) + trk_velocities[:, None, :, :] * delta_t[:, :, None, None] # (T,D,1,1) ) # Camera center – shape (3,) -> will broadcast cam_center = cam.params.location # Compute perpendicular distance using vectorized formula # p1 = cam_center (3,) # p2 = backproj (D, J, 3) # P = predicted_pose (T, D, J, 3) # Broadcast plan: v1 = P - p1 → (T, D, J, 3) # v2 = p2[None, ...]-p1 → (1, D, J, 3) # Shapes now line up; no stray singleton axis. p1 = cam_center p2 = backproj P = predicted_pose v1 = P - p1 v2 = p2[None, :, :, :] - p1 # (1, D, J, 3) cross = jnp.cross(v1, v2) # (T, D, J, 3) num = jnp.linalg.norm(cross, axis=-1) # (T, D, J) den = jnp.linalg.norm(v2, axis=-1) # (1, D, J) dist3d: Float[Array, "T D J"] = num / den affinity_3d = ( w_3d * (1 - dist3d / alpha_3d) * jnp.exp(-lambda_a * delta_t_broadcast) ) # ------------------------------------------------------------------ # Combine and reduce across keypoints → (T, D) # ------------------------------------------------------------------ total_affinity: Float[Array, "T D"] = jnp.sum(affinity_2d + affinity_3d, axis=-1) return total_affinity # type: ignore[return-value] @beartype def calculate_affinity_matrix( trackings: Sequence[Tracking], detections: Sequence[Detection] | Mapping[CameraID, list[Detection]], w_2d: float, alpha_2d: float, w_3d: float, alpha_3d: float, lambda_a: float, ) -> dict[CameraID, AffinityResult]: """ Calculate the affinity matrix between a set of trackings and detections. Args: trackings: Sequence of tracking objects detections: Sequence of detection objects or a group detections by ID w_2d: Weight for 2D affinity alpha_2d: Normalization factor for 2D distance w_3d: Weight for 3D affinity alpha_3d: Normalization factor for 3D distance lambda_a: Decay rate for time difference Returns: A dictionary mapping camera IDs to affinity results. """ if isinstance(detections, Mapping): detection_by_camera = detections else: detection_by_camera = classify_by_camera(detections) res: dict[CameraID, AffinityResult] = {} for camera_id, camera_detections in detection_by_camera.items(): affinity_matrix = calculate_camera_affinity_matrix_jax( trackings, camera_detections, w_2d, alpha_2d, w_3d, alpha_3d, lambda_a, ) # row, col """ indices_T:表示匹配到检测的tracking的索引(在tracking列表中的下标) indices_D:表示匹配到tracking的detection的索引(在detections列表中的下标) """ indices_T, indices_D = linear_sum_assignment(affinity_matrix) affinity_result = AffinityResult( matrix=affinity_matrix, trackings=trackings, detections=camera_detections, indices_T=indices_T, indices_D=indices_D, ) res[camera_id] = affinity_result return res # 对每一个3d目标进行滑动窗口平滑处理 def smooth_3d_keypoints( all_3d_kps: dict[str, list], window_size: int = 5 ) -> dict[str, list]: # window_size = 5 kernel = np.ones(window_size) / window_size smoothed_points = dict() for item_object_index in all_3d_kps.keys(): item_object = np.array(all_3d_kps[item_object_index]) if item_object.shape[0] < window_size: # 如果数据点少于窗口大小,则直接返回原始数据 smoothed_points[item_object_index] = item_object.tolist() continue # 对每个关键点的每个坐标轴分别做滑动平均 item_smoothed = np.zeros_like(item_object) # 遍历133个关节 for kp_idx in range(item_object.shape[1]): # 遍历每个关节的空间三维坐标点 for axis in range(3): # 对第i帧的滑动平滑方式 smoothed[i] = (point[i-2] + point[i-1] + point[i] + point[i+1] + point[i+2]) / 5 item_smoothed[:, kp_idx, axis] = np.convolve( item_object[:, kp_idx, axis], kernel, mode="same" ) smoothed_points[item_object_index] = item_smoothed.tolist() return smoothed_points # 相机内外参路径 CAMERA_PATH = Path("/home/admin/Documents/2025_4_17/camera_params") # 所有机位的相机内外参 AK_CAMERA_DATASET: ak.Array = get_camera_params(CAMERA_PATH) # 2d检测数据路径 DATASET_PATH = Path("/home/admin/Documents/2025_4_17/detect_result/many_people_01/") # 指定机位的2d检测数据 camera_port = [5607, 5608, 5609] KEYPOINT_DATASET = get_camera_detect(DATASET_PATH, camera_port, AK_CAMERA_DATASET) # 获取片段 # FRAME_INDEX = [i for i in range(20, 140)] # KEYPOINT_DATASET = get_segment(camera_port, FRAME_INDEX, KEYPOINT_DATASET) # 将所有的2d检测数据打包 sync_gen: Generator[list[Detection], Any, None] = get_batch_detect( KEYPOINT_DATASET, AK_CAMERA_DATASET, camera_port, batch_fps=24, ) # 图像宽高 WIDTH = 2560 HEIGHT = 1440 # 跟踪超参数 W_2D = 0.7 ALPHA_2D = 60.0 LAMBDA_A = 5.0 W_3D = 0.3 ALPHA_3D = 0.1 # 3d数据,键为追踪目标id,值为该目标的所有3d数据 all_3d_kps: dict[str, list] = {} # 跟踪目标集合 trackings: list[Tracking] = [] # 建立追踪目标 global_tracking_state = GlobalTrackingState() count = 0 while True: count += 1 try: # 获取下一个时间戳的所有相机检测结果 detections = next(sync_gen) print(detections) except StopIteration: # 检测数据读取完毕,退出主循环 print("No more detections.") break # 计算相似度矩阵 sorted_detections, affinity_matrix = ( calculate_affinity_matrix_by_epipolar_constraint(detections, alpha_2d=3500) ) # 计算集群 solver = GLPKSolver() aff_np = np.asarray(affinity_matrix).astype(np.float64) clusters, sol_matrix = solver.solve(aff_np) print(f"Clusters: {clusters}") # 划分集群 clusters_detections = clusters_to_detections(clusters, sorted_detections) # 获取当前的追踪目标 trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id) # =====初始化跟踪===== # 若当前帧没有跟踪目标,则初始化跟踪目标(假设第一帧就可以对两个目标完成初始化) if len(trackings) == 0: # 遍历每一个计算初始化跟踪目标 for cluster in clusters_detections: if len(cluster) < 2: continue camera_port_this = filter_camera_port(cluster) if len(camera_port_this) < len(camera_port): continue global_tracking_state.add_tracking(cluster) # 保留第一帧的3d姿态数据,按id存储到all_3d_kps字典 for element_tracking in global_tracking_state.trackings.values(): if str(element_tracking.id) not in all_3d_kps.keys(): all_3d_kps[str(element_tracking.id)] = [ element_tracking.state.keypoints.tolist() ] # 跳过本帧后续处理,进入下一帧 continue # =====丢失目标处理===== # =====更新跟踪状态===== # 遍历集群,计算每个集群2d检测数据与跟踪目标的相似度矩阵 for cluster in clusters_detections: if len(cluster) == 0: continue # 计算所有跟踪目标雨检测目标的相似度矩阵 affinities: dict[str, AffinityResult] = calculate_affinity_matrix( trackings, cluster, w_2d=W_2D, alpha_2d=ALPHA_2D, w_3d=W_3D, alpha_3d=ALPHA_3D, lambda_a=LAMBDA_A, ) unmatch_detection = [] # 遍历跟踪目标,更新跟踪目标 for element_tracking in trackings: # 存储每个跟踪目标匹配的2d检测数据 tracking_detections = [] for camera_name in affinities.keys(): indices_T = affinities[camera_name].indices_T.item() indices_D = affinities[camera_name].indices_D.item() match_tracking = affinities[camera_name].trackings[indices_T] if match_tracking == element_tracking: tracking_detections.append( affinities[camera_name].detections[indices_D] ) # 判断2d检测数据数量是否可以更新跟踪目标 if len(tracking_detections) > 2: # 跟新跟踪目标 update_tracking(element_tracking, tracking_detections) # 记录更新后的3d姿态数据 all_3d_kps[str(element_tracking.id)].append( element_tracking.state.keypoints.tolist() ) if count == 4: break # break # 对每一个3d目标进行滑动窗口平滑处理 smoothed_points = smooth_3d_keypoints(all_3d_kps, window_size=5) # 将结果保存到json文件中 with open("samples/Test_YEU.json", "wb") as f: f.write(orjson.dumps(smoothed_points)) """=====代买还没改完,目前存在的问题是时间戳不同步,同一组中存在多帧的数据====="""