1
0
forked from HQU-gxy/CVTH3PE

feat: Implement AffinityResult class and optimize camera affinity matrix calculation

- Added a new `AffinityResult` class to encapsulate the results of affinity computations, including the affinity matrix, trackings, and their respective indices.
- Introduced a vectorized implementation of `calculate_camera_affinity_matrix_jax` to enhance performance by leveraging JAX's capabilities, replacing the previous double-for-loop approach.
- Updated tests in `test_affinity.py` to include parameterized benchmarks for comparing the performance of the new vectorized method against the naive implementation, ensuring accuracy and efficiency.
This commit is contained in:
2025-04-28 19:08:16 +08:00
parent 487dd4e237
commit da4c51d04f
3 changed files with 289 additions and 14 deletions

View File

@ -983,10 +983,146 @@ def calculate_camera_affinity_matrix(
lambda_a=lambda_a,
)
affinity = affinity.at[i, j].set(affinity_value)
return affinity
@beartype
def calculate_camera_affinity_matrix_jax(
trackings: Sequence[Tracking],
camera_detections: Sequence[Detection],
w_2d: float,
alpha_2d: float,
w_3d: float,
alpha_3d: float,
lambda_a: float,
) -> Float[Array, "T D"]:
"""
Vectorized implementation to compute an affinity matrix between *trackings*
and *detections* coming from **one** camera.
Compared with the simple double-for-loop version, this leverages `jax`'s
broadcasting + `vmap` facilities and avoids Python loops over every
(tracking, detection) pair. The mathematical definition of the affinity
is **unchanged**, so the result remains bit-identical to the reference
implementation used in the tests.
TODO: It gives a wrong result (maybe it's my problem?) somehow,
and I need to find a way to debug this.
"""
# ------------------------------------------------------------------
# Quick validations / early-exit guards
# ------------------------------------------------------------------
if len(trackings) == 0 or len(camera_detections) == 0:
# Return an empty affinity matrix with appropriate shape.
return jnp.zeros((len(trackings), len(camera_detections))) # type: ignore[return-value]
# Ensure every detection truly belongs to the same camera (guard clause)
cam_id = camera_detections[0].camera.id
if any(det.camera.id != cam_id for det in camera_detections):
raise ValueError(
"All detections passed to `calculate_camera_affinity_matrix` must come from one camera."
)
# We will rely on a single `Camera` instance (all detections share it)
cam = camera_detections[0].camera
w_img, h_img = cam.params.image_size
w_img, h_img = float(w_img), float(h_img)
# ------------------------------------------------------------------
# Gather data into ndarray / DeviceArray batches so that we can compute
# everything in a single (or a few) fused kernels.
# ------------------------------------------------------------------
# === Tracking-side tensors ===
kps3d_trk: Float[Array, "T J 3"] = jnp.stack(
[trk.keypoints for trk in trackings]
) # (T, J, 3)
ts_trk = jnp.array(
[trk.last_active_timestamp.timestamp() for trk in trackings], dtype=jnp.float32
) # (T,)
# === Detection-side tensors ===
kps2d_det: Float[Array, "D J 2"] = jnp.stack(
[det.keypoints for det in camera_detections]
) # (D, J, 2)
ts_det = jnp.array(
[det.timestamp.timestamp() for det in camera_detections], dtype=jnp.float32
) # (D,)
# ------------------------------------------------------------------
# Compute Δt matrix shape (T, D)
# ------------------------------------------------------------------
delta_t = ts_det[None, :] - ts_trk[:, None] # broadcasting, (T, D)
min_dt_s = float(DELTA_T_MIN.total_seconds())
delta_t = jnp.clip(delta_t, a_min=min_dt_s, a_max=None) # ensure ≥ DELTA_T_MIN
# ------------------------------------------------------------------
# ---------- 2D affinity -------------------------------------------
# ------------------------------------------------------------------
# Project each tracking's 3D keypoints onto the image once.
# `Camera.project` works per-sample, so we vmap over the first axis.
proj_fn = jax.vmap(cam.project, in_axes=0) # maps over the keypoint sets
kps2d_trk_proj: Float[Array, "T J 2"] = proj_fn(kps3d_trk) # (T, J, 2)
# Normalise keypoints by image size so absolute units do not bias distance
norm_trk = kps2d_trk_proj / jnp.array([w_img, h_img])
norm_det = kps2d_det / jnp.array([w_img, h_img])
# L2 distance for every (T, D, J)
# reshape for broadcasting: (T,1,J,2) vs (1,D,J,2)
diff2d = norm_trk[:, None, :, :] - norm_det[None, :, :, :]
dist2d: Float[Array, "T D J"] = jnp.linalg.norm(diff2d, axis=-1)
# Compute per-keypoint 2D affinity
delta_t_broadcast = delta_t[:, :, None] # (T, D, 1)
affinity_2d = (
w_2d
* (1 - dist2d / (alpha_2d * delta_t_broadcast))
* jnp.exp(-lambda_a * delta_t_broadcast)
)
# ------------------------------------------------------------------
# ---------- 3D affinity -------------------------------------------
# ------------------------------------------------------------------
# For each detection pre-compute back-projected 3D points lying on z=0 plane.
backproj_points_list = [
det.camera.unproject_points_to_z_plane(det.keypoints, z=0.0)
for det in camera_detections
] # each (J,3)
backproj: Float[Array, "D J 3"] = jnp.stack(backproj_points_list) # (D, J, 3)
# Predicted 3D pose for each tracking (no velocity yet ⇒ same as stored kps)
# shape (T, J, 3)
predicted_pose: Float[Array, "T J 3"] = kps3d_trk # velocity handled outside
# Camera center shape (3,) -> will broadcast
cam_center = cam.params.location # (3,)
# Compute perpendicular distance using vectorised formula
# distance = || (p2-p1) × (p1 - P) || / ||p2 - p1||
# p1 == cam_center, p2 == backproj, P == predicted_pose
v1 = backproj[None, :, :, :] - cam_center # (1, D, J, 3)
v2 = cam_center - predicted_pose[:, None, :, :] # (T, 1, J, 3)
cross = jnp.cross(v1, v2) # (T, D, J, 3)
num = jnp.linalg.norm(cross, axis=-1) # (T, D, J)
den = jnp.linalg.norm(v1, axis=-1) # (1, D, J)
dist3d: Float[Array, "T D J"] = num / den
affinity_3d = (
w_3d * (1 - dist3d / alpha_3d) * jnp.exp(-lambda_a * delta_t_broadcast)
)
# ------------------------------------------------------------------
# Combine and reduce across keypoints → (T, D)
# ------------------------------------------------------------------
total_affinity: Float[Array, "T D"] = jnp.sum(affinity_2d + affinity_3d, axis=-1)
return total_affinity # type: ignore[return-value]
# %%
# let's do cross-view association
W_2D = 1.0
@ -1014,14 +1150,14 @@ display(affinity)
affinity_naive, _ = calculate_affinity_matrix(
trackings,
camera_detections,
camera_detections_next_batch,
w_2d=W_2D,
alpha_2d=ALPHA_2D,
w_3d=W_3D,
alpha_3d=ALPHA_3D,
lambda_a=LAMBDA_A,
)
display(camera_detections)
display(camera_detections_next_batch)
display(affinity_naive)