{ "cells": [ { "cell_type": "code", "execution_count": 290, "metadata": {}, "outputs": [], "source": [ "from copy import deepcopy\n", "from datetime import datetime, timedelta\n", "from pathlib import Path\n", "from typing import (\n", " Any,\n", " Generator,\n", " Optional,\n", " Sequence,\n", " TypeAlias,\n", " TypedDict,\n", " cast,\n", " overload,\n", ")\n", "\n", "import awkward as ak\n", "import jax\n", "import jax.numpy as jnp\n", "import numpy as np\n", "import orjson\n", "from beartype import beartype\n", "from cv2 import undistortPoints\n", "from jaxtyping import Array, Float, Num, jaxtyped\n", "from matplotlib import pyplot as plt\n", "from numpy.typing import ArrayLike\n", "from scipy.spatial.transform import Rotation as R\n", "\n", "from app.camera import Camera, CameraParams, Detection\n", "from app.visualize.whole_body import visualize_whole_body\n", "from filter_object_by_box import *\n", "NDArray: TypeAlias = np.ndarray" ] }, { "cell_type": "code", "execution_count": 291, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
[{name: 'AF_01', port: 5601, intrinsic: {...}, extrinsic: {...}, ...},\n",
       " {name: 'AF_02', port: 5602, intrinsic: {...}, extrinsic: {...}, ...},\n",
       " {name: 'AF_03', port: 5603, intrinsic: {...}, extrinsic: {...}, ...},\n",
       " {name: 'AF_04', port: 5604, intrinsic: {...}, extrinsic: {...}, ...},\n",
       " {name: 'AF_05', port: 5605, intrinsic: {...}, extrinsic: {...}, ...},\n",
       " {name: 'AF_06', port: 5606, intrinsic: {...}, extrinsic: {...}, ...},\n",
       " {name: 'AE_01', port: 5607, intrinsic: {...}, extrinsic: {...}, ...},\n",
       " {name: 'AE_1A', port: 5608, intrinsic: {...}, extrinsic: {...}, ...},\n",
       " {name: 'AE_08', port: 5609, intrinsic: {...}, extrinsic: {...}, ...}]\n",
       "------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n",
       "backend: cpu\n",
       "nbytes: 2.3 kB\n",
       "type: 9 * {\n",
       "    name: string,\n",
       "    port: int64,\n",
       "    intrinsic: {\n",
       "        camera_matrix: var * var * float64,\n",
       "        distortion_coefficients: var * float64\n",
       "    },\n",
       "    extrinsic: {\n",
       "        rvec: var * float64,\n",
       "        tvec: var * float64\n",
       "    },\n",
       "    resolution: {\n",
       "        width: int64,\n",
       "        height: int64\n",
       "    }\n",
       "}
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "CAMERA_PATH = Path(\n", " \"/home/admin/Documents/ActualTest_QuanCheng/camera_ex_params_1_2025_4_20/camera_params\"\n", ")\n", "AK_CAMERA_DATASET: ak.Array = ak.from_parquet(CAMERA_PATH / \"camera_params.parquet\")\n", "display(AK_CAMERA_DATASET)" ] }, { "cell_type": "code", "execution_count": 292, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "np.int64(5604)" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
{rvec: [-2.24, 0.0917, -2.14],\n",
       " tvec: [0.165, 0.217, 5.12]}\n",
       "----------------------------------------------------------\n",
       "backend: cpu\n",
       "nbytes: 592 B\n",
       "type: {\n",
       "    rvec: var * float64,\n",
       "    tvec: var * float64\n",
       "}
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "camera_data_5604 = AK_CAMERA_DATASET[3]\n", "display(camera_data_5604[\"port\"])\n", "display(camera_data_5604[\"extrinsic\"])" ] }, { "cell_type": "code", "execution_count": 293, "metadata": {}, "outputs": [], "source": [ "class Resolution(TypedDict):\n", " width: int\n", " height: int\n", "\n", "\n", "class Intrinsic(TypedDict):\n", " camera_matrix: Num[Array, \"3 3\"]\n", " \"\"\"\n", " K\n", " \"\"\"\n", " distortion_coefficients: Num[Array, \"N\"]\n", " \"\"\"\n", " distortion coefficients; usually 5\n", " \"\"\"\n", "\n", "\n", "class Extrinsic(TypedDict):\n", " rvec: Num[NDArray, \"3\"]\n", " tvec: Num[NDArray, \"3\"]\n", "\n", "\n", "class ExternalCameraParams(TypedDict):\n", " name: str\n", " port: int\n", " intrinsic: Intrinsic\n", " extrinsic: Extrinsic\n", " resolution: Resolution" ] }, { "cell_type": "code", "execution_count": 294, "metadata": {}, "outputs": [], "source": [ "DATASET_PATH = Path(\n", " \"/home/admin/Documents/ActualTest_QuanCheng/camera_ex_params_1_2025_4_20/detect_result/segement_1\"\n", ")\n", "\n", "\n", "def read_dataset_by_port(port: int) -> ak.Array:\n", " P = DATASET_PATH / f\"{port}.parquet\"\n", " P = DATASET_PATH / f\"filter_{port}.parquet\"\n", " return ak.from_parquet(P)\n", "\n", "\n", "KEYPOINT_DATASET = {\n", " int(p): read_dataset_by_port(p) for p in ak.to_numpy(AK_CAMERA_DATASET[\"port\"]) if p in [5603, 5605, 5608, 5609]\n", "}" ] }, { "cell_type": "code", "execution_count": 295, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{5603: ,\n", " 5605: ,\n", " 5608: ,\n", " 5609: }" ] }, "execution_count": 295, "metadata": {}, "output_type": "execute_result" } ], "source": [ "KEYPOINT_DATASET" ] }, { "cell_type": "code", "execution_count": 296, "metadata": {}, "outputs": [], "source": [ "class KeypointDataset(TypedDict):\n", " frame_index: int\n", " boxes: Num[NDArray, \"N 4\"]\n", " kps: Num[NDArray, \"N J 2\"]\n", " kps_scores: Num[NDArray, \"N J\"]\n", "\n", "\n", "@jaxtyped(typechecker=beartype)\n", "def to_transformation_matrix(\n", " rvec: Num[NDArray, \"3\"], tvec: Num[NDArray, \"3\"]\n", ") -> Num[NDArray, \"4 4\"]:\n", " res = np.eye(4)\n", " res[:3, :3] = R.from_rotvec(rvec).as_matrix()\n", " res[:3, 3] = tvec\n", " return res\n", "\n", "\n", "@jaxtyped(typechecker=beartype)\n", "def undistort_points(\n", " points: Num[NDArray, \"M 2\"],\n", " camera_matrix: Num[NDArray, \"3 3\"],\n", " dist_coeffs: Num[NDArray, \"N\"],\n", ") -> Num[NDArray, \"M 2\"]:\n", " K = camera_matrix\n", " dist = dist_coeffs\n", " res = undistortPoints(points, K, dist, P=K) # type: ignore\n", " return res.reshape(-1, 2)\n", "\n", "\n", "def from_camera_params(camera: ExternalCameraParams) -> Camera:\n", " rt = jnp.array(\n", " to_transformation_matrix(\n", " ak.to_numpy(camera[\"extrinsic\"][\"rvec\"]),\n", " ak.to_numpy(camera[\"extrinsic\"][\"tvec\"]),\n", " )\n", " )\n", " K = jnp.array(camera[\"intrinsic\"][\"camera_matrix\"]).reshape(3, 3)\n", " dist_coeffs = jnp.array(camera[\"intrinsic\"][\"distortion_coefficients\"])\n", " image_size = jnp.array(\n", " (camera[\"resolution\"][\"width\"], camera[\"resolution\"][\"height\"])\n", " )\n", " return Camera(\n", " id=camera[\"name\"],\n", " params=CameraParams(\n", " K=K,\n", " Rt=rt,\n", " dist_coeffs=dist_coeffs,\n", " image_size=image_size,\n", " ),\n", " )\n", "\n", "\n", "def preprocess_keypoint_dataset(\n", " dataset: Sequence[KeypointDataset],\n", " camera: Camera,\n", " fps: float,\n", " start_timestamp: datetime,\n", ") -> Generator[Detection, None, None]:\n", " frame_interval_s = 1 / fps\n", " for el in dataset:\n", " frame_index = el[\"frame_index\"]\n", " timestamp = start_timestamp + timedelta(seconds=frame_index * frame_interval_s)\n", " for kp, kp_score, boxes in zip(el[\"kps\"], el[\"kps_scores\"], el[\"boxes\"]):\n", " kp = undistort_points(\n", " np.asarray(kp),\n", " np.asarray(camera.params.K),\n", " np.asarray(camera.params.dist_coeffs),\n", " )\n", "\n", " yield Detection(\n", " keypoints=jnp.array(kp),\n", " confidences=jnp.array(kp_score),\n", " camera=camera,\n", " timestamp=timestamp,\n", " )" ] }, { "cell_type": "code", "execution_count": 297, "metadata": {}, "outputs": [], "source": [ "from typing import Any, Generator\n", "\n", "\n", "from app.camera import Detection\n", "\n", "\n", "DetectionGenerator: TypeAlias = Generator[Detection, None, None]\n", "\n", "\n", "def sync_batch_gen(gens: list[DetectionGenerator], diff: timedelta) -> Generator[list[Detection], Any, None]:\n", " \"\"\"\n", " given a list of detection generators, return a generator that yields a batch of detections\n", "\n", " Args:\n", " gens: list of detection generators\n", " diff: maximum timestamp difference between detections to consider them part of the same batch\n", " \"\"\"\n", " N = len(gens)\n", " last_batch_timestamp: Optional[datetime] = None\n", " next_batch_timestamp: Optional[datetime] = None\n", " current_batch: list[Detection] = []\n", " next_batch: list[Detection] = []\n", " paused: list[bool] = [False] * N\n", " finished: list[bool] = [False] * N\n", "\n", " def reset_paused():\n", " \"\"\"\n", " reset paused list based on finished list\n", " \"\"\"\n", " for i in range(N):\n", " if not finished[i]:\n", " paused[i] = False\n", " else:\n", " paused[i] = True\n", "\n", " EPS = 1e-6\n", " # a small epsilon to avoid floating point precision issues\n", " diff_esp = diff - timedelta(seconds=EPS)\n", " while True:\n", " for i, gen in enumerate(gens):\n", " try:\n", " if finished[i] or paused[i]:\n", " continue\n", " val = next(gen)\n", " if last_batch_timestamp is None:\n", " last_batch_timestamp = val.timestamp\n", " current_batch.append(val)\n", " else:\n", " if abs(val.timestamp - last_batch_timestamp) >= diff_esp:\n", " next_batch.append(val)\n", " if next_batch_timestamp is None:\n", " next_batch_timestamp = val.timestamp\n", " paused[i] = True\n", " if all(paused):\n", " yield current_batch\n", " current_batch = next_batch\n", " next_batch = []\n", " last_batch_timestamp = next_batch_timestamp\n", " next_batch_timestamp = None\n", " reset_paused()\n", " else:\n", " current_batch.append(val)\n", " except StopIteration:\n", " finished[i] = True\n", " paused[i] = True\n", " if all(finished):\n", " if len(current_batch) > 0:\n", " # All generators exhausted, flush remaining batch and exit\n", " yield current_batch\n", " break" ] }, { "cell_type": "code", "execution_count": 298, "metadata": {}, "outputs": [], "source": [ "@overload\n", "def to_projection_matrix(\n", " transformation_matrix: Num[NDArray, \"4 4\"], camera_matrix: Num[NDArray, \"3 3\"]\n", ") -> Num[NDArray, \"3 4\"]: ...\n", "\n", "\n", "@overload\n", "def to_projection_matrix(\n", " transformation_matrix: Num[Array, \"4 4\"], camera_matrix: Num[Array, \"3 3\"]\n", ") -> Num[Array, \"3 4\"]: ...\n", "\n", "\n", "@jaxtyped(typechecker=beartype)\n", "def to_projection_matrix(\n", " transformation_matrix: Num[Any, \"4 4\"],\n", " camera_matrix: Num[Any, \"3 3\"],\n", ") -> Num[Any, \"3 4\"]:\n", " return camera_matrix @ transformation_matrix[:3, :]\n", "\n", "\n", "to_projection_matrix_jit = jax.jit(to_projection_matrix)\n", "\n", "\n", "@jaxtyped(typechecker=beartype)\n", "def dlt(\n", " H1: Num[NDArray, \"3 4\"],\n", " H2: Num[NDArray, \"3 4\"],\n", " p1: Num[NDArray, \"2\"],\n", " p2: Num[NDArray, \"2\"],\n", ") -> Num[NDArray, \"3\"]:\n", " \"\"\"\n", " Direct Linear Transformation\n", " \"\"\"\n", " A = [\n", " p1[1] * H1[2, :] - H1[1, :],\n", " H1[0, :] - p1[0] * H1[2, :],\n", " p2[1] * H2[2, :] - H2[1, :],\n", " H2[0, :] - p2[0] * H2[2, :],\n", " ]\n", " A = np.array(A).reshape((4, 4))\n", "\n", " B = A.transpose() @ A\n", " from scipy import linalg\n", "\n", " U, s, Vh = linalg.svd(B, full_matrices=False)\n", " return Vh[3, 0:3] / Vh[3, 3]\n", "\n", "\n", "@overload\n", "def homogeneous_to_euclidean(points: Num[NDArray, \"N 4\"]) -> Num[NDArray, \"N 3\"]: ...\n", "\n", "\n", "@overload\n", "def homogeneous_to_euclidean(points: Num[Array, \"N 4\"]) -> Num[Array, \"N 3\"]: ...\n", "\n", "\n", "@jaxtyped(typechecker=beartype)\n", "def homogeneous_to_euclidean(\n", " points: Num[Any, \"N 4\"],\n", ") -> Num[Any, \"N 3\"]:\n", " \"\"\"\n", " 将齐次坐标转换为欧几里得坐标\n", "\n", " Args:\n", " points: homogeneous coordinates (x, y, z, w) in numpy array or jax array\n", "\n", " Returns:\n", " euclidean coordinates (x, y, z) in numpy array or jax array\n", " \"\"\"\n", " return points[..., :-1] / points[..., -1:]" ] }, { "cell_type": "code", "execution_count": 299, "metadata": {}, "outputs": [], "source": [ "camera_list = [\n", " 5601,\n", " 5602,\n", " 5603,\n", " 5604,\n", " 5605,\n", " 5606,\n", " 5607,\n", " 5608,\n", " 5609,\n", "]\n", "# compute camera extrinsic matrix and intrinsic matrix \n", "cameras = list(map(lambda x: from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET[\"port\"] == x][0]),camera_list))\n", "\n", "for element_camera in cameras:\n", " with jnp.printoptions(precision=4, suppress=True):\n", " # display(element_camera)\n", " # display(element_camera.params.Rt.reshape(-1))\n", " # display(element_camera.params.K.reshape(-1))\n", "\n", " # compute camera to object point distance\n", " transistion = element_camera.params.pose_matrix[:3, -1]\n", " # display(transistion)\n", " # display(jnp.linalg.norm(transistion).item())" ] }, { "cell_type": "code", "execution_count": 300, "metadata": {}, "outputs": [], "source": [ "@jaxtyped(typechecker=beartype)\n", "def triangulate_one_point_from_multiple_views_linear(\n", " proj_matrices: Float[Array, \"N 3 4\"],\n", " points: Num[Array, \"N 2\"],\n", " confidences: Optional[Float[Array, \"N\"]] = None,\n", ") -> Float[Array, \"3\"]:\n", " \"\"\"\n", " Args:\n", " proj_matrices: 形状为(N, 3, 4)的投影矩阵序列\n", " points: 形状为(N, 2)的点坐标序列\n", " confidences: 形状为(N,)的置信度序列,范围[0.0, 1.0]\n", "\n", " Returns:\n", " point_3d: 形状为(3,)的三角测量得到的3D点\n", " \"\"\"\n", " assert len(proj_matrices) == len(points)\n", "\n", " N = len(proj_matrices)\n", " confi: Float[Array, \"N\"]\n", "\n", " if confidences is None:\n", " confi = jnp.ones(N, dtype=np.float32)\n", " else:\n", " # Use square root of confidences for weighting - more balanced approach\n", " confi = jnp.sqrt(jnp.clip(confidences, 0, 1))\n", "\n", " # 将置信度小于0.1点的置信度均设置为0\n", " # valid_mask = confidences >= 0.1\n", " # confi = jnp.sqrt(jnp.clip(confidences * valid_mask, 0.0, 1.0))\n", " \n", "\n", " A = jnp.zeros((N * 2, 4), dtype=np.float32)\n", " for i in range(N):\n", " x, y = points[i]\n", " A = A.at[2 * i].set(proj_matrices[i, 2] * x - proj_matrices[i, 0])\n", " A = A.at[2 * i + 1].set(proj_matrices[i, 2] * y - proj_matrices[i, 1])\n", " A = A.at[2 * i].mul(confi[i])\n", " A = A.at[2 * i + 1].mul(confi[i])\n", "\n", " # https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.svd.html\n", " _, _, vh = jnp.linalg.svd(A, full_matrices=False)\n", " point_3d_homo = vh[-1] # shape (4,)\n", "\n", " # replace the Python `if` with a jnp.where\n", " point_3d_homo = jnp.where(\n", " point_3d_homo[3] < 0, # predicate (scalar bool tracer)\n", " -point_3d_homo, # if True\n", " point_3d_homo, # if False\n", " )\n", "\n", " point_3d = point_3d_homo[:3] / point_3d_homo[3]\n", " return point_3d\n", "\n", "\n", "@jaxtyped(typechecker=beartype)\n", "def triangulate_points_from_multiple_views_linear(\n", " proj_matrices: Float[Array, \"N 3 4\"],\n", " points: Num[Array, \"N P 2\"],\n", " confidences: Optional[Float[Array, \"N P\"]] = None,\n", ") -> Float[Array, \"P 3\"]:\n", " \"\"\"\n", " Batch‐triangulate P points observed by N cameras, linearly via SVD.\n", "\n", " Args:\n", " proj_matrices: (N, 3, 4) projection matrices\n", " points: (N, P, 2) image-coordinates per view\n", " confidences: (N, P, 1) optional per-view confidences in [0,1]\n", "\n", " Returns:\n", " (P, 3) 3D point for each of the P tracks\n", " \"\"\"\n", " N, P, _ = points.shape\n", " assert proj_matrices.shape[0] == N\n", " \n", " conf = jnp.array(confidences)\n", " \n", " # vectorize your one‐point routine over P\n", " vmap_triangulate = jax.vmap(\n", " triangulate_one_point_from_multiple_views_linear,\n", " in_axes=(None, 1, 1), # proj_matrices static, map over points[:,p,:], conf[:,p]\n", " out_axes=0,\n", " )\n", "\n", " # returns (P, 3)\n", " return vmap_triangulate(proj_matrices, points, conf)" ] }, { "cell_type": "code", "execution_count": 301, "metadata": {}, "outputs": [], "source": [ "FPS = 24\n", "# image_gen_5601 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5601], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET[\"port\"] == 5601][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore\n", "# image_gen_5602 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5602], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET[\"port\"] == 5602][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore\n", "image_gen_5603 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5603], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET[\"port\"] == 5603][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore\n", "# image_gen_5604 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5604], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET[\"port\"] == 5604][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore\n", "image_gen_5605 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5605], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET[\"port\"] == 5605][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore\n", "# image_gen_5606 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5606], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET[\"port\"] == 5606][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore\n", "# image_gen_5607 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5607], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET[\"port\"] == 5607][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore\n", "image_gen_5608 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5608], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET[\"port\"] == 5608][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore\n", "image_gen_5609 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5609], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET[\"port\"] == 5609][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore\n", "\n" ] }, { "cell_type": "code", "execution_count": 302, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.041666666666666664" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "display(1 / FPS)\n", "sync_gen = sync_batch_gen(\n", " [\n", " # image_gen_5601,\n", " # image_gen_5602,\n", " image_gen_5603,\n", " # image_gen_5604,\n", " image_gen_5605,\n", " # image_gen_5607,\n", " # image_gen_5606,\n", " image_gen_5608,\n", " image_gen_5609\n", " ],\n", " timedelta(seconds=1 / FPS),\n", ")" ] }, { "cell_type": "code", "execution_count": 303, "metadata": {}, "outputs": [], "source": [ "detections = next(sync_gen)" ] }, { "cell_type": "code", "execution_count": 304, "metadata": {}, "outputs": [], "source": [ "from app.tracking import AffinityResult, Tracking\n", "from copy import copy as shallow_copy\n", "from pyrsistent import v, pvector\n", "from beartype.typing import Mapping, Sequence\n", "from app.camera import (\n", " Camera,\n", " CameraID,\n", " CameraParams,\n", " Detection,\n", " calculate_affinity_matrix_by_epipolar_constraint,\n", " classify_by_camera,\n", ")\n", "\n", "from optax.assignment import hungarian_algorithm as linear_sum_assignment\n", "from itertools import chain\n", "\n", "DELTA_T_MIN = timedelta(milliseconds=10)\n", "\n", "class GlobalTrackingState:\n", " _last_id: int\n", " _trackings: dict[int, Tracking]\n", "\n", " def __init__(self):\n", " self._last_id = 0\n", " self._trackings = {}\n", "\n", " def __repr__(self) -> str:\n", " return (\n", " f\"GlobalTrackingState(last_id={self._last_id}, trackings={self._trackings})\"\n", " )\n", "\n", " @property\n", " def trackings(self) -> dict[int, Tracking]:\n", " return shallow_copy(self._trackings)\n", "\n", " def add_tracking(self, cluster: Sequence[Detection]) -> Tracking:\n", " kps_3d, latest_timestamp = triangle_from_cluster(cluster)\n", " next_id = self._last_id + 1\n", " tracking = Tracking(\n", " id=next_id,\n", " keypoints=kps_3d,\n", " last_active_timestamp=latest_timestamp,\n", " historical_detections=v(*cluster),\n", " )\n", " self._trackings[next_id] = tracking\n", " self._last_id = next_id\n", " return tracking\n", " \n", "@jaxtyped(typechecker=beartype)\n", "def triangle_from_cluster(\n", " cluster: Sequence[Detection],\n", ") -> tuple[Float[Array, \"N 3\"], datetime]:\n", " proj_matrices = jnp.array([el.camera.params.projection_matrix for el in cluster])\n", " points = jnp.array([el.keypoints_undistorted for el in cluster])\n", " confidences = jnp.array([el.confidences for el in cluster])\n", " latest_timestamp = max(el.timestamp for el in cluster)\n", " return (\n", " triangulate_points_from_multiple_views_linear(\n", " proj_matrices, points, confidences=confidences\n", " ),\n", " latest_timestamp,\n", " )\n", "\n", "@beartype\n", "def calculate_camera_affinity_matrix_jax(\n", " trackings: Sequence[Tracking],\n", " camera_detections: Sequence[Detection],\n", " w_2d: float,\n", " alpha_2d: float,\n", " w_3d: float,\n", " alpha_3d: float,\n", " lambda_a: float,\n", ") -> Float[Array, \"T D\"]:\n", " \"\"\"\n", " Vectorized implementation to compute an affinity matrix between *trackings*\n", " and *detections* coming from **one** camera.\n", "\n", " Compared with the simple double-for-loop version, this leverages `jax`'s\n", " broadcasting + `vmap` facilities and avoids Python loops over every\n", " (tracking, detection) pair. The mathematical definition of the affinity\n", " is **unchanged**, so the result remains bit-identical to the reference\n", " implementation used in the tests.\n", " \"\"\"\n", "\n", " # ------------------------------------------------------------------\n", " # Quick validations / early-exit guards\n", " # ------------------------------------------------------------------\n", " if len(trackings) == 0 or len(camera_detections) == 0:\n", " # Return an empty affinity matrix with appropriate shape.\n", " return jnp.zeros((len(trackings), len(camera_detections))) # type: ignore[return-value]\n", "\n", " cam = next(iter(camera_detections)).camera\n", " # Ensure every detection truly belongs to the same camera (guard clause)\n", " cam_id = cam.id\n", " if any(det.camera.id != cam_id for det in camera_detections):\n", " raise ValueError(\n", " \"All detections passed to `calculate_camera_affinity_matrix` must come from one camera.\"\n", " )\n", "\n", " # We will rely on a single `Camera` instance (all detections share it)\n", " w_img_, h_img_ = cam.params.image_size\n", " w_img, h_img = float(w_img_), float(h_img_)\n", "\n", " # ------------------------------------------------------------------\n", " # Gather data into ndarray / DeviceArray batches so that we can compute\n", " # everything in a single (or a few) fused kernels.\n", " # ------------------------------------------------------------------\n", "\n", " # === Tracking-side tensors ===\n", " kps3d_trk: Float[Array, \"T J 3\"] = jnp.stack(\n", " [trk.keypoints for trk in trackings]\n", " ) # (T, J, 3)\n", " J = kps3d_trk.shape[1]\n", " # === Detection-side tensors ===\n", " kps2d_det: Float[Array, \"D J 2\"] = jnp.stack(\n", " [det.keypoints for det in camera_detections]\n", " ) # (D, J, 2)\n", "\n", " # ------------------------------------------------------------------\n", " # Compute Δt matrix – shape (T, D)\n", " # ------------------------------------------------------------------\n", " # Epoch timestamps are ~1.7 × 10⁹; storing them in float32 wipes out\n", " # sub‑second detail (resolution ≈ 200 ms). Keep them in float64 until\n", " # after subtraction so we preserve Δt‑on‑the‑order‑of‑milliseconds.\n", " # --- timestamps ----------\n", " t0 = min(\n", " chain(\n", " (trk.last_active_timestamp for trk in trackings),\n", " (det.timestamp for det in camera_detections),\n", " )\n", " ).timestamp() # common origin (float)\n", " ts_trk = jnp.array(\n", " [trk.last_active_timestamp.timestamp() - t0 for trk in trackings],\n", " dtype=jnp.float32, # now small, ms-scale fits in fp32\n", " )\n", " ts_det = jnp.array(\n", " [det.timestamp.timestamp() - t0 for det in camera_detections],\n", " dtype=jnp.float32,\n", " )\n", " # Δt in seconds, fp32 throughout\n", " delta_t = ts_det[None, :] - ts_trk[:, None] # (T,D)\n", " min_dt_s = float(DELTA_T_MIN.total_seconds())\n", " delta_t = jnp.clip(delta_t, a_min=min_dt_s, a_max=None)\n", "\n", " # ------------------------------------------------------------------\n", " # ---------- 2D affinity -------------------------------------------\n", " # ------------------------------------------------------------------\n", " # Project each tracking's 3D keypoints onto the image once.\n", " # `Camera.project` works per-sample, so we vmap over the first axis.\n", "\n", " proj_fn = jax.vmap(cam.project, in_axes=0) # maps over the keypoint sets\n", " kps2d_trk_proj: Float[Array, \"T J 2\"] = proj_fn(kps3d_trk) # (T, J, 2)\n", "\n", " # Normalise keypoints by image size so absolute units do not bias distance\n", " norm_trk = kps2d_trk_proj / jnp.array([w_img, h_img])\n", " norm_det = kps2d_det / jnp.array([w_img, h_img])\n", "\n", " # L2 distance for every (T, D, J)\n", " # reshape for broadcasting: (T,1,J,2) vs (1,D,J,2)\n", " diff2d = norm_trk[:, None, :, :] - norm_det[None, :, :, :]\n", " dist2d: Float[Array, \"T D J\"] = jnp.linalg.norm(diff2d, axis=-1)\n", "\n", " # Compute per-keypoint 2D affinity\n", " delta_t_broadcast = delta_t[:, :, None] # (T, D, 1)\n", " affinity_2d = (\n", " w_2d\n", " * (1 - dist2d / (alpha_2d * delta_t_broadcast))\n", " * jnp.exp(-lambda_a * delta_t_broadcast)\n", " )\n", "\n", " # ------------------------------------------------------------------\n", " # ---------- 3D affinity -------------------------------------------\n", " # ------------------------------------------------------------------\n", " # For each detection pre-compute back-projected 3D points lying on z=0 plane.\n", "\n", " backproj_points_list = [\n", " det.camera.unproject_points_to_z_plane(det.keypoints, z=0.0)\n", " for det in camera_detections\n", " ] # each (J,3)\n", " backproj: Float[Array, \"D J 3\"] = jnp.stack(backproj_points_list) # (D, J, 3)\n", "\n", " zero_velocity = jnp.zeros((J, 3))\n", " trk_velocities = jnp.stack(\n", " [\n", " trk.velocity if trk.velocity is not None else zero_velocity\n", " for trk in trackings\n", " ]\n", " )\n", "\n", " predicted_pose: Float[Array, \"T D J 3\"] = (\n", " kps3d_trk[:, None, :, :] # (T,1,J,3)\n", " + trk_velocities[:, None, :, :] * delta_t[:, :, None, None] # (T,D,1,1)\n", " )\n", "\n", " # Camera center – shape (3,) -> will broadcast\n", " cam_center = cam.params.location\n", "\n", " # Compute perpendicular distance using vectorized formula\n", " # p1 = cam_center (3,)\n", " # p2 = backproj (D, J, 3)\n", " # P = predicted_pose (T, D, J, 3)\n", " # Broadcast plan: v1 = P - p1 → (T, D, J, 3)\n", " # v2 = p2[None, ...]-p1 → (1, D, J, 3)\n", " # Shapes now line up; no stray singleton axis.\n", " p1 = cam_center\n", " p2 = backproj\n", " P = predicted_pose\n", " v1 = P - p1\n", " v2 = p2[None, :, :, :] - p1 # (1, D, J, 3)\n", " cross = jnp.cross(v1, v2) # (T, D, J, 3)\n", " num = jnp.linalg.norm(cross, axis=-1) # (T, D, J)\n", " den = jnp.linalg.norm(v2, axis=-1) # (1, D, J)\n", " dist3d: Float[Array, \"T D J\"] = num / den\n", "\n", " affinity_3d = (\n", " w_3d * (1 - dist3d / alpha_3d) * jnp.exp(-lambda_a * delta_t_broadcast)\n", " )\n", "\n", " # ------------------------------------------------------------------\n", " # Combine and reduce across keypoints → (T, D)\n", " # ------------------------------------------------------------------\n", " total_affinity: Float[Array, \"T D\"] = jnp.sum(affinity_2d + affinity_3d, axis=-1)\n", " return total_affinity # type: ignore[return-value]\n", "\n", "\n", "@beartype\n", "def calculate_affinity_matrix(\n", " trackings: Sequence[Tracking],\n", " detections: Sequence[Detection] | Mapping[CameraID, list[Detection]],\n", " w_2d: float,\n", " alpha_2d: float,\n", " w_3d: float,\n", " alpha_3d: float,\n", " lambda_a: float,\n", ") -> dict[CameraID, AffinityResult]:\n", " \"\"\"\n", " Calculate the affinity matrix between a set of trackings and detections.\n", "\n", " Args:\n", " trackings: Sequence of tracking objects\n", " detections: Sequence of detection objects or a group detections by ID\n", " w_2d: Weight for 2D affinity\n", " alpha_2d: Normalization factor for 2D distance\n", " w_3d: Weight for 3D affinity\n", " alpha_3d: Normalization factor for 3D distance\n", " lambda_a: Decay rate for time difference\n", " Returns:\n", " A dictionary mapping camera IDs to affinity results.\n", " \"\"\"\n", " if isinstance(detections, Mapping):\n", " detection_by_camera = detections\n", " else:\n", " detection_by_camera = classify_by_camera(detections)\n", "\n", " res: dict[CameraID, AffinityResult] = {}\n", " for camera_id, camera_detections in detection_by_camera.items():\n", " affinity_matrix = calculate_camera_affinity_matrix_jax(\n", " trackings,\n", " camera_detections,\n", " w_2d,\n", " alpha_2d,\n", " w_3d,\n", " alpha_3d,\n", " lambda_a,\n", " )\n", " # row, col\n", " indices_T, indices_D = linear_sum_assignment(affinity_matrix)\n", " affinity_result = AffinityResult(\n", " matrix=affinity_matrix,\n", " trackings=trackings,\n", " detections=camera_detections,\n", " indices_T=indices_T,\n", " indices_D=indices_D,\n", " )\n", " res[camera_id] = affinity_result\n", " return res\n" ] }, { "cell_type": "code", "execution_count": 305, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "dict_values([Tracking(1, 2024-04-02 12:01:08.791667)])" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "global_tracking_state = GlobalTrackingState()\n", "detections = next(sync_gen)\n", "global_tracking_state.add_tracking(detections)\n", "display(global_tracking_state.trackings.values())" ] }, { "cell_type": "code", "execution_count": 306, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[Tracking(1, 2024-04-02 12:01:08.791667)]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "W_2D = 1.0\n", "ALPHA_2D = 1.0\n", "LAMBDA_A = 0.1\n", "W_3D = 1.0\n", "ALPHA_3D = 1.0\n", "\n", "detections = next(sync_gen)\n", "trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id)\n", "# display(detections)\n", "display(trackings)\n", "# 跟踪状态\n", "affinities = calculate_affinity_matrix(\n", " trackings=trackings,\n", " detections=detections,\n", " w_2d=W_2D,\n", " alpha_2d=ALPHA_2D,\n", " w_3d=W_3D,\n", " alpha_3d=ALPHA_3D,\n", " lambda_a=LAMBDA_A,\n", ")\n", "\n", "# 跟踪上的2d检测数据\n", "track_detetions = []\n", "for key in affinities.keys():\n", " # 每一个机位的跟踪数据\n", " detection = affinities[key].detections\n", " tracking = affinities[key].trackings\n", " if tracking == None:\n", " display(detection)" ] }, { "cell_type": "code", "execution_count": 318, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'AF_03': AffinityResult(matrix=Array([[243.36333]], dtype=float32), trackings=[Tracking(1, 2024-04-02 12:01:08.791667)], detections=[Detection(, 2024-04-02 12:01:08.833333)], indices_T=Array([0], dtype=int32), indices_D=Array([0], dtype=int32, weak_type=True)),\n", " 'AF_05': AffinityResult(matrix=Array([[207.50095]], dtype=float32), trackings=[Tracking(1, 2024-04-02 12:01:08.791667)], detections=[Detection(, 2024-04-02 12:01:08.833333)], indices_T=Array([0], dtype=int32), indices_D=Array([0], dtype=int32, weak_type=True)),\n", " 'AE_1A': AffinityResult(matrix=Array([[221.3469]], dtype=float32), trackings=[Tracking(1, 2024-04-02 12:01:08.791667)], detections=[Detection(, 2024-04-02 12:01:08.833333)], indices_T=Array([0], dtype=int32), indices_D=Array([0], dtype=int32, weak_type=True)),\n", " 'AE_08': AffinityResult(matrix=Array([[235.55804]], dtype=float32), trackings=[Tracking(1, 2024-04-02 12:01:08.791667)], detections=[Detection(, 2024-04-02 12:01:08.833333)], indices_T=Array([0], dtype=int32), indices_D=Array([0], dtype=int32, weak_type=True))}" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "[Tracking(1, 2024-04-02 12:01:08.791667)]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "display(affinities)\n", "display(trackings)" ] }, { "cell_type": "code", "execution_count": 308, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "10" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# 使用已经筛选过的数据\n", "def triangle_from_cluster(cluster: list[Detection]) -> Float[Array, \"3\"]:\n", " proj_matrices = jnp.array([el.camera.params.projection_matrix for el in cluster])\n", " points = jnp.array([el.keypoints for el in cluster])\n", " confidences = jnp.array([el.confidences for el in cluster])\n", " return triangulate_points_from_multiple_views_linear(\n", " proj_matrices, points, confidences=confidences\n", " )\n", "\n", "# 使用已经筛选过的数据\n", "count = 0\n", "all_frame_detections = []\n", "# 遍历筛选后的数据共59帧\n", "while count < 10:\n", " detections = next(sync_gen)\n", " all_frame_detections.append(triangle_from_cluster(detections).tolist())\n", " count += 1\n", "\n", "display(len(all_frame_detections))\n", "# with open(\"samples/QuanCheng_res_new.json\", \"wb\") as f:\n", "# f.write(orjson.dumps(all_frame_detections))" ] }, { "cell_type": "code", "execution_count": 309, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'\\nall_frame_detections = []\\nfor i in range(800):\\n detections = next(sync_gen)\\n # 筛选关键点\\n\\n filter_detections = []\\n for element in detections:\\n filter_box_points_3d = calculater_box_3d_points()\\n box_points_2d = calculater_box_2d_points(filter_box_points_3d, element.camera)\\n box_triangles_all_points = calculater_box_common_scope(box_points_2d)\\n union_area, union_polygon = calculate_triangle_union(box_triangles_all_points)\\n contours = get_contours(union_polygon)\\n \\n # 筛选目标框里的数据\\n if filter_kps_box(element.keypoints, contours):\\n filter_detections.append(element)\\n # 判断筛选后的数据是否为空\\n if len(filter_detections)>2:\\n all_frame_detections.append(triangle_from_cluster(filter_detections).tolist())\\n\\n # 筛选只剩一个人的数据,直接进行DLT\\n# res = {\"a\": all_frame_detections}\\n# display(res)\\nwith open(\"samples/QuanCheng_res.json\", \"wb\") as f:\\n f.write(orjson.dumps(all_frame_detections))\\n'" ] }, "execution_count": 309, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 直接导入全部检测数据,在此处进行筛选\n", "'''\n", "all_frame_detections = []\n", "for i in range(800):\n", " detections = next(sync_gen)\n", " # 筛选关键点\n", "\n", " filter_detections = []\n", " for element in detections:\n", " filter_box_points_3d = calculater_box_3d_points()\n", " box_points_2d = calculater_box_2d_points(filter_box_points_3d, element.camera)\n", " box_triangles_all_points = calculater_box_common_scope(box_points_2d)\n", " union_area, union_polygon = calculate_triangle_union(box_triangles_all_points)\n", " contours = get_contours(union_polygon)\n", " \n", " # 筛选目标框里的数据\n", " if filter_kps_box(element.keypoints, contours):\n", " filter_detections.append(element)\n", " # 判断筛选后的数据是否为空\n", " if len(filter_detections)>2:\n", " all_frame_detections.append(triangle_from_cluster(filter_detections).tolist())\n", "\n", " # 筛选只剩一个人的数据,直接进行DLT\n", "# res = {\"a\": all_frame_detections}\n", "# display(res)\n", "with open(\"samples/QuanCheng_res.json\", \"wb\") as f:\n", " f.write(orjson.dumps(all_frame_detections))\n", "'''" ] }, { "cell_type": "code", "execution_count": 310, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'display(len(detections))'" ] }, "execution_count": 310, "metadata": {}, "output_type": "execute_result" } ], "source": [ "'''display(len(detections))'''" ] }, { "cell_type": "code", "execution_count": 311, "metadata": {}, "outputs": [], "source": [ "# 极限约束\n", "# from app.camera import calculate_affinity_matrix_by_epipolar_constraint\n", "\n", "# sorted_detections, affinity_matrix = calculate_affinity_matrix_by_epipolar_constraint(\n", "# detections, alpha_2d=3500\n", "# )" ] }, { "cell_type": "code", "execution_count": 312, "metadata": {}, "outputs": [], "source": [ "# display(\n", "# list(\n", "# map(\n", "# lambda x: {\n", "# \"timestamp\": str(x.timestamp),\n", "# \"camera\": x.camera.id,\n", "# \"keypoint\": x.keypoints.shape,\n", "# },\n", "# sorted_detections,\n", "# )\n", "# )\n", "# )\n", "# with jnp.printoptions(precision=3, suppress=True):\n", "# display(affinity_matrix)" ] }, { "cell_type": "code", "execution_count": 313, "metadata": {}, "outputs": [], "source": [ "# from app.solver._old import GLPKSolver\n", "\n", "\n", "# def clusters_to_detections(\n", "# clusters: list[list[int]], sorted_detections: list[Detection]\n", "# ) -> list[list[Detection]]:\n", "# \"\"\"\n", "# given a list of clusters (which is the indices of the detections in the sorted_detections list),\n", "# extract the detections from the sorted_detections list\n", "\n", "# Args:\n", "# clusters: list of clusters, each cluster is a list of indices of the detections in the `sorted_detections` list\n", "# sorted_detections: list of SORTED detections\n", "\n", "# Returns:\n", "# list of clusters, each cluster is a list of detections\n", "# \"\"\"\n", "# return [[sorted_detections[i] for i in cluster] for cluster in clusters]\n", "\n", "\n", "# solver = GLPKSolver()\n", "# aff_np = np.asarray(affinity_matrix).astype(np.float64)\n", "# clusters, sol_matrix = solver.solve(aff_np)\n", "# display(clusters)\n", "# display(sol_matrix)" ] }, { "cell_type": "code", "execution_count": 314, "metadata": {}, "outputs": [], "source": [ "# WIDTH = 2560\n", "# HEIGHT = 1440\n", "\n", "# clusters_detections = clusters_to_detections(clusters, sorted_detections)\n", "# im = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8)\n", "# for el in clusters_detections[0]:\n", "# im = visualize_whole_body(np.asarray(el.keypoints), im)\n", "\n", "# p = plt.imshow(im)\n", "# display(p)" ] }, { "cell_type": "code", "execution_count": 315, "metadata": {}, "outputs": [], "source": [ "# im_prime = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8)\n", "# for el in clusters_detections[1]:\n", "# im_prime = visualize_whole_body(np.asarray(el.keypoints), im_prime)\n", "\n", "# p_prime = plt.imshow(im_prime)\n", "# display(p_prime)" ] }, { "cell_type": "code", "execution_count": 316, "metadata": {}, "outputs": [], "source": [ "# def triangle_from_cluster(cluster: list[Detection]) -> Float[Array, \"3\"]:\n", "# proj_matrices = jnp.array([el.camera.params.projection_matrix for el in cluster])\n", "# points = jnp.array([el.keypoints for el in cluster])\n", "# confidences = jnp.array([el.confidences for el in cluster])\n", "# return triangulate_points_from_multiple_views_linear(\n", "# proj_matrices, points, confidences=confidences\n", "# )\n", "\n", "\n", "# res = {\n", "# \"a\": triangle_from_cluster(clusters_detections[0]).tolist(),\n", "# \"b\": triangle_from_cluster(clusters_detections[1]).tolist(),\n", "# }\n", "# with open(\"samples/QuanCheng_res.json\", \"wb\") as f:\n", "# f.write(orjson.dumps(res))" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.9" } }, "nbformat": 4, "nbformat_minor": 2 }