1
0
forked from HQU-gxy/CVTH3PE
Files
CVTH3PE/playground.py
2025-07-11 15:52:46 +08:00

1286 lines
49 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# ---
# jupyter:
# jupytext:
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.17.0
# kernelspec:
# display_name: .venv
# language: python
# name: python3
# ---
# %%
from collections import OrderedDict
from copy import copy as shallow_copy
from copy import deepcopy as deep_copy
from dataclasses import dataclass
from datetime import datetime, timedelta
from functools import partial, reduce
from itertools import chain
from pathlib import Path
from typing import (
Any,
Generator,
Optional,
TypeAlias,
TypedDict,
TypeVar,
cast,
overload,
Iterable,
)
import awkward as ak
import jax
import jax.numpy as jnp
import numpy as np
from beartype import beartype
from beartype.typing import Mapping, Sequence
from cv2 import undistortPoints
from IPython.display import display
from jaxtyping import Array, Float, Num, jaxtyped
from matplotlib import pyplot as plt
from numpy.typing import ArrayLike
from optax.assignment import hungarian_algorithm as linear_sum_assignment
from pyrsistent import pvector, v, m, pmap, PMap, freeze, thaw
from scipy.spatial.transform import Rotation as R
from typing_extensions import deprecated
from collections import defaultdict
from app.camera import (
Camera,
CameraID,
CameraParams,
Detection,
calculate_affinity_matrix_by_epipolar_constraint,
classify_by_camera,
)
from app.solver._old import GLPKSolver
from app.tracking import (
TrackingID,
AffinityResult,
LastDifferenceVelocityFilter,
Tracking,
TrackingState,
)
from app.visualize.whole_body import visualize_whole_body
NDArray: TypeAlias = np.ndarray
# %%
DATASET_PATH = Path("samples") / "04_02" #定义数据集路径
AK_CAMERA_DATASET: ak.Array = ak.from_parquet(DATASET_PATH / "camera_params.parquet") # 从parquet文件中读取相机参数数据集
DELTA_T_MIN = timedelta(milliseconds=1) #定义最小时间间隔为1毫秒
display(AK_CAMERA_DATASET) #显示相机参数
# %%
class Resolution(TypedDict): #定义Resonlution类型用于表述图像分辨率
width: int
height: int
class Intrinsic(TypedDict): #定义Intrinsic类型用于表示相机参数
camera_matrix: Num[Array, "3 3"]
"""
K
"""
distortion_coefficients: Num[Array, "N"]
"""
distortion coefficients; usually 5
"""
class Extrinsic(TypedDict): #相机外参
rvec: Num[NDArray, "3"]
tvec: Num[NDArray, "3"]
class ExternalCameraParams(TypedDict): #外部相机参数
name: str
port: int
intrinsic: Intrinsic
extrinsic: Extrinsic
resolution: Resolution
# %%
def read_dataset_by_port(port: int) -> ak.Array: #定义函数根据端口号读取数据集
P = DATASET_PATH / f"{port}.parquet" #构建数据集文件路径
return ak.from_parquet(P) #从Parquet文件中读取数据集
KEYPOINT_DATASET = { #构建关键点数据集字典,键为端口号,,值为对应的数据集
int(p): read_dataset_by_port(p) for p in ak.to_numpy(AK_CAMERA_DATASET["port"])
}
# %%
class KeypointDataset(TypedDict): #用于表示关键点数据集
frame_index: int # 帧索引
boxes: Num[NDArray, "N 4"] # 边界框N个框每个框4个坐标
kps: Num[NDArray, "N J 2"] # 关键点N个对象每个对象J个关键点每个关键点2维坐标
kps_scores: Num[NDArray, "N J"] # 关键点分数N个对象每个对象J个分数
@jaxtyped(typechecker=beartype) #运行时检查函数参数和返回值是否符合类型注解中的维度约束
def to_transformation_matrix( #将旋转向量和平移向量转换为4x4的变换矩阵
rvec: Num[NDArray, "3"], tvec: Num[NDArray, "3"] #输入参数
) -> Num[NDArray, "4 4"]:
res = np.eye(4) #初始化一个4x4的单位矩阵
res[:3, :3] = R.from_rotvec(rvec).as_matrix() #将旋转向量转换为旋转矩阵并赋值给左上角3x3子矩阵
res[:3, 3] = tvec #将平移向量赋值给最后一列的前三个元素
return res
@jaxtyped(typechecker=beartype)
def undistort_points( # 对图像点进行去畸变处理
points: Num[NDArray, "M 2"], #输入参数 # M个点每个点2维坐标 (x, y)
camera_matrix: Num[NDArray, "3 3"], # 3×3相机内参矩阵
dist_coeffs: Num[NDArray, "N"], # N个畸变系数
) -> Num[NDArray, "M 2"]: # 返回M个去畸变后的点坐标
K = camera_matrix # 重新赋值参数
dist = dist_coeffs
res = undistortPoints(points, K, dist, P=K) # type: ignore #使用OpenCV 中的函数,用于对图像点进行去畸变处理
return res.reshape(-1, 2) #将输出结果重塑为 M×2 的二维数组,确保返回格式正确
def from_camera_params(camera: ExternalCameraParams) -> Camera: #将外部相机参数转换为内部 Camera 对象
rt = jnp.array(
to_transformation_matrix( #调用函数,将将旋转向量和平移向量组合为齐次变换矩阵
ak.to_numpy(camera["extrinsic"]["rvec"]), #数据转换为 NumPy 数组
ak.to_numpy(camera["extrinsic"]["tvec"]),
)
)
K = jnp.array(camera["intrinsic"]["camera_matrix"]).reshape(3, 3) #从外部参数中提取相机内参矩阵,重塑为 3×3 矩阵
dist_coeffs = jnp.array(camera["intrinsic"]["distortion_coefficients"]) #提取相机的畸变系数
image_size = jnp.array( #提取图像的宽度和高度,存储为 JAX 数组
(camera["resolution"]["width"], camera["resolution"]["height"])
)
return Camera(
id=camera["name"],
params=CameraParams( #封装所有相机参数
K=K, #相机内参矩阵
Rt=rt, #相机外参矩阵(齐次变换矩阵)
dist_coeffs=dist_coeffs, #畸变系数
image_size=image_size, #图像分辨率
),
)
def preprocess_keypoint_dataset( #用于将关键点数据集KeypointDataset 序列)转换为 Detection 对象流
dataset: Sequence[KeypointDataset], # 输入:关键点数据集序列
camera: Camera, # 相机参数
fps: float, # 帧率(帧/秒)
start_timestamp: datetime, # 起始时间戳
) -> Generator[Detection, None, None]: # 输出Detection对象生成器
frame_interval_s = 1 / fps #计算每帧的时间间隔(秒)
for el in dataset:
frame_index = el["frame_index"] # 获取当前帧索引
timestamp = start_timestamp + timedelta(seconds=frame_index * frame_interval_s)
for kp, kp_score in zip(el["kps"], el["kps_scores"]):
yield Detection(
keypoints=jnp.array(kp), # 关键点坐标
confidences=jnp.array(kp_score), # 关键点置信度
camera=camera, # 相机参数
timestamp=timestamp, # 时间戳
)
# %%
DetectionGenerator: TypeAlias = Generator[Detection, None, None] #别名定义
#将多个异步的检测流按时间戳同步,生成时间上 “对齐” 的批次
def sync_batch_gen(gens: Sequence[DetectionGenerator], diff: timedelta): #gens: 检测生成器列表diff: 允许的时间戳最大差异,用于判断两个检测是否属于同一批次
"""
given a list of detection generators, return a generator that yields a batch of detections
Args:
gens: list of detection generators
diff: maximum timestamp difference between detections to consider them part of the same batch
"""
N = len(gens) # 生成器数量
last_batch_timestamp: Optional[datetime] = None # 当前批次的时间戳
next_batch_timestamp: Optional[datetime] = None # 下一批次的时间戳
current_batch: list[Detection] = [] # 当前批次的检测结果
next_batch: list[Detection] = [] # 下一批次的检测结果
paused: list[bool] = [False] * N # 标记每个生成器是否暂停
finished: list[bool] = [False] * N # 标记每个生成器是否已耗尽
def reset_paused():
"""
reset paused list based on finished list
"""
for i in range(N):
if not finished[i]:
paused[i] = False
else:
paused[i] = True
EPS = 1e-6
# a small epsilon to avoid floating point precision issues
diff_esp = diff - timedelta(seconds=EPS) #用于处理浮点数精度问题,避免因微小时间差导致误判。
while True:
for i, gen in enumerate(gens):
try:
if finished[i] or paused[i]:
continue
val = next(gen) # 获取下一个检测结果
if last_batch_timestamp is None: # ... 时间戳比较与批次分配 ...
last_batch_timestamp = val.timestamp
current_batch.append(val) # 初始化第一批
else:
if abs(val.timestamp - last_batch_timestamp) >= diff_esp:
next_batch.append(val) # 时间差超过阈值,放入下一批
if next_batch_timestamp is None:
next_batch_timestamp = val.timestamp
paused[i] = True # 暂停该生成器,等待批次切换
if all(paused):
yield current_batch # 所有生成器都暂停时,输出当前批次
current_batch = next_batch
next_batch = []
last_batch_timestamp = next_batch_timestamp
next_batch_timestamp = None
reset_paused() # 重置暂停状态
else:
current_batch.append(val) # 时间差在阈值内,加入当前批次
except StopIteration:
finished[i] = True
paused[i] = True
if all(finished):
if len(current_batch) > 0:
# All generators exhausted, flush remaining batch and exit
yield current_batch # 输出最后一批
break
# %%
@overload
def to_projection_matrix( #将 变换矩阵4×4 和 相机内参矩阵3×3 组合成一个 投影矩阵3×4
transformation_matrix: Num[NDArray, "4 4"], camera_matrix: Num[NDArray, "3 3"]
) -> Num[NDArray, "3 4"]: ...
@overload
def to_projection_matrix( #将 变换矩阵4×4 和 相机内参矩阵3×3 组合成一个 投影矩阵3×4
transformation_matrix: Num[Array, "4 4"], camera_matrix: Num[Array, "3 3"]
) -> Num[Array, "3 4"]: ...
@jaxtyped(typechecker=beartype)
def to_projection_matrix( #计算投影矩阵使用jax.jit提高性能
transformation_matrix: Num[Any, "4 4"],
camera_matrix: Num[Any, "3 3"],
) -> Num[Any, "3 4"]:
return camera_matrix @ transformation_matrix[:3, :]
to_projection_matrix_jit = jax.jit(to_projection_matrix)
@jaxtyped(typechecker=beartype)
def dlt( # DLT算法
H1: Num[NDArray, "3 4"], # 第一个相机的投影矩阵3×4
H2: Num[NDArray, "3 4"], # 第二个相机的投影矩阵3×4
p1: Num[NDArray, "2"], # 三维点在第一个相机图像上的投影u1, v1
p2: Num[NDArray, "2"], # 三维点在第二个相机图像上的投影u2, v2
) -> Num[NDArray, "3"]: # 输出三维空间点坐标X, Y, Z
"""
Direct Linear Transformation
"""
A = [ # 构建矩阵A
p1[1] * H1[2, :] - H1[1, :], # 第一行v₁·H1[2,:] - H1[1,:]
H1[0, :] - p1[0] * H1[2, :], # 第二行H1[0,:] - u₁·H1[2,:]
p2[1] * H2[2, :] - H2[1, :], # 第三行v₂·H2[2,:] - H2[1,:]
H2[0, :] - p2[0] * H2[2, :], # 第四行H2[0,:] - u₂·H2[2,:]
]
A = np.array(A).reshape((4, 4)) # 转换为4×4矩阵
# 求解超定方程组
B = A.transpose() @ A # 计算A^T·A4×4矩阵
from scipy import linalg
U, s, Vh = linalg.svd(B, full_matrices=False) # SVD分解
return Vh[3, 0:3] / Vh[3, 3] # 提取解并归一化
@overload
def homogeneous_to_euclidean(points: Num[NDArray, "N 4"]) -> Num[NDArray, "N 3"]: ...
@overload
def homogeneous_to_euclidean(points: Num[Array, "N 4"]) -> Num[Array, "N 3"]: ...
@jaxtyped(typechecker=beartype)
def homogeneous_to_euclidean( #将 齐次坐标 转换为 欧几里得坐标
points: Num[Any, "N 4"],
) -> Num[Any, "N 3"]:
"""
将齐次坐标转换为欧几里得坐标
Args:
points: homogeneous coordinates (x, y, z, w) in numpy array or jax array
Returns:
euclidean coordinates (x, y, z) in numpy array or jax array
"""
return points[..., :-1] / points[..., -1:]
# %% # 创建三个相机的关键点检测生成器,并使用 sync_batch_gen 函数将它们同步为时间对齐的批次。
FPS = 24 # 帧率24帧/秒
# 创建三个相机的检测生成器假设port=5600,5601,5602对应三个不同相机
image_gen_5600 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5600], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5600][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore
image_gen_5601 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5601], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5601][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore
image_gen_5602 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5602], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5602][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore
display(1 / FPS) # 每帧时间间隔约0.0417秒
# 同步三个生成器时间窗口为1/FPS秒即同一批次内的检测时间差不超过一帧
sync_gen = sync_batch_gen(
[image_gen_5600, image_gen_5601, image_gen_5602], timedelta(seconds=1 / FPS)
)
# %% 基于 对极约束 计算不同相机检测结果之间的关联度矩阵,并返回排序后的检测结果和关联度矩阵
#输入 # next(sync_gen):从同步生成器获取的一批检测结果(包含多个相机在相近时间点的检测)
# alpha_2d=2000控制 2D 距离权重的参数,用于平衡对极约束和其他特征(如外观、运动)的影响
#输出 #sorted_detections排序后的检测结果列表
#affinity_matrix关联度矩阵matrix[i][j] 表示第 i 个检测与第 j 个检测的关联程度(值越大表示越可能是同一目标)
sorted_detections, affinity_matrix = calculate_affinity_matrix_by_epipolar_constraint(
next(sync_gen), alpha_2d=2000
)
display(sorted_detections)
# %% # 可视化多相机目标跟踪中的关键数据:检测时间戳和关联度矩阵
display( #将排序后的检测结果转换为包含时间戳和相机 ID 的字典列表,并在 Jupyter 中显示
list(
map(
lambda x: {"timestamp": str(x.timestamp), "camera": x.camera.id},
sorted_detections,
)
)
)
with jnp.printoptions(precision=3, suppress=True): #以高精度格式显示关联度矩阵,控制浮点数精度为 3 位,并禁用科学计数法
display(affinity_matrix)
# %% #实现了一个基于关联度矩阵的聚类算法,将可能属于同一目标的检测结果分组
def clusters_to_detections( # 聚类函数
clusters: Sequence[Sequence[int]], sorted_detections: Sequence[Detection]
) -> list[list[Detection]]:
"""
given a list of clusters (which is the indices of the detections in the sorted_detections list),
extract the detections from the sorted_detections list
Args:
clusters: list of clusters, each cluster is a list of indices of the detections in the `sorted_detections` list
sorted_detections: list of SORTED detections
Returns:
list of clusters, each cluster is a list of detections
"""
return [[sorted_detections[i] for i in cluster] for cluster in clusters]
solver = GLPKSolver() # 初始化GLPK线性规划求解器
aff_np = np.asarray(affinity_matrix).astype(np.float64) # 转换关联度矩阵为NumPy数组
clusters, sol_matrix = solver.solve(aff_np) # 求解聚类问题
display(clusters)
display(sol_matrix)
# %% #两个函数用于处理嵌套数据结构
T = TypeVar("T")
def flatten_values( # 将 字典 中所有序列值展开成一个 一维 列表
d: Mapping[Any, Sequence[T]],
) -> list[T]:
"""
Flatten a dictionary of sequences into a single list of values.
"""
return [v for vs in d.values() for v in vs]
def flatten_values_len( #计算字典中所有序列值的元素总数
d: Mapping[Any, Sequence[T]],
) -> int:
"""
Flatten a dictionary of sequences into a single list of values.
"""
val = reduce(lambda acc, xs: acc + len(xs), d.values(), 0)
return val
# %% #将同一目标在不同相机中的关键点投影到同一图像上,直观验证多相机跟踪的准确性
WIDTH = 2560
HEIGHT = 1440
# 将聚类结果转换为Detection对象列表
clusters_detections = clusters_to_detections(clusters, sorted_detections)
# 创建空白图像(黑色背景)
im = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8)
# 可视化第一个聚类中的所有检测(同一目标在不同相机中的关键点)
for el in clusters_detections[0]:
im = visualize_whole_body(np.asarray(el.keypoints), im)
# 显示结果图像
p = plt.imshow(im)
display(p)
# %% #根据上部分顺延,可视化第二个聚类,通常指检测中的第二个个体
im_prime = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8)
for el in clusters_detections[1]:
im_prime = visualize_whole_body(np.asarray(el.keypoints), im_prime)
p_prime = plt.imshow(im_prime)
display(p_prime)
# %% #从多视角图像点进行三维点三角测量的算法
@jaxtyped(typechecker=beartype)
def triangulate_one_point_from_multiple_views_linear( # 单一点的三角测量
proj_matrices: Float[Array, "N 3 4"],
points: Num[Array, "N 2"],
confidences: Optional[Float[Array, "N"]] = None,
) -> Float[Array, "3"]:
"""
Args:
proj_matrices: 形状为(N, 3, 4)的投影矩阵序列
points: 形状为(N, 2)的点坐标序列
confidences: 形状为(N,)的置信度序列,范围[0.0, 1.0]
Returns:
point_3d: 形状为(3,)的三角测量得到的3D点
"""
assert len(proj_matrices) == len(points)
N = len(proj_matrices)
confi: Float[Array, "N"]
if confidences is None:
confi = jnp.ones(N, dtype=np.float32)
else:
# Use square root of confidences for weighting - more balanced approach
confi = jnp.sqrt(jnp.clip(confidences, 0, 1))
A = jnp.zeros((N * 2, 4), dtype=np.float32)
for i in range(N):
x, y = points[i]
A = A.at[2 * i].set(proj_matrices[i, 2] * x - proj_matrices[i, 0])
A = A.at[2 * i + 1].set(proj_matrices[i, 2] * y - proj_matrices[i, 1])
A = A.at[2 * i].mul(confi[i])
A = A.at[2 * i + 1].mul(confi[i])
# https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.svd.html
_, _, vh = jnp.linalg.svd(A, full_matrices=False)
point_3d_homo = vh[-1] # shape (4,)
# replace the Python `if` with a jnp.where
point_3d_homo = jnp.where(
point_3d_homo[3] < 0, # predicate (scalar bool tracer)
-point_3d_homo, # if True
point_3d_homo, # if False
)
point_3d = point_3d_homo[:3] / point_3d_homo[3]
return point_3d
@jaxtyped(typechecker=beartype)
def triangulate_points_from_multiple_views_linear( # 批量三角测量
proj_matrices: Float[Array, "N 3 4"],
points: Num[Array, "N P 2"],
confidences: Optional[Float[Array, "N P"]] = None,
) -> Float[Array, "P 3"]:
"""
Batch-triangulate P points observed by N cameras, linearly via SVD.
Args:
proj_matrices: (N, 3, 4) projection matrices
points: (N, P, 2) image-coordinates per view
confidences: (N, P, 1) optional per-view confidences in [0,1]
Returns:
(P, 3) 3D point for each of the P tracks
"""
N, P, _ = points.shape
assert proj_matrices.shape[0] == N
if confidences is None:
conf = jnp.ones((N, P), dtype=jnp.float32)
else:
conf = jnp.sqrt(jnp.clip(confidences, 0.0, 1.0))
# vectorize your one-point routine over P
vmap_triangulate = jax.vmap(
triangulate_one_point_from_multiple_views_linear,
in_axes=(None, 1, 1), # proj_matrices static, map over points[:,p,:], conf[:,p]
out_axes=0,
)
return vmap_triangulate(proj_matrices, points, conf)
# %% #两个函数实现了带时间权重的多视角三维点三角测量算法
@jaxtyped(typechecker=beartype)
def triangulate_one_point_from_multiple_views_linear_time_weighted( #单一点三角测量函数
proj_matrices: Float[Array, "N 3 4"],
points: Num[Array, "N 2"],
delta_t: Num[Array, "N"],
lambda_t: float = 10.0,
confidences: Optional[Float[Array, "N"]] = None,
) -> Float[Array, "3"]:
"""
Triangulate one point from multiple views with time-weighted linear least squares.
Implements the incremental reconstruction method from "Cross-View Tracking for Multi-Human 3D Pose"
with weighting formula: w_i = exp(-λ_t(t-t_i)) / ||c^i^T||_2
Args:
proj_matrices: Shape (N, 3, 4) projection matrices sequence
points: Shape (N, 2) point coordinates sequence
delta_t: Time differences between current time and each observation (in seconds)
lambda_t: Time penalty rate (higher values decrease influence of older observations)
confidences: Shape (N,) confidence values in range [0.0, 1.0]
Returns:
point_3d: Shape (3,) triangulated 3D point
"""
assert len(proj_matrices) == len(points)
assert len(delta_t) == len(points)
N = len(proj_matrices)
# Prepare confidence weights
confi: Float[Array, "N"]
if confidences is None:
confi = jnp.ones(N, dtype=np.float32)
else:
confi = jnp.sqrt(jnp.clip(confidences, 0, 1))
A = jnp.zeros((N * 2, 4), dtype=np.float32)
# First build the coefficient matrix without weights
for i in range(N):
x, y = points[i]
A = A.at[2 * i].set(proj_matrices[i, 2] * x - proj_matrices[i, 0])
A = A.at[2 * i + 1].set(proj_matrices[i, 2] * y - proj_matrices[i, 1])
# Then apply the time-based and confidence weights
for i in range(N):
# Calculate time-decay weight: e^(-λ_t * Δt)
time_weight = jnp.exp(-lambda_t * delta_t[i])
# Calculate normalization factor: ||c^i^T||_2
row_norm_1 = jnp.linalg.norm(A[2 * i])
row_norm_2 = jnp.linalg.norm(A[2 * i + 1])
# Apply combined weight: time_weight / row_norm * confidence
w1 = (time_weight / row_norm_1) * confi[i]
w2 = (time_weight / row_norm_2) * confi[i]
A = A.at[2 * i].mul(w1)
A = A.at[2 * i + 1].mul(w2)
# Solve using SVD
_, _, vh = jnp.linalg.svd(A, full_matrices=False)
point_3d_homo = vh[-1] # shape (4,)
# Ensure homogeneous coordinate is positive
point_3d_homo = jnp.where(
point_3d_homo[3] < 0,
-point_3d_homo,
point_3d_homo,
)
# Convert from homogeneous to Euclidean coordinates
point_3d = point_3d_homo[:3] / point_3d_homo[3]
return point_3d
@jaxtyped(typechecker=beartype)
def triangulate_points_from_multiple_views_linear_time_weighted( #批量三角测量函数
proj_matrices: Float[Array, "N 3 4"],
points: Num[Array, "N P 2"],
delta_t: Num[Array, "N"],
lambda_t: float = 10.0,
confidences: Optional[Float[Array, "N P"]] = None,
) -> Float[Array, "P 3"]:
"""
Vectorized version that triangulates P points from N camera views with time-weighting.
This function uses JAX's vmap to efficiently triangulate multiple points in parallel.
Args:
proj_matrices: Shape (N, 3, 4) projection matrices for N cameras
points: Shape (N, P, 2) 2D points for P keypoints across N cameras
delta_t: Shape (N,) time differences between current time and each camera's timestamp (seconds)
lambda_t: Time penalty rate (higher values decrease influence of older observations)
confidences: Shape (N, P) confidence values for each point in each camera
Returns:
points_3d: Shape (P, 3) triangulated 3D points
"""
N, P, _ = points.shape
assert (
proj_matrices.shape[0] == N
), "Number of projection matrices must match number of cameras"
assert delta_t.shape[0] == N, "Number of time deltas must match number of cameras"
if confidences is None:
# Create uniform confidences if none provided
conf = jnp.ones((N, P), dtype=jnp.float32)
else:
conf = confidences
# Define the vmapped version of the single-point function
# We map over the second dimension (P points) of the input arrays
vmap_triangulate = jax.vmap(
triangulate_one_point_from_multiple_views_linear_time_weighted,
in_axes=(
None,
1,
None,
None,
1,
), # proj_matrices and delta_t static, map over points
out_axes=0, # Output has first dimension corresponding to points
)
# For each point p, extract the 2D coordinates from all cameras and triangulate
return vmap_triangulate(
proj_matrices, # (N, 3, 4) - static across points
points, # (N, P, 2) - map over dim 1 (P)
delta_t, # (N,) - static across points
lambda_t, # scalar - static
conf, # (N, P) - map over dim 1 (P)
)
# %% #从一个聚类的检测结果中通过三角测量计算三维点坐标,并返回该聚类的最新时间戳
@jaxtyped(typechecker=beartype)
def triangle_from_cluster(
cluster: Sequence[Detection],
) -> tuple[Float[Array, "N 3"], datetime]:
proj_matrices = jnp.array([el.camera.params.projection_matrix for el in cluster])
points = jnp.array([el.keypoints_undistorted for el in cluster])
confidences = jnp.array([el.confidences for el in cluster])
latest_timestamp = max(el.timestamp for el in cluster)
return (
triangulate_points_from_multiple_views_linear(
proj_matrices, points, confidences=confidences
),
latest_timestamp,
)
# %% #多目标跟踪系统的核心逻辑,用于从聚类的检测结果中创建和管理全局跟踪状态
def group_by_cluster_by_camera( #按相机分组函数
cluster: Sequence[Detection],
) -> PMap[CameraID, Detection]:
"""
group the detections by camera, and preserve the latest detection for each camera
"""
r: dict[CameraID, Detection] = {}
for el in cluster:
if el.camera.id in r:
eld = r[el.camera.id]
preserved = max([eld, el], key=lambda x: x.timestamp)
r[el.camera.id] = preserved
return pmap(r)
class GlobalTrackingState: #全局跟踪状态类
_last_id: int
_trackings: dict[int, Tracking]
def __init__(self):
self._last_id = 0
self._trackings = {}
def __repr__(self) -> str:
return (
f"GlobalTrackingState(last_id={self._last_id}, trackings={self._trackings})"
)
@property
def trackings(self) -> dict[int, Tracking]:
return shallow_copy(self._trackings)
def add_tracking(self, cluster: Sequence[Detection]) -> Tracking: #为一个聚类创建新的跟踪记录
if len(cluster) < 2:
raise ValueError(
"cluster must contain at least 2 detections to form a tracking"
)
kps_3d, latest_timestamp = triangle_from_cluster(cluster)
next_id = self._last_id + 1
tracking_state = TrackingState(
keypoints=kps_3d,
last_active_timestamp=latest_timestamp,
historical_detections_by_camera=group_by_cluster_by_camera(cluster),
)
tracking = Tracking(
id=next_id,
state=tracking_state,
velocity_filter=LastDifferenceVelocityFilter(kps_3d, latest_timestamp),
)
self._trackings[next_id] = tracking
self._last_id = next_id
return tracking
global_tracking_state = GlobalTrackingState()
for cluster in clusters_detections:
global_tracking_state.add_tracking(cluster)
display(global_tracking_state)
# %% #从同步生成器 sync_gen 中获取下一批时间对齐的检测结果,并通过 display() 函数进行可视化
next_group = next(sync_gen) # 从同步生成器获取下一批检测结果
display(next_group) # 在Jupyter环境中显示该批次数据
# %% #多相机跟踪系统中 关联亲和度 计算的核心算法
@jaxtyped(typechecker=beartype)
def calculate_distance_2d( #归一化 2D 距离
left: Num[Array, "J 2"],
right: Num[Array, "J 2"],
image_size: tuple[int, int] = (1, 1),
) -> Float[Array, "J"]:
"""
Calculate the *normalized* distance between two sets of keypoints.
Args:
left: The left keypoints
right: The right keypoints
image_size: The size of the image
Returns:
Array of normalized Euclidean distances between corresponding keypoints
"""
w, h = image_size
if w == 1 and h == 1:
# already normalized
left_normalized = left
right_normalized = right
else:
left_normalized = left / jnp.array([w, h])
right_normalized = right / jnp.array([w, h])
dist = jnp.linalg.norm(left_normalized - right_normalized, axis=-1)
return dist
@jaxtyped(typechecker=beartype)
def calculate_affinity_2d( #2D 亲和度分数
distance_2d: Float[Array, "J"],
delta_t: timedelta,
w_2d: float,
alpha_2d: float,
lambda_a: float,
) -> Float[Array, "J"]:
"""
Calculate the affinity between two detections based on the distances between their keypoints.
The affinity score is calculated by summing individual keypoint affinities:
A_2D = sum(w_2D * (1 - distance_2D / (alpha_2D*delta_t)) * np.exp(-lambda_a * delta_t)) for each keypoint
Args:
distance_2d: The normalized distances between keypoints (array with one value per keypoint)
w_2d: The weight for 2D affinity
alpha_2d: The normalization factor for distance
lambda_a: The decay rate for time difference
delta_t: The time delta between the two detections, in seconds
Returns:
Sum of affinity scores across all keypoints
"""
delta_t_s = delta_t.total_seconds()
affinity_per_keypoint = (
w_2d
* (1 - distance_2d / (alpha_2d * delta_t_s))
* jnp.exp(-lambda_a * delta_t_s)
)
return affinity_per_keypoint
@jaxtyped(typechecker=beartype)
def perpendicular_distance_point_to_line_two_points( #点到射线的垂直距离
point: Num[Array, "3"], line: tuple[Num[Array, "3"], Num[Array, "3"]]
) -> Float[Array, ""]:
"""
Calculate the perpendicular distance between a point and a line.
where `line` is represented by two points: `(line_start, line_end)`
Args:
point: The point to calculate the distance to
line: The line to calculate the distance to, represented by two points
Returns:
The perpendicular distance between the point and the line
(should be a scalar in `float`)
"""
line_start, line_end = line
distance = jnp.linalg.norm(
jnp.cross(line_end - line_start, line_start - point)
) / jnp.linalg.norm(line_end - line_start)
return distance
@jaxtyped(typechecker=beartype)
#多相机三维重建中的射线距离计算,是评估 2D 检测点与 3D 跟踪点匹配程度的核心算法
def perpendicular_distance_camera_2d_points_to_tracking_raycasting(
detection: Detection,
tracking: Tracking,
delta_t: timedelta,
) -> Float[Array, "J"]:
"""
NOTE: `delta_t` is now taken from the caller and NOT recomputed internally.
Calculate the perpendicular distances between predicted 3D tracking points
and the rays cast from camera center through the 2D image points.
Args:
detection: The detection object containing 2D keypoints and camera parameters
tracking: The tracking object containing 3D keypoints
delta_t: Time delta between the tracking's last update and current observation
Returns:
Array of perpendicular distances for each keypoint
"""
camera = detection.camera
predicted_pose = tracking.predict(delta_t)
# Back-project the 2D points to 3D space
# intersection with z=0 plane
back_projected_points = detection.camera.unproject_points_to_z_plane(
detection.keypoints, z=0.0
)
camera_center = camera.params.location
def calc_distance(predicted_point, back_projected_point):
return perpendicular_distance_point_to_line_two_points(
predicted_point, (camera_center, back_projected_point)
)
# Vectorize over all keypoints
vmap_calc_distance = jax.vmap(calc_distance)
distances: Float[Array, "J"] = vmap_calc_distance(
predicted_pose, back_projected_points
)
return distances
@jaxtyped(typechecker=beartype)
def calculate_affinity_3d( #3D 亲和度分数
distances: Float[Array, "J"],
delta_t: timedelta,
w_3d: float,
alpha_3d: float,
lambda_a: float,
) -> Float[Array, "J"]:
"""
Calculate 3D affinity score between a tracking and detection.
The affinity score is calculated by summing individual keypoint affinities:
A_3D = sum(w_3D * (1 - dl / alpha_3D) * np.exp(-lambda_a * delta_t)) for each keypoint
Args:
distances: Array of perpendicular distances for each keypoint
delta_t: Time difference between tracking and detection
w_3d: Weight for 3D affinity
alpha_3d: Normalization factor for distance
lambda_a: Decay rate for time difference
Returns:
Sum of affinity scores across all keypoints
"""
delta_t_s = delta_t.total_seconds()
affinity_per_keypoint = (
w_3d * (1 - distances / alpha_3d) * jnp.exp(-lambda_a * delta_t_s)
)
return affinity_per_keypoint
@beartype
def calculate_tracking_detection_affinity( #综合亲和度计算流程
tracking: Tracking,
detection: Detection,
w_2d: float,
alpha_2d: float,
w_3d: float,
alpha_3d: float,
lambda_a: float,
) -> float:
"""
Calculate the affinity between a tracking and a detection.
Args:
tracking: The tracking object
detection: The detection object
w_2d: Weight for 2D affinity
alpha_2d: Normalization factor for 2D distance
w_3d: Weight for 3D affinity
alpha_3d: Normalization factor for 3D distance
lambda_a: Decay rate for time difference
Returns:
Combined affinity score
"""
camera = detection.camera
delta_t_raw = detection.timestamp - tracking.state.last_active_timestamp
# Clamp delta_t to avoid division-by-zero / exploding affinity.
delta_t = max(delta_t_raw, DELTA_T_MIN)
# Calculate 2D affinity
tracking_2d_projection = camera.project(tracking.state.keypoints)
w, h = camera.params.image_size
distance_2d = calculate_distance_2d(
tracking_2d_projection,
detection.keypoints,
image_size=(int(w), int(h)),
)
affinity_2d = calculate_affinity_2d(
distance_2d,
delta_t,
w_2d=w_2d,
alpha_2d=alpha_2d,
lambda_a=lambda_a,
)
# Calculate 3D affinity
distances = perpendicular_distance_camera_2d_points_to_tracking_raycasting(
detection, tracking, delta_t
)
affinity_3d = calculate_affinity_3d(
distances,
delta_t,
w_3d=w_3d,
alpha_3d=alpha_3d,
lambda_a=lambda_a,
)
# Combine affinities
total_affinity = affinity_2d + affinity_3d
return jnp.sum(total_affinity).item()
# %% #实现了多相机跟踪系统中亲和度矩阵的高效计算是连接跟踪轨迹Tracking与新检测结果Detection的核心算法
@beartype
def calculate_camera_affinity_matrix_jax( #相机亲和度矩阵计算
trackings: Sequence[Tracking],
camera_detections: Sequence[Detection],
w_2d: float,
alpha_2d: float,
w_3d: float,
alpha_3d: float,
lambda_a: float,
) -> Float[Array, "T D"]:
"""
Vectorized implementation to compute an affinity matrix between *trackings*
and *detections* coming from **one** camera.
Compared with the simple double-for-loop version, this leverages `jax`'s
broadcasting + `vmap` facilities and avoids Python loops over every
(tracking, detection) pair. The mathematical definition of the affinity
is **unchanged**, so the result remains bit-identical to the reference
implementation used in the tests.
"""
# ------------------------------------------------------------------
# Quick validations / early-exit guards
# ------------------------------------------------------------------
if len(trackings) == 0 or len(camera_detections) == 0:
# Return an empty affinity matrix with appropriate shape.
return jnp.zeros((len(trackings), len(camera_detections))) # type: ignore[return-value]
cam = next(iter(camera_detections)).camera
# Ensure every detection truly belongs to the same camera (guard clause)
cam_id = cam.id
if any(det.camera.id != cam_id for det in camera_detections):
raise ValueError(
"All detections passed to `calculate_camera_affinity_matrix` must come from one camera."
)
# We will rely on a single `Camera` instance (all detections share it)
w_img_, h_img_ = cam.params.image_size
w_img, h_img = float(w_img_), float(h_img_)
# ------------------------------------------------------------------
# Gather data into ndarray / DeviceArray batches so that we can compute
# everything in a single (or a few) fused kernels.
# ------------------------------------------------------------------
# === Tracking-side tensors ===
kps3d_trk: Float[Array, "T J 3"] = jnp.stack(
[trk.state.keypoints for trk in trackings]
) # (T, J, 3)
J = kps3d_trk.shape[1]
# === Detection-side tensors ===
kps2d_det: Float[Array, "D J 2"] = jnp.stack(
[det.keypoints for det in camera_detections]
) # (D, J, 2)
# ------------------------------------------------------------------
# Compute Δt matrix shape (T, D)
# ------------------------------------------------------------------
# Epoch timestamps are ~1.7 × 10⁹; storing them in float32 wipes out
# subsecond detail (resolution ≈ 200 ms). Keep them in float64 until
# after subtraction so we preserve Δtontheorderofmilliseconds.
# --- timestamps ----------
t0 = min(
chain(
(trk.state.last_active_timestamp for trk in trackings),
(det.timestamp for det in camera_detections),
)
).timestamp() # common origin (float)
ts_trk = jnp.array(
[trk.state.last_active_timestamp.timestamp() - t0 for trk in trackings],
dtype=jnp.float32, # now small, ms-scale fits in fp32
)
ts_det = jnp.array(
[det.timestamp.timestamp() - t0 for det in camera_detections],
dtype=jnp.float32,
)
# Δt in seconds, fp32 throughout
delta_t = ts_det[None, :] - ts_trk[:, None] # (T,D)
min_dt_s = float(DELTA_T_MIN.total_seconds())
delta_t = jnp.clip(delta_t, a_min=min_dt_s, a_max=None)
# ------------------------------------------------------------------
# ---------- 2D affinity -------------------------------------------
# ------------------------------------------------------------------
# Project each tracking's 3D keypoints onto the image once.
# `Camera.project` works per-sample, so we vmap over the first axis.
proj_fn = jax.vmap(cam.project, in_axes=0) # maps over the keypoint sets
kps2d_trk_proj: Float[Array, "T J 2"] = proj_fn(kps3d_trk) # (T, J, 2)
# Normalise keypoints by image size so absolute units do not bias distance
norm_trk = kps2d_trk_proj / jnp.array([w_img, h_img])
norm_det = kps2d_det / jnp.array([w_img, h_img])
# L2 distance for every (T, D, J)
# reshape for broadcasting: (T,1,J,2) vs (1,D,J,2)
diff2d = norm_trk[:, None, :, :] - norm_det[None, :, :, :]
dist2d: Float[Array, "T D J"] = jnp.linalg.norm(diff2d, axis=-1)
# Compute per-keypoint 2D affinity
delta_t_broadcast = delta_t[:, :, None] # (T, D, 1)
affinity_2d = (
w_2d
* (1 - dist2d / (alpha_2d * delta_t_broadcast))
* jnp.exp(-lambda_a * delta_t_broadcast)
)
# ------------------------------------------------------------------
# ---------- 3D affinity -------------------------------------------
# ------------------------------------------------------------------
# For each detection pre-compute back-projected 3D points lying on z=0 plane.
backproj_points_list = [
det.camera.unproject_points_to_z_plane(det.keypoints, z=0.0)
for det in camera_detections
] # each (J,3)
backproj: Float[Array, "D J 3"] = jnp.stack(backproj_points_list) # (D, J, 3)
zero_velocity = jnp.zeros((J, 3))
trk_velocities = jnp.stack(
[
trk.velocity if trk.velocity is not None else zero_velocity
for trk in trackings
]
)
predicted_pose: Float[Array, "T D J 3"] = (
kps3d_trk[:, None, :, :] # (T,1,J,3)
+ trk_velocities[:, None, :, :] * delta_t[:, :, None, None] # (T,D,1,1)
)
# Camera center shape (3,) -> will broadcast
cam_center = cam.params.location
# Compute perpendicular distance using vectorized formula
# p1 = cam_center (3,)
# p2 = backproj (D, J, 3)
# P = predicted_pose (T, D, J, 3)
# Broadcast plan: v1 = P - p1 → (T, D, J, 3)
# v2 = p2[None, ...]-p1 → (1, D, J, 3)
# Shapes now line up; no stray singleton axis.
p1 = cam_center
p2 = backproj
P = predicted_pose
v1 = P - p1
v2 = p2[None, :, :, :] - p1 # (1, D, J, 3)
cross = jnp.cross(v1, v2) # (T, D, J, 3)
num = jnp.linalg.norm(cross, axis=-1) # (T, D, J)
den = jnp.linalg.norm(v2, axis=-1) # (1, D, J)
dist3d: Float[Array, "T D J"] = num / den
affinity_3d = (
w_3d * (1 - dist3d / alpha_3d) * jnp.exp(-lambda_a * delta_t_broadcast)
)
# ------------------------------------------------------------------
# Combine and reduce across keypoints → (T, D)
# ------------------------------------------------------------------
total_affinity: Float[Array, "T D"] = jnp.sum(affinity_2d + affinity_3d, axis=-1)
return total_affinity # type: ignore[return-value]
@beartype
def calculate_affinity_matrix( #多相机亲和度矩阵计算
trackings: Sequence[Tracking],
detections: Sequence[Detection] | Mapping[CameraID, list[Detection]],
w_2d: float,
alpha_2d: float,
w_3d: float,
alpha_3d: float,
lambda_a: float,
) -> dict[CameraID, AffinityResult]:
"""
Calculate the affinity matrix between a set of trackings and detections.
Args:
trackings: Sequence of tracking objects
detections: Sequence of detection objects or a group detections by ID
w_2d: Weight for 2D affinity
alpha_2d: Normalization factor for 2D distance
w_3d: Weight for 3D affinity
alpha_3d: Normalization factor for 3D distance
lambda_a: Decay rate for time difference
Returns:
A dictionary mapping camera IDs to affinity results.
"""
if isinstance(detections, Mapping):
detection_by_camera = detections
else:
detection_by_camera = classify_by_camera(detections)
res: dict[CameraID, AffinityResult] = {}
for camera_id, camera_detections in detection_by_camera.items():
affinity_matrix = calculate_camera_affinity_matrix_jax(
trackings,
camera_detections,
w_2d,
alpha_2d,
w_3d,
alpha_3d,
lambda_a,
)
# row, col
indices_T, indices_D = linear_sum_assignment(affinity_matrix)
affinity_result = AffinityResult(
matrix=affinity_matrix,
trackings=trackings,
detections=camera_detections,
indices_T=indices_T,
indices_D=indices_D,
)
res[camera_id] = affinity_result
return res
# %% #实现了跨视角关联cross-view association 流程
# let's do cross-view association
W_2D = 1.0
ALPHA_2D = 1.0
LAMBDA_A = 0.1
W_3D = 1.0
ALPHA_3D = 1.0
trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id)
unmatched_detections = shallow_copy(next_group)
camera_detections = classify_by_camera(unmatched_detections)
affinities = calculate_affinity_matrix(
trackings,
unmatched_detections,
w_2d=W_2D,
alpha_2d=ALPHA_2D,
w_3d=W_3D,
alpha_3d=ALPHA_3D,
lambda_a=LAMBDA_A,
)
display(affinities)
# %% #两个函数分别实现了关联结果聚合和轨迹更新的核心逻辑
def affinity_result_by_tracking( #关联结果聚合
results: Iterable[AffinityResult],
min_affinity: float = 0.0,
) -> dict[TrackingID, list[Detection]]:
"""
Group affinity results by target ID.
Args:
results: the affinity results to group
min_affinity: the minimum affinity to consider
Returns:
a dictionary mapping tracking IDs to a list of detections
"""
res: dict[TrackingID, list[Detection]] = defaultdict(list)
for affinity_result in results:
for affinity, t, d in affinity_result.tracking_association():
if affinity < min_affinity:
continue
res[t.id].append(d)
return res
def update_tracking( #更新流程
tracking: Tracking,
detections: Sequence[Detection],
max_delta_t: timedelta = timedelta(milliseconds=100),
lambda_t: float = 10.0,
) -> None:
"""
update the tracking with a new set of detections
Args:
tracking: the tracking to update
detections: the detections to update the tracking with
max_delta_t: the maximum time difference between the last active timestamp and the latest detection
lambda_t: the lambda value for the time difference
Note:
the function would mutate the tracking object
"""
last_active_timestamp = tracking.state.last_active_timestamp
latest_timestamp = max(d.timestamp for d in detections)
d = thaw(tracking.state.historical_detections_by_camera)
for detection in detections:
d[detection.camera.id] = detection
for camera_id, detection in d.items():
if detection.timestamp - latest_timestamp > max_delta_t:
del d[camera_id]
new_detections = freeze(d)
new_detections_list = list(new_detections.values())
project_matrices = jnp.stack(
[detection.camera.params.projection_matrix for detection in new_detections_list]
)
delta_t = jnp.array(
[
detection.timestamp.timestamp() - last_active_timestamp.timestamp()
for detection in new_detections_list
]
)
kps = jnp.stack([detection.keypoints for detection in new_detections_list])
conf = jnp.stack([detection.confidences for detection in new_detections_list])
kps_3d = triangulate_points_from_multiple_views_linear_time_weighted(
project_matrices, kps, delta_t, lambda_t, conf
)
new_state = TrackingState(
keypoints=kps_3d,
last_active_timestamp=latest_timestamp,
historical_detections_by_camera=new_detections,
)
tracking.update(kps_3d, latest_timestamp)
tracking.state = new_state
# %% #多目标跟踪系统中轨迹更新的核心流程
affinity_results_by_tracking = affinity_result_by_tracking(affinities.values()) # 1. 按轨迹ID聚合所有相机的匹配检测结果
for tracking_id, detections in affinity_results_by_tracking.items(): # 2. 遍历每个轨迹ID用匹配的检测结果更新轨迹
update_tracking(global_tracking_state.trackings[tracking_id], detections)
# %%