# --- # jupyter: # jupytext: # text_representation: # extension: .py # format_name: percent # format_version: '1.3' # jupytext_version: 1.17.0 # kernelspec: # display_name: .venv # language: python # name: python3 # --- # %% from copy import deepcopy from datetime import datetime, timedelta from pathlib import Path from typing import ( Any, Generator, Optional, Sequence, TypeAlias, TypedDict, cast, overload, ) import awkward as ak import jax import jax.numpy as jnp import numpy as np import orjson from beartype import beartype from cv2 import undistortPoints from jaxtyping import Array, Float, Num, jaxtyped from matplotlib import pyplot as plt from numpy.typing import ArrayLike from scipy.spatial.transform import Rotation as R from app.camera import Camera, CameraParams, Detection from app.visualize.whole_body import visualize_whole_body from IPython.display import display NDArray: TypeAlias = np.ndarray # %% DATASET_PATH = Path("samples") / "04_02" AK_CAMERA_DATASET: ak.Array = ak.from_parquet(DATASET_PATH / "camera_params.parquet") display(AK_CAMERA_DATASET) # %% class Resolution(TypedDict): width: int height: int class Intrinsic(TypedDict): camera_matrix: Num[Array, "3 3"] """ K """ distortion_coefficients: Num[Array, "N"] """ distortion coefficients; usually 5 """ class Extrinsic(TypedDict): rvec: Num[NDArray, "3"] tvec: Num[NDArray, "3"] class ExternalCameraParams(TypedDict): name: str port: int intrinsic: Intrinsic extrinsic: Extrinsic resolution: Resolution # %% def read_dataset_by_port(port: int) -> ak.Array: P = DATASET_PATH / f"{port}.parquet" return ak.from_parquet(P) KEYPOINT_DATASET = { int(p): read_dataset_by_port(p) for p in ak.to_numpy(AK_CAMERA_DATASET["port"]) } # %% class KeypointDataset(TypedDict): frame_index: int boxes: Num[NDArray, "N 4"] kps: Num[NDArray, "N J 2"] kps_scores: Num[NDArray, "N J"] @jaxtyped(typechecker=beartype) def to_transformation_matrix( rvec: Num[NDArray, "3"], tvec: Num[NDArray, "3"] ) -> Num[NDArray, "4 4"]: res = np.eye(4) res[:3, :3] = R.from_rotvec(rvec).as_matrix() res[:3, 3] = tvec return res @jaxtyped(typechecker=beartype) def undistort_points( points: Num[NDArray, "M 2"], camera_matrix: Num[NDArray, "3 3"], dist_coeffs: Num[NDArray, "N"], ) -> Num[NDArray, "M 2"]: K = camera_matrix dist = dist_coeffs res = undistortPoints(points, K, dist, P=K) # type: ignore return res.reshape(-1, 2) def from_camera_params(camera: ExternalCameraParams) -> Camera: rt = jnp.array( to_transformation_matrix( ak.to_numpy(camera["extrinsic"]["rvec"]), ak.to_numpy(camera["extrinsic"]["tvec"]), ) ) K = jnp.array(camera["intrinsic"]["camera_matrix"]).reshape(3, 3) dist_coeffs = jnp.array(camera["intrinsic"]["distortion_coefficients"]) image_size = jnp.array( (camera["resolution"]["width"], camera["resolution"]["height"]) ) return Camera( id=camera["name"], params=CameraParams( K=K, Rt=rt, dist_coeffs=dist_coeffs, image_size=image_size, ), ) def preprocess_keypoint_dataset( dataset: Sequence[KeypointDataset], camera: Camera, fps: float, start_timestamp: datetime, ) -> Generator[Detection, None, None]: frame_interval_s = 1 / fps for el in dataset: frame_index = el["frame_index"] timestamp = start_timestamp + timedelta(seconds=frame_index * frame_interval_s) for kp, kp_score in zip(el["kps"], el["kps_scores"]): yield Detection( keypoints=jnp.array(kp), confidences=jnp.array(kp_score), camera=camera, timestamp=timestamp, ) # %% DetectionGenerator: TypeAlias = Generator[Detection, None, None] def sync_batch_gen(gens: list[DetectionGenerator], diff: timedelta): """ given a list of detection generators, return a generator that yields a batch of detections Args: gens: list of detection generators diff: maximum timestamp difference between detections to consider them part of the same batch """ N = len(gens) last_batch_timestamp: Optional[datetime] = None next_batch_timestamp: Optional[datetime] = None current_batch: list[Detection] = [] next_batch: list[Detection] = [] paused: list[bool] = [False] * N finished: list[bool] = [False] * N def reset_paused(): """ reset paused list based on finished list """ for i in range(N): if not finished[i]: paused[i] = False else: paused[i] = True EPS = 1e-6 # a small epsilon to avoid floating point precision issues diff_esp = diff - timedelta(seconds=EPS) while True: for i, gen in enumerate(gens): try: if finished[i] or paused[i]: continue val = next(gen) if last_batch_timestamp is None: last_batch_timestamp = val.timestamp current_batch.append(val) else: if abs(val.timestamp - last_batch_timestamp) >= diff_esp: next_batch.append(val) if next_batch_timestamp is None: next_batch_timestamp = val.timestamp paused[i] = True if all(paused): yield current_batch current_batch = next_batch next_batch = [] last_batch_timestamp = next_batch_timestamp next_batch_timestamp = None reset_paused() else: current_batch.append(val) except StopIteration: finished[i] = True paused[i] = True if all(finished): if len(current_batch) > 0: # All generators exhausted, flush remaining batch and exit yield current_batch break # %% @overload def to_projection_matrix( transformation_matrix: Num[NDArray, "4 4"], camera_matrix: Num[NDArray, "3 3"] ) -> Num[NDArray, "3 4"]: ... @overload def to_projection_matrix( transformation_matrix: Num[Array, "4 4"], camera_matrix: Num[Array, "3 3"] ) -> Num[Array, "3 4"]: ... @jaxtyped(typechecker=beartype) def to_projection_matrix( transformation_matrix: Num[Any, "4 4"], camera_matrix: Num[Any, "3 3"], ) -> Num[Any, "3 4"]: return camera_matrix @ transformation_matrix[:3, :] to_projection_matrix_jit = jax.jit(to_projection_matrix) @jaxtyped(typechecker=beartype) def dlt( H1: Num[NDArray, "3 4"], H2: Num[NDArray, "3 4"], p1: Num[NDArray, "2"], p2: Num[NDArray, "2"], ) -> Num[NDArray, "3"]: """ Direct Linear Transformation """ A = [ p1[1] * H1[2, :] - H1[1, :], H1[0, :] - p1[0] * H1[2, :], p2[1] * H2[2, :] - H2[1, :], H2[0, :] - p2[0] * H2[2, :], ] A = np.array(A).reshape((4, 4)) B = A.transpose() @ A from scipy import linalg U, s, Vh = linalg.svd(B, full_matrices=False) return Vh[3, 0:3] / Vh[3, 3] @overload def homogeneous_to_euclidean(points: Num[NDArray, "N 4"]) -> Num[NDArray, "N 3"]: ... @overload def homogeneous_to_euclidean(points: Num[Array, "N 4"]) -> Num[Array, "N 3"]: ... @jaxtyped(typechecker=beartype) def homogeneous_to_euclidean( points: Num[Any, "N 4"], ) -> Num[Any, "N 3"]: """ 将齐次坐标转换为欧几里得坐标 Args: points: homogeneous coordinates (x, y, z, w) in numpy array or jax array Returns: euclidean coordinates (x, y, z) in numpy array or jax array """ return points[..., :-1] / points[..., -1:] # %% FPS = 24 image_gen_5600 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5600], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5600][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore image_gen_5601 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5601], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5601][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore image_gen_5602 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5602], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5602][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore display(1 / FPS) sync_gen = sync_batch_gen( [image_gen_5600, image_gen_5601, image_gen_5602], timedelta(seconds=1 / FPS) ) # %% detections = next(sync_gen) # %% from app.camera import calculate_affinity_matrix_by_epipolar_constraint sorted_detections, affinity_matrix = calculate_affinity_matrix_by_epipolar_constraint( detections, alpha_2d=2000 ) # %% display( list( map( lambda x: {"timestamp": str(x.timestamp), "camera": x.camera.id}, sorted_detections, ) ) ) with jnp.printoptions(precision=3, suppress=True): display(affinity_matrix) # %% from app.solver._old import GLPKSolver def clusters_to_detections( clusters: list[list[int]], sorted_detections: list[Detection] ) -> list[list[Detection]]: """ given a list of clusters (which is the indices of the detections in the sorted_detections list), extract the detections from the sorted_detections list Args: clusters: list of clusters, each cluster is a list of indices of the detections in the `sorted_detections` list sorted_detections: list of SORTED detections Returns: list of clusters, each cluster is a list of detections """ return [[sorted_detections[i] for i in cluster] for cluster in clusters] solver = GLPKSolver() aff_np = np.asarray(affinity_matrix).astype(np.float64) clusters, sol_matrix = solver.solve(aff_np) display(clusters) display(sol_matrix) # %% WIDTH = 2560 HEIGHT = 1440 clusters_detections = clusters_to_detections(clusters, sorted_detections) im = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8) for el in clusters_detections[0]: im = visualize_whole_body(np.asarray(el.keypoints), im) p = plt.imshow(im) display(p) # %% im_prime = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8) for el in clusters_detections[1]: im_prime = visualize_whole_body(np.asarray(el.keypoints), im_prime) p_prime = plt.imshow(im_prime) display(p_prime) # %% @jaxtyped(typechecker=beartype) def triangulate_one_point_from_multiple_views_linear( proj_matrices: Float[Array, "N 3 4"], points: Num[Array, "N 2"], confidences: Optional[Float[Array, "N"]] = None, ) -> Float[Array, "3"]: """ Args: proj_matrices: 形状为(N, 3, 4)的投影矩阵序列 points: 形状为(N, 2)的点坐标序列 confidences: 形状为(N,)的置信度序列,范围[0.0, 1.0] Returns: point_3d: 形状为(3,)的三角测量得到的3D点 """ assert len(proj_matrices) == len(points) N = len(proj_matrices) confi: Float[Array, "N"] if confidences is None: confi = jnp.ones(N, dtype=np.float32) else: # Use square root of confidences for weighting - more balanced approach confi = jnp.sqrt(jnp.clip(confidences, 0, 1)) A = jnp.zeros((N * 2, 4), dtype=np.float32) for i in range(N): x, y = points[i] A = A.at[2 * i].set(proj_matrices[i, 2] * x - proj_matrices[i, 0]) A = A.at[2 * i + 1].set(proj_matrices[i, 2] * y - proj_matrices[i, 1]) A = A.at[2 * i].mul(confi[i]) A = A.at[2 * i + 1].mul(confi[i]) # https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.svd.html _, _, vh = jnp.linalg.svd(A, full_matrices=False) point_3d_homo = vh[-1] # shape (4,) # replace the Python `if` with a jnp.where point_3d_homo = jnp.where( point_3d_homo[3] < 0, # predicate (scalar bool tracer) -point_3d_homo, # if True point_3d_homo, # if False ) point_3d = point_3d_homo[:3] / point_3d_homo[3] return point_3d @jaxtyped(typechecker=beartype) def triangulate_points_from_multiple_views_linear( proj_matrices: Float[Array, "N 3 4"], points: Num[Array, "N P 2"], confidences: Optional[Float[Array, "N P"]] = None, ) -> Float[Array, "P 3"]: """ Batch-triangulate P points observed by N cameras, linearly via SVD. Args: proj_matrices: (N, 3, 4) projection matrices points: (N, P, 2) image-coordinates per view confidences: (N, P, 1) optional per-view confidences in [0,1] Returns: (P, 3) 3D point for each of the P tracks """ N, P, _ = points.shape assert proj_matrices.shape[0] == N if confidences is None: conf = jnp.ones((N, P), dtype=jnp.float32) else: conf = jnp.sqrt(jnp.clip(confidences, 0.0, 1.0)) # vectorize your one‐point routine over P vmap_triangulate = jax.vmap( triangulate_one_point_from_multiple_views_linear, in_axes=(None, 1, 1), # proj_matrices static, map over points[:,p,:], conf[:,p] out_axes=0, ) # returns (P, 3) return vmap_triangulate(proj_matrices, points, conf) # %% from dataclasses import dataclass from copy import copy as shallow_copy, deepcopy as deep_copy @jaxtyped(typechecker=beartype) @dataclass(frozen=True) class Tracking: id: int keypoints: Float[Array, "J 3"] last_active_timestamp: datetime def __repr__(self) -> str: return f"Tracking({self.id}, {self.last_active_timestamp})" @jaxtyped(typechecker=beartype) def triangle_from_cluster( cluster: list[Detection], ) -> tuple[Float[Array, "N 3"], datetime]: proj_matrices = jnp.array([el.camera.params.projection_matrix for el in cluster]) points = jnp.array([el.keypoints_undistorted for el in cluster]) confidences = jnp.array([el.confidences for el in cluster]) latest_timestamp = max(el.timestamp for el in cluster) return ( triangulate_points_from_multiple_views_linear( proj_matrices, points, confidences=confidences ), latest_timestamp, ) # res = { # "a": triangle_from_cluster(clusters_detections[0]).tolist(), # "b": triangle_from_cluster(clusters_detections[1]).tolist(), # } # with open("samples/res.json", "wb") as f: # f.write(orjson.dumps(res)) class GlobalTrackingState: _last_id: int _trackings: dict[int, Tracking] def __init__(self): self._last_id = 0 self._trackings = {} def __repr__(self) -> str: return ( f"GlobalTrackingState(last_id={self._last_id}, trackings={self._trackings})" ) @property def trackings(self) -> dict[int, Tracking]: return shallow_copy(self._trackings) def add_tracking(self, cluster: list[Detection]) -> Tracking: kps_3d, latest_timestamp = triangle_from_cluster(cluster) next_id = self._last_id + 1 tracking = Tracking( id=next_id, keypoints=kps_3d, last_active_timestamp=latest_timestamp ) self._trackings[next_id] = tracking self._last_id = next_id return tracking global_tracking_state = GlobalTrackingState() for cluster in clusters_detections: global_tracking_state.add_tracking(cluster) display(global_tracking_state) # %% next_group = next(sync_gen) display(next_group) # %% from app.camera import classify_by_camera # let's do cross-view association trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id) detections = shallow_copy(next_group) # cross-view association matrix with shape (T, D), where T is the number of trackings, D is the number of detections affinity = np.zeros((len(trackings), len(detections))) detection_by_camera = classify_by_camera(detections) for i, tracking in enumerate(trackings): for c, detections in detection_by_camera.items(): camera = next(iter(detections)).camera # pixel space, unnormalized tracking_2d_projection = camera.project(tracking.keypoints)