wip
This commit is contained in:
@ -1074,11 +1074,15 @@ def calculate_camera_affinity_matrix_jax(
|
||||
)
|
||||
|
||||
# Camera center – shape (3,) -> will broadcast
|
||||
cam_center = cam.params.location # (3,)
|
||||
cam_center = cam.params.location
|
||||
|
||||
# Compute perpendicular distance using vectorized formula
|
||||
# distance = || (P - p1) × (p2 - p1) || / ||p2 - p1||
|
||||
# p1 = cam_center, p2 = backproj, P = predicted_pose
|
||||
# p1 = cam_center (3,)
|
||||
# p2 = backproj (D, J, 3)
|
||||
# P = predicted_pose (T, D, J, 3)
|
||||
# Broadcast plan: v1 = P - p1 → (T, D, J, 3)
|
||||
# v2 = p2[None, ...]-p1 → (1, D, J, 3)
|
||||
# Shapes now line up; no stray singleton axis.
|
||||
p1 = cam_center
|
||||
p2 = backproj
|
||||
P = predicted_pose
|
||||
@ -1101,6 +1105,62 @@ def calculate_camera_affinity_matrix_jax(
|
||||
return total_affinity # type: ignore[return-value]
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Debug helper – compare JAX vs reference implementation
|
||||
# ------------------------------------------------------------------
|
||||
@beartype
|
||||
def debug_compare_affinity_matrices(
|
||||
trackings: Sequence[Tracking],
|
||||
camera_detections: Sequence[Detection],
|
||||
*,
|
||||
w_2d: float,
|
||||
alpha_2d: float,
|
||||
w_3d: float,
|
||||
alpha_3d: float,
|
||||
lambda_a: float,
|
||||
atol: float = 1e-5,
|
||||
rtol: float = 1e-3,
|
||||
) -> None:
|
||||
"""
|
||||
Compute both affinity matrices and print out the max absolute / relative
|
||||
difference. If any entry differs more than atol+rtol*|ref|, dump the
|
||||
offending indices so you can inspect individual terms.
|
||||
"""
|
||||
aff_jax = calculate_camera_affinity_matrix_jax(
|
||||
trackings,
|
||||
camera_detections,
|
||||
w_2d=w_2d,
|
||||
alpha_2d=alpha_2d,
|
||||
w_3d=w_3d,
|
||||
alpha_3d=alpha_3d,
|
||||
lambda_a=lambda_a,
|
||||
)
|
||||
aff_ref = calculate_camera_affinity_matrix(
|
||||
trackings,
|
||||
camera_detections,
|
||||
w_2d=w_2d,
|
||||
alpha_2d=alpha_2d,
|
||||
w_3d=w_3d,
|
||||
alpha_3d=alpha_3d,
|
||||
lambda_a=lambda_a,
|
||||
)
|
||||
|
||||
diff = jnp.abs(aff_jax - aff_ref)
|
||||
max_abs = float(diff.max())
|
||||
max_rel = float((diff / (jnp.abs(aff_ref) + 1e-12)).max())
|
||||
jax.debug.print(f"[DEBUG] max abs diff {max_abs:.6g}, max rel diff {max_rel:.6g}")
|
||||
|
||||
bad = jnp.where(diff > atol + rtol * jnp.abs(aff_ref))
|
||||
if bad[0].size > 0:
|
||||
for t, d in zip(*[arr.tolist() for arr in bad]):
|
||||
jax.debug.print(
|
||||
f" ↳ mismatch at (T={t}, D={d}): "
|
||||
f"jax={aff_jax[t,d]:.6g}, ref={aff_ref[t,d]:.6g}"
|
||||
)
|
||||
else:
|
||||
jax.debug.print("✅ matrices match within tolerance")
|
||||
|
||||
|
||||
# %%
|
||||
# let's do cross-view association
|
||||
W_2D = 1.0
|
||||
@ -1138,6 +1198,15 @@ affinity_naive, _ = calculate_affinity_matrix(
|
||||
display(camera_detections_next_batch)
|
||||
display(affinity_naive)
|
||||
|
||||
debug_compare_affinity_matrices(
|
||||
trackings,
|
||||
camera_detections_next_batch,
|
||||
w_2d=W_2D,
|
||||
alpha_2d=ALPHA_2D,
|
||||
w_3d=W_3D,
|
||||
alpha_3d=ALPHA_3D,
|
||||
lambda_a=LAMBDA_A,
|
||||
)
|
||||
|
||||
# %%
|
||||
# Perform Hungarian algorithm for assignment for each camera
|
||||
|
||||
Reference in New Issue
Block a user