1
0
forked from HQU-gxy/CVTH3PE
This commit is contained in:
2025-04-29 12:18:19 +08:00
parent 65cc646927
commit 7dd703edd6

View File

@ -1074,11 +1074,15 @@ def calculate_camera_affinity_matrix_jax(
)
# Camera center shape (3,) -> will broadcast
cam_center = cam.params.location # (3,)
cam_center = cam.params.location
# Compute perpendicular distance using vectorized formula
# distance = || (P - p1) × (p2 - p1) || / ||p2 - p1||
# p1 = cam_center, p2 = backproj, P = predicted_pose
# p1 = cam_center (3,)
# p2 = backproj (D, J, 3)
# P = predicted_pose (T, D, J, 3)
# Broadcast plan: v1 = P - p1 → (T, D, J, 3)
# v2 = p2[None, ...]-p1 → (1, D, J, 3)
# Shapes now line up; no stray singleton axis.
p1 = cam_center
p2 = backproj
P = predicted_pose
@ -1101,6 +1105,62 @@ def calculate_camera_affinity_matrix_jax(
return total_affinity # type: ignore[return-value]
# ------------------------------------------------------------------
# Debug helper compare JAX vs reference implementation
# ------------------------------------------------------------------
@beartype
def debug_compare_affinity_matrices(
trackings: Sequence[Tracking],
camera_detections: Sequence[Detection],
*,
w_2d: float,
alpha_2d: float,
w_3d: float,
alpha_3d: float,
lambda_a: float,
atol: float = 1e-5,
rtol: float = 1e-3,
) -> None:
"""
Compute both affinity matrices and print out the max absolute / relative
difference. If any entry differs more than atol+rtol*|ref|, dump the
offending indices so you can inspect individual terms.
"""
aff_jax = calculate_camera_affinity_matrix_jax(
trackings,
camera_detections,
w_2d=w_2d,
alpha_2d=alpha_2d,
w_3d=w_3d,
alpha_3d=alpha_3d,
lambda_a=lambda_a,
)
aff_ref = calculate_camera_affinity_matrix(
trackings,
camera_detections,
w_2d=w_2d,
alpha_2d=alpha_2d,
w_3d=w_3d,
alpha_3d=alpha_3d,
lambda_a=lambda_a,
)
diff = jnp.abs(aff_jax - aff_ref)
max_abs = float(diff.max())
max_rel = float((diff / (jnp.abs(aff_ref) + 1e-12)).max())
jax.debug.print(f"[DEBUG] max abs diff {max_abs:.6g}, max rel diff {max_rel:.6g}")
bad = jnp.where(diff > atol + rtol * jnp.abs(aff_ref))
if bad[0].size > 0:
for t, d in zip(*[arr.tolist() for arr in bad]):
jax.debug.print(
f" ↳ mismatch at (T={t}, D={d}): "
f"jax={aff_jax[t,d]:.6g}, ref={aff_ref[t,d]:.6g}"
)
else:
jax.debug.print("✅ matrices match within tolerance")
# %%
# let's do cross-view association
W_2D = 1.0
@ -1138,6 +1198,15 @@ affinity_naive, _ = calculate_affinity_matrix(
display(camera_detections_next_batch)
display(affinity_naive)
debug_compare_affinity_matrices(
trackings,
camera_detections_next_batch,
w_2d=W_2D,
alpha_2d=ALPHA_2D,
w_3d=W_3D,
alpha_3d=ALPHA_3D,
lambda_a=LAMBDA_A,
)
# %%
# Perform Hungarian algorithm for assignment for each camera