1
0
forked from HQU-gxy/CVTH3PE

refactor: Update affinity matrix calculation and dependencies

- Replaced the `linear_sum_assignment` import from `scipy.optimize` with `hungarian_algorithm` from `optax` to enhance performance in affinity matrix calculations.
- Introduced a new `AffinityResult` class to encapsulate results of affinity computations, including trackings and detections, improving the structure of the affinity calculation process.
- Removed deprecated functions and debug print statements to streamline the codebase.
- Updated `pyproject.toml` and `uv.lock` to include `optax` as a dependency, ensuring compatibility with the new implementation.
- Refactored imports and type hints for better organization and consistency across the codebase.
This commit is contained in:
2025-04-29 15:45:24 +08:00
parent ce1d5f3cf7
commit 29ca66ad47
6 changed files with 152 additions and 529 deletions

View File

@ -2,8 +2,10 @@ from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import ( from typing import (
Any, Any,
Callable,
Generator, Generator,
Optional, Optional,
Sequence,
TypeAlias, TypeAlias,
TypedDict, TypedDict,
TypeVar, TypeVar,
@ -14,7 +16,11 @@ from typing import (
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from beartype import beartype from beartype import beartype
from jaxtyping import Array, Float, jaxtyped from beartype.typing import Mapping, Sequence
from jax import Array
from jaxtyping import Array, Float, Int, jaxtyped
from app.camera import Detection
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
@ -67,3 +73,30 @@ class Tracking:
# Step 2 pure JAX math # Step 2 pure JAX math
# ------------------------------------------------------------------ # ------------------------------------------------------------------
return self.keypoints + velocity * delta_t_s return self.keypoints + velocity * delta_t_s
@jaxtyped(typechecker=beartype)
@dataclass
class AffinityResult:
"""
Result of affinity computation between trackings and detections.
"""
matrix: Float[Array, "T D"]
trackings: Sequence[Tracking]
detections: Sequence[Detection]
indices_T: Int[Array, "T"] # pylint: disable=invalid-name
indices_D: Int[Array, "D"] # pylint: disable=invalid-name
def tracking_detections(
self,
) -> Generator[tuple[float, Tracking, Detection], None, None]:
"""
iterate over the best matching trackings and detections
"""
for t, d in zip(self.indices_T, self.indices_D):
yield (
self.matrix[t, d].item(),
self.trackings[t],
self.detections[d],
)

View File

@ -1,37 +0,0 @@
from dataclasses import dataclass
from typing import Sequence, Callable, Generator
from app.camera import Detection
from . import Tracking
from beartype.typing import Sequence, Mapping
from jaxtyping import jaxtyped, Float, Int
from jax import Array
@dataclass
class AffinityResult:
"""
Result of affinity computation between trackings and detections.
"""
matrix: Float[Array, "T D"]
"""
Affinity matrix between trackings and detections.
"""
trackings: Sequence[Tracking]
"""
Trackings used to compute the affinity matrix.
"""
detections: Sequence[Detection]
"""
Detections used to compute the affinity matrix.
"""
indices_T: Sequence[int]
indices_D: Sequence[int]
def tracking_detections(self) -> Generator[tuple[Tracking, Detection]]:
for t, d in zip(self.indices_T, self.indices_D):
yield (self.trackings[t], self.detections[d])

View File

@ -45,7 +45,7 @@ from IPython.display import display
from jaxtyping import Array, Float, Num, jaxtyped from jaxtyping import Array, Float, Num, jaxtyped
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from numpy.typing import ArrayLike from numpy.typing import ArrayLike
from scipy.optimize import linear_sum_assignment from optax.assignment import hungarian_algorithm as linear_sum_assignment
from scipy.spatial.transform import Rotation as R from scipy.spatial.transform import Rotation as R
from typing_extensions import deprecated from typing_extensions import deprecated
@ -58,7 +58,7 @@ from app.camera import (
classify_by_camera, classify_by_camera,
) )
from app.solver._old import GLPKSolver from app.solver._old import GLPKSolver
from app.tracking import Tracking from app.tracking import AffinityResult, Tracking
from app.visualize.whole_body import visualize_whole_body from app.visualize.whole_body import visualize_whole_body
NDArray: TypeAlias = np.ndarray NDArray: TypeAlias = np.ndarray
@ -69,12 +69,6 @@ AK_CAMERA_DATASET: ak.Array = ak.from_parquet(DATASET_PATH / "camera_params.parq
DELTA_T_MIN = timedelta(milliseconds=10) DELTA_T_MIN = timedelta(milliseconds=10)
display(AK_CAMERA_DATASET) display(AK_CAMERA_DATASET)
_DEBUG_CURRENT_TRACKING: tuple[int, int] = (0, 0)
def _global_current_tracking_str():
return str(_DEBUG_CURRENT_TRACKING)
# %% # %%
class Resolution(TypedDict): class Resolution(TypedDict):
@ -594,23 +588,6 @@ def calculate_distance_2d(
left_normalized = left / jnp.array([w, h]) left_normalized = left / jnp.array([w, h])
right_normalized = right / jnp.array([w, h]) right_normalized = right / jnp.array([w, h])
dist = jnp.linalg.norm(left_normalized - right_normalized, axis=-1) dist = jnp.linalg.norm(left_normalized - right_normalized, axis=-1)
lt = left_normalized[:6]
rt = right_normalized[:6]
jax.debug.print(
"[REF]{} norm_trk first6 = {}",
_global_current_tracking_str(),
lt,
)
jax.debug.print(
"[REF]{} norm_det first6 = {}",
_global_current_tracking_str(),
rt,
)
jax.debug.print(
"[REF]{} dist2d first6 = {}",
_global_current_tracking_str(),
dist[:6],
)
return dist return dist
@ -806,191 +783,12 @@ def calculate_tracking_detection_affinity(
lambda_a=lambda_a, lambda_a=lambda_a,
) )
jax.debug.print(
"[REF] aff2d{} first6 = {}",
_global_current_tracking_str(),
affinity_2d[:6],
)
jax.debug.print(
"[REF] aff3d{} first6 = {}", _global_current_tracking_str(), affinity_3d[:6]
)
jax.debug.print(
"[REF] aff2d.shape={}; aff3d.shape={}",
affinity_2d.shape,
affinity_3d.shape,
)
# Combine affinities # Combine affinities
total_affinity = affinity_2d + affinity_3d total_affinity = affinity_2d + affinity_3d
return jnp.sum(total_affinity).item() return jnp.sum(total_affinity).item()
# %% # %%
@deprecated(
"Use `calculate_camera_affinity_matrix` instead. This implementation has the problem of under-utilizing views from different cameras."
)
@beartype
def calculate_affinity_matrix(
trackings: Sequence[Tracking],
detections: Sequence[Detection] | OrderedDict[CameraID, list[Detection]],
w_2d: float,
alpha_2d: float,
w_3d: float,
alpha_3d: float,
lambda_a: float,
) -> tuple[Float[Array, "T D"], OrderedDict[CameraID, list[Detection]]]:
"""
Calculate the affinity matrix between a set of trackings and detections.
Args:
trackings: Sequence of tracking objects
detections: Sequence of detection objects
w_2d: Weight for 2D affinity
alpha_2d: Normalization factor for 2D distance
w_3d: Weight for 3D affinity
alpha_3d: Normalization factor for 3D distance
lambda_a: Decay rate for time difference
Returns:
- affinity matrix of shape (T, D) where T is number of trackings and D
is number of detections
- dictionary mapping camera IDs to lists of detections from that camera,
since it's a `OrderDict` you could flat it out to get the indices of
detections in the affinity matrix
Matrix Layout:
The affinity matrix has shape (T, D), where:
- T = number of trackings (rows)
- D = total number of detections across all cameras (columns)
The matrix is organized as follows:
```
| Camera 1 | Camera 2 | Camera c |
| d1 d2 ... | d1 d2 ... | d1 d2 ... |
---------+-------------+-------------+-------------+
Track 1 | a11 a12 ... | a11 a12 ... | a11 a12 ... |
Track 2 | a21 a22 ... | a21 a22 ... | a21 a22 ... |
... | ... | ... | ... |
Track t | at1 at2 ... | at1 at2 ... | at1 at2 ... |
```
Where:
- Rows are ordered by tracking ID
- Columns are ordered by camera, then by detection within each camera
- Each cell aij represents the affinity between tracking i and detection j
The detection ordering in columns follows the exact same order as iterating
through the detection_by_camera dictionary, which is returned alongside
the matrix to maintain this relationship.
"""
if isinstance(detections, OrderedDict):
D = flatten_values_len(detections)
affinity = jnp.zeros((len(trackings), D))
detection_by_camera = detections
else:
affinity = jnp.zeros((len(trackings), len(detections)))
detection_by_camera = classify_by_camera(detections)
for i, tracking in enumerate(trackings):
j = 0
for _, camera_detections in detection_by_camera.items():
for det in camera_detections:
affinity_value = calculate_tracking_detection_affinity(
tracking,
det,
w_2d=w_2d,
alpha_2d=alpha_2d,
w_3d=w_3d,
alpha_3d=alpha_3d,
lambda_a=lambda_a,
)
affinity = affinity.at[i, j].set(affinity_value)
j += 1
return affinity, detection_by_camera
@beartype
def calculate_camera_affinity_matrix(
trackings: Sequence[Tracking],
camera_detections: Sequence[Detection],
w_2d: float,
alpha_2d: float,
w_3d: float,
alpha_3d: float,
lambda_a: float,
) -> Float[Array, "T D"]:
"""
Calculate an affinity matrix between trackings and detections from a single camera.
This follows the iterative camera-by-camera approach from the paper
"Cross-View Tracking for Multi-Human 3D Pose Estimation at over 100 FPS".
Instead of creating one large matrix for all cameras, this creates
a separate matrix for each camera, which can be processed independently.
Args:
trackings: Sequence of tracking objects
camera_detections: Sequence of detection objects, from the same camera
w_2d: Weight for 2D affinity
alpha_2d: Normalization factor for 2D distance
w_3d: Weight for 3D affinity
alpha_3d: Normalization factor for 3D distance
lambda_a: Decay rate for time difference
Returns:
Affinity matrix of shape (T, D) where:
- T = number of trackings (rows)
- D = number of detections from this specific camera (columns)
Matrix Layout:
The affinity matrix for a single camera has shape (T, D), where:
- T = number of trackings (rows)
- D = number of detections from this camera (columns)
The matrix is organized as follows:
```
| Detections from Camera c |
| d1 d2 d3 ... |
---------+------------------------+
Track 1 | a11 a12 a13 ... |
Track 2 | a21 a22 a23 ... |
... | ... ... ... ... |
Track t | at1 at2 at3 ... |
```
Each cell aij represents the affinity between tracking i and detection j,
computed using both 2D and 3D geometric correspondences.
"""
def verify_all_detection_from_same_camera(detections: Sequence[Detection]):
if not detections:
return True
camera_id = next(iter(detections)).camera.id
return all(map(lambda d: d.camera.id == camera_id, detections))
if not verify_all_detection_from_same_camera(camera_detections):
raise ValueError("All detections must be from the same camera")
affinity = jnp.zeros((len(trackings), len(camera_detections)))
for i, tracking in enumerate(trackings):
for j, det in enumerate(camera_detections):
global _DEBUG_CURRENT_TRACKING
_DEBUG_CURRENT_TRACKING = (i, j)
affinity_value = calculate_tracking_detection_affinity(
tracking,
det,
w_2d=w_2d,
alpha_2d=alpha_2d,
w_3d=w_3d,
alpha_3d=alpha_3d,
lambda_a=lambda_a,
)
affinity = affinity.at[i, j].set(affinity_value)
return affinity
@beartype @beartype
def calculate_camera_affinity_matrix_jax( def calculate_camera_affinity_matrix_jax(
trackings: Sequence[Tracking], trackings: Sequence[Tracking],
@ -1010,9 +808,6 @@ def calculate_camera_affinity_matrix_jax(
(tracking, detection) pair. The mathematical definition of the affinity (tracking, detection) pair. The mathematical definition of the affinity
is **unchanged**, so the result remains bit-identical to the reference is **unchanged**, so the result remains bit-identical to the reference
implementation used in the tests. implementation used in the tests.
TODO: It gives a wrong result (maybe it's my problem?) somehow,
and I need to find a way to debug this.
""" """
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@ -1052,8 +847,8 @@ def calculate_camera_affinity_matrix_jax(
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Compute Δt matrix shape (T, D) # Compute Δt matrix shape (T, D)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Epoch timestamps are ~1.7×10⁹; storing them in float32 wipes out # Epoch timestamps are ~1.7 × 10⁹; storing them in float32 wipes out
# subsecond detail (resolution ≈ 200ms). Keep them in float64 until # subsecond detail (resolution ≈ 200 ms). Keep them in float64 until
# after subtraction so we preserve Δtontheorderofmilliseconds. # after subtraction so we preserve Δtontheorderofmilliseconds.
# --- timestamps ---------- # --- timestamps ----------
t0 = min( t0 = min(
@ -1093,12 +888,6 @@ def calculate_camera_affinity_matrix_jax(
diff2d = norm_trk[:, None, :, :] - norm_det[None, :, :, :] diff2d = norm_trk[:, None, :, :] - norm_det[None, :, :, :]
dist2d: Float[Array, "T D J"] = jnp.linalg.norm(diff2d, axis=-1) dist2d: Float[Array, "T D J"] = jnp.linalg.norm(diff2d, axis=-1)
jax.debug.print(
"[JAX] norm_trk[0,0,:6] = {}", norm_trk[0, :, :6] # shape (J,2) 取前6
)
jax.debug.print("[JAX] norm_det[0,:6] = {}", norm_det[0, :6]) # shape (J,2)
jax.debug.print("[JAX] dist2d(T0,D0) first6 = {}", dist2d[0, 0, :6])
# Compute per-keypoint 2D affinity # Compute per-keypoint 2D affinity
delta_t_broadcast = delta_t[:, :, None] # (T, D, 1) delta_t_broadcast = delta_t[:, :, None] # (T, D, 1)
affinity_2d = ( affinity_2d = (
@ -1155,11 +944,6 @@ def calculate_camera_affinity_matrix_jax(
w_3d * (1 - dist3d / alpha_3d) * jnp.exp(-lambda_a * delta_t_broadcast) w_3d * (1 - dist3d / alpha_3d) * jnp.exp(-lambda_a * delta_t_broadcast)
) )
jax.debug.print("[JAX] aff3d(T0,D0) first6 = {}", affinity_3d[0, 0, :6])
jax.debug.print("[JAX] aff2d(T0,D0) first6 = {}", affinity_2d[0, 0, :6])
jax.debug.print(
"[JAX] aff2d.shape={}; aff3d.shape={}", affinity_2d.shape, affinity_3d.shape
)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Combine and reduce across keypoints → (T, D) # Combine and reduce across keypoints → (T, D)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@ -1167,60 +951,57 @@ def calculate_camera_affinity_matrix_jax(
return total_affinity # type: ignore[return-value] return total_affinity # type: ignore[return-value]
# ------------------------------------------------------------------
# Debug helper compare JAX vs reference implementation
# ------------------------------------------------------------------
@beartype @beartype
def debug_compare_affinity_matrices( def calculate_affinity_matrix(
trackings: Sequence[Tracking], trackings: Sequence[Tracking],
camera_detections: Sequence[Detection], detections: Sequence[Detection] | Mapping[CameraID, list[Detection]],
*,
w_2d: float, w_2d: float,
alpha_2d: float, alpha_2d: float,
w_3d: float, w_3d: float,
alpha_3d: float, alpha_3d: float,
lambda_a: float, lambda_a: float,
atol: float = 1e-5, ) -> dict[CameraID, AffinityResult]:
rtol: float = 1e-3,
) -> None:
""" """
Compute both affinity matrices and print out the max absolute / relative Calculate the affinity matrix between a set of trackings and detections.
difference. If any entry differs more than atol+rtol*|ref|, dump the
offending indices so you can inspect individual terms. Args:
trackings: Sequence of tracking objects
detections: Sequence of detection objects or a group detections by ID
w_2d: Weight for 2D affinity
alpha_2d: Normalization factor for 2D distance
w_3d: Weight for 3D affinity
alpha_3d: Normalization factor for 3D distance
lambda_a: Decay rate for time difference
Returns:
A dictionary mapping camera IDs to affinity results.
""" """
aff_jax = calculate_camera_affinity_matrix_jax( if isinstance(detections, Mapping):
trackings, detection_by_camera = detections
camera_detections,
w_2d=w_2d,
alpha_2d=alpha_2d,
w_3d=w_3d,
alpha_3d=alpha_3d,
lambda_a=lambda_a,
)
aff_ref = calculate_camera_affinity_matrix(
trackings,
camera_detections,
w_2d=w_2d,
alpha_2d=alpha_2d,
w_3d=w_3d,
alpha_3d=alpha_3d,
lambda_a=lambda_a,
)
diff = jnp.abs(aff_jax - aff_ref)
max_abs = float(diff.max())
max_rel = float((diff / (jnp.abs(aff_ref) + 1e-12)).max())
jax.debug.print(f"[DEBUG] max abs diff {max_abs:.6g}, max rel diff {max_rel:.6g}")
bad = jnp.where(diff > atol + rtol * jnp.abs(aff_ref))
if bad[0].size > 0:
for t, d in zip(*[arr.tolist() for arr in bad]):
jax.debug.print(
f" ↳ mismatch at (T={t}, D={d}): "
f"jax={aff_jax[t,d]:.6g}, ref={aff_ref[t,d]:.6g}"
)
else: else:
jax.debug.print("✅ matrices match within tolerance") detection_by_camera = classify_by_camera(detections)
res: dict[CameraID, AffinityResult] = {}
for camera_id, camera_detections in detection_by_camera.items():
affinity_matrix = calculate_camera_affinity_matrix_jax(
trackings,
camera_detections,
w_2d,
alpha_2d,
w_3d,
alpha_3d,
lambda_a,
)
# row, col
indices_T, indices_D = linear_sum_assignment(affinity_matrix)
affinity_result = AffinityResult(
matrix=affinity_matrix,
trackings=trackings,
detections=camera_detections,
indices_T=indices_T,
indices_D=indices_D,
)
res[camera_id] = affinity_result
return res
# %% # %%
@ -1235,15 +1016,15 @@ trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id)
unmatched_detections = shallow_copy(next_group) unmatched_detections = shallow_copy(next_group)
camera_detections = classify_by_camera(unmatched_detections) camera_detections = classify_by_camera(unmatched_detections)
camera_detections_next_batch = camera_detections["AE_08"] affinities = calculate_affinity_matrix(
debug_compare_affinity_matrices(
trackings, trackings,
camera_detections_next_batch, unmatched_detections,
w_2d=W_2D, w_2d=W_2D,
alpha_2d=ALPHA_2D, alpha_2d=ALPHA_2D,
w_3d=W_3D, w_3d=W_3D,
alpha_3d=ALPHA_3D, alpha_3d=ALPHA_3D,
lambda_a=LAMBDA_A, lambda_a=LAMBDA_A,
) )
display(affinities)
# %% # %%

View File

@ -15,6 +15,7 @@ dependencies = [
"jupytext>=1.17.0", "jupytext>=1.17.0",
"matplotlib>=3.10.1", "matplotlib>=3.10.1",
"opencv-python-headless>=4.11.0.86", "opencv-python-headless>=4.11.0.86",
"optax>=0.2.4",
"orjson>=3.10.15", "orjson>=3.10.15",
"pandas>=2.2.3", "pandas>=2.2.3",
"plotly>=6.0.1", "plotly>=6.0.1",

View File

@ -1,224 +0,0 @@
from datetime import datetime, timedelta
import time
import jax.numpy as jnp
import numpy as np
import pytest
from hypothesis import given, settings, HealthCheck
from hypothesis import strategies as st
from app.camera import Camera, CameraParams
from playground import (
Detection,
Tracking,
calculate_affinity_matrix,
calculate_camera_affinity_matrix,
)
# ----------------------------------------------------------------------------
# Helper functions to generate synthetic cameras / trackings / detections
# ----------------------------------------------------------------------------
def _make_dummy_camera(cam_id: str, rng: np.random.Generator) -> Camera:
K = jnp.eye(3)
Rt = jnp.eye(4)
dist = jnp.zeros(5)
image_size = jnp.array([1000, 1000])
params = CameraParams(K=K, Rt=Rt, dist_coeffs=dist, image_size=image_size)
return Camera(id=cam_id, params=params)
def _random_keypoints_3d(rng: np.random.Generator, J: int):
return jnp.asarray(rng.uniform(-1.0, 1.0, size=(J, 3)).astype(np.float32))
def _random_keypoints_2d(rng: np.random.Generator, J: int):
return jnp.asarray(rng.uniform(0.0, 1000.0, size=(J, 2)).astype(np.float32))
def _make_trackings(rng: np.random.Generator, camera: Camera, T: int, J: int):
now = datetime.now()
trackings = []
for i in range(T):
kps3d = _random_keypoints_3d(rng, J)
trk = Tracking(
id=i + 1,
keypoints=kps3d,
last_active_timestamp=now
- timedelta(milliseconds=int(rng.integers(20, 50))),
)
trackings.append(trk)
return trackings
def _make_detections(rng: np.random.Generator, camera: Camera, D: int, J: int):
now = datetime.now()
detections = []
for _ in range(D):
kps2d = _random_keypoints_2d(rng, J)
det = Detection(
keypoints=kps2d,
confidences=jnp.ones(J, dtype=jnp.float32),
camera=camera,
timestamp=now,
)
detections.append(det)
return detections
# ----------------------------------------------------------------------------
# Property-based test: per-camera vs naive slice should match
# ----------------------------------------------------------------------------
@settings(max_examples=3, deadline=None, suppress_health_check=[HealthCheck.too_slow])
@given(
T=st.integers(min_value=1, max_value=4),
D=st.integers(min_value=1, max_value=4),
J=st.integers(min_value=5, max_value=15),
seed=st.integers(min_value=0, max_value=10000),
)
def test_per_camera_matches_naive(T, D, J, seed):
rng = np.random.default_rng(seed)
cam = _make_dummy_camera("C0", rng)
trackings = _make_trackings(rng, cam, T, J)
detections = _make_detections(rng, cam, D, J)
# Parameters
W_2D = 1.0
ALPHA_2D = 1.0
LAMBDA_A = 0.1
W_3D = 1.0
ALPHA_3D = 1.0
# Compute per-camera affinity (fast)
A_fast = calculate_camera_affinity_matrix(
trackings,
detections,
w_2d=W_2D,
alpha_2d=ALPHA_2D,
w_3d=W_3D,
alpha_3d=ALPHA_3D,
lambda_a=LAMBDA_A,
)
# Compute naive multi-camera affinity and slice out this camera
from collections import OrderedDict
det_dict = OrderedDict({"C0": detections})
A_naive, _ = calculate_affinity_matrix(
trackings,
det_dict,
w_2d=W_2D,
alpha_2d=ALPHA_2D,
w_3d=W_3D,
alpha_3d=ALPHA_3D,
lambda_a=LAMBDA_A,
)
# both fast and naive implementation gives NaN
# we need to inject real-world data
# print("A_fast")
# print(A_fast)
# print("A_naive")
# print(A_naive)
# They should be close
np.testing.assert_allclose(A_fast, np.asarray(A_naive), rtol=1e-5, atol=1e-5)
@pytest.mark.parametrize("T,D,J", [(2, 3, 10), (4, 4, 15), (6, 8, 20)])
def test_benchmark_affinity_matrix(T, D, J):
"""Compare performance between naive and fast affinity matrix calculation."""
seed = 42
rng = np.random.default_rng(seed)
cam = _make_dummy_camera("C0", rng)
trackings = _make_trackings(rng, cam, T, J)
detections = _make_detections(rng, cam, D, J)
# Parameters
w_2d = 1.0
alpha_2d = 1.0
w_3d = 1.0
alpha_3d = 1.0
lambda_a = 0.1
# Setup for naive
from collections import OrderedDict
det_dict = OrderedDict({"C0": detections})
# First run to compile
A_fast = calculate_camera_affinity_matrix(
trackings,
detections,
w_2d=w_2d,
alpha_2d=alpha_2d,
w_3d=w_3d,
alpha_3d=alpha_3d,
lambda_a=lambda_a,
)
A_naive, _ = calculate_affinity_matrix(
trackings,
det_dict,
w_2d=w_2d,
alpha_2d=alpha_2d,
w_3d=w_3d,
alpha_3d=alpha_3d,
lambda_a=lambda_a,
)
# Assert they match before timing
np.testing.assert_allclose(A_fast, np.asarray(A_naive), rtol=1e-5, atol=1e-5)
# Timing
num_runs = 3
# Time the vectorized version
start = time.perf_counter()
for _ in range(num_runs):
calculate_camera_affinity_matrix(
trackings,
detections,
w_2d=w_2d,
alpha_2d=alpha_2d,
w_3d=w_3d,
alpha_3d=alpha_3d,
lambda_a=lambda_a,
)
end = time.perf_counter()
vectorized_time = (end - start) / num_runs
# Time the naive version
start = time.perf_counter()
for _ in range(num_runs):
calculate_affinity_matrix(
trackings,
det_dict,
w_2d=w_2d,
alpha_2d=alpha_2d,
w_3d=w_3d,
alpha_3d=alpha_3d,
lambda_a=lambda_a,
)
end = time.perf_counter()
naive_time = (end - start) / num_runs
speedup = naive_time / vectorized_time
print(f"\nBenchmark T={T}, D={D}, J={J}:")
print(f" Vectorized: {vectorized_time*1000:.2f}ms per run")
print(f" Naive: {naive_time*1000:.2f}ms per run")
print(f" Speedup: {speedup:.2f}x")
# Sanity check - vectorized should be faster!
assert speedup > 1.0, "Vectorized implementation should be faster"
if __name__ == "__main__" and pytest is not None:
# python -m pytest -xvs -k test_benchmark
# pytest.main([__file__])
pytest.main(["-xvs", __file__, "-k", "test_benchmark"])

69
uv.lock generated
View File

@ -16,6 +16,15 @@ resolution-markers = [
"(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')", "(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')",
] ]
[[package]]
name = "absl-py"
version = "2.2.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/b5/f0/e6342091061ed3a46aadc116b13edd7bb5249c3ab1b3ef07f24b0c248fc3/absl_py-2.2.2.tar.gz", hash = "sha256:bf25b2c2eed013ca456918c453d687eab4e8309fba81ee2f4c1a6aa2494175eb", size = 119982 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/f6/d4/349f7f4bd5ea92dab34f5bb0fe31775ef6c311427a14d5a5b31ecb442341/absl_py-2.2.2-py3-none-any.whl", hash = "sha256:e5797bc6abe45f64fd95dc06394ca3f2bedf3b5d895e9da691c9ee3397d70092", size = 135565 },
]
[[package]] [[package]]
name = "anyio" name = "anyio"
version = "4.8.0" version = "4.8.0"
@ -354,6 +363,24 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/0e/f6/65ecc6878a89bb1c23a086ea335ad4bf21a588990c3f535a227b9eea9108/charset_normalizer-3.4.1-py3-none-any.whl", hash = "sha256:d98b1668f06378c6dbefec3b92299716b931cd4e6061f3c875a71ced1780ab85", size = 49767 }, { url = "https://files.pythonhosted.org/packages/0e/f6/65ecc6878a89bb1c23a086ea335ad4bf21a588990c3f535a227b9eea9108/charset_normalizer-3.4.1-py3-none-any.whl", hash = "sha256:d98b1668f06378c6dbefec3b92299716b931cd4e6061f3c875a71ced1780ab85", size = 49767 },
] ]
[[package]]
name = "chex"
version = "0.1.89"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "absl-py" },
{ name = "jax" },
{ name = "jaxlib" },
{ name = "numpy" },
{ name = "setuptools", marker = "python_full_version >= '3.12'" },
{ name = "toolz" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/ca/ac/504a8019f7ef372fc6cc3999ec9e3d0fbb38e6992f55d845d5b928010c11/chex-0.1.89.tar.gz", hash = "sha256:78f856e6a0a8459edfcbb402c2c044d2b8102eac4b633838cbdfdcdb09c6c8e0", size = 90676 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/5e/6c/309972937d931069816dc8b28193a650485bc35cca92c04c8c15c4bd181e/chex-0.1.89-py3-none-any.whl", hash = "sha256:145241c27d8944adb634fb7d472a460e1c1b643f561507d4031ad5156ef82dfa", size = 99908 },
]
[[package]] [[package]]
name = "colorama" name = "colorama"
version = "0.4.6" version = "0.4.6"
@ -454,6 +481,7 @@ dependencies = [
{ name = "jupytext" }, { name = "jupytext" },
{ name = "matplotlib" }, { name = "matplotlib" },
{ name = "opencv-python-headless" }, { name = "opencv-python-headless" },
{ name = "optax" },
{ name = "orjson" }, { name = "orjson" },
{ name = "pandas" }, { name = "pandas" },
{ name = "plotly" }, { name = "plotly" },
@ -482,6 +510,7 @@ requires-dist = [
{ name = "jupytext", specifier = ">=1.17.0" }, { name = "jupytext", specifier = ">=1.17.0" },
{ name = "matplotlib", specifier = ">=3.10.1" }, { name = "matplotlib", specifier = ">=3.10.1" },
{ name = "opencv-python-headless", specifier = ">=4.11.0.86" }, { name = "opencv-python-headless", specifier = ">=4.11.0.86" },
{ name = "optax", specifier = ">=0.2.4" },
{ name = "orjson", specifier = ">=3.10.15" }, { name = "orjson", specifier = ">=3.10.15" },
{ name = "pandas", specifier = ">=2.2.3" }, { name = "pandas", specifier = ">=2.2.3" },
{ name = "plotly", specifier = ">=6.0.1" }, { name = "plotly", specifier = ">=6.0.1" },
@ -583,6 +612,20 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/07/6c/aa3f2f849e01cb6a001cd8554a88d4c77c5c1a31c95bdf1cf9301e6d9ef4/defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61", size = 25604 }, { url = "https://files.pythonhosted.org/packages/07/6c/aa3f2f849e01cb6a001cd8554a88d4c77c5c1a31c95bdf1cf9301e6d9ef4/defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61", size = 25604 },
] ]
[[package]]
name = "etils"
version = "1.12.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/e4/12/1cc11e88a0201280ff389bc4076df7c3432e39d9f22cba8b71aa263f67b8/etils-1.12.2.tar.gz", hash = "sha256:c6b9e1f0ce66d1bbf54f99201b08a60ba396d3446d9eb18d4bc39b26a2e1a5ee", size = 104711 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/dd/71/40ee142e564b8a34a7ae9546e99e665e0001011a3254d5bbbe113d72ccba/etils-1.12.2-py3-none-any.whl", hash = "sha256:4600bec9de6cf5cb043a171e1856e38b5f273719cf3ecef90199f7091a6b3912", size = 167613 },
]
[package.optional-dependencies]
epy = [
{ name = "typing-extensions" },
]
[[package]] [[package]]
name = "exceptiongroup" name = "exceptiongroup"
version = "1.2.2" version = "1.2.2"
@ -1925,6 +1968,23 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl", hash = "sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd", size = 71932 }, { url = "https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl", hash = "sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd", size = 71932 },
] ]
[[package]]
name = "optax"
version = "0.2.4"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "absl-py" },
{ name = "chex" },
{ name = "etils", extra = ["epy"] },
{ name = "jax" },
{ name = "jaxlib" },
{ name = "numpy" },
]
sdist = { url = "https://files.pythonhosted.org/packages/af/b5/f88a0d851547b2e6b2c7e7e6509ad66236b3e7019f1f095bb03dbaa61fa1/optax-0.2.4.tar.gz", hash = "sha256:4e05d3d5307e6dde4c319187ae36e6cd3a0c035d4ed25e9e992449a304f47336", size = 229717 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/5c/24/28d0bb21600a78e46754947333ec9a297044af884d360092eb8561575fe9/optax-0.2.4-py3-none-any.whl", hash = "sha256:db35c04e50b52596662efb002334de08c2a0a74971e4da33f467e84fac08886a", size = 319212 },
]
[[package]] [[package]]
name = "orjson" name = "orjson"
version = "3.10.15" version = "3.10.15"
@ -2834,6 +2894,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/6e/c2/61d3e0f47e2b74ef40a68b9e6ad5984f6241a942f7cd3bbfbdbd03861ea9/tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc", size = 14257 }, { url = "https://files.pythonhosted.org/packages/6e/c2/61d3e0f47e2b74ef40a68b9e6ad5984f6241a942f7cd3bbfbdbd03861ea9/tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc", size = 14257 },
] ]
[[package]]
name = "toolz"
version = "1.0.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/8a/0b/d80dfa675bf592f636d1ea0b835eab4ec8df6e9415d8cfd766df54456123/toolz-1.0.0.tar.gz", hash = "sha256:2c86e3d9a04798ac556793bced838816296a2f085017664e4995cb40a1047a02", size = 66790 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/03/98/eb27cc78ad3af8e302c9d8ff4977f5026676e130d28dd7578132a457170c/toolz-1.0.0-py3-none-any.whl", hash = "sha256:292c8f1c4e7516bf9086f8850935c799a874039c8bcf959d47b600e4c44a6236", size = 56383 },
]
[[package]] [[package]]
name = "torch" name = "torch"
version = "2.6.0" version = "2.6.0"