refactor: Update affinity matrix calculation and dependencies
- Replaced the `linear_sum_assignment` import from `scipy.optimize` with `hungarian_algorithm` from `optax` to enhance performance in affinity matrix calculations. - Introduced a new `AffinityResult` class to encapsulate results of affinity computations, including trackings and detections, improving the structure of the affinity calculation process. - Removed deprecated functions and debug print statements to streamline the codebase. - Updated `pyproject.toml` and `uv.lock` to include `optax` as a dependency, ensuring compatibility with the new implementation. - Refactored imports and type hints for better organization and consistency across the codebase.
This commit is contained in:
@ -2,8 +2,10 @@ from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Generator,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypeAlias,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
@ -14,7 +16,11 @@ from typing import (
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from beartype import beartype
|
||||
from jaxtyping import Array, Float, jaxtyped
|
||||
from beartype.typing import Mapping, Sequence
|
||||
from jax import Array
|
||||
from jaxtyping import Array, Float, Int, jaxtyped
|
||||
|
||||
from app.camera import Detection
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
@ -67,3 +73,30 @@ class Tracking:
|
||||
# Step 2 – pure JAX math
|
||||
# ------------------------------------------------------------------
|
||||
return self.keypoints + velocity * delta_t_s
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
@dataclass
|
||||
class AffinityResult:
|
||||
"""
|
||||
Result of affinity computation between trackings and detections.
|
||||
"""
|
||||
|
||||
matrix: Float[Array, "T D"]
|
||||
trackings: Sequence[Tracking]
|
||||
detections: Sequence[Detection]
|
||||
indices_T: Int[Array, "T"] # pylint: disable=invalid-name
|
||||
indices_D: Int[Array, "D"] # pylint: disable=invalid-name
|
||||
|
||||
def tracking_detections(
|
||||
self,
|
||||
) -> Generator[tuple[float, Tracking, Detection], None, None]:
|
||||
"""
|
||||
iterate over the best matching trackings and detections
|
||||
"""
|
||||
for t, d in zip(self.indices_T, self.indices_D):
|
||||
yield (
|
||||
self.matrix[t, d].item(),
|
||||
self.trackings[t],
|
||||
self.detections[d],
|
||||
)
|
||||
|
||||
@ -1,37 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Sequence, Callable, Generator
|
||||
|
||||
from app.camera import Detection
|
||||
from . import Tracking
|
||||
from beartype.typing import Sequence, Mapping
|
||||
from jaxtyping import jaxtyped, Float, Int
|
||||
from jax import Array
|
||||
|
||||
|
||||
@dataclass
|
||||
class AffinityResult:
|
||||
"""
|
||||
Result of affinity computation between trackings and detections.
|
||||
"""
|
||||
|
||||
matrix: Float[Array, "T D"]
|
||||
"""
|
||||
Affinity matrix between trackings and detections.
|
||||
"""
|
||||
|
||||
trackings: Sequence[Tracking]
|
||||
"""
|
||||
Trackings used to compute the affinity matrix.
|
||||
"""
|
||||
|
||||
detections: Sequence[Detection]
|
||||
"""
|
||||
Detections used to compute the affinity matrix.
|
||||
"""
|
||||
|
||||
indices_T: Sequence[int]
|
||||
indices_D: Sequence[int]
|
||||
|
||||
def tracking_detections(self) -> Generator[tuple[Tracking, Detection]]:
|
||||
for t, d in zip(self.indices_T, self.indices_D):
|
||||
yield (self.trackings[t], self.detections[d])
|
||||
315
playground.py
315
playground.py
@ -45,7 +45,7 @@ from IPython.display import display
|
||||
from jaxtyping import Array, Float, Num, jaxtyped
|
||||
from matplotlib import pyplot as plt
|
||||
from numpy.typing import ArrayLike
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
from optax.assignment import hungarian_algorithm as linear_sum_assignment
|
||||
from scipy.spatial.transform import Rotation as R
|
||||
from typing_extensions import deprecated
|
||||
|
||||
@ -58,7 +58,7 @@ from app.camera import (
|
||||
classify_by_camera,
|
||||
)
|
||||
from app.solver._old import GLPKSolver
|
||||
from app.tracking import Tracking
|
||||
from app.tracking import AffinityResult, Tracking
|
||||
from app.visualize.whole_body import visualize_whole_body
|
||||
|
||||
NDArray: TypeAlias = np.ndarray
|
||||
@ -69,12 +69,6 @@ AK_CAMERA_DATASET: ak.Array = ak.from_parquet(DATASET_PATH / "camera_params.parq
|
||||
DELTA_T_MIN = timedelta(milliseconds=10)
|
||||
display(AK_CAMERA_DATASET)
|
||||
|
||||
_DEBUG_CURRENT_TRACKING: tuple[int, int] = (0, 0)
|
||||
|
||||
|
||||
def _global_current_tracking_str():
|
||||
return str(_DEBUG_CURRENT_TRACKING)
|
||||
|
||||
|
||||
# %%
|
||||
class Resolution(TypedDict):
|
||||
@ -594,23 +588,6 @@ def calculate_distance_2d(
|
||||
left_normalized = left / jnp.array([w, h])
|
||||
right_normalized = right / jnp.array([w, h])
|
||||
dist = jnp.linalg.norm(left_normalized - right_normalized, axis=-1)
|
||||
lt = left_normalized[:6]
|
||||
rt = right_normalized[:6]
|
||||
jax.debug.print(
|
||||
"[REF]{} norm_trk first6 = {}",
|
||||
_global_current_tracking_str(),
|
||||
lt,
|
||||
)
|
||||
jax.debug.print(
|
||||
"[REF]{} norm_det first6 = {}",
|
||||
_global_current_tracking_str(),
|
||||
rt,
|
||||
)
|
||||
jax.debug.print(
|
||||
"[REF]{} dist2d first6 = {}",
|
||||
_global_current_tracking_str(),
|
||||
dist[:6],
|
||||
)
|
||||
return dist
|
||||
|
||||
|
||||
@ -806,191 +783,12 @@ def calculate_tracking_detection_affinity(
|
||||
lambda_a=lambda_a,
|
||||
)
|
||||
|
||||
jax.debug.print(
|
||||
"[REF] aff2d{} first6 = {}",
|
||||
_global_current_tracking_str(),
|
||||
affinity_2d[:6],
|
||||
)
|
||||
jax.debug.print(
|
||||
"[REF] aff3d{} first6 = {}", _global_current_tracking_str(), affinity_3d[:6]
|
||||
)
|
||||
jax.debug.print(
|
||||
"[REF] aff2d.shape={}; aff3d.shape={}",
|
||||
affinity_2d.shape,
|
||||
affinity_3d.shape,
|
||||
)
|
||||
# Combine affinities
|
||||
total_affinity = affinity_2d + affinity_3d
|
||||
return jnp.sum(total_affinity).item()
|
||||
|
||||
|
||||
# %%
|
||||
@deprecated(
|
||||
"Use `calculate_camera_affinity_matrix` instead. This implementation has the problem of under-utilizing views from different cameras."
|
||||
)
|
||||
@beartype
|
||||
def calculate_affinity_matrix(
|
||||
trackings: Sequence[Tracking],
|
||||
detections: Sequence[Detection] | OrderedDict[CameraID, list[Detection]],
|
||||
w_2d: float,
|
||||
alpha_2d: float,
|
||||
w_3d: float,
|
||||
alpha_3d: float,
|
||||
lambda_a: float,
|
||||
) -> tuple[Float[Array, "T D"], OrderedDict[CameraID, list[Detection]]]:
|
||||
"""
|
||||
Calculate the affinity matrix between a set of trackings and detections.
|
||||
|
||||
Args:
|
||||
trackings: Sequence of tracking objects
|
||||
detections: Sequence of detection objects
|
||||
w_2d: Weight for 2D affinity
|
||||
alpha_2d: Normalization factor for 2D distance
|
||||
w_3d: Weight for 3D affinity
|
||||
alpha_3d: Normalization factor for 3D distance
|
||||
lambda_a: Decay rate for time difference
|
||||
|
||||
Returns:
|
||||
- affinity matrix of shape (T, D) where T is number of trackings and D
|
||||
is number of detections
|
||||
- dictionary mapping camera IDs to lists of detections from that camera,
|
||||
since it's a `OrderDict` you could flat it out to get the indices of
|
||||
detections in the affinity matrix
|
||||
|
||||
Matrix Layout:
|
||||
The affinity matrix has shape (T, D), where:
|
||||
- T = number of trackings (rows)
|
||||
- D = total number of detections across all cameras (columns)
|
||||
|
||||
The matrix is organized as follows:
|
||||
|
||||
```
|
||||
| Camera 1 | Camera 2 | Camera c |
|
||||
| d1 d2 ... | d1 d2 ... | d1 d2 ... |
|
||||
---------+-------------+-------------+-------------+
|
||||
Track 1 | a11 a12 ... | a11 a12 ... | a11 a12 ... |
|
||||
Track 2 | a21 a22 ... | a21 a22 ... | a21 a22 ... |
|
||||
... | ... | ... | ... |
|
||||
Track t | at1 at2 ... | at1 at2 ... | at1 at2 ... |
|
||||
```
|
||||
|
||||
Where:
|
||||
- Rows are ordered by tracking ID
|
||||
- Columns are ordered by camera, then by detection within each camera
|
||||
- Each cell aij represents the affinity between tracking i and detection j
|
||||
|
||||
The detection ordering in columns follows the exact same order as iterating
|
||||
through the detection_by_camera dictionary, which is returned alongside
|
||||
the matrix to maintain this relationship.
|
||||
"""
|
||||
if isinstance(detections, OrderedDict):
|
||||
D = flatten_values_len(detections)
|
||||
affinity = jnp.zeros((len(trackings), D))
|
||||
detection_by_camera = detections
|
||||
else:
|
||||
affinity = jnp.zeros((len(trackings), len(detections)))
|
||||
detection_by_camera = classify_by_camera(detections)
|
||||
|
||||
for i, tracking in enumerate(trackings):
|
||||
j = 0
|
||||
for _, camera_detections in detection_by_camera.items():
|
||||
for det in camera_detections:
|
||||
affinity_value = calculate_tracking_detection_affinity(
|
||||
tracking,
|
||||
det,
|
||||
w_2d=w_2d,
|
||||
alpha_2d=alpha_2d,
|
||||
w_3d=w_3d,
|
||||
alpha_3d=alpha_3d,
|
||||
lambda_a=lambda_a,
|
||||
)
|
||||
affinity = affinity.at[i, j].set(affinity_value)
|
||||
j += 1
|
||||
|
||||
return affinity, detection_by_camera
|
||||
|
||||
|
||||
@beartype
|
||||
def calculate_camera_affinity_matrix(
|
||||
trackings: Sequence[Tracking],
|
||||
camera_detections: Sequence[Detection],
|
||||
w_2d: float,
|
||||
alpha_2d: float,
|
||||
w_3d: float,
|
||||
alpha_3d: float,
|
||||
lambda_a: float,
|
||||
) -> Float[Array, "T D"]:
|
||||
"""
|
||||
Calculate an affinity matrix between trackings and detections from a single camera.
|
||||
|
||||
This follows the iterative camera-by-camera approach from the paper
|
||||
"Cross-View Tracking for Multi-Human 3D Pose Estimation at over 100 FPS".
|
||||
Instead of creating one large matrix for all cameras, this creates
|
||||
a separate matrix for each camera, which can be processed independently.
|
||||
|
||||
Args:
|
||||
trackings: Sequence of tracking objects
|
||||
camera_detections: Sequence of detection objects, from the same camera
|
||||
w_2d: Weight for 2D affinity
|
||||
alpha_2d: Normalization factor for 2D distance
|
||||
w_3d: Weight for 3D affinity
|
||||
alpha_3d: Normalization factor for 3D distance
|
||||
lambda_a: Decay rate for time difference
|
||||
|
||||
Returns:
|
||||
Affinity matrix of shape (T, D) where:
|
||||
- T = number of trackings (rows)
|
||||
- D = number of detections from this specific camera (columns)
|
||||
|
||||
Matrix Layout:
|
||||
The affinity matrix for a single camera has shape (T, D), where:
|
||||
- T = number of trackings (rows)
|
||||
- D = number of detections from this camera (columns)
|
||||
|
||||
The matrix is organized as follows:
|
||||
|
||||
```
|
||||
| Detections from Camera c |
|
||||
| d1 d2 d3 ... |
|
||||
---------+------------------------+
|
||||
Track 1 | a11 a12 a13 ... |
|
||||
Track 2 | a21 a22 a23 ... |
|
||||
... | ... ... ... ... |
|
||||
Track t | at1 at2 at3 ... |
|
||||
```
|
||||
|
||||
Each cell aij represents the affinity between tracking i and detection j,
|
||||
computed using both 2D and 3D geometric correspondences.
|
||||
"""
|
||||
|
||||
def verify_all_detection_from_same_camera(detections: Sequence[Detection]):
|
||||
if not detections:
|
||||
return True
|
||||
camera_id = next(iter(detections)).camera.id
|
||||
return all(map(lambda d: d.camera.id == camera_id, detections))
|
||||
|
||||
if not verify_all_detection_from_same_camera(camera_detections):
|
||||
raise ValueError("All detections must be from the same camera")
|
||||
|
||||
affinity = jnp.zeros((len(trackings), len(camera_detections)))
|
||||
|
||||
for i, tracking in enumerate(trackings):
|
||||
for j, det in enumerate(camera_detections):
|
||||
global _DEBUG_CURRENT_TRACKING
|
||||
_DEBUG_CURRENT_TRACKING = (i, j)
|
||||
affinity_value = calculate_tracking_detection_affinity(
|
||||
tracking,
|
||||
det,
|
||||
w_2d=w_2d,
|
||||
alpha_2d=alpha_2d,
|
||||
w_3d=w_3d,
|
||||
alpha_3d=alpha_3d,
|
||||
lambda_a=lambda_a,
|
||||
)
|
||||
affinity = affinity.at[i, j].set(affinity_value)
|
||||
return affinity
|
||||
|
||||
|
||||
@beartype
|
||||
def calculate_camera_affinity_matrix_jax(
|
||||
trackings: Sequence[Tracking],
|
||||
@ -1010,9 +808,6 @@ def calculate_camera_affinity_matrix_jax(
|
||||
(tracking, detection) pair. The mathematical definition of the affinity
|
||||
is **unchanged**, so the result remains bit-identical to the reference
|
||||
implementation used in the tests.
|
||||
|
||||
TODO: It gives a wrong result (maybe it's my problem?) somehow,
|
||||
and I need to find a way to debug this.
|
||||
"""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@ -1052,8 +847,8 @@ def calculate_camera_affinity_matrix_jax(
|
||||
# ------------------------------------------------------------------
|
||||
# Compute Δt matrix – shape (T, D)
|
||||
# ------------------------------------------------------------------
|
||||
# Epoch timestamps are ~1.7 × 10⁹; storing them in float32 wipes out
|
||||
# sub‑second detail (resolution ≈ 200 ms). Keep them in float64 until
|
||||
# Epoch timestamps are ~1.7 × 10⁹; storing them in float32 wipes out
|
||||
# sub‑second detail (resolution ≈ 200 ms). Keep them in float64 until
|
||||
# after subtraction so we preserve Δt‑on‑the‑order‑of‑milliseconds.
|
||||
# --- timestamps ----------
|
||||
t0 = min(
|
||||
@ -1093,12 +888,6 @@ def calculate_camera_affinity_matrix_jax(
|
||||
diff2d = norm_trk[:, None, :, :] - norm_det[None, :, :, :]
|
||||
dist2d: Float[Array, "T D J"] = jnp.linalg.norm(diff2d, axis=-1)
|
||||
|
||||
jax.debug.print(
|
||||
"[JAX] norm_trk[0,0,:6] = {}", norm_trk[0, :, :6] # shape (J,2) 取前6
|
||||
)
|
||||
jax.debug.print("[JAX] norm_det[0,:6] = {}", norm_det[0, :6]) # shape (J,2)
|
||||
jax.debug.print("[JAX] dist2d(T0,D0) first6 = {}", dist2d[0, 0, :6])
|
||||
|
||||
# Compute per-keypoint 2D affinity
|
||||
delta_t_broadcast = delta_t[:, :, None] # (T, D, 1)
|
||||
affinity_2d = (
|
||||
@ -1155,11 +944,6 @@ def calculate_camera_affinity_matrix_jax(
|
||||
w_3d * (1 - dist3d / alpha_3d) * jnp.exp(-lambda_a * delta_t_broadcast)
|
||||
)
|
||||
|
||||
jax.debug.print("[JAX] aff3d(T0,D0) first6 = {}", affinity_3d[0, 0, :6])
|
||||
jax.debug.print("[JAX] aff2d(T0,D0) first6 = {}", affinity_2d[0, 0, :6])
|
||||
jax.debug.print(
|
||||
"[JAX] aff2d.shape={}; aff3d.shape={}", affinity_2d.shape, affinity_3d.shape
|
||||
)
|
||||
# ------------------------------------------------------------------
|
||||
# Combine and reduce across keypoints → (T, D)
|
||||
# ------------------------------------------------------------------
|
||||
@ -1167,60 +951,57 @@ def calculate_camera_affinity_matrix_jax(
|
||||
return total_affinity # type: ignore[return-value]
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Debug helper – compare JAX vs reference implementation
|
||||
# ------------------------------------------------------------------
|
||||
@beartype
|
||||
def debug_compare_affinity_matrices(
|
||||
def calculate_affinity_matrix(
|
||||
trackings: Sequence[Tracking],
|
||||
camera_detections: Sequence[Detection],
|
||||
*,
|
||||
detections: Sequence[Detection] | Mapping[CameraID, list[Detection]],
|
||||
w_2d: float,
|
||||
alpha_2d: float,
|
||||
w_3d: float,
|
||||
alpha_3d: float,
|
||||
lambda_a: float,
|
||||
atol: float = 1e-5,
|
||||
rtol: float = 1e-3,
|
||||
) -> None:
|
||||
) -> dict[CameraID, AffinityResult]:
|
||||
"""
|
||||
Compute both affinity matrices and print out the max absolute / relative
|
||||
difference. If any entry differs more than atol+rtol*|ref|, dump the
|
||||
offending indices so you can inspect individual terms.
|
||||
Calculate the affinity matrix between a set of trackings and detections.
|
||||
|
||||
Args:
|
||||
trackings: Sequence of tracking objects
|
||||
detections: Sequence of detection objects or a group detections by ID
|
||||
w_2d: Weight for 2D affinity
|
||||
alpha_2d: Normalization factor for 2D distance
|
||||
w_3d: Weight for 3D affinity
|
||||
alpha_3d: Normalization factor for 3D distance
|
||||
lambda_a: Decay rate for time difference
|
||||
Returns:
|
||||
A dictionary mapping camera IDs to affinity results.
|
||||
"""
|
||||
aff_jax = calculate_camera_affinity_matrix_jax(
|
||||
trackings,
|
||||
camera_detections,
|
||||
w_2d=w_2d,
|
||||
alpha_2d=alpha_2d,
|
||||
w_3d=w_3d,
|
||||
alpha_3d=alpha_3d,
|
||||
lambda_a=lambda_a,
|
||||
)
|
||||
aff_ref = calculate_camera_affinity_matrix(
|
||||
trackings,
|
||||
camera_detections,
|
||||
w_2d=w_2d,
|
||||
alpha_2d=alpha_2d,
|
||||
w_3d=w_3d,
|
||||
alpha_3d=alpha_3d,
|
||||
lambda_a=lambda_a,
|
||||
)
|
||||
|
||||
diff = jnp.abs(aff_jax - aff_ref)
|
||||
max_abs = float(diff.max())
|
||||
max_rel = float((diff / (jnp.abs(aff_ref) + 1e-12)).max())
|
||||
jax.debug.print(f"[DEBUG] max abs diff {max_abs:.6g}, max rel diff {max_rel:.6g}")
|
||||
|
||||
bad = jnp.where(diff > atol + rtol * jnp.abs(aff_ref))
|
||||
if bad[0].size > 0:
|
||||
for t, d in zip(*[arr.tolist() for arr in bad]):
|
||||
jax.debug.print(
|
||||
f" ↳ mismatch at (T={t}, D={d}): "
|
||||
f"jax={aff_jax[t,d]:.6g}, ref={aff_ref[t,d]:.6g}"
|
||||
)
|
||||
if isinstance(detections, Mapping):
|
||||
detection_by_camera = detections
|
||||
else:
|
||||
jax.debug.print("✅ matrices match within tolerance")
|
||||
detection_by_camera = classify_by_camera(detections)
|
||||
|
||||
res: dict[CameraID, AffinityResult] = {}
|
||||
for camera_id, camera_detections in detection_by_camera.items():
|
||||
affinity_matrix = calculate_camera_affinity_matrix_jax(
|
||||
trackings,
|
||||
camera_detections,
|
||||
w_2d,
|
||||
alpha_2d,
|
||||
w_3d,
|
||||
alpha_3d,
|
||||
lambda_a,
|
||||
)
|
||||
# row, col
|
||||
indices_T, indices_D = linear_sum_assignment(affinity_matrix)
|
||||
affinity_result = AffinityResult(
|
||||
matrix=affinity_matrix,
|
||||
trackings=trackings,
|
||||
detections=camera_detections,
|
||||
indices_T=indices_T,
|
||||
indices_D=indices_D,
|
||||
)
|
||||
res[camera_id] = affinity_result
|
||||
return res
|
||||
|
||||
|
||||
# %%
|
||||
@ -1235,15 +1016,15 @@ trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id)
|
||||
unmatched_detections = shallow_copy(next_group)
|
||||
camera_detections = classify_by_camera(unmatched_detections)
|
||||
|
||||
camera_detections_next_batch = camera_detections["AE_08"]
|
||||
debug_compare_affinity_matrices(
|
||||
affinities = calculate_affinity_matrix(
|
||||
trackings,
|
||||
camera_detections_next_batch,
|
||||
unmatched_detections,
|
||||
w_2d=W_2D,
|
||||
alpha_2d=ALPHA_2D,
|
||||
w_3d=W_3D,
|
||||
alpha_3d=ALPHA_3D,
|
||||
lambda_a=LAMBDA_A,
|
||||
)
|
||||
display(affinities)
|
||||
|
||||
# %%
|
||||
|
||||
@ -15,6 +15,7 @@ dependencies = [
|
||||
"jupytext>=1.17.0",
|
||||
"matplotlib>=3.10.1",
|
||||
"opencv-python-headless>=4.11.0.86",
|
||||
"optax>=0.2.4",
|
||||
"orjson>=3.10.15",
|
||||
"pandas>=2.2.3",
|
||||
"plotly>=6.0.1",
|
||||
|
||||
@ -1,224 +0,0 @@
|
||||
from datetime import datetime, timedelta
|
||||
import time
|
||||
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import pytest
|
||||
from hypothesis import given, settings, HealthCheck
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from app.camera import Camera, CameraParams
|
||||
from playground import (
|
||||
Detection,
|
||||
Tracking,
|
||||
calculate_affinity_matrix,
|
||||
calculate_camera_affinity_matrix,
|
||||
)
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# Helper functions to generate synthetic cameras / trackings / detections
|
||||
# ----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_dummy_camera(cam_id: str, rng: np.random.Generator) -> Camera:
|
||||
K = jnp.eye(3)
|
||||
Rt = jnp.eye(4)
|
||||
dist = jnp.zeros(5)
|
||||
image_size = jnp.array([1000, 1000])
|
||||
params = CameraParams(K=K, Rt=Rt, dist_coeffs=dist, image_size=image_size)
|
||||
return Camera(id=cam_id, params=params)
|
||||
|
||||
|
||||
def _random_keypoints_3d(rng: np.random.Generator, J: int):
|
||||
return jnp.asarray(rng.uniform(-1.0, 1.0, size=(J, 3)).astype(np.float32))
|
||||
|
||||
|
||||
def _random_keypoints_2d(rng: np.random.Generator, J: int):
|
||||
return jnp.asarray(rng.uniform(0.0, 1000.0, size=(J, 2)).astype(np.float32))
|
||||
|
||||
|
||||
def _make_trackings(rng: np.random.Generator, camera: Camera, T: int, J: int):
|
||||
now = datetime.now()
|
||||
trackings = []
|
||||
for i in range(T):
|
||||
kps3d = _random_keypoints_3d(rng, J)
|
||||
trk = Tracking(
|
||||
id=i + 1,
|
||||
keypoints=kps3d,
|
||||
last_active_timestamp=now
|
||||
- timedelta(milliseconds=int(rng.integers(20, 50))),
|
||||
)
|
||||
trackings.append(trk)
|
||||
return trackings
|
||||
|
||||
|
||||
def _make_detections(rng: np.random.Generator, camera: Camera, D: int, J: int):
|
||||
now = datetime.now()
|
||||
detections = []
|
||||
for _ in range(D):
|
||||
kps2d = _random_keypoints_2d(rng, J)
|
||||
det = Detection(
|
||||
keypoints=kps2d,
|
||||
confidences=jnp.ones(J, dtype=jnp.float32),
|
||||
camera=camera,
|
||||
timestamp=now,
|
||||
)
|
||||
detections.append(det)
|
||||
return detections
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# Property-based test: per-camera vs naive slice should match
|
||||
# ----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@settings(max_examples=3, deadline=None, suppress_health_check=[HealthCheck.too_slow])
|
||||
@given(
|
||||
T=st.integers(min_value=1, max_value=4),
|
||||
D=st.integers(min_value=1, max_value=4),
|
||||
J=st.integers(min_value=5, max_value=15),
|
||||
seed=st.integers(min_value=0, max_value=10000),
|
||||
)
|
||||
def test_per_camera_matches_naive(T, D, J, seed):
|
||||
rng = np.random.default_rng(seed)
|
||||
|
||||
cam = _make_dummy_camera("C0", rng)
|
||||
|
||||
trackings = _make_trackings(rng, cam, T, J)
|
||||
detections = _make_detections(rng, cam, D, J)
|
||||
|
||||
# Parameters
|
||||
W_2D = 1.0
|
||||
ALPHA_2D = 1.0
|
||||
LAMBDA_A = 0.1
|
||||
W_3D = 1.0
|
||||
ALPHA_3D = 1.0
|
||||
|
||||
# Compute per-camera affinity (fast)
|
||||
A_fast = calculate_camera_affinity_matrix(
|
||||
trackings,
|
||||
detections,
|
||||
w_2d=W_2D,
|
||||
alpha_2d=ALPHA_2D,
|
||||
w_3d=W_3D,
|
||||
alpha_3d=ALPHA_3D,
|
||||
lambda_a=LAMBDA_A,
|
||||
)
|
||||
|
||||
# Compute naive multi-camera affinity and slice out this camera
|
||||
from collections import OrderedDict
|
||||
|
||||
det_dict = OrderedDict({"C0": detections})
|
||||
A_naive, _ = calculate_affinity_matrix(
|
||||
trackings,
|
||||
det_dict,
|
||||
w_2d=W_2D,
|
||||
alpha_2d=ALPHA_2D,
|
||||
w_3d=W_3D,
|
||||
alpha_3d=ALPHA_3D,
|
||||
lambda_a=LAMBDA_A,
|
||||
)
|
||||
# both fast and naive implementation gives NaN
|
||||
# we need to inject real-world data
|
||||
|
||||
# print("A_fast")
|
||||
# print(A_fast)
|
||||
# print("A_naive")
|
||||
# print(A_naive)
|
||||
|
||||
# They should be close
|
||||
np.testing.assert_allclose(A_fast, np.asarray(A_naive), rtol=1e-5, atol=1e-5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("T,D,J", [(2, 3, 10), (4, 4, 15), (6, 8, 20)])
|
||||
def test_benchmark_affinity_matrix(T, D, J):
|
||||
"""Compare performance between naive and fast affinity matrix calculation."""
|
||||
seed = 42
|
||||
rng = np.random.default_rng(seed)
|
||||
cam = _make_dummy_camera("C0", rng)
|
||||
|
||||
trackings = _make_trackings(rng, cam, T, J)
|
||||
detections = _make_detections(rng, cam, D, J)
|
||||
|
||||
# Parameters
|
||||
w_2d = 1.0
|
||||
alpha_2d = 1.0
|
||||
w_3d = 1.0
|
||||
alpha_3d = 1.0
|
||||
lambda_a = 0.1
|
||||
|
||||
# Setup for naive
|
||||
from collections import OrderedDict
|
||||
|
||||
det_dict = OrderedDict({"C0": detections})
|
||||
|
||||
# First run to compile
|
||||
A_fast = calculate_camera_affinity_matrix(
|
||||
trackings,
|
||||
detections,
|
||||
w_2d=w_2d,
|
||||
alpha_2d=alpha_2d,
|
||||
w_3d=w_3d,
|
||||
alpha_3d=alpha_3d,
|
||||
lambda_a=lambda_a,
|
||||
)
|
||||
A_naive, _ = calculate_affinity_matrix(
|
||||
trackings,
|
||||
det_dict,
|
||||
w_2d=w_2d,
|
||||
alpha_2d=alpha_2d,
|
||||
w_3d=w_3d,
|
||||
alpha_3d=alpha_3d,
|
||||
lambda_a=lambda_a,
|
||||
)
|
||||
|
||||
# Assert they match before timing
|
||||
np.testing.assert_allclose(A_fast, np.asarray(A_naive), rtol=1e-5, atol=1e-5)
|
||||
|
||||
# Timing
|
||||
num_runs = 3
|
||||
|
||||
# Time the vectorized version
|
||||
start = time.perf_counter()
|
||||
for _ in range(num_runs):
|
||||
calculate_camera_affinity_matrix(
|
||||
trackings,
|
||||
detections,
|
||||
w_2d=w_2d,
|
||||
alpha_2d=alpha_2d,
|
||||
w_3d=w_3d,
|
||||
alpha_3d=alpha_3d,
|
||||
lambda_a=lambda_a,
|
||||
)
|
||||
end = time.perf_counter()
|
||||
vectorized_time = (end - start) / num_runs
|
||||
|
||||
# Time the naive version
|
||||
start = time.perf_counter()
|
||||
for _ in range(num_runs):
|
||||
calculate_affinity_matrix(
|
||||
trackings,
|
||||
det_dict,
|
||||
w_2d=w_2d,
|
||||
alpha_2d=alpha_2d,
|
||||
w_3d=w_3d,
|
||||
alpha_3d=alpha_3d,
|
||||
lambda_a=lambda_a,
|
||||
)
|
||||
end = time.perf_counter()
|
||||
naive_time = (end - start) / num_runs
|
||||
|
||||
speedup = naive_time / vectorized_time
|
||||
print(f"\nBenchmark T={T}, D={D}, J={J}:")
|
||||
print(f" Vectorized: {vectorized_time*1000:.2f}ms per run")
|
||||
print(f" Naive: {naive_time*1000:.2f}ms per run")
|
||||
print(f" Speedup: {speedup:.2f}x")
|
||||
|
||||
# Sanity check - vectorized should be faster!
|
||||
assert speedup > 1.0, "Vectorized implementation should be faster"
|
||||
|
||||
|
||||
if __name__ == "__main__" and pytest is not None:
|
||||
# python -m pytest -xvs -k test_benchmark
|
||||
# pytest.main([__file__])
|
||||
pytest.main(["-xvs", __file__, "-k", "test_benchmark"])
|
||||
69
uv.lock
generated
69
uv.lock
generated
@ -16,6 +16,15 @@ resolution-markers = [
|
||||
"(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "absl-py"
|
||||
version = "2.2.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b5/f0/e6342091061ed3a46aadc116b13edd7bb5249c3ab1b3ef07f24b0c248fc3/absl_py-2.2.2.tar.gz", hash = "sha256:bf25b2c2eed013ca456918c453d687eab4e8309fba81ee2f4c1a6aa2494175eb", size = 119982 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f6/d4/349f7f4bd5ea92dab34f5bb0fe31775ef6c311427a14d5a5b31ecb442341/absl_py-2.2.2-py3-none-any.whl", hash = "sha256:e5797bc6abe45f64fd95dc06394ca3f2bedf3b5d895e9da691c9ee3397d70092", size = 135565 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anyio"
|
||||
version = "4.8.0"
|
||||
@ -354,6 +363,24 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/0e/f6/65ecc6878a89bb1c23a086ea335ad4bf21a588990c3f535a227b9eea9108/charset_normalizer-3.4.1-py3-none-any.whl", hash = "sha256:d98b1668f06378c6dbefec3b92299716b931cd4e6061f3c875a71ced1780ab85", size = 49767 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "chex"
|
||||
version = "0.1.89"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "absl-py" },
|
||||
{ name = "jax" },
|
||||
{ name = "jaxlib" },
|
||||
{ name = "numpy" },
|
||||
{ name = "setuptools", marker = "python_full_version >= '3.12'" },
|
||||
{ name = "toolz" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ca/ac/504a8019f7ef372fc6cc3999ec9e3d0fbb38e6992f55d845d5b928010c11/chex-0.1.89.tar.gz", hash = "sha256:78f856e6a0a8459edfcbb402c2c044d2b8102eac4b633838cbdfdcdb09c6c8e0", size = 90676 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/5e/6c/309972937d931069816dc8b28193a650485bc35cca92c04c8c15c4bd181e/chex-0.1.89-py3-none-any.whl", hash = "sha256:145241c27d8944adb634fb7d472a460e1c1b643f561507d4031ad5156ef82dfa", size = 99908 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "colorama"
|
||||
version = "0.4.6"
|
||||
@ -454,6 +481,7 @@ dependencies = [
|
||||
{ name = "jupytext" },
|
||||
{ name = "matplotlib" },
|
||||
{ name = "opencv-python-headless" },
|
||||
{ name = "optax" },
|
||||
{ name = "orjson" },
|
||||
{ name = "pandas" },
|
||||
{ name = "plotly" },
|
||||
@ -482,6 +510,7 @@ requires-dist = [
|
||||
{ name = "jupytext", specifier = ">=1.17.0" },
|
||||
{ name = "matplotlib", specifier = ">=3.10.1" },
|
||||
{ name = "opencv-python-headless", specifier = ">=4.11.0.86" },
|
||||
{ name = "optax", specifier = ">=0.2.4" },
|
||||
{ name = "orjson", specifier = ">=3.10.15" },
|
||||
{ name = "pandas", specifier = ">=2.2.3" },
|
||||
{ name = "plotly", specifier = ">=6.0.1" },
|
||||
@ -583,6 +612,20 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/07/6c/aa3f2f849e01cb6a001cd8554a88d4c77c5c1a31c95bdf1cf9301e6d9ef4/defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61", size = 25604 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "etils"
|
||||
version = "1.12.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/e4/12/1cc11e88a0201280ff389bc4076df7c3432e39d9f22cba8b71aa263f67b8/etils-1.12.2.tar.gz", hash = "sha256:c6b9e1f0ce66d1bbf54f99201b08a60ba396d3446d9eb18d4bc39b26a2e1a5ee", size = 104711 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/dd/71/40ee142e564b8a34a7ae9546e99e665e0001011a3254d5bbbe113d72ccba/etils-1.12.2-py3-none-any.whl", hash = "sha256:4600bec9de6cf5cb043a171e1856e38b5f273719cf3ecef90199f7091a6b3912", size = 167613 },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
epy = [
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "exceptiongroup"
|
||||
version = "1.2.2"
|
||||
@ -1925,6 +1968,23 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl", hash = "sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd", size = 71932 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "optax"
|
||||
version = "0.2.4"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "absl-py" },
|
||||
{ name = "chex" },
|
||||
{ name = "etils", extra = ["epy"] },
|
||||
{ name = "jax" },
|
||||
{ name = "jaxlib" },
|
||||
{ name = "numpy" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/af/b5/f88a0d851547b2e6b2c7e7e6509ad66236b3e7019f1f095bb03dbaa61fa1/optax-0.2.4.tar.gz", hash = "sha256:4e05d3d5307e6dde4c319187ae36e6cd3a0c035d4ed25e9e992449a304f47336", size = 229717 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/5c/24/28d0bb21600a78e46754947333ec9a297044af884d360092eb8561575fe9/optax-0.2.4-py3-none-any.whl", hash = "sha256:db35c04e50b52596662efb002334de08c2a0a74971e4da33f467e84fac08886a", size = 319212 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "orjson"
|
||||
version = "3.10.15"
|
||||
@ -2834,6 +2894,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/6e/c2/61d3e0f47e2b74ef40a68b9e6ad5984f6241a942f7cd3bbfbdbd03861ea9/tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc", size = 14257 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "toolz"
|
||||
version = "1.0.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/8a/0b/d80dfa675bf592f636d1ea0b835eab4ec8df6e9415d8cfd766df54456123/toolz-1.0.0.tar.gz", hash = "sha256:2c86e3d9a04798ac556793bced838816296a2f085017664e4995cb40a1047a02", size = 66790 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/03/98/eb27cc78ad3af8e302c9d8ff4977f5026676e130d28dd7578132a457170c/toolz-1.0.0-py3-none-any.whl", hash = "sha256:292c8f1c4e7516bf9086f8850935c799a874039c8bcf959d47b600e4c44a6236", size = 56383 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "torch"
|
||||
version = "2.6.0"
|
||||
|
||||
Reference in New Issue
Block a user