Files
CVTH3PE/playground.py
crosstyan 29c8ef3990 fix: fix the timestamp precision error cause the jax version not giving the correct result
Epoch timestamps are ~1.7 × 10⁹; storing them in float32 wipes out
sub‑second detail (resolution ≈ 200 ms).  Keep them in float64 until
after subtraction so we preserve Δt‑on‑the‑order‑of‑milliseconds.

- Introduced a `_DEBUG_CURRENT_TRACKING` variable to track the current indices of tracking and detection during calculations.
- Added a `_global_current_tracking_str` function to format the current tracking state for debugging purposes.
- Enhanced `calculate_distance_2d` and `calculate_tracking_detection_affinity` functions with debug print statements to log intermediate values, improving traceability of calculations.
- Updated `perpendicular_distance_camera_2d_points_to_tracking_raycasting` to accept `delta_t` from the caller while ensuring it adheres to a minimum threshold.
- Refactored timestamp handling in `calculate_camera_affinity_matrix_jax` to maintain precision during calculations.
2025-04-29 12:56:58 +08:00

1248 lines
39 KiB
Python
Raw Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# ---
# jupyter:
# jupytext:
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.17.0
# kernelspec:
# display_name: .venv
# language: python
# name: python3
# ---
# %%
from collections import OrderedDict
from copy import copy as shallow_copy
from copy import deepcopy as deep_copy
from dataclasses import dataclass
from datetime import datetime, timedelta
from functools import partial, reduce
from pathlib import Path
from typing import (
Any,
Generator,
Optional,
TypeAlias,
TypedDict,
TypeVar,
cast,
overload,
)
import awkward as ak
import jax
import jax.numpy as jnp
import numpy as np
import orjson
from beartype import beartype
from beartype.typing import Mapping, Sequence
from cv2 import undistortPoints
from IPython.display import display
from jaxtyping import Array, Float, Num, jaxtyped
from matplotlib import pyplot as plt
from numpy.typing import ArrayLike
from scipy.optimize import linear_sum_assignment
from scipy.spatial.transform import Rotation as R
from typing_extensions import deprecated
from app.camera import (
Camera,
CameraID,
CameraParams,
Detection,
calculate_affinity_matrix_by_epipolar_constraint,
classify_by_camera,
)
from app.solver._old import GLPKSolver
from app.tracking import Tracking
from app.visualize.whole_body import visualize_whole_body
NDArray: TypeAlias = np.ndarray
# %%
DATASET_PATH = Path("samples") / "04_02"
AK_CAMERA_DATASET: ak.Array = ak.from_parquet(DATASET_PATH / "camera_params.parquet")
DELTA_T_MIN = timedelta(milliseconds=10)
display(AK_CAMERA_DATASET)
_DEBUG_CURRENT_TRACKING: tuple[int, int] = (0, 0)
def _global_current_tracking_str():
return str(_DEBUG_CURRENT_TRACKING)
# %%
class Resolution(TypedDict):
width: int
height: int
class Intrinsic(TypedDict):
camera_matrix: Num[Array, "3 3"]
"""
K
"""
distortion_coefficients: Num[Array, "N"]
"""
distortion coefficients; usually 5
"""
class Extrinsic(TypedDict):
rvec: Num[NDArray, "3"]
tvec: Num[NDArray, "3"]
class ExternalCameraParams(TypedDict):
name: str
port: int
intrinsic: Intrinsic
extrinsic: Extrinsic
resolution: Resolution
# %%
def read_dataset_by_port(port: int) -> ak.Array:
P = DATASET_PATH / f"{port}.parquet"
return ak.from_parquet(P)
KEYPOINT_DATASET = {
int(p): read_dataset_by_port(p) for p in ak.to_numpy(AK_CAMERA_DATASET["port"])
}
# %%
class KeypointDataset(TypedDict):
frame_index: int
boxes: Num[NDArray, "N 4"]
kps: Num[NDArray, "N J 2"]
kps_scores: Num[NDArray, "N J"]
@jaxtyped(typechecker=beartype)
def to_transformation_matrix(
rvec: Num[NDArray, "3"], tvec: Num[NDArray, "3"]
) -> Num[NDArray, "4 4"]:
res = np.eye(4)
res[:3, :3] = R.from_rotvec(rvec).as_matrix()
res[:3, 3] = tvec
return res
@jaxtyped(typechecker=beartype)
def undistort_points(
points: Num[NDArray, "M 2"],
camera_matrix: Num[NDArray, "3 3"],
dist_coeffs: Num[NDArray, "N"],
) -> Num[NDArray, "M 2"]:
K = camera_matrix
dist = dist_coeffs
res = undistortPoints(points, K, dist, P=K) # type: ignore
return res.reshape(-1, 2)
def from_camera_params(camera: ExternalCameraParams) -> Camera:
rt = jnp.array(
to_transformation_matrix(
ak.to_numpy(camera["extrinsic"]["rvec"]),
ak.to_numpy(camera["extrinsic"]["tvec"]),
)
)
K = jnp.array(camera["intrinsic"]["camera_matrix"]).reshape(3, 3)
dist_coeffs = jnp.array(camera["intrinsic"]["distortion_coefficients"])
image_size = jnp.array(
(camera["resolution"]["width"], camera["resolution"]["height"])
)
return Camera(
id=camera["name"],
params=CameraParams(
K=K,
Rt=rt,
dist_coeffs=dist_coeffs,
image_size=image_size,
),
)
def preprocess_keypoint_dataset(
dataset: Sequence[KeypointDataset],
camera: Camera,
fps: float,
start_timestamp: datetime,
) -> Generator[Detection, None, None]:
frame_interval_s = 1 / fps
for el in dataset:
frame_index = el["frame_index"]
timestamp = start_timestamp + timedelta(seconds=frame_index * frame_interval_s)
for kp, kp_score in zip(el["kps"], el["kps_scores"]):
yield Detection(
keypoints=jnp.array(kp),
confidences=jnp.array(kp_score),
camera=camera,
timestamp=timestamp,
)
# %%
DetectionGenerator: TypeAlias = Generator[Detection, None, None]
def sync_batch_gen(gens: Sequence[DetectionGenerator], diff: timedelta):
"""
given a list of detection generators, return a generator that yields a batch of detections
Args:
gens: list of detection generators
diff: maximum timestamp difference between detections to consider them part of the same batch
"""
N = len(gens)
last_batch_timestamp: Optional[datetime] = None
next_batch_timestamp: Optional[datetime] = None
current_batch: list[Detection] = []
next_batch: list[Detection] = []
paused: list[bool] = [False] * N
finished: list[bool] = [False] * N
def reset_paused():
"""
reset paused list based on finished list
"""
for i in range(N):
if not finished[i]:
paused[i] = False
else:
paused[i] = True
EPS = 1e-6
# a small epsilon to avoid floating point precision issues
diff_esp = diff - timedelta(seconds=EPS)
while True:
for i, gen in enumerate(gens):
try:
if finished[i] or paused[i]:
continue
val = next(gen)
if last_batch_timestamp is None:
last_batch_timestamp = val.timestamp
current_batch.append(val)
else:
if abs(val.timestamp - last_batch_timestamp) >= diff_esp:
next_batch.append(val)
if next_batch_timestamp is None:
next_batch_timestamp = val.timestamp
paused[i] = True
if all(paused):
yield current_batch
current_batch = next_batch
next_batch = []
last_batch_timestamp = next_batch_timestamp
next_batch_timestamp = None
reset_paused()
else:
current_batch.append(val)
except StopIteration:
finished[i] = True
paused[i] = True
if all(finished):
if len(current_batch) > 0:
# All generators exhausted, flush remaining batch and exit
yield current_batch
break
# %%
@overload
def to_projection_matrix(
transformation_matrix: Num[NDArray, "4 4"], camera_matrix: Num[NDArray, "3 3"]
) -> Num[NDArray, "3 4"]: ...
@overload
def to_projection_matrix(
transformation_matrix: Num[Array, "4 4"], camera_matrix: Num[Array, "3 3"]
) -> Num[Array, "3 4"]: ...
@jaxtyped(typechecker=beartype)
def to_projection_matrix(
transformation_matrix: Num[Any, "4 4"],
camera_matrix: Num[Any, "3 3"],
) -> Num[Any, "3 4"]:
return camera_matrix @ transformation_matrix[:3, :]
to_projection_matrix_jit = jax.jit(to_projection_matrix)
@jaxtyped(typechecker=beartype)
def dlt(
H1: Num[NDArray, "3 4"],
H2: Num[NDArray, "3 4"],
p1: Num[NDArray, "2"],
p2: Num[NDArray, "2"],
) -> Num[NDArray, "3"]:
"""
Direct Linear Transformation
"""
A = [
p1[1] * H1[2, :] - H1[1, :],
H1[0, :] - p1[0] * H1[2, :],
p2[1] * H2[2, :] - H2[1, :],
H2[0, :] - p2[0] * H2[2, :],
]
A = np.array(A).reshape((4, 4))
B = A.transpose() @ A
from scipy import linalg
U, s, Vh = linalg.svd(B, full_matrices=False)
return Vh[3, 0:3] / Vh[3, 3]
@overload
def homogeneous_to_euclidean(points: Num[NDArray, "N 4"]) -> Num[NDArray, "N 3"]: ...
@overload
def homogeneous_to_euclidean(points: Num[Array, "N 4"]) -> Num[Array, "N 3"]: ...
@jaxtyped(typechecker=beartype)
def homogeneous_to_euclidean(
points: Num[Any, "N 4"],
) -> Num[Any, "N 3"]:
"""
将齐次坐标转换为欧几里得坐标
Args:
points: homogeneous coordinates (x, y, z, w) in numpy array or jax array
Returns:
euclidean coordinates (x, y, z) in numpy array or jax array
"""
return points[..., :-1] / points[..., -1:]
# %%
FPS = 24
image_gen_5600 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5600], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5600][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore
image_gen_5601 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5601], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5601][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore
image_gen_5602 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5602], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5602][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore
display(1 / FPS)
sync_gen = sync_batch_gen(
[image_gen_5600, image_gen_5601, image_gen_5602], timedelta(seconds=1 / FPS)
)
# %%
sorted_detections, affinity_matrix = calculate_affinity_matrix_by_epipolar_constraint(
next(sync_gen), alpha_2d=2000
)
display(sorted_detections)
# %%
display(
list(
map(
lambda x: {"timestamp": str(x.timestamp), "camera": x.camera.id},
sorted_detections,
)
)
)
with jnp.printoptions(precision=3, suppress=True):
display(affinity_matrix)
# %%
def clusters_to_detections(
clusters: Sequence[Sequence[int]], sorted_detections: Sequence[Detection]
) -> list[list[Detection]]:
"""
given a list of clusters (which is the indices of the detections in the sorted_detections list),
extract the detections from the sorted_detections list
Args:
clusters: list of clusters, each cluster is a list of indices of the detections in the `sorted_detections` list
sorted_detections: list of SORTED detections
Returns:
list of clusters, each cluster is a list of detections
"""
return [[sorted_detections[i] for i in cluster] for cluster in clusters]
solver = GLPKSolver()
aff_np = np.asarray(affinity_matrix).astype(np.float64)
clusters, sol_matrix = solver.solve(aff_np)
display(clusters)
display(sol_matrix)
# %%
T = TypeVar("T")
def flatten_values(
d: Mapping[Any, Sequence[T]],
) -> list[T]:
"""
Flatten a dictionary of sequences into a single list of values.
"""
return [v for vs in d.values() for v in vs]
def flatten_values_len(
d: Mapping[Any, Sequence[T]],
) -> int:
"""
Flatten a dictionary of sequences into a single list of values.
"""
val = reduce(lambda acc, xs: acc + len(xs), d.values(), 0)
return val
# %%
WIDTH = 2560
HEIGHT = 1440
clusters_detections = clusters_to_detections(clusters, sorted_detections)
im = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8)
for el in clusters_detections[0]:
im = visualize_whole_body(np.asarray(el.keypoints), im)
p = plt.imshow(im)
display(p)
# %%
im_prime = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8)
for el in clusters_detections[1]:
im_prime = visualize_whole_body(np.asarray(el.keypoints), im_prime)
p_prime = plt.imshow(im_prime)
display(p_prime)
# %%
@jaxtyped(typechecker=beartype)
def triangulate_one_point_from_multiple_views_linear(
proj_matrices: Float[Array, "N 3 4"],
points: Num[Array, "N 2"],
confidences: Optional[Float[Array, "N"]] = None,
) -> Float[Array, "3"]:
"""
Args:
proj_matrices: 形状为(N, 3, 4)的投影矩阵序列
points: 形状为(N, 2)的点坐标序列
confidences: 形状为(N,)的置信度序列,范围[0.0, 1.0]
Returns:
point_3d: 形状为(3,)的三角测量得到的3D点
"""
assert len(proj_matrices) == len(points)
N = len(proj_matrices)
confi: Float[Array, "N"]
if confidences is None:
confi = jnp.ones(N, dtype=np.float32)
else:
# Use square root of confidences for weighting - more balanced approach
confi = jnp.sqrt(jnp.clip(confidences, 0, 1))
A = jnp.zeros((N * 2, 4), dtype=np.float32)
for i in range(N):
x, y = points[i]
A = A.at[2 * i].set(proj_matrices[i, 2] * x - proj_matrices[i, 0])
A = A.at[2 * i + 1].set(proj_matrices[i, 2] * y - proj_matrices[i, 1])
A = A.at[2 * i].mul(confi[i])
A = A.at[2 * i + 1].mul(confi[i])
# https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.svd.html
_, _, vh = jnp.linalg.svd(A, full_matrices=False)
point_3d_homo = vh[-1] # shape (4,)
# replace the Python `if` with a jnp.where
point_3d_homo = jnp.where(
point_3d_homo[3] < 0, # predicate (scalar bool tracer)
-point_3d_homo, # if True
point_3d_homo, # if False
)
point_3d = point_3d_homo[:3] / point_3d_homo[3]
return point_3d
@jaxtyped(typechecker=beartype)
def triangulate_points_from_multiple_views_linear(
proj_matrices: Float[Array, "N 3 4"],
points: Num[Array, "N P 2"],
confidences: Optional[Float[Array, "N P"]] = None,
) -> Float[Array, "P 3"]:
"""
Batch-triangulate P points observed by N cameras, linearly via SVD.
Args:
proj_matrices: (N, 3, 4) projection matrices
points: (N, P, 2) image-coordinates per view
confidences: (N, P, 1) optional per-view confidences in [0,1]
Returns:
(P, 3) 3D point for each of the P tracks
"""
N, P, _ = points.shape
assert proj_matrices.shape[0] == N
if confidences is None:
conf = jnp.ones((N, P), dtype=jnp.float32)
else:
conf = jnp.sqrt(jnp.clip(confidences, 0.0, 1.0))
# vectorize your one-point routine over P
vmap_triangulate = jax.vmap(
triangulate_one_point_from_multiple_views_linear,
in_axes=(None, 1, 1), # proj_matrices static, map over points[:,p,:], conf[:,p]
out_axes=0,
)
return vmap_triangulate(proj_matrices, points, conf)
# %%
@jaxtyped(typechecker=beartype)
def triangle_from_cluster(
cluster: Sequence[Detection],
) -> tuple[Float[Array, "N 3"], datetime]:
proj_matrices = jnp.array([el.camera.params.projection_matrix for el in cluster])
points = jnp.array([el.keypoints_undistorted for el in cluster])
confidences = jnp.array([el.confidences for el in cluster])
latest_timestamp = max(el.timestamp for el in cluster)
return (
triangulate_points_from_multiple_views_linear(
proj_matrices, points, confidences=confidences
),
latest_timestamp,
)
# %%
class GlobalTrackingState:
_last_id: int
_trackings: dict[int, Tracking]
def __init__(self):
self._last_id = 0
self._trackings = {}
def __repr__(self) -> str:
return (
f"GlobalTrackingState(last_id={self._last_id}, trackings={self._trackings})"
)
@property
def trackings(self) -> dict[int, Tracking]:
return shallow_copy(self._trackings)
def add_tracking(self, cluster: Sequence[Detection]) -> Tracking:
kps_3d, latest_timestamp = triangle_from_cluster(cluster)
next_id = self._last_id + 1
tracking = Tracking(
id=next_id, keypoints=kps_3d, last_active_timestamp=latest_timestamp
)
self._trackings[next_id] = tracking
self._last_id = next_id
return tracking
global_tracking_state = GlobalTrackingState()
for cluster in clusters_detections:
global_tracking_state.add_tracking(cluster)
display(global_tracking_state)
# %%
next_group = next(sync_gen)
display(next_group)
# %%
@jaxtyped(typechecker=beartype)
def calculate_distance_2d(
left: Num[Array, "J 2"],
right: Num[Array, "J 2"],
image_size: tuple[int, int] = (1, 1),
) -> Float[Array, "J"]:
"""
Calculate the *normalized* distance between two sets of keypoints.
Args:
left: The left keypoints
right: The right keypoints
image_size: The size of the image
Returns:
Array of normalized Euclidean distances between corresponding keypoints
"""
w, h = image_size
if w == 1 and h == 1:
# already normalized
left_normalized = left
right_normalized = right
else:
left_normalized = left / jnp.array([w, h])
right_normalized = right / jnp.array([w, h])
dist = jnp.linalg.norm(left_normalized - right_normalized, axis=-1)
lt = left_normalized[:6]
rt = right_normalized[:6]
jax.debug.print(
"[REF]{} norm_trk first6 = {}",
_global_current_tracking_str(),
lt,
)
jax.debug.print(
"[REF]{} norm_det first6 = {}",
_global_current_tracking_str(),
rt,
)
jax.debug.print(
"[REF]{} dist2d first6 = {}",
_global_current_tracking_str(),
dist[:6],
)
return dist
@jaxtyped(typechecker=beartype)
def calculate_affinity_2d(
distance_2d: Float[Array, "J"],
delta_t: timedelta,
w_2d: float,
alpha_2d: float,
lambda_a: float,
) -> Float[Array, "J"]:
"""
Calculate the affinity between two detections based on the distances between their keypoints.
The affinity score is calculated by summing individual keypoint affinities:
A_2D = sum(w_2D * (1 - distance_2D / (alpha_2D*delta_t)) * np.exp(-lambda_a * delta_t)) for each keypoint
Args:
distance_2d: The normalized distances between keypoints (array with one value per keypoint)
w_2d: The weight for 2D affinity
alpha_2d: The normalization factor for distance
lambda_a: The decay rate for time difference
delta_t: The time delta between the two detections, in seconds
Returns:
Sum of affinity scores across all keypoints
"""
delta_t_s = delta_t.total_seconds()
affinity_per_keypoint = (
w_2d
* (1 - distance_2d / (alpha_2d * delta_t_s))
* jnp.exp(-lambda_a * delta_t_s)
)
return affinity_per_keypoint
@jaxtyped(typechecker=beartype)
def perpendicular_distance_point_to_line_two_points(
point: Num[Array, "3"], line: tuple[Num[Array, "3"], Num[Array, "3"]]
) -> Float[Array, ""]:
"""
Calculate the perpendicular distance between a point and a line.
where `line` is represented by two points: `(line_start, line_end)`
Args:
point: The point to calculate the distance to
line: The line to calculate the distance to, represented by two points
Returns:
The perpendicular distance between the point and the line
(should be a scalar in `float`)
"""
line_start, line_end = line
distance = jnp.linalg.norm(
jnp.cross(line_end - line_start, line_start - point)
) / jnp.linalg.norm(line_end - line_start)
return distance
@jaxtyped(typechecker=beartype)
def perpendicular_distance_camera_2d_points_to_tracking_raycasting(
detection: Detection,
tracking: Tracking,
delta_t: timedelta,
) -> Float[Array, "J"]:
"""
NOTE: `delta_t` is now taken from the caller and NOT recomputed internally.
Calculate the perpendicular distances between predicted 3D tracking points
and the rays cast from camera center through the 2D image points.
Args:
detection: The detection object containing 2D keypoints and camera parameters
tracking: The tracking object containing 3D keypoints
delta_t: Time delta between the tracking's last update and current observation
Returns:
Array of perpendicular distances for each keypoint
"""
camera = detection.camera
# Use the delta_t supplied by the caller, but clamp to DELTA_T_MIN to
# avoid division-by-zero / exploding affinities.
delta_t = max(delta_t, DELTA_T_MIN)
delta_t_s = delta_t.total_seconds()
predicted_pose = tracking.predict(delta_t_s)
# Back-project the 2D points to 3D space
# intersection with z=0 plane
back_projected_points = detection.camera.unproject_points_to_z_plane(
detection.keypoints, z=0.0
)
camera_center = camera.params.location
def calc_distance(predicted_point, back_projected_point):
return perpendicular_distance_point_to_line_two_points(
predicted_point, (camera_center, back_projected_point)
)
# Vectorize over all keypoints
vmap_calc_distance = jax.vmap(calc_distance)
distances: Float[Array, "J"] = vmap_calc_distance(
predicted_pose, back_projected_points
)
return distances
@jaxtyped(typechecker=beartype)
def calculate_affinity_3d(
distances: Float[Array, "J"],
delta_t: timedelta,
w_3d: float,
alpha_3d: float,
lambda_a: float,
) -> Float[Array, "J"]:
"""
Calculate 3D affinity score between a tracking and detection.
The affinity score is calculated by summing individual keypoint affinities:
A_3D = sum(w_3D * (1 - dl / alpha_3D) * np.exp(-lambda_a * delta_t)) for each keypoint
Args:
distances: Array of perpendicular distances for each keypoint
delta_t: Time difference between tracking and detection
w_3d: Weight for 3D affinity
alpha_3d: Normalization factor for distance
lambda_a: Decay rate for time difference
Returns:
Sum of affinity scores across all keypoints
"""
delta_t_s = delta_t.total_seconds()
affinity_per_keypoint = (
w_3d * (1 - distances / alpha_3d) * jnp.exp(-lambda_a * delta_t_s)
)
return affinity_per_keypoint
@beartype
def calculate_tracking_detection_affinity(
tracking: Tracking,
detection: Detection,
w_2d: float,
alpha_2d: float,
w_3d: float,
alpha_3d: float,
lambda_a: float,
) -> float:
"""
Calculate the affinity between a tracking and a detection.
Args:
tracking: The tracking object
detection: The detection object
w_2d: Weight for 2D affinity
alpha_2d: Normalization factor for 2D distance
w_3d: Weight for 3D affinity
alpha_3d: Normalization factor for 3D distance
lambda_a: Decay rate for time difference
Returns:
Combined affinity score
"""
camera = detection.camera
delta_t_raw = detection.timestamp - tracking.last_active_timestamp
# Clamp delta_t to avoid division-by-zero / exploding affinity.
delta_t = max(delta_t_raw, DELTA_T_MIN)
# Calculate 2D affinity
tracking_2d_projection = camera.project(tracking.keypoints)
w, h = camera.params.image_size
distance_2d = calculate_distance_2d(
tracking_2d_projection,
detection.keypoints,
image_size=(int(w), int(h)),
)
affinity_2d = calculate_affinity_2d(
distance_2d,
delta_t,
w_2d=w_2d,
alpha_2d=alpha_2d,
lambda_a=lambda_a,
)
# Calculate 3D affinity
distances = perpendicular_distance_camera_2d_points_to_tracking_raycasting(
detection, tracking, delta_t
)
affinity_3d = calculate_affinity_3d(
distances,
delta_t,
w_3d=w_3d,
alpha_3d=alpha_3d,
lambda_a=lambda_a,
)
jax.debug.print(
"[REF] aff2d{} first6 = {}",
_global_current_tracking_str(),
affinity_2d[:6],
)
jax.debug.print(
"[REF] aff3d{} first6 = {}", _global_current_tracking_str(), affinity_3d[:6]
)
jax.debug.print(
"[REF] aff2d.shape={}; aff3d.shape={}",
affinity_2d.shape,
affinity_3d.shape,
)
# Combine affinities
total_affinity = affinity_2d + affinity_3d
return jnp.sum(total_affinity).item()
# %%
@deprecated(
"Use `calculate_camera_affinity_matrix` instead. This implementation has the problem of under-utilizing views from different cameras."
)
@beartype
def calculate_affinity_matrix(
trackings: Sequence[Tracking],
detections: Sequence[Detection] | OrderedDict[CameraID, list[Detection]],
w_2d: float,
alpha_2d: float,
w_3d: float,
alpha_3d: float,
lambda_a: float,
) -> tuple[Float[Array, "T D"], OrderedDict[CameraID, list[Detection]]]:
"""
Calculate the affinity matrix between a set of trackings and detections.
Args:
trackings: Sequence of tracking objects
detections: Sequence of detection objects
w_2d: Weight for 2D affinity
alpha_2d: Normalization factor for 2D distance
w_3d: Weight for 3D affinity
alpha_3d: Normalization factor for 3D distance
lambda_a: Decay rate for time difference
Returns:
- affinity matrix of shape (T, D) where T is number of trackings and D
is number of detections
- dictionary mapping camera IDs to lists of detections from that camera,
since it's a `OrderDict` you could flat it out to get the indices of
detections in the affinity matrix
Matrix Layout:
The affinity matrix has shape (T, D), where:
- T = number of trackings (rows)
- D = total number of detections across all cameras (columns)
The matrix is organized as follows:
```
| Camera 1 | Camera 2 | Camera c |
| d1 d2 ... | d1 d2 ... | d1 d2 ... |
---------+-------------+-------------+-------------+
Track 1 | a11 a12 ... | a11 a12 ... | a11 a12 ... |
Track 2 | a21 a22 ... | a21 a22 ... | a21 a22 ... |
... | ... | ... | ... |
Track t | at1 at2 ... | at1 at2 ... | at1 at2 ... |
```
Where:
- Rows are ordered by tracking ID
- Columns are ordered by camera, then by detection within each camera
- Each cell aij represents the affinity between tracking i and detection j
The detection ordering in columns follows the exact same order as iterating
through the detection_by_camera dictionary, which is returned alongside
the matrix to maintain this relationship.
"""
if isinstance(detections, OrderedDict):
D = flatten_values_len(detections)
affinity = jnp.zeros((len(trackings), D))
detection_by_camera = detections
else:
affinity = jnp.zeros((len(trackings), len(detections)))
detection_by_camera = classify_by_camera(detections)
for i, tracking in enumerate(trackings):
j = 0
for _, camera_detections in detection_by_camera.items():
for det in camera_detections:
affinity_value = calculate_tracking_detection_affinity(
tracking,
det,
w_2d=w_2d,
alpha_2d=alpha_2d,
w_3d=w_3d,
alpha_3d=alpha_3d,
lambda_a=lambda_a,
)
affinity = affinity.at[i, j].set(affinity_value)
j += 1
return affinity, detection_by_camera
@beartype
def calculate_camera_affinity_matrix(
trackings: Sequence[Tracking],
camera_detections: Sequence[Detection],
w_2d: float,
alpha_2d: float,
w_3d: float,
alpha_3d: float,
lambda_a: float,
) -> Float[Array, "T D"]:
"""
Calculate an affinity matrix between trackings and detections from a single camera.
This follows the iterative camera-by-camera approach from the paper
"Cross-View Tracking for Multi-Human 3D Pose Estimation at over 100 FPS".
Instead of creating one large matrix for all cameras, this creates
a separate matrix for each camera, which can be processed independently.
Args:
trackings: Sequence of tracking objects
camera_detections: Sequence of detection objects, from the same camera
w_2d: Weight for 2D affinity
alpha_2d: Normalization factor for 2D distance
w_3d: Weight for 3D affinity
alpha_3d: Normalization factor for 3D distance
lambda_a: Decay rate for time difference
Returns:
Affinity matrix of shape (T, D) where:
- T = number of trackings (rows)
- D = number of detections from this specific camera (columns)
Matrix Layout:
The affinity matrix for a single camera has shape (T, D), where:
- T = number of trackings (rows)
- D = number of detections from this camera (columns)
The matrix is organized as follows:
```
| Detections from Camera c |
| d1 d2 d3 ... |
---------+------------------------+
Track 1 | a11 a12 a13 ... |
Track 2 | a21 a22 a23 ... |
... | ... ... ... ... |
Track t | at1 at2 at3 ... |
```
Each cell aij represents the affinity between tracking i and detection j,
computed using both 2D and 3D geometric correspondences.
"""
def verify_all_detection_from_same_camera(detections: Sequence[Detection]):
if not detections:
return True
camera_id = next(iter(detections)).camera.id
return all(map(lambda d: d.camera.id == camera_id, detections))
if not verify_all_detection_from_same_camera(camera_detections):
raise ValueError("All detections must be from the same camera")
affinity = jnp.zeros((len(trackings), len(camera_detections)))
for i, tracking in enumerate(trackings):
for j, det in enumerate(camera_detections):
global _DEBUG_CURRENT_TRACKING
_DEBUG_CURRENT_TRACKING = (i, j)
affinity_value = calculate_tracking_detection_affinity(
tracking,
det,
w_2d=w_2d,
alpha_2d=alpha_2d,
w_3d=w_3d,
alpha_3d=alpha_3d,
lambda_a=lambda_a,
)
affinity = affinity.at[i, j].set(affinity_value)
return affinity
@beartype
def calculate_camera_affinity_matrix_jax(
trackings: Sequence[Tracking],
camera_detections: Sequence[Detection],
w_2d: float,
alpha_2d: float,
w_3d: float,
alpha_3d: float,
lambda_a: float,
) -> Float[Array, "T D"]:
"""
Vectorized implementation to compute an affinity matrix between *trackings*
and *detections* coming from **one** camera.
Compared with the simple double-for-loop version, this leverages `jax`'s
broadcasting + `vmap` facilities and avoids Python loops over every
(tracking, detection) pair. The mathematical definition of the affinity
is **unchanged**, so the result remains bit-identical to the reference
implementation used in the tests.
TODO: It gives a wrong result (maybe it's my problem?) somehow,
and I need to find a way to debug this.
"""
# ------------------------------------------------------------------
# Quick validations / early-exit guards
# ------------------------------------------------------------------
if len(trackings) == 0 or len(camera_detections) == 0:
# Return an empty affinity matrix with appropriate shape.
return jnp.zeros((len(trackings), len(camera_detections))) # type: ignore[return-value]
cam = next(iter(camera_detections)).camera
# Ensure every detection truly belongs to the same camera (guard clause)
cam_id = cam.id
if any(det.camera.id != cam_id for det in camera_detections):
raise ValueError(
"All detections passed to `calculate_camera_affinity_matrix` must come from one camera."
)
# We will rely on a single `Camera` instance (all detections share it)
w_img_, h_img_ = cam.params.image_size
w_img, h_img = float(w_img_), float(h_img_)
# ------------------------------------------------------------------
# Gather data into ndarray / DeviceArray batches so that we can compute
# everything in a single (or a few) fused kernels.
# ------------------------------------------------------------------
# === Tracking-side tensors ===
kps3d_trk: Float[Array, "T J 3"] = jnp.stack(
[trk.keypoints for trk in trackings]
) # (T, J, 3)
J = kps3d_trk.shape[1]
# === Detection-side tensors ===
kps2d_det: Float[Array, "D J 2"] = jnp.stack(
[det.keypoints for det in camera_detections]
) # (D, J, 2)
# ------------------------------------------------------------------
# Compute Δt matrix shape (T, D)
# ------------------------------------------------------------------
# Epoch timestamps are ~1.7×10⁹; storing them in float32 wipes out
# subsecond detail (resolution ≈ 200ms). Keep them in float64 until
# after subtraction so we preserve Δtontheorderofmilliseconds.
# --- timestamps ----------
tracking0 = next(iter(trackings))
detection0 = next(iter(camera_detections))
t0 = min(
tracking0.last_active_timestamp, detection0.timestamp
).timestamp() # common origin (float)
ts_trk = jnp.array(
[trk.last_active_timestamp.timestamp() - t0 for trk in trackings],
dtype=jnp.float32, # now small, ms-scale fits in fp32
)
ts_det = jnp.array(
[det.timestamp.timestamp() - t0 for det in camera_detections],
dtype=jnp.float32,
)
# Δt in seconds, fp32 throughout
delta_t = ts_det[None, :] - ts_trk[:, None] # (T,D)
min_dt_s = float(DELTA_T_MIN.total_seconds())
delta_t = jnp.clip(delta_t, a_min=min_dt_s, a_max=None)
# ------------------------------------------------------------------
# ---------- 2D affinity -------------------------------------------
# ------------------------------------------------------------------
# Project each tracking's 3D keypoints onto the image once.
# `Camera.project` works per-sample, so we vmap over the first axis.
proj_fn = jax.vmap(cam.project, in_axes=0) # maps over the keypoint sets
kps2d_trk_proj: Float[Array, "T J 2"] = proj_fn(kps3d_trk) # (T, J, 2)
# Normalise keypoints by image size so absolute units do not bias distance
norm_trk = kps2d_trk_proj / jnp.array([w_img, h_img])
norm_det = kps2d_det / jnp.array([w_img, h_img])
# L2 distance for every (T, D, J)
# reshape for broadcasting: (T,1,J,2) vs (1,D,J,2)
diff2d = norm_trk[:, None, :, :] - norm_det[None, :, :, :]
dist2d: Float[Array, "T D J"] = jnp.linalg.norm(diff2d, axis=-1)
jax.debug.print(
"[JAX] norm_trk[0,0,:6] = {}", norm_trk[0, :, :6] # shape (J,2) 取前6
)
jax.debug.print("[JAX] norm_det[0,:6] = {}", norm_det[0, :6]) # shape (J,2)
jax.debug.print("[JAX] dist2d(T0,D0) first6 = {}", dist2d[0, 0, :6])
# Compute per-keypoint 2D affinity
delta_t_broadcast = delta_t[:, :, None] # (T, D, 1)
affinity_2d = (
w_2d
* (1 - dist2d / (alpha_2d * delta_t_broadcast))
* jnp.exp(-lambda_a * delta_t_broadcast)
)
# ------------------------------------------------------------------
# ---------- 3D affinity -------------------------------------------
# ------------------------------------------------------------------
# For each detection pre-compute back-projected 3D points lying on z=0 plane.
backproj_points_list = [
det.camera.unproject_points_to_z_plane(det.keypoints, z=0.0)
for det in camera_detections
] # each (J,3)
backproj: Float[Array, "D J 3"] = jnp.stack(backproj_points_list) # (D, J, 3)
zero_velocity = jnp.zeros((J, 3))
trk_velocities = jnp.stack(
[
trk.velocity if trk.velocity is not None else zero_velocity
for trk in trackings
]
)
predicted_pose: Float[Array, "T D J 3"] = (
kps3d_trk[:, None, :, :] # (T,1,J,3)
+ trk_velocities[:, None, :, :] * delta_t[:, :, None, None] # (T,D,1,1)
)
# Camera center shape (3,) -> will broadcast
cam_center = cam.params.location
# Compute perpendicular distance using vectorized formula
# p1 = cam_center (3,)
# p2 = backproj (D, J, 3)
# P = predicted_pose (T, D, J, 3)
# Broadcast plan: v1 = P - p1 → (T, D, J, 3)
# v2 = p2[None, ...]-p1 → (1, D, J, 3)
# Shapes now line up; no stray singleton axis.
p1 = cam_center
p2 = backproj
P = predicted_pose
v1 = P - p1
v2 = p2[None, :, :, :] - p1 # (1, D, J, 3)
cross = jnp.cross(v1, v2) # (T, D, J, 3)
num = jnp.linalg.norm(cross, axis=-1) # (T, D, J)
den = jnp.linalg.norm(v2, axis=-1) # (1, D, J)
dist3d: Float[Array, "T D J"] = num / den
affinity_3d = (
w_3d * (1 - dist3d / alpha_3d) * jnp.exp(-lambda_a * delta_t_broadcast)
)
jax.debug.print("[JAX] aff3d(T0,D0) first6 = {}", affinity_3d[0, 0, :6])
jax.debug.print("[JAX] aff2d(T0,D0) first6 = {}", affinity_2d[0, 0, :6])
jax.debug.print(
"[JAX] aff2d.shape={}; aff3d.shape={}", affinity_2d.shape, affinity_3d.shape
)
# ------------------------------------------------------------------
# Combine and reduce across keypoints → (T, D)
# ------------------------------------------------------------------
total_affinity: Float[Array, "T D"] = jnp.sum(affinity_2d + affinity_3d, axis=-1)
return total_affinity # type: ignore[return-value]
# ------------------------------------------------------------------
# Debug helper compare JAX vs reference implementation
# ------------------------------------------------------------------
@beartype
def debug_compare_affinity_matrices(
trackings: Sequence[Tracking],
camera_detections: Sequence[Detection],
*,
w_2d: float,
alpha_2d: float,
w_3d: float,
alpha_3d: float,
lambda_a: float,
atol: float = 1e-5,
rtol: float = 1e-3,
) -> None:
"""
Compute both affinity matrices and print out the max absolute / relative
difference. If any entry differs more than atol+rtol*|ref|, dump the
offending indices so you can inspect individual terms.
"""
aff_jax = calculate_camera_affinity_matrix_jax(
trackings,
camera_detections,
w_2d=w_2d,
alpha_2d=alpha_2d,
w_3d=w_3d,
alpha_3d=alpha_3d,
lambda_a=lambda_a,
)
aff_ref = calculate_camera_affinity_matrix(
trackings,
camera_detections,
w_2d=w_2d,
alpha_2d=alpha_2d,
w_3d=w_3d,
alpha_3d=alpha_3d,
lambda_a=lambda_a,
)
diff = jnp.abs(aff_jax - aff_ref)
max_abs = float(diff.max())
max_rel = float((diff / (jnp.abs(aff_ref) + 1e-12)).max())
jax.debug.print(f"[DEBUG] max abs diff {max_abs:.6g}, max rel diff {max_rel:.6g}")
bad = jnp.where(diff > atol + rtol * jnp.abs(aff_ref))
if bad[0].size > 0:
for t, d in zip(*[arr.tolist() for arr in bad]):
jax.debug.print(
f" ↳ mismatch at (T={t}, D={d}): "
f"jax={aff_jax[t,d]:.6g}, ref={aff_ref[t,d]:.6g}"
)
else:
jax.debug.print("✅ matrices match within tolerance")
# %%
# let's do cross-view association
W_2D = 1.0
ALPHA_2D = 1.0
LAMBDA_A = 0.1
W_3D = 1.0
ALPHA_3D = 1.0
trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id)
unmatched_detections = shallow_copy(next_group)
camera_detections = classify_by_camera(unmatched_detections)
camera_detections_next_batch = camera_detections["AE_08"]
debug_compare_affinity_matrices(
trackings,
camera_detections_next_batch,
w_2d=W_2D,
alpha_2d=ALPHA_2D,
w_3d=W_3D,
alpha_3d=ALPHA_3D,
lambda_a=LAMBDA_A,
)
# %%