forked from HQU-gxy/CVTH3PE
wip
This commit is contained in:
@ -1074,11 +1074,15 @@ def calculate_camera_affinity_matrix_jax(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Camera center – shape (3,) -> will broadcast
|
# Camera center – shape (3,) -> will broadcast
|
||||||
cam_center = cam.params.location # (3,)
|
cam_center = cam.params.location
|
||||||
|
|
||||||
# Compute perpendicular distance using vectorized formula
|
# Compute perpendicular distance using vectorized formula
|
||||||
# distance = || (P - p1) × (p2 - p1) || / ||p2 - p1||
|
# p1 = cam_center (3,)
|
||||||
# p1 = cam_center, p2 = backproj, P = predicted_pose
|
# p2 = backproj (D, J, 3)
|
||||||
|
# P = predicted_pose (T, D, J, 3)
|
||||||
|
# Broadcast plan: v1 = P - p1 → (T, D, J, 3)
|
||||||
|
# v2 = p2[None, ...]-p1 → (1, D, J, 3)
|
||||||
|
# Shapes now line up; no stray singleton axis.
|
||||||
p1 = cam_center
|
p1 = cam_center
|
||||||
p2 = backproj
|
p2 = backproj
|
||||||
P = predicted_pose
|
P = predicted_pose
|
||||||
@ -1101,6 +1105,62 @@ def calculate_camera_affinity_matrix_jax(
|
|||||||
return total_affinity # type: ignore[return-value]
|
return total_affinity # type: ignore[return-value]
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Debug helper – compare JAX vs reference implementation
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
@beartype
|
||||||
|
def debug_compare_affinity_matrices(
|
||||||
|
trackings: Sequence[Tracking],
|
||||||
|
camera_detections: Sequence[Detection],
|
||||||
|
*,
|
||||||
|
w_2d: float,
|
||||||
|
alpha_2d: float,
|
||||||
|
w_3d: float,
|
||||||
|
alpha_3d: float,
|
||||||
|
lambda_a: float,
|
||||||
|
atol: float = 1e-5,
|
||||||
|
rtol: float = 1e-3,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Compute both affinity matrices and print out the max absolute / relative
|
||||||
|
difference. If any entry differs more than atol+rtol*|ref|, dump the
|
||||||
|
offending indices so you can inspect individual terms.
|
||||||
|
"""
|
||||||
|
aff_jax = calculate_camera_affinity_matrix_jax(
|
||||||
|
trackings,
|
||||||
|
camera_detections,
|
||||||
|
w_2d=w_2d,
|
||||||
|
alpha_2d=alpha_2d,
|
||||||
|
w_3d=w_3d,
|
||||||
|
alpha_3d=alpha_3d,
|
||||||
|
lambda_a=lambda_a,
|
||||||
|
)
|
||||||
|
aff_ref = calculate_camera_affinity_matrix(
|
||||||
|
trackings,
|
||||||
|
camera_detections,
|
||||||
|
w_2d=w_2d,
|
||||||
|
alpha_2d=alpha_2d,
|
||||||
|
w_3d=w_3d,
|
||||||
|
alpha_3d=alpha_3d,
|
||||||
|
lambda_a=lambda_a,
|
||||||
|
)
|
||||||
|
|
||||||
|
diff = jnp.abs(aff_jax - aff_ref)
|
||||||
|
max_abs = float(diff.max())
|
||||||
|
max_rel = float((diff / (jnp.abs(aff_ref) + 1e-12)).max())
|
||||||
|
jax.debug.print(f"[DEBUG] max abs diff {max_abs:.6g}, max rel diff {max_rel:.6g}")
|
||||||
|
|
||||||
|
bad = jnp.where(diff > atol + rtol * jnp.abs(aff_ref))
|
||||||
|
if bad[0].size > 0:
|
||||||
|
for t, d in zip(*[arr.tolist() for arr in bad]):
|
||||||
|
jax.debug.print(
|
||||||
|
f" ↳ mismatch at (T={t}, D={d}): "
|
||||||
|
f"jax={aff_jax[t,d]:.6g}, ref={aff_ref[t,d]:.6g}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
jax.debug.print("✅ matrices match within tolerance")
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# let's do cross-view association
|
# let's do cross-view association
|
||||||
W_2D = 1.0
|
W_2D = 1.0
|
||||||
@ -1138,6 +1198,15 @@ affinity_naive, _ = calculate_affinity_matrix(
|
|||||||
display(camera_detections_next_batch)
|
display(camera_detections_next_batch)
|
||||||
display(affinity_naive)
|
display(affinity_naive)
|
||||||
|
|
||||||
|
debug_compare_affinity_matrices(
|
||||||
|
trackings,
|
||||||
|
camera_detections_next_batch,
|
||||||
|
w_2d=W_2D,
|
||||||
|
alpha_2d=ALPHA_2D,
|
||||||
|
w_3d=W_3D,
|
||||||
|
alpha_3d=ALPHA_3D,
|
||||||
|
lambda_a=LAMBDA_A,
|
||||||
|
)
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# Perform Hungarian algorithm for assignment for each camera
|
# Perform Hungarian algorithm for assignment for each camera
|
||||||
|
|||||||
Reference in New Issue
Block a user