forked from HQU-gxy/CVTH3PE
753 lines
49 KiB
Plaintext
753 lines
49 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"id": "65c62d87",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from _collections_abc import dict_values\n",
|
||
"from math import isnan\n",
|
||
"from pathlib import Path\n",
|
||
"from re import L\n",
|
||
"import awkward as ak\n",
|
||
"from typing import (\n",
|
||
" Any,\n",
|
||
" Generator,\n",
|
||
" Iterable,\n",
|
||
" Optional,\n",
|
||
" Sequence,\n",
|
||
" TypeAlias,\n",
|
||
" TypedDict,\n",
|
||
" cast,\n",
|
||
" TypeVar,\n",
|
||
")\n",
|
||
"from datetime import datetime, timedelta\n",
|
||
"from jaxtyping import Array, Float, Num, jaxtyped\n",
|
||
"import numpy as np\n",
|
||
"from cv2 import undistortPoints\n",
|
||
"from app.camera import Camera, CameraParams, Detection\n",
|
||
"import jax.numpy as jnp\n",
|
||
"from beartype import beartype\n",
|
||
"from scipy.spatial.transform import Rotation as R\n",
|
||
"from app.camera import (\n",
|
||
" Camera,\n",
|
||
" CameraID,\n",
|
||
" CameraParams,\n",
|
||
" Detection,\n",
|
||
" calculate_affinity_matrix_by_epipolar_constraint,\n",
|
||
" classify_by_camera,\n",
|
||
")\n",
|
||
"from copy import copy as shallow_copy\n",
|
||
"import jax\n",
|
||
"from beartype.typing import Mapping, Sequence\n",
|
||
"from itertools import chain\n",
|
||
"import orjson\n",
|
||
"from app.visualize.whole_body import visualize_whole_body\n",
|
||
"from matplotlib import pyplot as plt"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"id": "74ec95dd",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"NDArray: TypeAlias = np.ndarray\n",
|
||
"DetectionGenerator: TypeAlias = Generator[Detection, None, None]\n",
|
||
"\n",
|
||
"DELTA_T_MIN = timedelta(milliseconds=1)\n",
|
||
"\"\"\"所有类型\"\"\"\n",
|
||
"\n",
|
||
"T = TypeVar(\"T\")\n",
|
||
"\n",
|
||
"\n",
|
||
"def unwrap(val: Optional[T]) -> T:\n",
|
||
" if val is None:\n",
|
||
" raise ValueError(\"None\")\n",
|
||
" return val\n",
|
||
"\n",
|
||
"\n",
|
||
"class KeypointDataset(TypedDict):\n",
|
||
" frame_index: int\n",
|
||
" boxes: Num[NDArray, \"N 4\"]\n",
|
||
" kps: Num[NDArray, \"N J 2\"]\n",
|
||
" kps_scores: Num[NDArray, \"N J\"]\n",
|
||
"\n",
|
||
"\n",
|
||
"class Resolution(TypedDict):\n",
|
||
" width: int\n",
|
||
" height: int\n",
|
||
"\n",
|
||
"\n",
|
||
"class Intrinsic(TypedDict):\n",
|
||
" camera_matrix: Num[Array, \"3 3\"]\n",
|
||
" \"\"\"\n",
|
||
" K\n",
|
||
" \"\"\"\n",
|
||
" distortion_coefficients: Num[Array, \"N\"]\n",
|
||
" \"\"\"\n",
|
||
" distortion coefficients; usually 5\n",
|
||
" \"\"\"\n",
|
||
"\n",
|
||
"\n",
|
||
"class Extrinsic(TypedDict):\n",
|
||
" rvec: Num[NDArray, \"3\"]\n",
|
||
" tvec: Num[NDArray, \"3\"]\n",
|
||
"\n",
|
||
"\n",
|
||
"class ExternalCameraParams(TypedDict):\n",
|
||
" name: str\n",
|
||
" port: int\n",
|
||
" intrinsic: Intrinsic\n",
|
||
" extrinsic: Extrinsic\n",
|
||
" resolution: Resolution\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"id": "2e192496",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"\"\"\"获得所有机位的相机内外参\"\"\"\n",
|
||
"def get_camera_params(camera_path: Path) -> ak.Array:\n",
|
||
" camera_dataset: ak.Array = ak.from_parquet(camera_path / \"camera_params.parquet\")\n",
|
||
" return camera_dataset\n",
|
||
"\n",
|
||
"\"\"\"获取所有机位的2d检测数据\"\"\"\n",
|
||
"def get_camera_detect(\n",
|
||
" detect_path: Path, camera_port: list[int], camera_dataset: ak.Array\n",
|
||
") -> dict[int, ak.Array]:\n",
|
||
" keypoint_data: dict[int, ak.Array] = {}\n",
|
||
" for element_port in ak.to_numpy(camera_dataset[\"port\"]):\n",
|
||
" if element_port in camera_port:\n",
|
||
" keypoint_data[int(element_port)] = ak.from_parquet(\n",
|
||
" detect_path / f\"{element_port}.parquet\"\n",
|
||
" )\n",
|
||
" return keypoint_data"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"id": "6eee3591",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 相机内外参路径\n",
|
||
"CAMERA_PATH = Path(\"/home/admin/Documents/2025_4_17/camera_params\")\n",
|
||
"# 所有机位的相机内外参\n",
|
||
"AK_CAMERA_DATASET: ak.Array = get_camera_params(CAMERA_PATH)\n",
|
||
"# 2d检测数据路径\n",
|
||
"DATASET_PATH = Path(\"/home/admin/Documents/2025_4_17/detect_result/many_people_01/\")\n",
|
||
"# 指定机位的2d检测数据\n",
|
||
"camera_port = [5607, 5608, 5609]\n",
|
||
"KEYPOINT_DATASET = get_camera_detect(DATASET_PATH, camera_port, AK_CAMERA_DATASET)\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"id": "ce225126",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"/tmp/ipykernel_2333927/1636344639.py:1: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword\n",
|
||
" kps_5607 =np.array(KEYPOINT_DATASET[5607]['kps'])\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"(710, 2, 133, 2)"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"/tmp/ipykernel_2333927/1636344639.py:3: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword\n",
|
||
" kps_5607_socers = np.array(KEYPOINT_DATASET[5607]['kps_scores'])\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"(710, 2, 133)"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"kps_5607 =np.array(KEYPOINT_DATASET[5607]['kps'])\n",
|
||
"display(kps_5607.shape)\n",
|
||
"kps_5607_socers = np.array(KEYPOINT_DATASET[5607]['kps_scores'])\n",
|
||
"display(kps_5607_socers.shape)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"id": "abf50aa8",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"\n",
|
||
"\"\"\"将所有2d检测数据打包\"\"\"\n",
|
||
"\n",
|
||
"\n",
|
||
"@jaxtyped(typechecker=beartype)\n",
|
||
"def undistort_points(\n",
|
||
" points: Num[NDArray, \"M 2\"],\n",
|
||
" camera_matrix: Num[NDArray, \"3 3\"],\n",
|
||
" dist_coeffs: Num[NDArray, \"N\"],\n",
|
||
") -> Num[NDArray, \"M 2\"]:\n",
|
||
" K = camera_matrix\n",
|
||
" dist = dist_coeffs\n",
|
||
" res = undistortPoints(points, K, dist, P=K) # type: ignore\n",
|
||
" return res.reshape(-1, 2)\n",
|
||
"\n",
|
||
"\n",
|
||
"@jaxtyped(typechecker=beartype)\n",
|
||
"def to_transformation_matrix(\n",
|
||
" rvec: Num[NDArray, \"3\"], tvec: Num[NDArray, \"3\"]\n",
|
||
") -> Num[NDArray, \"4 4\"]:\n",
|
||
" res = np.eye(4)\n",
|
||
" res[:3, :3] = R.from_rotvec(rvec).as_matrix()\n",
|
||
" res[:3, 3] = tvec\n",
|
||
" return res\n",
|
||
"\n",
|
||
"\n",
|
||
"def from_camera_params(camera: ExternalCameraParams) -> Camera:\n",
|
||
" rt = jnp.array(\n",
|
||
" to_transformation_matrix(\n",
|
||
" ak.to_numpy(camera[\"extrinsic\"][\"rvec\"]),\n",
|
||
" ak.to_numpy(camera[\"extrinsic\"][\"tvec\"]),\n",
|
||
" )\n",
|
||
" )\n",
|
||
" K = jnp.array(camera[\"intrinsic\"][\"camera_matrix\"]).reshape(3, 3)\n",
|
||
" dist_coeffs = jnp.array(camera[\"intrinsic\"][\"distortion_coefficients\"])\n",
|
||
" image_size = jnp.array(\n",
|
||
" (camera[\"resolution\"][\"width\"], camera[\"resolution\"][\"height\"])\n",
|
||
" )\n",
|
||
" return Camera(\n",
|
||
" id=camera[\"name\"],\n",
|
||
" params=CameraParams(\n",
|
||
" K=K,\n",
|
||
" Rt=rt,\n",
|
||
" dist_coeffs=dist_coeffs,\n",
|
||
" image_size=image_size,\n",
|
||
" ),\n",
|
||
" )\n",
|
||
"\n",
|
||
"\n",
|
||
"def preprocess_keypoint_dataset(\n",
|
||
" dataset: Sequence[KeypointDataset],\n",
|
||
" camera: Camera,\n",
|
||
" fps: float,\n",
|
||
" start_timestamp: datetime,\n",
|
||
") -> Generator[Detection, None, None]:\n",
|
||
" frame_interval_s = 1 / fps\n",
|
||
" for el in dataset:\n",
|
||
" frame_index = el[\"frame_index\"]\n",
|
||
" timestamp = start_timestamp + timedelta(seconds=frame_index * frame_interval_s)\n",
|
||
" for kp, kp_score, boxes in zip(el[\"kps\"], el[\"kps_scores\"], el[\"boxes\"]):\n",
|
||
" kp = undistort_points(\n",
|
||
" np.asarray(kp),\n",
|
||
" np.asarray(camera.params.K),\n",
|
||
" np.asarray(camera.params.dist_coeffs),\n",
|
||
" )\n",
|
||
"\n",
|
||
" yield Detection(\n",
|
||
" keypoints=jnp.array(kp),\n",
|
||
" confidences=jnp.array(kp_score),\n",
|
||
" camera=camera,\n",
|
||
" timestamp=timestamp,\n",
|
||
" )\n",
|
||
"\n",
|
||
"\n",
|
||
"def sync_batch_gen(\n",
|
||
" gens: list[DetectionGenerator], diff: timedelta\n",
|
||
") -> Generator[list[Detection], Any, None]:\n",
|
||
" from more_itertools import partition\n",
|
||
"\n",
|
||
" \"\"\"\n",
|
||
" given a list of detection generators, return a generator that yields a batch of detections\n",
|
||
"\n",
|
||
" Args:\n",
|
||
" gens: list of detection generators\n",
|
||
" diff: maximum timestamp difference between detections to consider them part of the same batch\n",
|
||
" \"\"\"\n",
|
||
" N = len(gens)\n",
|
||
" last_batch_timestamp: Optional[datetime] = None\n",
|
||
" current_batch: list[Detection] = []\n",
|
||
" paused: list[bool] = [False] * N\n",
|
||
" finished: list[bool] = [False] * N\n",
|
||
" unmached_detections: list[Detection] = []\n",
|
||
"\n",
|
||
" def reset_paused():\n",
|
||
" \"\"\"\n",
|
||
" reset paused list based on finished list\n",
|
||
" \"\"\"\n",
|
||
" for i in range(N):\n",
|
||
" if not finished[i]:\n",
|
||
" paused[i] = False\n",
|
||
" else:\n",
|
||
" paused[i] = True\n",
|
||
"\n",
|
||
" EPS = 1e-6\n",
|
||
" # a small epsilon to avoid floating point precision issues\n",
|
||
" diff_esp = diff - timedelta(seconds=EPS)\n",
|
||
" while True:\n",
|
||
" for i, gen in enumerate(gens):\n",
|
||
" try:\n",
|
||
" if finished[i] or paused[i]:\n",
|
||
" if all(finished):\n",
|
||
" if len(current_batch) > 0:\n",
|
||
" # All generators exhausted, flush remaining batch and exit\n",
|
||
" yield current_batch\n",
|
||
" return\n",
|
||
" else:\n",
|
||
" continue\n",
|
||
" val = next(gen)\n",
|
||
" if last_batch_timestamp is None:\n",
|
||
" last_batch_timestamp = val.timestamp\n",
|
||
" current_batch.append(val)\n",
|
||
" else:\n",
|
||
" if abs(val.timestamp - last_batch_timestamp) >= diff_esp:\n",
|
||
" unmached_detections.append(val)\n",
|
||
" paused[i] = True\n",
|
||
" if all(paused):\n",
|
||
" yield current_batch\n",
|
||
" reset_paused()\n",
|
||
" last_batch_timestamp = last_batch_timestamp + diff\n",
|
||
" bad, good = partition(\n",
|
||
" lambda x: x.timestamp < unwrap(last_batch_timestamp),\n",
|
||
" unmached_detections,\n",
|
||
" )\n",
|
||
" current_batch = list(good)\n",
|
||
" unmached_detections = list(bad)\n",
|
||
" else:\n",
|
||
" current_batch.append(val)\n",
|
||
" except StopIteration:\n",
|
||
" return\n",
|
||
"\n",
|
||
"\n",
|
||
"def get_batch_detect(\n",
|
||
" keypoint_dataset,\n",
|
||
" camera_dataset,\n",
|
||
" camera_port: list[int],\n",
|
||
" FPS: int = 24,\n",
|
||
" batch_fps: int = 10,\n",
|
||
") -> Generator[list[Detection], Any, None]:\n",
|
||
" gen_data = [\n",
|
||
" preprocess_keypoint_dataset(\n",
|
||
" keypoint_dataset[port],\n",
|
||
" from_camera_params(camera_dataset[camera_dataset[\"port\"] == port][0]),\n",
|
||
" FPS,\n",
|
||
" datetime(2024, 4, 2, 12, 0, 0),\n",
|
||
" )\n",
|
||
" for port in camera_port\n",
|
||
" ]\n",
|
||
"\n",
|
||
" sync_gen: Generator[list[Detection], Any, None] = sync_batch_gen(\n",
|
||
" gen_data,\n",
|
||
" timedelta(seconds=1 / batch_fps),\n",
|
||
" )\n",
|
||
" return sync_gen\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 16,
|
||
"id": "82ac4c3e",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 将所有的2d检测数据打包\n",
|
||
"sync_gen: Generator[list[Detection], Any, None] = get_batch_detect(\n",
|
||
" KEYPOINT_DATASET,\n",
|
||
" AK_CAMERA_DATASET,\n",
|
||
" camera_port,\n",
|
||
" batch_fps=24,\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 86,
|
||
"id": "33559b73",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"[Detection(<Camera id=AE_01>, 2024-04-02 12:00:00.333333),\n",
|
||
" Detection(<Camera id=AE_1A>, 2024-04-02 12:00:00.333333),\n",
|
||
" Detection(<Camera id=AE_08>, 2024-04-02 12:00:00.333333),\n",
|
||
" Detection(<Camera id=AE_01>, 2024-04-02 12:00:00.375000),\n",
|
||
" Detection(<Camera id=AE_1A>, 2024-04-02 12:00:00.375000),\n",
|
||
" Detection(<Camera id=AE_08>, 2024-04-02 12:00:00.375000)]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"6"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"detections = next(sync_gen)\n",
|
||
"display(detections)\n",
|
||
"display(len(detections))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 87,
|
||
"id": "87d44153",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 极限约束\n",
|
||
"from app.camera import calculate_affinity_matrix_by_epipolar_constraint\n",
|
||
"\n",
|
||
"sorted_detections, affinity_matrix = calculate_affinity_matrix_by_epipolar_constraint(\n",
|
||
" detections, alpha_2d=3500\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 88,
|
||
"id": "821ca702",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"[{'timestamp': '2024-04-02 12:00:00.333333',\n",
|
||
" 'camera': 'AE_01',\n",
|
||
" 'keypoint': (133, 2)},\n",
|
||
" {'timestamp': '2024-04-02 12:00:00.375000',\n",
|
||
" 'camera': 'AE_01',\n",
|
||
" 'keypoint': (133, 2)},\n",
|
||
" {'timestamp': '2024-04-02 12:00:00.333333',\n",
|
||
" 'camera': 'AE_1A',\n",
|
||
" 'keypoint': (133, 2)},\n",
|
||
" {'timestamp': '2024-04-02 12:00:00.375000',\n",
|
||
" 'camera': 'AE_1A',\n",
|
||
" 'keypoint': (133, 2)},\n",
|
||
" {'timestamp': '2024-04-02 12:00:00.333333',\n",
|
||
" 'camera': 'AE_08',\n",
|
||
" 'keypoint': (133, 2)},\n",
|
||
" {'timestamp': '2024-04-02 12:00:00.375000',\n",
|
||
" 'camera': 'AE_08',\n",
|
||
" 'keypoint': (133, 2)}]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"Array([[ -inf, -inf, 0.729, 0.63 , 0.549, 0.453],\n",
|
||
" [ -inf, -inf, 0.786, 0.702, 0.651, 0.559],\n",
|
||
" [0.729, 0.786, -inf, -inf, 0.846, 0.787],\n",
|
||
" [0.63 , 0.702, -inf, -inf, 0.907, 0.847],\n",
|
||
" [0.549, 0.651, 0.846, 0.907, -inf, -inf],\n",
|
||
" [0.453, 0.559, 0.787, 0.847, -inf, -inf]], dtype=float32)"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"display(\n",
|
||
" list(\n",
|
||
" map(\n",
|
||
" lambda x: {\n",
|
||
" \"timestamp\": str(x.timestamp),\n",
|
||
" \"camera\": x.camera.id,\n",
|
||
" \"keypoint\": x.keypoints.shape,\n",
|
||
" },\n",
|
||
" sorted_detections,\n",
|
||
" )\n",
|
||
" )\n",
|
||
")\n",
|
||
"with jnp.printoptions(precision=3, suppress=True):\n",
|
||
" display(affinity_matrix)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 89,
|
||
"id": "10499f36",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"[[0, 2, 4], [1, 3, 5]]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"array([[0, 0, 1, 0, 1, 0],\n",
|
||
" [0, 0, 0, 1, 0, 1],\n",
|
||
" [1, 0, 0, 0, 1, 0],\n",
|
||
" [0, 1, 0, 0, 0, 1],\n",
|
||
" [1, 0, 1, 0, 0, 0],\n",
|
||
" [0, 1, 0, 1, 0, 0]])"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"from app.solver._old import GLPKSolver\n",
|
||
"\n",
|
||
"def clusters_to_detections(\n",
|
||
" clusters: list[list[int]], sorted_detections: list[Detection]\n",
|
||
") -> list[list[Detection]]:\n",
|
||
" \"\"\"\n",
|
||
" given a list of clusters (which is the indices of the detections in the sorted_detections list),\n",
|
||
" extract the detections from the sorted_detections list\n",
|
||
"\n",
|
||
" Args:\n",
|
||
" clusters: list of clusters, each cluster is a list of indices of the detections in the `sorted_detections` list\n",
|
||
" sorted_detections: list of SORTED detections\n",
|
||
"\n",
|
||
" Returns:\n",
|
||
" list of clusters, each cluster is a list of detections\n",
|
||
" \"\"\"\n",
|
||
" return [[sorted_detections[i] for i in cluster] for cluster in clusters]\n",
|
||
"\n",
|
||
"\n",
|
||
"solver = GLPKSolver()\n",
|
||
"aff_np = np.asarray(affinity_matrix).astype(np.float64)\n",
|
||
"clusters, sol_matrix = solver.solve(aff_np)\n",
|
||
"display(clusters)\n",
|
||
"display(sol_matrix)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 90,
|
||
"id": "b7a05c6b",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"<matplotlib.image.AxesImage at 0x7fab74369250>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"<Figure size 640x480 with 1 Axes>"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"WIDTH = 2560\n",
|
||
"HEIGHT = 1440\n",
|
||
"\n",
|
||
"clusters_detections = clusters_to_detections(clusters, sorted_detections)\n",
|
||
"im = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8)\n",
|
||
"for el in clusters_detections[0]:\n",
|
||
" im = visualize_whole_body(np.asarray(el.keypoints), im)\n",
|
||
"\n",
|
||
"p = plt.imshow(im)\n",
|
||
"display(p)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 91,
|
||
"id": "eac843c2",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# im_prime = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8)\n",
|
||
"# for el in clusters_detections[1]:\n",
|
||
"# im_prime = visualize_whole_body(np.asarray(el.keypoints), im_prime)\n",
|
||
"\n",
|
||
"# p_prime = plt.imshow(im_prime)\n",
|
||
"# display(p_prime)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 92,
|
||
"id": "037dcc22",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"@jaxtyped(typechecker=beartype)\n",
|
||
"def triangulate_one_point_from_multiple_views_linear(\n",
|
||
" proj_matrices: Float[Array, \"N 3 4\"],\n",
|
||
" points: Num[Array, \"N 2\"],\n",
|
||
" confidences: Optional[Float[Array, \"N\"]] = None,\n",
|
||
" conf_threshold: float = 0.4, # 0.2\n",
|
||
") -> Float[Array, \"3\"]:\n",
|
||
" \"\"\"\n",
|
||
" Args:\n",
|
||
" proj_matrices: 形状为(N, 3, 4)的投影矩阵序列\n",
|
||
" points: 形状为(N, 2)的点坐标序列\n",
|
||
" confidences: 形状为(N,)的置信度序列,范围[0.0, 1.0]\n",
|
||
" conf_threshold: 置信度阈值,低于该值的观测不参与DLT\n",
|
||
" Returns:\n",
|
||
" point_3d: 形状为(3,)的三角测量得到的3D点\n",
|
||
" \"\"\"\n",
|
||
" assert len(proj_matrices) == len(points)\n",
|
||
" N = len(proj_matrices)\n",
|
||
" # 置信度加权DLT\n",
|
||
" if confidences is None:\n",
|
||
" weights = jnp.ones(N, dtype=jnp.float32)\n",
|
||
" else:\n",
|
||
" valid_mask = confidences >= conf_threshold\n",
|
||
" weights = jnp.where(valid_mask, confidences, 0.0)\n",
|
||
" sum_weights = jnp.sum(weights)\n",
|
||
" weights = jnp.where(sum_weights > 0, weights / sum_weights, weights)\n",
|
||
"\n",
|
||
" A = jnp.zeros((N * 2, 4), dtype=jnp.float32)\n",
|
||
" for i in range(N):\n",
|
||
" x, y = points[i]\n",
|
||
" row1 = proj_matrices[i, 2] * x - proj_matrices[i, 0]\n",
|
||
" row2 = proj_matrices[i, 2] * y - proj_matrices[i, 1]\n",
|
||
" A = A.at[2 * i].set(row1 * weights[i])\n",
|
||
" A = A.at[2 * i + 1].set(row2 * weights[i])\n",
|
||
"\n",
|
||
" _, _, vh = jnp.linalg.svd(A, full_matrices=False)\n",
|
||
" point_3d_homo = vh[-1]\n",
|
||
" point_3d_homo = jnp.where(point_3d_homo[3] < 0, -point_3d_homo, point_3d_homo)\n",
|
||
" is_zero_weight = jnp.sum(weights) == 0\n",
|
||
" point_3d = jnp.where(\n",
|
||
" is_zero_weight,\n",
|
||
" jnp.full((3,), jnp.nan, dtype=jnp.float32),\n",
|
||
" jnp.where(\n",
|
||
" jnp.abs(point_3d_homo[3]) > 1e-8,\n",
|
||
" point_3d_homo[:3] / point_3d_homo[3],\n",
|
||
" jnp.full((3,), jnp.nan, dtype=jnp.float32),\n",
|
||
" ),\n",
|
||
" )\n",
|
||
" return point_3d\n",
|
||
"\n",
|
||
"\n",
|
||
"jaxtyped(typechecker=beartype)\n",
|
||
"def triangulate_points_from_multiple_views_linear(\n",
|
||
" proj_matrices: Float[Array, \"N 3 4\"],\n",
|
||
" points: Num[Array, \"N P 2\"],\n",
|
||
" confidences: Optional[Float[Array, \"N P\"]] = None,\n",
|
||
") -> Float[Array, \"P 3\"]:\n",
|
||
" \"\"\"\n",
|
||
" Batch‐triangulate P points observed by N cameras, linearly via SVD.\n",
|
||
"\n",
|
||
" Args:\n",
|
||
" proj_matrices: (N, 3, 4) projection matrices\n",
|
||
" points: (N, P, 2) image-coordinates per view\n",
|
||
" confidences: (N, P, 1) optional per-view confidences in [0,1]\n",
|
||
"\n",
|
||
" Returns:\n",
|
||
" (P, 3) 3D point for each of the P tracks\n",
|
||
" \"\"\"\n",
|
||
" N, P, _ = points.shape\n",
|
||
" assert proj_matrices.shape[0] == N\n",
|
||
"\n",
|
||
" if confidences is None:\n",
|
||
" conf = jnp.ones((N, P), dtype=jnp.float32)\n",
|
||
" else:\n",
|
||
" conf = jnp.array(confidences)\n",
|
||
"\n",
|
||
" vmap_triangulate = jax.vmap(\n",
|
||
" triangulate_one_point_from_multiple_views_linear,\n",
|
||
" in_axes=(None, 1, 1),\n",
|
||
" out_axes=0,\n",
|
||
" )\n",
|
||
" return vmap_triangulate(proj_matrices, points, conf)\n",
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 93,
|
||
"id": "8fc0074d",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def triangle_from_cluster(cluster: list[Detection]) -> Float[Array, \"3\"]:\n",
|
||
" proj_matrices = jnp.array([el.camera.params.projection_matrix for el in cluster])\n",
|
||
" points = jnp.array([el.keypoints for el in cluster])\n",
|
||
" confidences = jnp.array([el.confidences for el in cluster])\n",
|
||
" return triangulate_points_from_multiple_views_linear(\n",
|
||
" proj_matrices, points, confidences=confidences\n",
|
||
" )\n",
|
||
"\n",
|
||
"\n",
|
||
"res = {\n",
|
||
" \"a\": [triangle_from_cluster(clusters_detections[0]).tolist()],\n",
|
||
" \"b\": [triangle_from_cluster(clusters_detections[1]).tolist()],\n",
|
||
"}\n",
|
||
"with open(\"samples/Test_YEU.json\", \"wb\") as f:\n",
|
||
" f.write(orjson.dumps(res))"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "cvth3pe",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.12.9"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|