feat: Implement AffinityResult class and optimize camera affinity matrix calculation
- Added a new `AffinityResult` class to encapsulate the results of affinity computations, including the affinity matrix, trackings, and their respective indices. - Introduced a vectorized implementation of `calculate_camera_affinity_matrix_jax` to enhance performance by leveraging JAX's capabilities, replacing the previous double-for-loop approach. - Updated tests in `test_affinity.py` to include parameterized benchmarks for comparing the performance of the new vectorized method against the naive implementation, ensuring accuracy and efficiency.
This commit is contained in:
34
affinity_result.py
Normal file
34
affinity_result.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
|
from playground import Tracking
|
||||||
|
from beartype.typing import Sequence, Mapping
|
||||||
|
from jaxtyping import jaxtyped, Float, Int
|
||||||
|
from jax import Array
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AffinityResult:
|
||||||
|
"""
|
||||||
|
Result of affinity computation between trackings and detections.
|
||||||
|
"""
|
||||||
|
|
||||||
|
matrix: Float[Array, "T D"]
|
||||||
|
"""
|
||||||
|
Affinity matrix between trackings and detections.
|
||||||
|
"""
|
||||||
|
|
||||||
|
trackings: Sequence[Tracking]
|
||||||
|
"""
|
||||||
|
Trackings used to compute the affinity matrix.
|
||||||
|
"""
|
||||||
|
|
||||||
|
indices_T: Sequence[int]
|
||||||
|
"""
|
||||||
|
Indices of the trackings that were used to compute the affinity matrix.
|
||||||
|
"""
|
||||||
|
|
||||||
|
indices_D: Sequence[int]
|
||||||
|
"""
|
||||||
|
Indices of the detections that were used to compute the affinity matrix.
|
||||||
|
"""
|
||||||
142
playground.py
142
playground.py
@ -983,10 +983,146 @@ def calculate_camera_affinity_matrix(
|
|||||||
lambda_a=lambda_a,
|
lambda_a=lambda_a,
|
||||||
)
|
)
|
||||||
affinity = affinity.at[i, j].set(affinity_value)
|
affinity = affinity.at[i, j].set(affinity_value)
|
||||||
|
|
||||||
return affinity
|
return affinity
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def calculate_camera_affinity_matrix_jax(
|
||||||
|
trackings: Sequence[Tracking],
|
||||||
|
camera_detections: Sequence[Detection],
|
||||||
|
w_2d: float,
|
||||||
|
alpha_2d: float,
|
||||||
|
w_3d: float,
|
||||||
|
alpha_3d: float,
|
||||||
|
lambda_a: float,
|
||||||
|
) -> Float[Array, "T D"]:
|
||||||
|
"""
|
||||||
|
Vectorized implementation to compute an affinity matrix between *trackings*
|
||||||
|
and *detections* coming from **one** camera.
|
||||||
|
|
||||||
|
Compared with the simple double-for-loop version, this leverages `jax`'s
|
||||||
|
broadcasting + `vmap` facilities and avoids Python loops over every
|
||||||
|
(tracking, detection) pair. The mathematical definition of the affinity
|
||||||
|
is **unchanged**, so the result remains bit-identical to the reference
|
||||||
|
implementation used in the tests.
|
||||||
|
|
||||||
|
TODO: It gives a wrong result (maybe it's my problem?) somehow,
|
||||||
|
and I need to find a way to debug this.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Quick validations / early-exit guards
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
if len(trackings) == 0 or len(camera_detections) == 0:
|
||||||
|
# Return an empty affinity matrix with appropriate shape.
|
||||||
|
return jnp.zeros((len(trackings), len(camera_detections))) # type: ignore[return-value]
|
||||||
|
|
||||||
|
# Ensure every detection truly belongs to the same camera (guard clause)
|
||||||
|
cam_id = camera_detections[0].camera.id
|
||||||
|
if any(det.camera.id != cam_id for det in camera_detections):
|
||||||
|
raise ValueError(
|
||||||
|
"All detections passed to `calculate_camera_affinity_matrix` must come from one camera."
|
||||||
|
)
|
||||||
|
|
||||||
|
# We will rely on a single `Camera` instance (all detections share it)
|
||||||
|
cam = camera_detections[0].camera
|
||||||
|
w_img, h_img = cam.params.image_size
|
||||||
|
w_img, h_img = float(w_img), float(h_img)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Gather data into ndarray / DeviceArray batches so that we can compute
|
||||||
|
# everything in a single (or a few) fused kernels.
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
# === Tracking-side tensors ===
|
||||||
|
kps3d_trk: Float[Array, "T J 3"] = jnp.stack(
|
||||||
|
[trk.keypoints for trk in trackings]
|
||||||
|
) # (T, J, 3)
|
||||||
|
ts_trk = jnp.array(
|
||||||
|
[trk.last_active_timestamp.timestamp() for trk in trackings], dtype=jnp.float32
|
||||||
|
) # (T,)
|
||||||
|
|
||||||
|
# === Detection-side tensors ===
|
||||||
|
kps2d_det: Float[Array, "D J 2"] = jnp.stack(
|
||||||
|
[det.keypoints for det in camera_detections]
|
||||||
|
) # (D, J, 2)
|
||||||
|
ts_det = jnp.array(
|
||||||
|
[det.timestamp.timestamp() for det in camera_detections], dtype=jnp.float32
|
||||||
|
) # (D,)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Compute Δt matrix – shape (T, D)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
delta_t = ts_det[None, :] - ts_trk[:, None] # broadcasting, (T, D)
|
||||||
|
min_dt_s = float(DELTA_T_MIN.total_seconds())
|
||||||
|
delta_t = jnp.clip(delta_t, a_min=min_dt_s, a_max=None) # ensure ≥ DELTA_T_MIN
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# ---------- 2D affinity -------------------------------------------
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Project each tracking's 3D keypoints onto the image once.
|
||||||
|
# `Camera.project` works per-sample, so we vmap over the first axis.
|
||||||
|
|
||||||
|
proj_fn = jax.vmap(cam.project, in_axes=0) # maps over the keypoint sets
|
||||||
|
kps2d_trk_proj: Float[Array, "T J 2"] = proj_fn(kps3d_trk) # (T, J, 2)
|
||||||
|
|
||||||
|
# Normalise keypoints by image size so absolute units do not bias distance
|
||||||
|
norm_trk = kps2d_trk_proj / jnp.array([w_img, h_img])
|
||||||
|
norm_det = kps2d_det / jnp.array([w_img, h_img])
|
||||||
|
|
||||||
|
# L2 distance for every (T, D, J)
|
||||||
|
# reshape for broadcasting: (T,1,J,2) vs (1,D,J,2)
|
||||||
|
diff2d = norm_trk[:, None, :, :] - norm_det[None, :, :, :]
|
||||||
|
dist2d: Float[Array, "T D J"] = jnp.linalg.norm(diff2d, axis=-1)
|
||||||
|
|
||||||
|
# Compute per-keypoint 2D affinity
|
||||||
|
delta_t_broadcast = delta_t[:, :, None] # (T, D, 1)
|
||||||
|
affinity_2d = (
|
||||||
|
w_2d
|
||||||
|
* (1 - dist2d / (alpha_2d * delta_t_broadcast))
|
||||||
|
* jnp.exp(-lambda_a * delta_t_broadcast)
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# ---------- 3D affinity -------------------------------------------
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# For each detection pre-compute back-projected 3D points lying on z=0 plane.
|
||||||
|
|
||||||
|
backproj_points_list = [
|
||||||
|
det.camera.unproject_points_to_z_plane(det.keypoints, z=0.0)
|
||||||
|
for det in camera_detections
|
||||||
|
] # each (J,3)
|
||||||
|
backproj: Float[Array, "D J 3"] = jnp.stack(backproj_points_list) # (D, J, 3)
|
||||||
|
|
||||||
|
# Predicted 3D pose for each tracking (no velocity yet ⇒ same as stored kps)
|
||||||
|
# shape (T, J, 3)
|
||||||
|
predicted_pose: Float[Array, "T J 3"] = kps3d_trk # velocity handled outside
|
||||||
|
|
||||||
|
# Camera center – shape (3,) -> will broadcast
|
||||||
|
cam_center = cam.params.location # (3,)
|
||||||
|
|
||||||
|
# Compute perpendicular distance using vectorised formula
|
||||||
|
# distance = || (p2-p1) × (p1 - P) || / ||p2 - p1||
|
||||||
|
# p1 == cam_center, p2 == backproj, P == predicted_pose
|
||||||
|
|
||||||
|
v1 = backproj[None, :, :, :] - cam_center # (1, D, J, 3)
|
||||||
|
v2 = cam_center - predicted_pose[:, None, :, :] # (T, 1, J, 3)
|
||||||
|
cross = jnp.cross(v1, v2) # (T, D, J, 3)
|
||||||
|
num = jnp.linalg.norm(cross, axis=-1) # (T, D, J)
|
||||||
|
den = jnp.linalg.norm(v1, axis=-1) # (1, D, J)
|
||||||
|
dist3d: Float[Array, "T D J"] = num / den
|
||||||
|
|
||||||
|
affinity_3d = (
|
||||||
|
w_3d * (1 - dist3d / alpha_3d) * jnp.exp(-lambda_a * delta_t_broadcast)
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Combine and reduce across keypoints → (T, D)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
total_affinity: Float[Array, "T D"] = jnp.sum(affinity_2d + affinity_3d, axis=-1)
|
||||||
|
return total_affinity # type: ignore[return-value]
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# let's do cross-view association
|
# let's do cross-view association
|
||||||
W_2D = 1.0
|
W_2D = 1.0
|
||||||
@ -1014,14 +1150,14 @@ display(affinity)
|
|||||||
|
|
||||||
affinity_naive, _ = calculate_affinity_matrix(
|
affinity_naive, _ = calculate_affinity_matrix(
|
||||||
trackings,
|
trackings,
|
||||||
camera_detections,
|
camera_detections_next_batch,
|
||||||
w_2d=W_2D,
|
w_2d=W_2D,
|
||||||
alpha_2d=ALPHA_2D,
|
alpha_2d=ALPHA_2D,
|
||||||
w_3d=W_3D,
|
w_3d=W_3D,
|
||||||
alpha_3d=ALPHA_3D,
|
alpha_3d=ALPHA_3D,
|
||||||
lambda_a=LAMBDA_A,
|
lambda_a=LAMBDA_A,
|
||||||
)
|
)
|
||||||
display(camera_detections)
|
display(camera_detections_next_batch)
|
||||||
display(affinity_naive)
|
display(affinity_naive)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
import time
|
||||||
|
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -86,15 +87,22 @@ def test_per_camera_matches_naive(T, D, J, seed):
|
|||||||
trackings = _make_trackings(rng, cam, T, J)
|
trackings = _make_trackings(rng, cam, T, J)
|
||||||
detections = _make_detections(rng, cam, D, J)
|
detections = _make_detections(rng, cam, D, J)
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
W_2D = 1.0
|
||||||
|
ALPHA_2D = 1.0
|
||||||
|
LAMBDA_A = 0.1
|
||||||
|
W_3D = 1.0
|
||||||
|
ALPHA_3D = 1.0
|
||||||
|
|
||||||
# Compute per-camera affinity (fast)
|
# Compute per-camera affinity (fast)
|
||||||
A_fast = calculate_camera_affinity_matrix(
|
A_fast = calculate_camera_affinity_matrix(
|
||||||
trackings,
|
trackings,
|
||||||
detections,
|
detections,
|
||||||
w_2d=1.0,
|
w_2d=W_2D,
|
||||||
alpha_2d=1.0,
|
alpha_2d=ALPHA_2D,
|
||||||
w_3d=1.0,
|
w_3d=W_3D,
|
||||||
alpha_3d=1.0,
|
alpha_3d=ALPHA_3D,
|
||||||
lambda_a=0.1,
|
lambda_a=LAMBDA_A,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compute naive multi-camera affinity and slice out this camera
|
# Compute naive multi-camera affinity and slice out this camera
|
||||||
@ -104,16 +112,113 @@ def test_per_camera_matches_naive(T, D, J, seed):
|
|||||||
A_naive, _ = calculate_affinity_matrix(
|
A_naive, _ = calculate_affinity_matrix(
|
||||||
trackings,
|
trackings,
|
||||||
det_dict,
|
det_dict,
|
||||||
w_2d=1.0,
|
w_2d=W_2D,
|
||||||
alpha_2d=1.0,
|
alpha_2d=ALPHA_2D,
|
||||||
w_3d=1.0,
|
w_3d=W_3D,
|
||||||
alpha_3d=1.0,
|
alpha_3d=ALPHA_3D,
|
||||||
lambda_a=0.1,
|
lambda_a=LAMBDA_A,
|
||||||
)
|
)
|
||||||
|
# both fast and naive implementation gives NaN
|
||||||
|
# we need to inject real-world data
|
||||||
|
|
||||||
|
# print("A_fast")
|
||||||
|
# print(A_fast)
|
||||||
|
# print("A_naive")
|
||||||
|
# print(A_naive)
|
||||||
|
|
||||||
# They should be close
|
# They should be close
|
||||||
np.testing.assert_allclose(A_fast, np.asarray(A_naive), rtol=1e-5, atol=1e-5)
|
np.testing.assert_allclose(A_fast, np.asarray(A_naive), rtol=1e-5, atol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("T,D,J", [(2, 3, 10), (4, 4, 15), (6, 8, 20)])
|
||||||
|
def test_benchmark_affinity_matrix(T, D, J):
|
||||||
|
"""Compare performance between naive and fast affinity matrix calculation."""
|
||||||
|
seed = 42
|
||||||
|
rng = np.random.default_rng(seed)
|
||||||
|
cam = _make_dummy_camera("C0", rng)
|
||||||
|
|
||||||
|
trackings = _make_trackings(rng, cam, T, J)
|
||||||
|
detections = _make_detections(rng, cam, D, J)
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
w_2d = 1.0
|
||||||
|
alpha_2d = 1.0
|
||||||
|
w_3d = 1.0
|
||||||
|
alpha_3d = 1.0
|
||||||
|
lambda_a = 0.1
|
||||||
|
|
||||||
|
# Setup for naive
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
det_dict = OrderedDict({"C0": detections})
|
||||||
|
|
||||||
|
# First run to compile
|
||||||
|
A_fast = calculate_camera_affinity_matrix(
|
||||||
|
trackings,
|
||||||
|
detections,
|
||||||
|
w_2d=w_2d,
|
||||||
|
alpha_2d=alpha_2d,
|
||||||
|
w_3d=w_3d,
|
||||||
|
alpha_3d=alpha_3d,
|
||||||
|
lambda_a=lambda_a,
|
||||||
|
)
|
||||||
|
A_naive, _ = calculate_affinity_matrix(
|
||||||
|
trackings,
|
||||||
|
det_dict,
|
||||||
|
w_2d=w_2d,
|
||||||
|
alpha_2d=alpha_2d,
|
||||||
|
w_3d=w_3d,
|
||||||
|
alpha_3d=alpha_3d,
|
||||||
|
lambda_a=lambda_a,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert they match before timing
|
||||||
|
np.testing.assert_allclose(A_fast, np.asarray(A_naive), rtol=1e-5, atol=1e-5)
|
||||||
|
|
||||||
|
# Timing
|
||||||
|
num_runs = 3
|
||||||
|
|
||||||
|
# Time the vectorized version
|
||||||
|
start = time.perf_counter()
|
||||||
|
for _ in range(num_runs):
|
||||||
|
calculate_camera_affinity_matrix(
|
||||||
|
trackings,
|
||||||
|
detections,
|
||||||
|
w_2d=w_2d,
|
||||||
|
alpha_2d=alpha_2d,
|
||||||
|
w_3d=w_3d,
|
||||||
|
alpha_3d=alpha_3d,
|
||||||
|
lambda_a=lambda_a,
|
||||||
|
)
|
||||||
|
end = time.perf_counter()
|
||||||
|
vectorized_time = (end - start) / num_runs
|
||||||
|
|
||||||
|
# Time the naive version
|
||||||
|
start = time.perf_counter()
|
||||||
|
for _ in range(num_runs):
|
||||||
|
calculate_affinity_matrix(
|
||||||
|
trackings,
|
||||||
|
det_dict,
|
||||||
|
w_2d=w_2d,
|
||||||
|
alpha_2d=alpha_2d,
|
||||||
|
w_3d=w_3d,
|
||||||
|
alpha_3d=alpha_3d,
|
||||||
|
lambda_a=lambda_a,
|
||||||
|
)
|
||||||
|
end = time.perf_counter()
|
||||||
|
naive_time = (end - start) / num_runs
|
||||||
|
|
||||||
|
speedup = naive_time / vectorized_time
|
||||||
|
print(f"\nBenchmark T={T}, D={D}, J={J}:")
|
||||||
|
print(f" Vectorized: {vectorized_time*1000:.2f}ms per run")
|
||||||
|
print(f" Naive: {naive_time*1000:.2f}ms per run")
|
||||||
|
print(f" Speedup: {speedup:.2f}x")
|
||||||
|
|
||||||
|
# Sanity check - vectorized should be faster!
|
||||||
|
assert speedup > 1.0, "Vectorized implementation should be faster"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__" and pytest is not None:
|
if __name__ == "__main__" and pytest is not None:
|
||||||
pytest.main([__file__])
|
# python -m pytest -xvs -k test_benchmark
|
||||||
|
# pytest.main([__file__])
|
||||||
|
pytest.main(["-xvs", __file__, "-k", "test_benchmark"])
|
||||||
|
|||||||
Reference in New Issue
Block a user