{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from copy import deepcopy\n", "from datetime import datetime, timedelta\n", "from pathlib import Path\n", "from typing import (Any, Generator, Optional, Sequence, TypeAlias, TypedDict,\n", " cast, overload)\n", "\n", "import awkward as ak\n", "import jax\n", "import jax.numpy as jnp\n", "import numpy as np\n", "import orjson\n", "from beartype import beartype\n", "from cv2 import undistortPoints\n", "from jaxtyping import Array, Float, Num, jaxtyped\n", "from matplotlib import pyplot as plt\n", "from numpy.typing import ArrayLike\n", "from scipy.spatial.transform import Rotation as R\n", "\n", "from app.camera import Camera, CameraParams, Detection\n", "from app.visualize.whole_body import visualize_whole_body\n", "\n", "NDArray: TypeAlias = np.ndarray" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "DATASET_PATH = Path(\"samples\") / \"04_02\" \n", "AK_CAMERA_DATASET: ak.Array = ak.from_parquet(DATASET_PATH / \"camera_params.parquet\")\n", "display(AK_CAMERA_DATASET)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "class Resolution(TypedDict):\n", " width: int\n", " height: int\n", "\n", "class Intrinsic(TypedDict):\n", " camera_matrix: Num[Array, \"3 3\"]\n", " \"\"\"\n", " K\n", " \"\"\"\n", " distortion_coefficients: Num[Array, \"N\"]\n", " \"\"\"\n", " distortion coefficients; usually 5\n", " \"\"\"\n", "\n", "class Extrinsic(TypedDict):\n", " rvec: Num[NDArray, \"3\"]\n", " tvec: Num[NDArray, \"3\"]\n", "\n", "class ExternalCameraParams(TypedDict):\n", " name: str\n", " port: int\n", " intrinsic: Intrinsic\n", " extrinsic: Extrinsic\n", " resolution: Resolution\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def read_dataset_by_port(port: int) -> ak.Array:\n", " P = DATASET_PATH / f\"{port}.parquet\"\n", " return ak.from_parquet(P)\n", "\n", "KEYPOINT_DATASET = {int(p): read_dataset_by_port(p) for p in ak.to_numpy(AK_CAMERA_DATASET[\"port\"])}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "KEYPOINT_DATASET[5601]" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "class KeypointDataset(TypedDict):\n", " frame_index: int\n", " boxes: Num[NDArray, \"N 4\"]\n", " kps: Num[NDArray, \"N J 2\"]\n", " kps_scores: Num[NDArray, \"N J\"]\n", "\n", "\n", "@jaxtyped(typechecker=beartype)\n", "def to_transformation_matrix(\n", " rvec: Num[NDArray, \"3\"], tvec: Num[NDArray, \"3\"]\n", ") -> Num[NDArray, \"4 4\"]:\n", " res = np.eye(4)\n", " res[:3, :3] = R.from_rotvec(rvec).as_matrix()\n", " res[:3, 3] = tvec\n", " return res\n", "\n", "\n", "@jaxtyped(typechecker=beartype)\n", "def undistort_points(\n", " points: Num[NDArray, \"M 2\"],\n", " camera_matrix: Num[NDArray, \"3 3\"],\n", " dist_coeffs: Num[NDArray, \"N\"],\n", ") -> Num[NDArray, \"M 2\"]:\n", " K = camera_matrix\n", " dist = dist_coeffs\n", " res = undistortPoints(points, K, dist, P=K) # type: ignore\n", " return res.reshape(-1, 2)\n", "\n", "\n", "def from_camera_params(camera: ExternalCameraParams) -> Camera:\n", " rt = jnp.array(\n", " to_transformation_matrix(\n", " ak.to_numpy(camera[\"extrinsic\"][\"rvec\"]),\n", " ak.to_numpy(camera[\"extrinsic\"][\"tvec\"]),\n", " )\n", " )\n", " K = jnp.array(camera[\"intrinsic\"][\"camera_matrix\"]).reshape(3, 3)\n", " dist_coeffs = jnp.array(camera[\"intrinsic\"][\"distortion_coefficients\"])\n", " image_size = jnp.array(\n", " (camera[\"resolution\"][\"width\"], camera[\"resolution\"][\"height\"])\n", " )\n", " return Camera(\n", " id=camera[\"name\"],\n", " params=CameraParams(\n", " K=K,\n", " Rt=rt,\n", " dist_coeffs=dist_coeffs,\n", " image_size=image_size,\n", " ),\n", " )\n", "\n", "\n", "def preprocess_keypoint_dataset(\n", " dataset: Sequence[KeypointDataset],\n", " camera: Camera,\n", " fps: float,\n", " start_timestamp: datetime,\n", ") -> Generator[Detection, None, None]:\n", " frame_interval_s = 1 / fps\n", " for el in dataset:\n", " frame_index = el[\"frame_index\"]\n", " timestamp = start_timestamp + timedelta(seconds=frame_index * frame_interval_s)\n", " for kp, kp_score in zip(el[\"kps\"], el[\"kps_scores\"]):\n", " yield Detection(\n", " keypoints=jnp.array(kp),\n", " confidences=jnp.array(kp_score),\n", " camera=camera,\n", " timestamp=timestamp,\n", " )" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "DetectionGenerator: TypeAlias = Generator[Detection, None, None]\n", "\n", "\n", "def sync_batch_gen(gens: list[DetectionGenerator], diff: timedelta):\n", " \"\"\"\n", " given a list of detection generators, return a generator that yields a batch of detections\n", "\n", " Args:\n", " gens: list of detection generators\n", " diff: maximum timestamp difference between detections to consider them part of the same batch\n", " \"\"\"\n", " N = len(gens)\n", " last_batch_timestamp: Optional[datetime] = None\n", " next_batch_timestamp: Optional[datetime] = None\n", " current_batch: list[Detection] = []\n", " next_batch: list[Detection] = []\n", " paused: list[bool] = [False] * N\n", " finished: list[bool] = [False] * N\n", "\n", " def reset_paused():\n", " \"\"\"\n", " reset paused list based on finished list\n", " \"\"\"\n", " for i in range(N):\n", " if not finished[i]:\n", " paused[i] = False\n", " else:\n", " paused[i] = True\n", "\n", " EPS = 1e-6\n", " # a small epsilon to avoid floating point precision issues\n", " diff_esp = diff - timedelta(seconds=EPS)\n", " while True:\n", " for i, gen in enumerate(gens):\n", " try:\n", " if finished[i] or paused[i]:\n", " continue\n", " val = next(gen)\n", " if last_batch_timestamp is None:\n", " last_batch_timestamp = val.timestamp\n", " current_batch.append(val)\n", " else:\n", " if abs(val.timestamp - last_batch_timestamp) >= diff_esp:\n", " next_batch.append(val)\n", " if next_batch_timestamp is None:\n", " next_batch_timestamp = val.timestamp\n", " paused[i] = True\n", " if all(paused):\n", " yield current_batch\n", " current_batch = next_batch\n", " next_batch = []\n", " last_batch_timestamp = next_batch_timestamp\n", " next_batch_timestamp = None\n", " reset_paused()\n", " else:\n", " current_batch.append(val)\n", " except StopIteration:\n", " finished[i] = True\n", " paused[i] = True\n", " if all(finished):\n", " if len(current_batch) > 0:\n", " # All generators exhausted, flush remaining batch and exit\n", " yield current_batch\n", " break" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "@overload\n", "def to_projection_matrix(\n", " transformation_matrix: Num[NDArray, \"4 4\"], camera_matrix: Num[NDArray, \"3 3\"]\n", ") -> Num[NDArray, \"3 4\"]: ...\n", "\n", "\n", "@overload\n", "def to_projection_matrix(\n", " transformation_matrix: Num[Array, \"4 4\"], camera_matrix: Num[Array, \"3 3\"]\n", ") -> Num[Array, \"3 4\"]: ...\n", "\n", "\n", "@jaxtyped(typechecker=beartype)\n", "def to_projection_matrix(\n", " transformation_matrix: Num[Any, \"4 4\"],\n", " camera_matrix: Num[Any, \"3 3\"],\n", ") -> Num[Any, \"3 4\"]:\n", " return camera_matrix @ transformation_matrix[:3, :]\n", "\n", "to_projection_matrix_jit = jax.jit(to_projection_matrix)\n", "\n", "\n", "@jaxtyped(typechecker=beartype)\n", "def dlt(\n", " H1: Num[NDArray, \"3 4\"],\n", " H2: Num[NDArray, \"3 4\"],\n", " p1: Num[NDArray, \"2\"],\n", " p2: Num[NDArray, \"2\"],\n", ") -> Num[NDArray, \"3\"]:\n", " \"\"\"\n", " Direct Linear Transformation\n", " \"\"\"\n", " A = [\n", " p1[1] * H1[2, :] - H1[1, :],\n", " H1[0, :] - p1[0] * H1[2, :],\n", " p2[1] * H2[2, :] - H2[1, :],\n", " H2[0, :] - p2[0] * H2[2, :],\n", " ]\n", " A = np.array(A).reshape((4, 4))\n", "\n", " B = A.transpose() @ A\n", " from scipy import linalg\n", "\n", " U, s, Vh = linalg.svd(B, full_matrices=False)\n", " return Vh[3, 0:3] / Vh[3, 3]\n", "\n", "\n", "@overload\n", "def homogeneous_to_euclidean(points: Num[NDArray, \"N 4\"]) -> Num[NDArray, \"N 3\"]: ...\n", "\n", "\n", "@overload\n", "def homogeneous_to_euclidean(points: Num[Array, \"N 4\"]) -> Num[Array, \"N 3\"]: ...\n", "\n", "\n", "@jaxtyped(typechecker=beartype)\n", "def homogeneous_to_euclidean(\n", " points: Num[Any, \"N 4\"],\n", ") -> Num[Any, \"N 3\"]:\n", " \"\"\"\n", " 将齐次坐标转换为欧几里得坐标\n", "\n", " Args:\n", " points: homogeneous coordinates (x, y, z, w) in numpy array or jax array\n", "\n", " Returns:\n", " euclidean coordinates (x, y, z) in numpy array or jax array\n", " \"\"\"\n", " return points[..., :-1] / points[..., -1:]\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "FPS = 24\n", "image_gen_5600 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5600], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET[\"port\"] == 5600][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore\n", "image_gen_5601 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5601], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET[\"port\"] == 5601][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore\n", "image_gen_5602 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5602], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET[\"port\"] == 5602][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore\n", "\n", "display(1/FPS)\n", "sync_gen = sync_batch_gen([image_gen_5600, image_gen_5601, image_gen_5602], timedelta(seconds=1/FPS))" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "detections = next(sync_gen)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "from app.camera import calculate_affinity_matrix_by_epipolar_constraint\n", "\n", "sorted_detections, affinity_matrix = calculate_affinity_matrix_by_epipolar_constraint(detections, \n", " alpha_2d=2000)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "display(list(map(lambda x: {\"timestamp\": str(x.timestamp), \"camera\": x.camera.id}, sorted_detections)))\n", "with jnp.printoptions(precision=3, suppress=True):\n", " display(affinity_matrix)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from app.solver._old import GLPKSolver\n", "\n", "def clusters_to_detections(clusters: list[list[int]], sorted_detections: list[Detection]) -> list[list[Detection]]:\n", " \"\"\"\n", " given a list of clusters (which is the indices of the detections in the sorted_detections list),\n", " extract the detections from the sorted_detections list\n", "\n", " Args:\n", " clusters: list of clusters, each cluster is a list of indices of the detections in the `sorted_detections` list\n", " sorted_detections: list of SORTED detections\n", "\n", " Returns:\n", " list of clusters, each cluster is a list of detections\n", " \"\"\"\n", " return [[sorted_detections[i] for i in cluster] for cluster in clusters]\n", "\n", "solver = GLPKSolver()\n", "aff_np = np.asarray(affinity_matrix).astype(np.float64)\n", "clusters, sol_matrix = solver.solve(aff_np)\n", "display(clusters)\n", "display(sol_matrix)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "WIDTH = 2560\n", "HEIGHT = 1440\n", "\n", "clusters_detections = clusters_to_detections(clusters, sorted_detections)\n", "im = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8)\n", "for el in clusters_detections[0]:\n", " im = visualize_whole_body(np.asarray(el.keypoints), im)\n", "\n", "p = plt.imshow(im)\n", "display(p)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "im_prime = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8)\n", "for el in clusters_detections[1]:\n", " im_prime = visualize_whole_body(np.asarray(el.keypoints), im_prime)\n", "\n", "p_prime = plt.imshow(im_prime)\n", "display(p_prime)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "@jaxtyped(typechecker=beartype)\n", "def triangulate_one_point_from_multiple_views_linear(\n", " proj_matrices: Float[Array, \"N 3 4\"],\n", " points: Num[Array, \"N 2\"],\n", " confidences: Optional[Float[Array, \"N\"]] = None,\n", ") -> Float[Array, \"3\"]:\n", " \"\"\"\n", " Args:\n", " proj_matrices: 形状为(N, 3, 4)的投影矩阵序列\n", " points: 形状为(N, 2)的点坐标序列\n", " confidences: 形状为(N,)的置信度序列,范围[0.0, 1.0]\n", "\n", " Returns:\n", " point_3d: 形状为(3,)的三角测量得到的3D点\n", " \"\"\"\n", " assert len(proj_matrices) == len(points)\n", "\n", " N = len(proj_matrices)\n", " confi: Float[Array, \"N\"]\n", " if confidences is None:\n", " confi = jnp.ones(N, dtype=np.float32)\n", " else:\n", " # Use square root of confidences for weighting - more balanced approach\n", " confi = jnp.sqrt(jnp.clip(confidences, 0, 1))\n", "\n", " A = jnp.zeros((N * 2, 4), dtype=np.float32)\n", " for i in range(N):\n", " x, y = points[i]\n", " A = A.at[2 * i].set(proj_matrices[i, 2] * x - proj_matrices[i, 0])\n", " A = A.at[2 * i + 1].set(proj_matrices[i, 2] * y - proj_matrices[i, 1])\n", " A = A.at[2 * i].mul(confi[i])\n", " A = A.at[2 * i + 1].mul(confi[i])\n", "\n", " # https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.svd.html\n", " _, _, vh = jnp.linalg.svd(A, full_matrices=False)\n", " point_3d_homo = vh[-1] # shape (4,)\n", "\n", " # replace the Python `if` with a jnp.where\n", " point_3d_homo = jnp.where(\n", " point_3d_homo[3] < 0, # predicate (scalar bool tracer)\n", " -point_3d_homo, # if True\n", " point_3d_homo, # if False\n", " )\n", "\n", " point_3d = point_3d_homo[:3] / point_3d_homo[3]\n", " return point_3d\n", "\n", "\n", "@jaxtyped(typechecker=beartype)\n", "def triangulate_points_from_multiple_views_linear(\n", " proj_matrices: Float[Array, \"N 3 4\"],\n", " points: Num[Array, \"N P 2\"],\n", " confidences: Optional[Float[Array, \"N P\"]] = None,\n", ") -> Float[Array, \"P 3\"]:\n", " \"\"\"\n", " Batch-triangulate P points observed by N cameras, linearly via SVD.\n", "\n", " Args:\n", " proj_matrices: (N, 3, 4) projection matrices\n", " points: (N, P, 2) image-coordinates per view\n", " confidences: (N, P, 1) optional per-view confidences in [0,1]\n", "\n", " Returns:\n", " (P, 3) 3D point for each of the P tracks\n", " \"\"\"\n", " N, P, _ = points.shape\n", " assert proj_matrices.shape[0] == N\n", " if confidences is None:\n", " conf = jnp.ones((N, P), dtype=jnp.float32)\n", " else:\n", " conf = jnp.sqrt(jnp.clip(confidences, 0.0, 1.0))\n", "\n", " # vectorize your one‐point routine over P\n", " vmap_triangulate = jax.vmap(\n", " triangulate_one_point_from_multiple_views_linear,\n", " in_axes=(None, 1, 1), # proj_matrices static, map over points[:,p,:], conf[:,p]\n", " out_axes=0,\n", " )\n", "\n", " # returns (P, 3)\n", " return vmap_triangulate(proj_matrices, points, conf)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from dataclasses import dataclass\n", "from copy import copy as shallow_copy, deepcopy as deep_copy\n", "\n", "\n", "@jaxtyped(typechecker=beartype)\n", "@dataclass\n", "class Tracking:\n", " id: int\n", " keypoints: Float[Array, \"J 3\"]\n", "\n", "\n", "@jaxtyped(typechecker=beartype)\n", "def triangle_from_cluster(cluster: list[Detection]) -> Float[Array, \"N 3\"]:\n", " proj_matrices = jnp.array([el.camera.params.projection_matrix for el in cluster])\n", " points = jnp.array([el.keypoints_undistorted for el in cluster])\n", " confidences = jnp.array([el.confidences for el in cluster])\n", " return triangulate_points_from_multiple_views_linear(\n", " proj_matrices, points, confidences=confidences\n", " )\n", "\n", "# res = {\n", "# \"a\": triangle_from_cluster(clusters_detections[0]).tolist(),\n", "# \"b\": triangle_from_cluster(clusters_detections[1]).tolist(),\n", "# }\n", "# with open(\"samples/res.json\", \"wb\") as f:\n", "# f.write(orjson.dumps(res))\n", "\n", "class GlobalTrackingState:\n", " _last_id: int\n", " _trackings: list[Tracking]\n", "\n", " def __init__(self):\n", " self._last_id = 0\n", " self._trackings = []\n", "\n", " @property\n", " def trackings(self) -> list[Tracking]:\n", " return shallow_copy(self._trackings)\n", "\n", " def add_tracking(self, cluster: list[Detection]) -> Tracking:\n", " tracking = Tracking(id=self._last_id, keypoints=triangle_from_cluster(cluster))\n", " self._last_id += 1\n", " self._trackings.append(tracking)\n", " return tracking\n", "\n", "\n", "global_tracking_state = GlobalTrackingState()\n", "for cluster in clusters_detections:\n", " global_tracking_state.add_tracking(cluster)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "next_group = next(sync_gen)\n", "display(next_group)" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.9" } }, "nbformat": 4, "nbformat_minor": 2 }