From 7dd703edd6eecefe4dbaa63df755661943427a9e Mon Sep 17 00:00:00 2001 From: crosstyan Date: Tue, 29 Apr 2025 12:18:19 +0800 Subject: [PATCH] wip --- playground.py | 75 ++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 72 insertions(+), 3 deletions(-) diff --git a/playground.py b/playground.py index 68067c6..2531df0 100644 --- a/playground.py +++ b/playground.py @@ -1074,11 +1074,15 @@ def calculate_camera_affinity_matrix_jax( ) # Camera center – shape (3,) -> will broadcast - cam_center = cam.params.location # (3,) + cam_center = cam.params.location # Compute perpendicular distance using vectorized formula - # distance = || (P - p1) × (p2 - p1) || / ||p2 - p1|| - # p1 = cam_center, p2 = backproj, P = predicted_pose + # p1 = cam_center (3,) + # p2 = backproj (D, J, 3) + # P = predicted_pose (T, D, J, 3) + # Broadcast plan: v1 = P - p1 → (T, D, J, 3) + # v2 = p2[None, ...]-p1 → (1, D, J, 3) + # Shapes now line up; no stray singleton axis. p1 = cam_center p2 = backproj P = predicted_pose @@ -1101,6 +1105,62 @@ def calculate_camera_affinity_matrix_jax( return total_affinity # type: ignore[return-value] +# ------------------------------------------------------------------ +# Debug helper – compare JAX vs reference implementation +# ------------------------------------------------------------------ +@beartype +def debug_compare_affinity_matrices( + trackings: Sequence[Tracking], + camera_detections: Sequence[Detection], + *, + w_2d: float, + alpha_2d: float, + w_3d: float, + alpha_3d: float, + lambda_a: float, + atol: float = 1e-5, + rtol: float = 1e-3, +) -> None: + """ + Compute both affinity matrices and print out the max absolute / relative + difference. If any entry differs more than atol+rtol*|ref|, dump the + offending indices so you can inspect individual terms. + """ + aff_jax = calculate_camera_affinity_matrix_jax( + trackings, + camera_detections, + w_2d=w_2d, + alpha_2d=alpha_2d, + w_3d=w_3d, + alpha_3d=alpha_3d, + lambda_a=lambda_a, + ) + aff_ref = calculate_camera_affinity_matrix( + trackings, + camera_detections, + w_2d=w_2d, + alpha_2d=alpha_2d, + w_3d=w_3d, + alpha_3d=alpha_3d, + lambda_a=lambda_a, + ) + + diff = jnp.abs(aff_jax - aff_ref) + max_abs = float(diff.max()) + max_rel = float((diff / (jnp.abs(aff_ref) + 1e-12)).max()) + jax.debug.print(f"[DEBUG] max abs diff {max_abs:.6g}, max rel diff {max_rel:.6g}") + + bad = jnp.where(diff > atol + rtol * jnp.abs(aff_ref)) + if bad[0].size > 0: + for t, d in zip(*[arr.tolist() for arr in bad]): + jax.debug.print( + f" ↳ mismatch at (T={t}, D={d}): " + f"jax={aff_jax[t,d]:.6g}, ref={aff_ref[t,d]:.6g}" + ) + else: + jax.debug.print("✅ matrices match within tolerance") + + # %% # let's do cross-view association W_2D = 1.0 @@ -1138,6 +1198,15 @@ affinity_naive, _ = calculate_affinity_matrix( display(camera_detections_next_batch) display(affinity_naive) +debug_compare_affinity_matrices( + trackings, + camera_detections_next_batch, + w_2d=W_2D, + alpha_2d=ALPHA_2D, + w_3d=W_3D, + alpha_3d=ALPHA_3D, + lambda_a=LAMBDA_A, +) # %% # Perform Hungarian algorithm for assignment for each camera