# --- # jupyter: # jupytext: # text_representation: # extension: .py # format_name: percent # format_version: '1.3' # jupytext_version: 1.17.0 # kernelspec: # display_name: .venv # language: python # name: python3 # --- # %% from copy import copy as shallow_copy from copy import deepcopy from copy import deepcopy as deep_copy from dataclasses import dataclass from datetime import datetime, timedelta from pathlib import Path from typing import ( Any, Generator, Optional, Sequence, TypeAlias, TypedDict, cast, overload, ) import awkward as ak import jax import jax.numpy as jnp import numpy as np import orjson from beartype import beartype from cv2 import undistortPoints from IPython.display import display from jaxtyping import Array, Float, Num, jaxtyped from matplotlib import pyplot as plt from numpy.typing import ArrayLike from scipy.spatial.transform import Rotation as R from collections import OrderedDict from app.camera import ( Camera, CameraID, CameraParams, Detection, calculate_affinity_matrix_by_epipolar_constraint, classify_by_camera, ) from app.solver._old import GLPKSolver from app.visualize.whole_body import visualize_whole_body NDArray: TypeAlias = np.ndarray # %% DATASET_PATH = Path("samples") / "04_02" AK_CAMERA_DATASET: ak.Array = ak.from_parquet(DATASET_PATH / "camera_params.parquet") display(AK_CAMERA_DATASET) # %% class Resolution(TypedDict): width: int height: int class Intrinsic(TypedDict): camera_matrix: Num[Array, "3 3"] """ K """ distortion_coefficients: Num[Array, "N"] """ distortion coefficients; usually 5 """ class Extrinsic(TypedDict): rvec: Num[NDArray, "3"] tvec: Num[NDArray, "3"] class ExternalCameraParams(TypedDict): name: str port: int intrinsic: Intrinsic extrinsic: Extrinsic resolution: Resolution # %% def read_dataset_by_port(port: int) -> ak.Array: P = DATASET_PATH / f"{port}.parquet" return ak.from_parquet(P) KEYPOINT_DATASET = { int(p): read_dataset_by_port(p) for p in ak.to_numpy(AK_CAMERA_DATASET["port"]) } # %% class KeypointDataset(TypedDict): frame_index: int boxes: Num[NDArray, "N 4"] kps: Num[NDArray, "N J 2"] kps_scores: Num[NDArray, "N J"] @jaxtyped(typechecker=beartype) def to_transformation_matrix( rvec: Num[NDArray, "3"], tvec: Num[NDArray, "3"] ) -> Num[NDArray, "4 4"]: res = np.eye(4) res[:3, :3] = R.from_rotvec(rvec).as_matrix() res[:3, 3] = tvec return res @jaxtyped(typechecker=beartype) def undistort_points( points: Num[NDArray, "M 2"], camera_matrix: Num[NDArray, "3 3"], dist_coeffs: Num[NDArray, "N"], ) -> Num[NDArray, "M 2"]: K = camera_matrix dist = dist_coeffs res = undistortPoints(points, K, dist, P=K) # type: ignore return res.reshape(-1, 2) def from_camera_params(camera: ExternalCameraParams) -> Camera: rt = jnp.array( to_transformation_matrix( ak.to_numpy(camera["extrinsic"]["rvec"]), ak.to_numpy(camera["extrinsic"]["tvec"]), ) ) K = jnp.array(camera["intrinsic"]["camera_matrix"]).reshape(3, 3) dist_coeffs = jnp.array(camera["intrinsic"]["distortion_coefficients"]) image_size = jnp.array( (camera["resolution"]["width"], camera["resolution"]["height"]) ) return Camera( id=camera["name"], params=CameraParams( K=K, Rt=rt, dist_coeffs=dist_coeffs, image_size=image_size, ), ) def preprocess_keypoint_dataset( dataset: Sequence[KeypointDataset], camera: Camera, fps: float, start_timestamp: datetime, ) -> Generator[Detection, None, None]: frame_interval_s = 1 / fps for el in dataset: frame_index = el["frame_index"] timestamp = start_timestamp + timedelta(seconds=frame_index * frame_interval_s) for kp, kp_score in zip(el["kps"], el["kps_scores"]): yield Detection( keypoints=jnp.array(kp), confidences=jnp.array(kp_score), camera=camera, timestamp=timestamp, ) # %% DetectionGenerator: TypeAlias = Generator[Detection, None, None] def sync_batch_gen(gens: Sequence[DetectionGenerator], diff: timedelta): """ given a list of detection generators, return a generator that yields a batch of detections Args: gens: list of detection generators diff: maximum timestamp difference between detections to consider them part of the same batch """ N = len(gens) last_batch_timestamp: Optional[datetime] = None next_batch_timestamp: Optional[datetime] = None current_batch: list[Detection] = [] next_batch: list[Detection] = [] paused: list[bool] = [False] * N finished: list[bool] = [False] * N def reset_paused(): """ reset paused list based on finished list """ for i in range(N): if not finished[i]: paused[i] = False else: paused[i] = True EPS = 1e-6 # a small epsilon to avoid floating point precision issues diff_esp = diff - timedelta(seconds=EPS) while True: for i, gen in enumerate(gens): try: if finished[i] or paused[i]: continue val = next(gen) if last_batch_timestamp is None: last_batch_timestamp = val.timestamp current_batch.append(val) else: if abs(val.timestamp - last_batch_timestamp) >= diff_esp: next_batch.append(val) if next_batch_timestamp is None: next_batch_timestamp = val.timestamp paused[i] = True if all(paused): yield current_batch current_batch = next_batch next_batch = [] last_batch_timestamp = next_batch_timestamp next_batch_timestamp = None reset_paused() else: current_batch.append(val) except StopIteration: finished[i] = True paused[i] = True if all(finished): if len(current_batch) > 0: # All generators exhausted, flush remaining batch and exit yield current_batch break # %% @overload def to_projection_matrix( transformation_matrix: Num[NDArray, "4 4"], camera_matrix: Num[NDArray, "3 3"] ) -> Num[NDArray, "3 4"]: ... @overload def to_projection_matrix( transformation_matrix: Num[Array, "4 4"], camera_matrix: Num[Array, "3 3"] ) -> Num[Array, "3 4"]: ... @jaxtyped(typechecker=beartype) def to_projection_matrix( transformation_matrix: Num[Any, "4 4"], camera_matrix: Num[Any, "3 3"], ) -> Num[Any, "3 4"]: return camera_matrix @ transformation_matrix[:3, :] to_projection_matrix_jit = jax.jit(to_projection_matrix) @jaxtyped(typechecker=beartype) def dlt( H1: Num[NDArray, "3 4"], H2: Num[NDArray, "3 4"], p1: Num[NDArray, "2"], p2: Num[NDArray, "2"], ) -> Num[NDArray, "3"]: """ Direct Linear Transformation """ A = [ p1[1] * H1[2, :] - H1[1, :], H1[0, :] - p1[0] * H1[2, :], p2[1] * H2[2, :] - H2[1, :], H2[0, :] - p2[0] * H2[2, :], ] A = np.array(A).reshape((4, 4)) B = A.transpose() @ A from scipy import linalg U, s, Vh = linalg.svd(B, full_matrices=False) return Vh[3, 0:3] / Vh[3, 3] @overload def homogeneous_to_euclidean(points: Num[NDArray, "N 4"]) -> Num[NDArray, "N 3"]: ... @overload def homogeneous_to_euclidean(points: Num[Array, "N 4"]) -> Num[Array, "N 3"]: ... @jaxtyped(typechecker=beartype) def homogeneous_to_euclidean( points: Num[Any, "N 4"], ) -> Num[Any, "N 3"]: """ 将齐次坐标转换为欧几里得坐标 Args: points: homogeneous coordinates (x, y, z, w) in numpy array or jax array Returns: euclidean coordinates (x, y, z) in numpy array or jax array """ return points[..., :-1] / points[..., -1:] # %% FPS = 24 image_gen_5600 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5600], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5600][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore image_gen_5601 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5601], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5601][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore image_gen_5602 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5602], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5602][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore display(1 / FPS) sync_gen = sync_batch_gen( [image_gen_5600, image_gen_5601, image_gen_5602], timedelta(seconds=1 / FPS) ) # %% sorted_detections, affinity_matrix = calculate_affinity_matrix_by_epipolar_constraint( next(sync_gen), alpha_2d=2000 ) display(sorted_detections) # %% display( list( map( lambda x: {"timestamp": str(x.timestamp), "camera": x.camera.id}, sorted_detections, ) ) ) with jnp.printoptions(precision=3, suppress=True): display(affinity_matrix) # %% def clusters_to_detections( clusters: Sequence[Sequence[int]], sorted_detections: Sequence[Detection] ) -> list[list[Detection]]: """ given a list of clusters (which is the indices of the detections in the sorted_detections list), extract the detections from the sorted_detections list Args: clusters: list of clusters, each cluster is a list of indices of the detections in the `sorted_detections` list sorted_detections: list of SORTED detections Returns: list of clusters, each cluster is a list of detections """ return [[sorted_detections[i] for i in cluster] for cluster in clusters] solver = GLPKSolver() aff_np = np.asarray(affinity_matrix).astype(np.float64) clusters, sol_matrix = solver.solve(aff_np) display(clusters) display(sol_matrix) # %% WIDTH = 2560 HEIGHT = 1440 clusters_detections = clusters_to_detections(clusters, sorted_detections) im = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8) for el in clusters_detections[0]: im = visualize_whole_body(np.asarray(el.keypoints), im) p = plt.imshow(im) display(p) # %% im_prime = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8) for el in clusters_detections[1]: im_prime = visualize_whole_body(np.asarray(el.keypoints), im_prime) p_prime = plt.imshow(im_prime) display(p_prime) # %% @jaxtyped(typechecker=beartype) def triangulate_one_point_from_multiple_views_linear( proj_matrices: Float[Array, "N 3 4"], points: Num[Array, "N 2"], confidences: Optional[Float[Array, "N"]] = None, ) -> Float[Array, "3"]: """ Args: proj_matrices: 形状为(N, 3, 4)的投影矩阵序列 points: 形状为(N, 2)的点坐标序列 confidences: 形状为(N,)的置信度序列,范围[0.0, 1.0] Returns: point_3d: 形状为(3,)的三角测量得到的3D点 """ assert len(proj_matrices) == len(points) N = len(proj_matrices) confi: Float[Array, "N"] if confidences is None: confi = jnp.ones(N, dtype=np.float32) else: # Use square root of confidences for weighting - more balanced approach confi = jnp.sqrt(jnp.clip(confidences, 0, 1)) A = jnp.zeros((N * 2, 4), dtype=np.float32) for i in range(N): x, y = points[i] A = A.at[2 * i].set(proj_matrices[i, 2] * x - proj_matrices[i, 0]) A = A.at[2 * i + 1].set(proj_matrices[i, 2] * y - proj_matrices[i, 1]) A = A.at[2 * i].mul(confi[i]) A = A.at[2 * i + 1].mul(confi[i]) # https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.svd.html _, _, vh = jnp.linalg.svd(A, full_matrices=False) point_3d_homo = vh[-1] # shape (4,) # replace the Python `if` with a jnp.where point_3d_homo = jnp.where( point_3d_homo[3] < 0, # predicate (scalar bool tracer) -point_3d_homo, # if True point_3d_homo, # if False ) point_3d = point_3d_homo[:3] / point_3d_homo[3] return point_3d @jaxtyped(typechecker=beartype) def triangulate_points_from_multiple_views_linear( proj_matrices: Float[Array, "N 3 4"], points: Num[Array, "N P 2"], confidences: Optional[Float[Array, "N P"]] = None, ) -> Float[Array, "P 3"]: """ Batch-triangulate P points observed by N cameras, linearly via SVD. Args: proj_matrices: (N, 3, 4) projection matrices points: (N, P, 2) image-coordinates per view confidences: (N, P, 1) optional per-view confidences in [0,1] Returns: (P, 3) 3D point for each of the P tracks """ N, P, _ = points.shape assert proj_matrices.shape[0] == N if confidences is None: conf = jnp.ones((N, P), dtype=jnp.float32) else: conf = jnp.sqrt(jnp.clip(confidences, 0.0, 1.0)) # vectorize your one‐point routine over P vmap_triangulate = jax.vmap( triangulate_one_point_from_multiple_views_linear, in_axes=(None, 1, 1), # proj_matrices static, map over points[:,p,:], conf[:,p] out_axes=0, ) return vmap_triangulate(proj_matrices, points, conf) # %% @jaxtyped(typechecker=beartype) @dataclass(frozen=True) class Tracking: id: int """ The tracking id """ keypoints: Float[Array, "J 3"] """ The 3D keypoints of the tracking """ last_active_timestamp: datetime velocity: Optional[Float[Array, "3"]] = None """ Could be `None`. Like when the 3D pose is initialized. `velocity` should be updated when target association yields a new 3D pose. """ def __repr__(self) -> str: return f"Tracking({self.id}, {self.last_active_timestamp})" @jaxtyped(typechecker=beartype) def triangle_from_cluster( cluster: Sequence[Detection], ) -> tuple[Float[Array, "N 3"], datetime]: proj_matrices = jnp.array([el.camera.params.projection_matrix for el in cluster]) points = jnp.array([el.keypoints_undistorted for el in cluster]) confidences = jnp.array([el.confidences for el in cluster]) latest_timestamp = max(el.timestamp for el in cluster) return ( triangulate_points_from_multiple_views_linear( proj_matrices, points, confidences=confidences ), latest_timestamp, ) # %% class GlobalTrackingState: _last_id: int _trackings: dict[int, Tracking] def __init__(self): self._last_id = 0 self._trackings = {} def __repr__(self) -> str: return ( f"GlobalTrackingState(last_id={self._last_id}, trackings={self._trackings})" ) @property def trackings(self) -> dict[int, Tracking]: return shallow_copy(self._trackings) def add_tracking(self, cluster: Sequence[Detection]) -> Tracking: kps_3d, latest_timestamp = triangle_from_cluster(cluster) next_id = self._last_id + 1 tracking = Tracking( id=next_id, keypoints=kps_3d, last_active_timestamp=latest_timestamp ) self._trackings[next_id] = tracking self._last_id = next_id return tracking global_tracking_state = GlobalTrackingState() for cluster in clusters_detections: global_tracking_state.add_tracking(cluster) display(global_tracking_state) # %% next_group = next(sync_gen) display(next_group) # %% @jaxtyped(typechecker=beartype) def calculate_distance_2d( left: Num[Array, "J 2"], right: Num[Array, "J 2"], image_size: tuple[int, int] = (1, 1), ) -> Float[Array, "J"]: """ Calculate the *normalized* distance between two sets of keypoints. Args: left: The left keypoints right: The right keypoints image_size: The size of the image Returns: Array of normalized Euclidean distances between corresponding keypoints """ w, h = image_size if w == 1 and h == 1: # already normalized left_normalized = left right_normalized = right else: left_normalized = left / jnp.array([w, h]) right_normalized = right / jnp.array([w, h]) return jnp.linalg.norm(left_normalized - right_normalized, axis=-1) @jaxtyped(typechecker=beartype) def calculate_affinity_2d( distance_2d: Float[Array, "J"], delta_t: timedelta, w_2d: float, alpha_2d: float, lambda_a: float, ) -> Float[Array, "J"]: """ Calculate the affinity between two detections based on the distances between their keypoints. The affinity score is calculated by summing individual keypoint affinities: A_2D = sum(w_2D * (1 - distance_2D / (alpha_2D*delta_t)) * np.exp(-lambda_a * delta_t)) for each keypoint Args: distance_2d: The normalized distances between keypoints (array with one value per keypoint) w_2d: The weight for 2D affinity alpha_2d: The normalization factor for distance lambda_a: The decay rate for time difference delta_t: The time delta between the two detections, in seconds Returns: Sum of affinity scores across all keypoints """ delta_t_s = delta_t.total_seconds() affinity_per_keypoint = ( w_2d * (1 - distance_2d / (alpha_2d * delta_t_s)) * jnp.exp(-lambda_a * delta_t_s) ) return affinity_per_keypoint @jaxtyped(typechecker=beartype) def perpendicular_distance_point_to_line_two_points( point: Num[Array, "3"], line: tuple[Num[Array, "3"], Num[Array, "3"]] ) -> Float[Array, ""]: """ Calculate the perpendicular distance between a point and a line. where `line` is represented by two points: `(line_start, line_end)` """ line_start, line_end = line distance = jnp.linalg.norm( jnp.cross(line_end - line_start, line_start - point) ) / jnp.linalg.norm(line_end - line_start) return distance @jaxtyped(typechecker=beartype) def perpendicular_distance_camera_2d_points_to_tracking_raycasting( detection: Detection, tracking: Tracking, delta_t: timedelta, ) -> Float[Array, "J"]: """ Calculate the perpendicular distances between predicted 3D tracking points and the rays cast from camera center through the 2D image points. Args: detection: The detection object containing 2D keypoints and camera parameters tracking: The tracking object containing 3D keypoints delta_t: Time delta between the tracking's last update and current observation Returns: Array of perpendicular distances for each keypoint """ camera = detection.camera # Convert timedelta to seconds for prediction delta_t_s = delta_t.total_seconds() # Predict the 3D pose based on tracking and delta_t predicted_pose = predict_pose_3d(tracking, delta_t_s) # Back-project the 2D points to 3D space (assuming z=0 plane) back_projected_points = detection.camera.unproject_points_to_z_plane( detection.keypoints, z=0.0 ) # Get camera center from the camera parameters camera_center = camera.params.location # Define function to calculate distance between a predicted point and its corresponding ray def calc_distance(predicted_point, back_projected_point): return perpendicular_distance_point_to_line_two_points( predicted_point, (camera_center, back_projected_point) ) # Vectorize over all keypoints vmap_calc_distance = jax.vmap(calc_distance) # Calculate and return distances for all keypoints return vmap_calc_distance(predicted_pose, back_projected_points) @jaxtyped(typechecker=beartype) def calculate_affinity_3d( distances: Float[Array, "J"], delta_t: timedelta, w_3d: float, alpha_3d: float, lambda_a: float, ) -> Float[Array, "J"]: """ Calculate 3D affinity score between a tracking and detection. The affinity score is calculated by summing individual keypoint affinities: A_3D = sum(w_3D * (1 - dl / alpha_3D) * np.exp(-lambda_a * delta_t)) for each keypoint Args: distances: Array of perpendicular distances for each keypoint delta_t: Time difference between tracking and detection w_3d: Weight for 3D affinity alpha_3d: Normalization factor for distance lambda_a: Decay rate for time difference Returns: Sum of affinity scores across all keypoints """ delta_t_s = delta_t.total_seconds() affinity_per_keypoint = ( w_3d * (1 - distances / alpha_3d) * jnp.exp(-lambda_a * delta_t_s) ) return affinity_per_keypoint def predict_pose_3d( tracking: Tracking, delta_t_s: float, ) -> Float[Array, "J 3"]: """ Predict the 3D pose of a tracking based on its velocity. """ if tracking.velocity is None: return tracking.keypoints return tracking.keypoints + tracking.velocity * delta_t_s @beartype def calculate_tracking_detection_affinity( tracking: Tracking, detection: Detection, w_2d: float, alpha_2d: float, w_3d: float, alpha_3d: float, lambda_a: float, ) -> float: """ Calculate the affinity between a tracking and a detection. Args: tracking: The tracking object detection: The detection object w_2d: Weight for 2D affinity alpha_2d: Normalization factor for 2D distance w_3d: Weight for 3D affinity alpha_3d: Normalization factor for 3D distance lambda_a: Decay rate for time difference Returns: Combined affinity score """ camera = detection.camera delta_t = detection.timestamp - tracking.last_active_timestamp # Calculate 2D affinity tracking_2d_projection = camera.project(tracking.keypoints) w, h = camera.params.image_size distance_2d = calculate_distance_2d( tracking_2d_projection, detection.keypoints, image_size=(int(w), int(h)), ) affinity_2d = calculate_affinity_2d( distance_2d, delta_t, w_2d=w_2d, alpha_2d=alpha_2d, lambda_a=lambda_a, ) # Calculate 3D affinity distances = perpendicular_distance_camera_2d_points_to_tracking_raycasting( detection, tracking, delta_t ) affinity_3d = calculate_affinity_3d( distances, delta_t, w_3d=w_3d, alpha_3d=alpha_3d, lambda_a=lambda_a, ) # Combine affinities total_affinity = affinity_2d + affinity_3d return jnp.sum(total_affinity).item() @beartype def calculate_affinity_matrix( trackings: Sequence[Tracking], detections: Sequence[Detection], w_2d: float, alpha_2d: float, w_3d: float, alpha_3d: float, lambda_a: float, ) -> tuple[Float[Array, "T D"], OrderedDict[CameraID, list[Detection]]]: """ Calculate the affinity matrix between a set of trackings and detections. Args: trackings: Sequence of tracking objects detections: Sequence of detection objects w_2d: Weight for 2D affinity alpha_2d: Normalization factor for 2D distance w_3d: Weight for 3D affinity alpha_3d: Normalization factor for 3D distance lambda_a: Decay rate for time difference Returns: - affinity matrix of shape (T, D) where T is number of trackings and D is number of detections - dictionary mapping camera IDs to lists of detections from that camera, since it's a `OrderDict` you could flat it out to get the indices of detections in the affinity matrix Matrix Layout: The affinity matrix has shape (T, D), where: - T = number of trackings (rows) - D = total number of detections across all cameras (columns) The matrix is organized as follows: ``` | Camera 1 | Camera 2 | Camera c | | d1 d2 ... | d1 d2 ... | d1 d2 ... | ---------+-------------+-------------+-------------+ Track 1 | a11 a12 ... | a11 a12 ... | a11 a12 ... | Track 2 | a21 a22 ... | a21 a22 ... | a21 a22 ... | ... | ... | ... | ... | Track t | at1 at2 ... | at1 at2 ... | at1 at2 ... | ``` Where: - Rows are ordered by tracking ID - Columns are ordered by camera, then by detection within each camera - Each cell aij represents the affinity between tracking i and detection j The detection ordering in columns follows the exact same order as iterating through the detection_by_camera dictionary, which is returned alongside the matrix to maintain this relationship. """ affinity = jnp.zeros((len(trackings), len(detections))) detection_by_camera = classify_by_camera(detections) for i, tracking in enumerate(trackings): j = 0 for c, camera_detections in detection_by_camera.items(): for det in camera_detections: affinity_value = calculate_tracking_detection_affinity( tracking, det, w_2d=w_2d, alpha_2d=alpha_2d, w_3d=w_3d, alpha_3d=alpha_3d, lambda_a=lambda_a, ) affinity = affinity.at[i, j].set(affinity_value) j += 1 return affinity, detection_by_camera # %% # let's do cross-view association W_2D = 1.0 ALPHA_2D = 1.0 LAMBDA_A = 0.1 W_3D = 1.0 ALPHA_3D = 1.0 trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id) unmatched_detections = shallow_copy(next_group) affinity, detection_by_camera = calculate_affinity_matrix( trackings, unmatched_detections, w_2d=W_2D, alpha_2d=ALPHA_2D, w_3d=W_3D, alpha_3d=ALPHA_3D, lambda_a=LAMBDA_A, ) display(affinity)