fix: fix the timestamp precision error cause the jax version not giving the correct result

Epoch timestamps are ~1.7 × 10⁹; storing them in float32 wipes out
sub‑second detail (resolution ≈ 200 ms).  Keep them in float64 until
after subtraction so we preserve Δt‑on‑the‑order‑of‑milliseconds.

- Introduced a `_DEBUG_CURRENT_TRACKING` variable to track the current indices of tracking and detection during calculations.
- Added a `_global_current_tracking_str` function to format the current tracking state for debugging purposes.
- Enhanced `calculate_distance_2d` and `calculate_tracking_detection_affinity` functions with debug print statements to log intermediate values, improving traceability of calculations.
- Updated `perpendicular_distance_camera_2d_points_to_tracking_raycasting` to accept `delta_t` from the caller while ensuring it adheres to a minimum threshold.
- Refactored timestamp handling in `calculate_camera_affinity_matrix_jax` to maintain precision during calculations.
This commit is contained in:
2025-04-29 12:56:58 +08:00
parent 7dd703edd6
commit 29c8ef3990

View File

@ -68,6 +68,12 @@ AK_CAMERA_DATASET: ak.Array = ak.from_parquet(DATASET_PATH / "camera_params.parq
DELTA_T_MIN = timedelta(milliseconds=10)
display(AK_CAMERA_DATASET)
_DEBUG_CURRENT_TRACKING: tuple[int, int] = (0, 0)
def _global_current_tracking_str():
return str(_DEBUG_CURRENT_TRACKING)
# %%
class Resolution(TypedDict):
@ -586,7 +592,25 @@ def calculate_distance_2d(
else:
left_normalized = left / jnp.array([w, h])
right_normalized = right / jnp.array([w, h])
return jnp.linalg.norm(left_normalized - right_normalized, axis=-1)
dist = jnp.linalg.norm(left_normalized - right_normalized, axis=-1)
lt = left_normalized[:6]
rt = right_normalized[:6]
jax.debug.print(
"[REF]{} norm_trk first6 = {}",
_global_current_tracking_str(),
lt,
)
jax.debug.print(
"[REF]{} norm_det first6 = {}",
_global_current_tracking_str(),
rt,
)
jax.debug.print(
"[REF]{} dist2d first6 = {}",
_global_current_tracking_str(),
dist[:6],
)
return dist
@jaxtyped(typechecker=beartype)
@ -652,6 +676,7 @@ def perpendicular_distance_camera_2d_points_to_tracking_raycasting(
delta_t: timedelta,
) -> Float[Array, "J"]:
"""
NOTE: `delta_t` is now taken from the caller and NOT recomputed internally.
Calculate the perpendicular distances between predicted 3D tracking points
and the rays cast from camera center through the 2D image points.
@ -664,10 +689,9 @@ def perpendicular_distance_camera_2d_points_to_tracking_raycasting(
Array of perpendicular distances for each keypoint
"""
camera = detection.camera
assert detection.timestamp >= tracking.last_active_timestamp
delta_t_raw = detection.timestamp - tracking.last_active_timestamp
# Clamp delta_t to avoid division-by-zero / exploding affinity.
delta_t = max(delta_t_raw, DELTA_T_MIN)
# Use the delta_t supplied by the caller, but clamp to DELTA_T_MIN to
# avoid division-by-zero / exploding affinities.
delta_t = max(delta_t, DELTA_T_MIN)
delta_t_s = delta_t.total_seconds()
predicted_pose = tracking.predict(delta_t_s)
@ -781,6 +805,19 @@ def calculate_tracking_detection_affinity(
lambda_a=lambda_a,
)
jax.debug.print(
"[REF] aff2d{} first6 = {}",
_global_current_tracking_str(),
affinity_2d[:6],
)
jax.debug.print(
"[REF] aff3d{} first6 = {}", _global_current_tracking_str(), affinity_3d[:6]
)
jax.debug.print(
"[REF] aff2d.shape={}; aff3d.shape={}",
affinity_2d.shape,
affinity_3d.shape,
)
# Combine affinities
total_affinity = affinity_2d + affinity_3d
return jnp.sum(total_affinity).item()
@ -938,6 +975,8 @@ def calculate_camera_affinity_matrix(
for i, tracking in enumerate(trackings):
for j, det in enumerate(camera_detections):
global _DEBUG_CURRENT_TRACKING
_DEBUG_CURRENT_TRACKING = (i, j)
affinity_value = calculate_tracking_detection_affinity(
tracking,
det,
@ -1004,24 +1043,35 @@ def calculate_camera_affinity_matrix_jax(
[trk.keypoints for trk in trackings]
) # (T, J, 3)
J = kps3d_trk.shape[1]
ts_trk = jnp.array(
[trk.last_active_timestamp.timestamp() for trk in trackings], dtype=jnp.float32
) # (T,)
# === Detection-side tensors ===
kps2d_det: Float[Array, "D J 2"] = jnp.stack(
[det.keypoints for det in camera_detections]
) # (D, J, 2)
ts_det = jnp.array(
[det.timestamp.timestamp() for det in camera_detections], dtype=jnp.float32
) # (D,)
# ------------------------------------------------------------------
# Compute Δt matrix shape (T, D)
# ------------------------------------------------------------------
delta_t = ts_det[None, :] - ts_trk[:, None] # broadcasting, (T, D)
# Epoch timestamps are ~1.7×10⁹; storing them in float32 wipes out
# subsecond detail (resolution ≈ 200ms). Keep them in float64 until
# after subtraction so we preserve Δtontheorderofmilliseconds.
# --- timestamps ----------
tracking0 = next(iter(trackings))
detection0 = next(iter(camera_detections))
t0 = min(
tracking0.last_active_timestamp, detection0.timestamp
).timestamp() # common origin (float)
ts_trk = jnp.array(
[trk.last_active_timestamp.timestamp() - t0 for trk in trackings],
dtype=jnp.float32, # now small, ms-scale fits in fp32
)
ts_det = jnp.array(
[det.timestamp.timestamp() - t0 for det in camera_detections],
dtype=jnp.float32,
)
# Δt in seconds, fp32 throughout
delta_t = ts_det[None, :] - ts_trk[:, None] # (T,D)
min_dt_s = float(DELTA_T_MIN.total_seconds())
delta_t = jnp.clip(delta_t, a_min=min_dt_s, a_max=None) # ensure ≥ DELTA_T_MIN
delta_t = jnp.clip(delta_t, a_min=min_dt_s, a_max=None)
# ------------------------------------------------------------------
# ---------- 2D affinity -------------------------------------------
@ -1041,6 +1091,12 @@ def calculate_camera_affinity_matrix_jax(
diff2d = norm_trk[:, None, :, :] - norm_det[None, :, :, :]
dist2d: Float[Array, "T D J"] = jnp.linalg.norm(diff2d, axis=-1)
jax.debug.print(
"[JAX] norm_trk[0,0,:6] = {}", norm_trk[0, :, :6] # shape (J,2) 取前6
)
jax.debug.print("[JAX] norm_det[0,:6] = {}", norm_det[0, :6]) # shape (J,2)
jax.debug.print("[JAX] dist2d(T0,D0) first6 = {}", dist2d[0, 0, :6])
# Compute per-keypoint 2D affinity
delta_t_broadcast = delta_t[:, :, None] # (T, D, 1)
affinity_2d = (
@ -1097,7 +1153,11 @@ def calculate_camera_affinity_matrix_jax(
w_3d * (1 - dist3d / alpha_3d) * jnp.exp(-lambda_a * delta_t_broadcast)
)
jax.debug.print("Shapes: dist2d {} dist3d {}", dist2d.shape, dist3d.shape)
jax.debug.print("[JAX] aff3d(T0,D0) first6 = {}", affinity_3d[0, 0, :6])
jax.debug.print("[JAX] aff2d(T0,D0) first6 = {}", affinity_2d[0, 0, :6])
jax.debug.print(
"[JAX] aff2d.shape={}; aff3d.shape={}", affinity_2d.shape, affinity_3d.shape
)
# ------------------------------------------------------------------
# Combine and reduce across keypoints → (T, D)
# ------------------------------------------------------------------
@ -1174,30 +1234,6 @@ unmatched_detections = shallow_copy(next_group)
camera_detections = classify_by_camera(unmatched_detections)
camera_detections_next_batch = camera_detections["AE_08"]
affinity = calculate_camera_affinity_matrix_jax(
trackings,
camera_detections_next_batch,
w_2d=W_2D,
alpha_2d=ALPHA_2D,
w_3d=W_3D,
alpha_3d=ALPHA_3D,
lambda_a=LAMBDA_A,
)
display(camera_detections_next_batch)
display(affinity)
affinity_naive, _ = calculate_affinity_matrix(
trackings,
camera_detections_next_batch,
w_2d=W_2D,
alpha_2d=ALPHA_2D,
w_3d=W_3D,
alpha_3d=ALPHA_3D,
lambda_a=LAMBDA_A,
)
display(camera_detections_next_batch)
display(affinity_naive)
debug_compare_affinity_matrices(
trackings,
camera_detections_next_batch,
@ -1209,11 +1245,3 @@ debug_compare_affinity_matrices(
)
# %%
# Perform Hungarian algorithm for assignment for each camera
indices_T, indices_D = linear_sum_assignment(affinity, maximize=True)
indices_T = cast(Sequence[int], indices_T)
indices_D = cast(Sequence[int], indices_D)
display(indices_T)
display(indices_D)
# %%