# --- # jupyter: # jupytext: # text_representation: # extension: .py # format_name: percent # format_version: '1.3' # jupytext_version: 1.17.0 # kernelspec: # display_name: .venv # language: python # name: python3 # --- # %% from copy import copy as shallow_copy from copy import deepcopy from copy import deepcopy as deep_copy from dataclasses import dataclass from datetime import datetime, timedelta from pathlib import Path from typing import ( Any, Generator, Optional, Sequence, TypeAlias, TypedDict, cast, overload, ) import awkward as ak import jax import jax.numpy as jnp import numpy as np import orjson from beartype import beartype from cv2 import undistortPoints from IPython.display import display from jaxtyping import Array, Float, Num, jaxtyped from matplotlib import pyplot as plt from numpy.typing import ArrayLike from scipy.spatial.transform import Rotation as R from app.camera import ( Camera, CameraParams, Detection, calculate_affinity_matrix_by_epipolar_constraint, classify_by_camera, ) from app.solver._old import GLPKSolver from app.visualize.whole_body import visualize_whole_body NDArray: TypeAlias = np.ndarray # %% DATASET_PATH = Path("samples") / "04_02" AK_CAMERA_DATASET: ak.Array = ak.from_parquet(DATASET_PATH / "camera_params.parquet") display(AK_CAMERA_DATASET) # %% class Resolution(TypedDict): width: int height: int class Intrinsic(TypedDict): camera_matrix: Num[Array, "3 3"] """ K """ distortion_coefficients: Num[Array, "N"] """ distortion coefficients; usually 5 """ class Extrinsic(TypedDict): rvec: Num[NDArray, "3"] tvec: Num[NDArray, "3"] class ExternalCameraParams(TypedDict): name: str port: int intrinsic: Intrinsic extrinsic: Extrinsic resolution: Resolution # %% def read_dataset_by_port(port: int) -> ak.Array: P = DATASET_PATH / f"{port}.parquet" return ak.from_parquet(P) KEYPOINT_DATASET = { int(p): read_dataset_by_port(p) for p in ak.to_numpy(AK_CAMERA_DATASET["port"]) } # %% class KeypointDataset(TypedDict): frame_index: int boxes: Num[NDArray, "N 4"] kps: Num[NDArray, "N J 2"] kps_scores: Num[NDArray, "N J"] @jaxtyped(typechecker=beartype) def to_transformation_matrix( rvec: Num[NDArray, "3"], tvec: Num[NDArray, "3"] ) -> Num[NDArray, "4 4"]: res = np.eye(4) res[:3, :3] = R.from_rotvec(rvec).as_matrix() res[:3, 3] = tvec return res @jaxtyped(typechecker=beartype) def undistort_points( points: Num[NDArray, "M 2"], camera_matrix: Num[NDArray, "3 3"], dist_coeffs: Num[NDArray, "N"], ) -> Num[NDArray, "M 2"]: K = camera_matrix dist = dist_coeffs res = undistortPoints(points, K, dist, P=K) # type: ignore return res.reshape(-1, 2) def from_camera_params(camera: ExternalCameraParams) -> Camera: rt = jnp.array( to_transformation_matrix( ak.to_numpy(camera["extrinsic"]["rvec"]), ak.to_numpy(camera["extrinsic"]["tvec"]), ) ) K = jnp.array(camera["intrinsic"]["camera_matrix"]).reshape(3, 3) dist_coeffs = jnp.array(camera["intrinsic"]["distortion_coefficients"]) image_size = jnp.array( (camera["resolution"]["width"], camera["resolution"]["height"]) ) return Camera( id=camera["name"], params=CameraParams( K=K, Rt=rt, dist_coeffs=dist_coeffs, image_size=image_size, ), ) def preprocess_keypoint_dataset( dataset: Sequence[KeypointDataset], camera: Camera, fps: float, start_timestamp: datetime, ) -> Generator[Detection, None, None]: frame_interval_s = 1 / fps for el in dataset: frame_index = el["frame_index"] timestamp = start_timestamp + timedelta(seconds=frame_index * frame_interval_s) for kp, kp_score in zip(el["kps"], el["kps_scores"]): yield Detection( keypoints=jnp.array(kp), confidences=jnp.array(kp_score), camera=camera, timestamp=timestamp, ) # %% DetectionGenerator: TypeAlias = Generator[Detection, None, None] def sync_batch_gen(gens: list[DetectionGenerator], diff: timedelta): """ given a list of detection generators, return a generator that yields a batch of detections Args: gens: list of detection generators diff: maximum timestamp difference between detections to consider them part of the same batch """ N = len(gens) last_batch_timestamp: Optional[datetime] = None next_batch_timestamp: Optional[datetime] = None current_batch: list[Detection] = [] next_batch: list[Detection] = [] paused: list[bool] = [False] * N finished: list[bool] = [False] * N def reset_paused(): """ reset paused list based on finished list """ for i in range(N): if not finished[i]: paused[i] = False else: paused[i] = True EPS = 1e-6 # a small epsilon to avoid floating point precision issues diff_esp = diff - timedelta(seconds=EPS) while True: for i, gen in enumerate(gens): try: if finished[i] or paused[i]: continue val = next(gen) if last_batch_timestamp is None: last_batch_timestamp = val.timestamp current_batch.append(val) else: if abs(val.timestamp - last_batch_timestamp) >= diff_esp: next_batch.append(val) if next_batch_timestamp is None: next_batch_timestamp = val.timestamp paused[i] = True if all(paused): yield current_batch current_batch = next_batch next_batch = [] last_batch_timestamp = next_batch_timestamp next_batch_timestamp = None reset_paused() else: current_batch.append(val) except StopIteration: finished[i] = True paused[i] = True if all(finished): if len(current_batch) > 0: # All generators exhausted, flush remaining batch and exit yield current_batch break # %% @overload def to_projection_matrix( transformation_matrix: Num[NDArray, "4 4"], camera_matrix: Num[NDArray, "3 3"] ) -> Num[NDArray, "3 4"]: ... @overload def to_projection_matrix( transformation_matrix: Num[Array, "4 4"], camera_matrix: Num[Array, "3 3"] ) -> Num[Array, "3 4"]: ... @jaxtyped(typechecker=beartype) def to_projection_matrix( transformation_matrix: Num[Any, "4 4"], camera_matrix: Num[Any, "3 3"], ) -> Num[Any, "3 4"]: return camera_matrix @ transformation_matrix[:3, :] to_projection_matrix_jit = jax.jit(to_projection_matrix) @jaxtyped(typechecker=beartype) def dlt( H1: Num[NDArray, "3 4"], H2: Num[NDArray, "3 4"], p1: Num[NDArray, "2"], p2: Num[NDArray, "2"], ) -> Num[NDArray, "3"]: """ Direct Linear Transformation """ A = [ p1[1] * H1[2, :] - H1[1, :], H1[0, :] - p1[0] * H1[2, :], p2[1] * H2[2, :] - H2[1, :], H2[0, :] - p2[0] * H2[2, :], ] A = np.array(A).reshape((4, 4)) B = A.transpose() @ A from scipy import linalg U, s, Vh = linalg.svd(B, full_matrices=False) return Vh[3, 0:3] / Vh[3, 3] @overload def homogeneous_to_euclidean(points: Num[NDArray, "N 4"]) -> Num[NDArray, "N 3"]: ... @overload def homogeneous_to_euclidean(points: Num[Array, "N 4"]) -> Num[Array, "N 3"]: ... @jaxtyped(typechecker=beartype) def homogeneous_to_euclidean( points: Num[Any, "N 4"], ) -> Num[Any, "N 3"]: """ 将齐次坐标转换为欧几里得坐标 Args: points: homogeneous coordinates (x, y, z, w) in numpy array or jax array Returns: euclidean coordinates (x, y, z) in numpy array or jax array """ return points[..., :-1] / points[..., -1:] # %% FPS = 24 image_gen_5600 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5600], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5600][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore image_gen_5601 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5601], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5601][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore image_gen_5602 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5602], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5602][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore display(1 / FPS) sync_gen = sync_batch_gen( [image_gen_5600, image_gen_5601, image_gen_5602], timedelta(seconds=1 / FPS) ) # %% sorted_detections, affinity_matrix = calculate_affinity_matrix_by_epipolar_constraint( next(sync_gen), alpha_2d=2000 ) display(sorted_detections) # %% display( list( map( lambda x: {"timestamp": str(x.timestamp), "camera": x.camera.id}, sorted_detections, ) ) ) with jnp.printoptions(precision=3, suppress=True): display(affinity_matrix) # %% def clusters_to_detections( clusters: list[list[int]], sorted_detections: list[Detection] ) -> list[list[Detection]]: """ given a list of clusters (which is the indices of the detections in the sorted_detections list), extract the detections from the sorted_detections list Args: clusters: list of clusters, each cluster is a list of indices of the detections in the `sorted_detections` list sorted_detections: list of SORTED detections Returns: list of clusters, each cluster is a list of detections """ return [[sorted_detections[i] for i in cluster] for cluster in clusters] solver = GLPKSolver() aff_np = np.asarray(affinity_matrix).astype(np.float64) clusters, sol_matrix = solver.solve(aff_np) display(clusters) display(sol_matrix) # %% WIDTH = 2560 HEIGHT = 1440 clusters_detections = clusters_to_detections(clusters, sorted_detections) im = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8) for el in clusters_detections[0]: im = visualize_whole_body(np.asarray(el.keypoints), im) p = plt.imshow(im) display(p) # %% im_prime = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8) for el in clusters_detections[1]: im_prime = visualize_whole_body(np.asarray(el.keypoints), im_prime) p_prime = plt.imshow(im_prime) display(p_prime) # %% @jaxtyped(typechecker=beartype) def triangulate_one_point_from_multiple_views_linear( proj_matrices: Float[Array, "N 3 4"], points: Num[Array, "N 2"], confidences: Optional[Float[Array, "N"]] = None, ) -> Float[Array, "3"]: """ Args: proj_matrices: 形状为(N, 3, 4)的投影矩阵序列 points: 形状为(N, 2)的点坐标序列 confidences: 形状为(N,)的置信度序列,范围[0.0, 1.0] Returns: point_3d: 形状为(3,)的三角测量得到的3D点 """ assert len(proj_matrices) == len(points) N = len(proj_matrices) confi: Float[Array, "N"] if confidences is None: confi = jnp.ones(N, dtype=np.float32) else: # Use square root of confidences for weighting - more balanced approach confi = jnp.sqrt(jnp.clip(confidences, 0, 1)) A = jnp.zeros((N * 2, 4), dtype=np.float32) for i in range(N): x, y = points[i] A = A.at[2 * i].set(proj_matrices[i, 2] * x - proj_matrices[i, 0]) A = A.at[2 * i + 1].set(proj_matrices[i, 2] * y - proj_matrices[i, 1]) A = A.at[2 * i].mul(confi[i]) A = A.at[2 * i + 1].mul(confi[i]) # https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.svd.html _, _, vh = jnp.linalg.svd(A, full_matrices=False) point_3d_homo = vh[-1] # shape (4,) # replace the Python `if` with a jnp.where point_3d_homo = jnp.where( point_3d_homo[3] < 0, # predicate (scalar bool tracer) -point_3d_homo, # if True point_3d_homo, # if False ) point_3d = point_3d_homo[:3] / point_3d_homo[3] return point_3d @jaxtyped(typechecker=beartype) def triangulate_points_from_multiple_views_linear( proj_matrices: Float[Array, "N 3 4"], points: Num[Array, "N P 2"], confidences: Optional[Float[Array, "N P"]] = None, ) -> Float[Array, "P 3"]: """ Batch-triangulate P points observed by N cameras, linearly via SVD. Args: proj_matrices: (N, 3, 4) projection matrices points: (N, P, 2) image-coordinates per view confidences: (N, P, 1) optional per-view confidences in [0,1] Returns: (P, 3) 3D point for each of the P tracks """ N, P, _ = points.shape assert proj_matrices.shape[0] == N if confidences is None: conf = jnp.ones((N, P), dtype=jnp.float32) else: conf = jnp.sqrt(jnp.clip(confidences, 0.0, 1.0)) # vectorize your one‐point routine over P vmap_triangulate = jax.vmap( triangulate_one_point_from_multiple_views_linear, in_axes=(None, 1, 1), # proj_matrices static, map over points[:,p,:], conf[:,p] out_axes=0, ) return vmap_triangulate(proj_matrices, points, conf) # %% @jaxtyped(typechecker=beartype) @dataclass(frozen=True) class Tracking: id: int """ The tracking id """ keypoints: Float[Array, "J 3"] """ The 3D keypoints of the tracking """ last_active_timestamp: datetime velocity: Optional[Float[Array, "3"]] = None """ Could be `None`. Like when the 3D pose is initialized. `velocity` should be updated when target association yields a new 3D pose. """ def __repr__(self) -> str: return f"Tracking({self.id}, {self.last_active_timestamp})" @jaxtyped(typechecker=beartype) def triangle_from_cluster( cluster: list[Detection], ) -> tuple[Float[Array, "N 3"], datetime]: proj_matrices = jnp.array([el.camera.params.projection_matrix for el in cluster]) points = jnp.array([el.keypoints_undistorted for el in cluster]) confidences = jnp.array([el.confidences for el in cluster]) latest_timestamp = max(el.timestamp for el in cluster) return ( triangulate_points_from_multiple_views_linear( proj_matrices, points, confidences=confidences ), latest_timestamp, ) # res = { # "a": triangle_from_cluster(clusters_detections[0]).tolist(), # "b": triangle_from_cluster(clusters_detections[1]).tolist(), # } # with open("samples/res.json", "wb") as f: # f.write(orjson.dumps(res)) class GlobalTrackingState: _last_id: int _trackings: dict[int, Tracking] def __init__(self): self._last_id = 0 self._trackings = {} def __repr__(self) -> str: return ( f"GlobalTrackingState(last_id={self._last_id}, trackings={self._trackings})" ) @property def trackings(self) -> dict[int, Tracking]: return shallow_copy(self._trackings) def add_tracking(self, cluster: list[Detection]) -> Tracking: kps_3d, latest_timestamp = triangle_from_cluster(cluster) next_id = self._last_id + 1 tracking = Tracking( id=next_id, keypoints=kps_3d, last_active_timestamp=latest_timestamp ) self._trackings[next_id] = tracking self._last_id = next_id return tracking global_tracking_state = GlobalTrackingState() for cluster in clusters_detections: global_tracking_state.add_tracking(cluster) display(global_tracking_state) # %% next_group = next(sync_gen) display(next_group) # %% @jaxtyped(typechecker=beartype) def calculate_distance_2d( left: Num[Array, "J 2"], right: Num[Array, "J 2"], image_size: tuple[int, int] = (1, 1), ): """ Calculate the *normalized* distance between two sets of keypoints. Args: left: The left keypoints right: The right keypoints image_size: The size of the image """ w, h = image_size if w == 1 and h == 1: # already normalized left_normalized = left right_normalized = right else: left_normalized = left / jnp.array([w, h]) right_normalized = right / jnp.array([w, h]) return jnp.linalg.norm(left_normalized - right_normalized, axis=-1) @jaxtyped(typechecker=beartype) def calculate_affinity_2d( distance_2d: float, w_2d: float, alpha_2d: float, lambda_a: float, delta_t: float ) -> float: """ Calculate the affinity between two detections based on the distance between their keypoints. Args: distance_2d: The normalized distance between the two keypoints (see `calculate_distance_2d`) w_2d: The weight of the distance (parameter) alpha_2d: The alpha value for the distance calculation (parameter) lambda_a: The lambda value for the distance calculation (parameter) delta_t: The time delta between the two detections, in seconds """ return w_2d * (1 - distance_2d / (alpha_2d * delta_t)) * np.exp(-lambda_a * delta_t) @jaxtyped(typechecker=beartype) def perpendicular_distance_point_to_line_two_points( point: Num[Array, "2"], line: tuple[Num[Array, "2"], Num[Array, "2"]] ): """ Calculate the perpendicular distance between a point and a line. where `line` is represented by two points: `(line_start, line_end)` """ line_start, line_end = line distance = jnp.linalg.norm( jnp.cross(line_end - line_start, line_start - point) ) / jnp.linalg.norm(line_end - line_start) return distance def predict_pose_3d( tracking: Tracking, delta_t: float, ) -> Float[Array, "J 3"]: """ Predict the 3D pose of a tracking based on its velocity. """ if tracking.velocity is None: return tracking.keypoints return tracking.keypoints + tracking.velocity * delta_t # %% # let's do cross-view association trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id) unmatched_detections = shallow_copy(next_group) # cross-view association matrix with shape (T, D), where T is the number of # trackings, D is the number of detections # layout: # a_t1_c1_d1, a_t1_c1_d2, a_t1_c1_d3,...,a_t1_c2_d1,..., a_t1_cc_dd # a_t2_c1_d1,... # ... # a_tt_c1_d1,... , a_tt_cc_dd # # where T <- [t1..tt]; D <- join(c1..cc), where `cn` is a collection of # detections from camera `n` affinity = np.zeros((len(trackings), len(unmatched_detections))) detection_by_camera = classify_by_camera(unmatched_detections) for i, tracking in enumerate(trackings): for c, detections in detection_by_camera.items(): camera = next(iter(detections)).camera # pixel space, unnormalized tracking_2d_projection = camera.project(tracking.keypoints) for det in detections: ...