fix: fix the timestamp precision error cause the jax version not giving the correct result

Epoch timestamps are ~1.7 × 10⁹; storing them in float32 wipes out
sub‑second detail (resolution ≈ 200 ms).  Keep them in float64 until
after subtraction so we preserve Δt‑on‑the‑order‑of‑milliseconds.

- Introduced a `_DEBUG_CURRENT_TRACKING` variable to track the current indices of tracking and detection during calculations.
- Added a `_global_current_tracking_str` function to format the current tracking state for debugging purposes.
- Enhanced `calculate_distance_2d` and `calculate_tracking_detection_affinity` functions with debug print statements to log intermediate values, improving traceability of calculations.
- Updated `perpendicular_distance_camera_2d_points_to_tracking_raycasting` to accept `delta_t` from the caller while ensuring it adheres to a minimum threshold.
- Refactored timestamp handling in `calculate_camera_affinity_matrix_jax` to maintain precision during calculations.
This commit is contained in:
2025-04-29 12:56:58 +08:00
parent 7dd703edd6
commit 29c8ef3990

View File

@ -68,6 +68,12 @@ AK_CAMERA_DATASET: ak.Array = ak.from_parquet(DATASET_PATH / "camera_params.parq
DELTA_T_MIN = timedelta(milliseconds=10) DELTA_T_MIN = timedelta(milliseconds=10)
display(AK_CAMERA_DATASET) display(AK_CAMERA_DATASET)
_DEBUG_CURRENT_TRACKING: tuple[int, int] = (0, 0)
def _global_current_tracking_str():
return str(_DEBUG_CURRENT_TRACKING)
# %% # %%
class Resolution(TypedDict): class Resolution(TypedDict):
@ -586,7 +592,25 @@ def calculate_distance_2d(
else: else:
left_normalized = left / jnp.array([w, h]) left_normalized = left / jnp.array([w, h])
right_normalized = right / jnp.array([w, h]) right_normalized = right / jnp.array([w, h])
return jnp.linalg.norm(left_normalized - right_normalized, axis=-1) dist = jnp.linalg.norm(left_normalized - right_normalized, axis=-1)
lt = left_normalized[:6]
rt = right_normalized[:6]
jax.debug.print(
"[REF]{} norm_trk first6 = {}",
_global_current_tracking_str(),
lt,
)
jax.debug.print(
"[REF]{} norm_det first6 = {}",
_global_current_tracking_str(),
rt,
)
jax.debug.print(
"[REF]{} dist2d first6 = {}",
_global_current_tracking_str(),
dist[:6],
)
return dist
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
@ -652,6 +676,7 @@ def perpendicular_distance_camera_2d_points_to_tracking_raycasting(
delta_t: timedelta, delta_t: timedelta,
) -> Float[Array, "J"]: ) -> Float[Array, "J"]:
""" """
NOTE: `delta_t` is now taken from the caller and NOT recomputed internally.
Calculate the perpendicular distances between predicted 3D tracking points Calculate the perpendicular distances between predicted 3D tracking points
and the rays cast from camera center through the 2D image points. and the rays cast from camera center through the 2D image points.
@ -664,10 +689,9 @@ def perpendicular_distance_camera_2d_points_to_tracking_raycasting(
Array of perpendicular distances for each keypoint Array of perpendicular distances for each keypoint
""" """
camera = detection.camera camera = detection.camera
assert detection.timestamp >= tracking.last_active_timestamp # Use the delta_t supplied by the caller, but clamp to DELTA_T_MIN to
delta_t_raw = detection.timestamp - tracking.last_active_timestamp # avoid division-by-zero / exploding affinities.
# Clamp delta_t to avoid division-by-zero / exploding affinity. delta_t = max(delta_t, DELTA_T_MIN)
delta_t = max(delta_t_raw, DELTA_T_MIN)
delta_t_s = delta_t.total_seconds() delta_t_s = delta_t.total_seconds()
predicted_pose = tracking.predict(delta_t_s) predicted_pose = tracking.predict(delta_t_s)
@ -781,6 +805,19 @@ def calculate_tracking_detection_affinity(
lambda_a=lambda_a, lambda_a=lambda_a,
) )
jax.debug.print(
"[REF] aff2d{} first6 = {}",
_global_current_tracking_str(),
affinity_2d[:6],
)
jax.debug.print(
"[REF] aff3d{} first6 = {}", _global_current_tracking_str(), affinity_3d[:6]
)
jax.debug.print(
"[REF] aff2d.shape={}; aff3d.shape={}",
affinity_2d.shape,
affinity_3d.shape,
)
# Combine affinities # Combine affinities
total_affinity = affinity_2d + affinity_3d total_affinity = affinity_2d + affinity_3d
return jnp.sum(total_affinity).item() return jnp.sum(total_affinity).item()
@ -938,6 +975,8 @@ def calculate_camera_affinity_matrix(
for i, tracking in enumerate(trackings): for i, tracking in enumerate(trackings):
for j, det in enumerate(camera_detections): for j, det in enumerate(camera_detections):
global _DEBUG_CURRENT_TRACKING
_DEBUG_CURRENT_TRACKING = (i, j)
affinity_value = calculate_tracking_detection_affinity( affinity_value = calculate_tracking_detection_affinity(
tracking, tracking,
det, det,
@ -1004,24 +1043,35 @@ def calculate_camera_affinity_matrix_jax(
[trk.keypoints for trk in trackings] [trk.keypoints for trk in trackings]
) # (T, J, 3) ) # (T, J, 3)
J = kps3d_trk.shape[1] J = kps3d_trk.shape[1]
ts_trk = jnp.array(
[trk.last_active_timestamp.timestamp() for trk in trackings], dtype=jnp.float32
) # (T,)
# === Detection-side tensors === # === Detection-side tensors ===
kps2d_det: Float[Array, "D J 2"] = jnp.stack( kps2d_det: Float[Array, "D J 2"] = jnp.stack(
[det.keypoints for det in camera_detections] [det.keypoints for det in camera_detections]
) # (D, J, 2) ) # (D, J, 2)
ts_det = jnp.array(
[det.timestamp.timestamp() for det in camera_detections], dtype=jnp.float32
) # (D,)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Compute Δt matrix shape (T, D) # Compute Δt matrix shape (T, D)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
delta_t = ts_det[None, :] - ts_trk[:, None] # broadcasting, (T, D) # Epoch timestamps are ~1.7×10⁹; storing them in float32 wipes out
# subsecond detail (resolution ≈ 200ms). Keep them in float64 until
# after subtraction so we preserve Δtontheorderofmilliseconds.
# --- timestamps ----------
tracking0 = next(iter(trackings))
detection0 = next(iter(camera_detections))
t0 = min(
tracking0.last_active_timestamp, detection0.timestamp
).timestamp() # common origin (float)
ts_trk = jnp.array(
[trk.last_active_timestamp.timestamp() - t0 for trk in trackings],
dtype=jnp.float32, # now small, ms-scale fits in fp32
)
ts_det = jnp.array(
[det.timestamp.timestamp() - t0 for det in camera_detections],
dtype=jnp.float32,
)
# Δt in seconds, fp32 throughout
delta_t = ts_det[None, :] - ts_trk[:, None] # (T,D)
min_dt_s = float(DELTA_T_MIN.total_seconds()) min_dt_s = float(DELTA_T_MIN.total_seconds())
delta_t = jnp.clip(delta_t, a_min=min_dt_s, a_max=None) # ensure ≥ DELTA_T_MIN delta_t = jnp.clip(delta_t, a_min=min_dt_s, a_max=None)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# ---------- 2D affinity ------------------------------------------- # ---------- 2D affinity -------------------------------------------
@ -1041,6 +1091,12 @@ def calculate_camera_affinity_matrix_jax(
diff2d = norm_trk[:, None, :, :] - norm_det[None, :, :, :] diff2d = norm_trk[:, None, :, :] - norm_det[None, :, :, :]
dist2d: Float[Array, "T D J"] = jnp.linalg.norm(diff2d, axis=-1) dist2d: Float[Array, "T D J"] = jnp.linalg.norm(diff2d, axis=-1)
jax.debug.print(
"[JAX] norm_trk[0,0,:6] = {}", norm_trk[0, :, :6] # shape (J,2) 取前6
)
jax.debug.print("[JAX] norm_det[0,:6] = {}", norm_det[0, :6]) # shape (J,2)
jax.debug.print("[JAX] dist2d(T0,D0) first6 = {}", dist2d[0, 0, :6])
# Compute per-keypoint 2D affinity # Compute per-keypoint 2D affinity
delta_t_broadcast = delta_t[:, :, None] # (T, D, 1) delta_t_broadcast = delta_t[:, :, None] # (T, D, 1)
affinity_2d = ( affinity_2d = (
@ -1097,7 +1153,11 @@ def calculate_camera_affinity_matrix_jax(
w_3d * (1 - dist3d / alpha_3d) * jnp.exp(-lambda_a * delta_t_broadcast) w_3d * (1 - dist3d / alpha_3d) * jnp.exp(-lambda_a * delta_t_broadcast)
) )
jax.debug.print("Shapes: dist2d {} dist3d {}", dist2d.shape, dist3d.shape) jax.debug.print("[JAX] aff3d(T0,D0) first6 = {}", affinity_3d[0, 0, :6])
jax.debug.print("[JAX] aff2d(T0,D0) first6 = {}", affinity_2d[0, 0, :6])
jax.debug.print(
"[JAX] aff2d.shape={}; aff3d.shape={}", affinity_2d.shape, affinity_3d.shape
)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Combine and reduce across keypoints → (T, D) # Combine and reduce across keypoints → (T, D)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@ -1174,30 +1234,6 @@ unmatched_detections = shallow_copy(next_group)
camera_detections = classify_by_camera(unmatched_detections) camera_detections = classify_by_camera(unmatched_detections)
camera_detections_next_batch = camera_detections["AE_08"] camera_detections_next_batch = camera_detections["AE_08"]
affinity = calculate_camera_affinity_matrix_jax(
trackings,
camera_detections_next_batch,
w_2d=W_2D,
alpha_2d=ALPHA_2D,
w_3d=W_3D,
alpha_3d=ALPHA_3D,
lambda_a=LAMBDA_A,
)
display(camera_detections_next_batch)
display(affinity)
affinity_naive, _ = calculate_affinity_matrix(
trackings,
camera_detections_next_batch,
w_2d=W_2D,
alpha_2d=ALPHA_2D,
w_3d=W_3D,
alpha_3d=ALPHA_3D,
lambda_a=LAMBDA_A,
)
display(camera_detections_next_batch)
display(affinity_naive)
debug_compare_affinity_matrices( debug_compare_affinity_matrices(
trackings, trackings,
camera_detections_next_batch, camera_detections_next_batch,
@ -1209,11 +1245,3 @@ debug_compare_affinity_matrices(
) )
# %% # %%
# Perform Hungarian algorithm for assignment for each camera
indices_T, indices_D = linear_sum_assignment(affinity, maximize=True)
indices_T = cast(Sequence[int], indices_T)
indices_D = cast(Sequence[int], indices_D)
display(indices_T)
display(indices_D)
# %%