fix: fix the timestamp precision error cause the jax version not giving the correct result
Epoch timestamps are ~1.7 × 10⁹; storing them in float32 wipes out sub‑second detail (resolution ≈ 200 ms). Keep them in float64 until after subtraction so we preserve Δt‑on‑the‑order‑of‑milliseconds. - Introduced a `_DEBUG_CURRENT_TRACKING` variable to track the current indices of tracking and detection during calculations. - Added a `_global_current_tracking_str` function to format the current tracking state for debugging purposes. - Enhanced `calculate_distance_2d` and `calculate_tracking_detection_affinity` functions with debug print statements to log intermediate values, improving traceability of calculations. - Updated `perpendicular_distance_camera_2d_points_to_tracking_raycasting` to accept `delta_t` from the caller while ensuring it adheres to a minimum threshold. - Refactored timestamp handling in `calculate_camera_affinity_matrix_jax` to maintain precision during calculations.
This commit is contained in:
122
playground.py
122
playground.py
@ -68,6 +68,12 @@ AK_CAMERA_DATASET: ak.Array = ak.from_parquet(DATASET_PATH / "camera_params.parq
|
|||||||
DELTA_T_MIN = timedelta(milliseconds=10)
|
DELTA_T_MIN = timedelta(milliseconds=10)
|
||||||
display(AK_CAMERA_DATASET)
|
display(AK_CAMERA_DATASET)
|
||||||
|
|
||||||
|
_DEBUG_CURRENT_TRACKING: tuple[int, int] = (0, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def _global_current_tracking_str():
|
||||||
|
return str(_DEBUG_CURRENT_TRACKING)
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
class Resolution(TypedDict):
|
class Resolution(TypedDict):
|
||||||
@ -586,7 +592,25 @@ def calculate_distance_2d(
|
|||||||
else:
|
else:
|
||||||
left_normalized = left / jnp.array([w, h])
|
left_normalized = left / jnp.array([w, h])
|
||||||
right_normalized = right / jnp.array([w, h])
|
right_normalized = right / jnp.array([w, h])
|
||||||
return jnp.linalg.norm(left_normalized - right_normalized, axis=-1)
|
dist = jnp.linalg.norm(left_normalized - right_normalized, axis=-1)
|
||||||
|
lt = left_normalized[:6]
|
||||||
|
rt = right_normalized[:6]
|
||||||
|
jax.debug.print(
|
||||||
|
"[REF]{} norm_trk first6 = {}",
|
||||||
|
_global_current_tracking_str(),
|
||||||
|
lt,
|
||||||
|
)
|
||||||
|
jax.debug.print(
|
||||||
|
"[REF]{} norm_det first6 = {}",
|
||||||
|
_global_current_tracking_str(),
|
||||||
|
rt,
|
||||||
|
)
|
||||||
|
jax.debug.print(
|
||||||
|
"[REF]{} dist2d first6 = {}",
|
||||||
|
_global_current_tracking_str(),
|
||||||
|
dist[:6],
|
||||||
|
)
|
||||||
|
return dist
|
||||||
|
|
||||||
|
|
||||||
@jaxtyped(typechecker=beartype)
|
@jaxtyped(typechecker=beartype)
|
||||||
@ -652,6 +676,7 @@ def perpendicular_distance_camera_2d_points_to_tracking_raycasting(
|
|||||||
delta_t: timedelta,
|
delta_t: timedelta,
|
||||||
) -> Float[Array, "J"]:
|
) -> Float[Array, "J"]:
|
||||||
"""
|
"""
|
||||||
|
NOTE: `delta_t` is now taken from the caller and NOT recomputed internally.
|
||||||
Calculate the perpendicular distances between predicted 3D tracking points
|
Calculate the perpendicular distances between predicted 3D tracking points
|
||||||
and the rays cast from camera center through the 2D image points.
|
and the rays cast from camera center through the 2D image points.
|
||||||
|
|
||||||
@ -664,10 +689,9 @@ def perpendicular_distance_camera_2d_points_to_tracking_raycasting(
|
|||||||
Array of perpendicular distances for each keypoint
|
Array of perpendicular distances for each keypoint
|
||||||
"""
|
"""
|
||||||
camera = detection.camera
|
camera = detection.camera
|
||||||
assert detection.timestamp >= tracking.last_active_timestamp
|
# Use the delta_t supplied by the caller, but clamp to DELTA_T_MIN to
|
||||||
delta_t_raw = detection.timestamp - tracking.last_active_timestamp
|
# avoid division-by-zero / exploding affinities.
|
||||||
# Clamp delta_t to avoid division-by-zero / exploding affinity.
|
delta_t = max(delta_t, DELTA_T_MIN)
|
||||||
delta_t = max(delta_t_raw, DELTA_T_MIN)
|
|
||||||
delta_t_s = delta_t.total_seconds()
|
delta_t_s = delta_t.total_seconds()
|
||||||
predicted_pose = tracking.predict(delta_t_s)
|
predicted_pose = tracking.predict(delta_t_s)
|
||||||
|
|
||||||
@ -781,6 +805,19 @@ def calculate_tracking_detection_affinity(
|
|||||||
lambda_a=lambda_a,
|
lambda_a=lambda_a,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
jax.debug.print(
|
||||||
|
"[REF] aff2d{} first6 = {}",
|
||||||
|
_global_current_tracking_str(),
|
||||||
|
affinity_2d[:6],
|
||||||
|
)
|
||||||
|
jax.debug.print(
|
||||||
|
"[REF] aff3d{} first6 = {}", _global_current_tracking_str(), affinity_3d[:6]
|
||||||
|
)
|
||||||
|
jax.debug.print(
|
||||||
|
"[REF] aff2d.shape={}; aff3d.shape={}",
|
||||||
|
affinity_2d.shape,
|
||||||
|
affinity_3d.shape,
|
||||||
|
)
|
||||||
# Combine affinities
|
# Combine affinities
|
||||||
total_affinity = affinity_2d + affinity_3d
|
total_affinity = affinity_2d + affinity_3d
|
||||||
return jnp.sum(total_affinity).item()
|
return jnp.sum(total_affinity).item()
|
||||||
@ -938,6 +975,8 @@ def calculate_camera_affinity_matrix(
|
|||||||
|
|
||||||
for i, tracking in enumerate(trackings):
|
for i, tracking in enumerate(trackings):
|
||||||
for j, det in enumerate(camera_detections):
|
for j, det in enumerate(camera_detections):
|
||||||
|
global _DEBUG_CURRENT_TRACKING
|
||||||
|
_DEBUG_CURRENT_TRACKING = (i, j)
|
||||||
affinity_value = calculate_tracking_detection_affinity(
|
affinity_value = calculate_tracking_detection_affinity(
|
||||||
tracking,
|
tracking,
|
||||||
det,
|
det,
|
||||||
@ -1004,24 +1043,35 @@ def calculate_camera_affinity_matrix_jax(
|
|||||||
[trk.keypoints for trk in trackings]
|
[trk.keypoints for trk in trackings]
|
||||||
) # (T, J, 3)
|
) # (T, J, 3)
|
||||||
J = kps3d_trk.shape[1]
|
J = kps3d_trk.shape[1]
|
||||||
ts_trk = jnp.array(
|
|
||||||
[trk.last_active_timestamp.timestamp() for trk in trackings], dtype=jnp.float32
|
|
||||||
) # (T,)
|
|
||||||
|
|
||||||
# === Detection-side tensors ===
|
# === Detection-side tensors ===
|
||||||
kps2d_det: Float[Array, "D J 2"] = jnp.stack(
|
kps2d_det: Float[Array, "D J 2"] = jnp.stack(
|
||||||
[det.keypoints for det in camera_detections]
|
[det.keypoints for det in camera_detections]
|
||||||
) # (D, J, 2)
|
) # (D, J, 2)
|
||||||
ts_det = jnp.array(
|
|
||||||
[det.timestamp.timestamp() for det in camera_detections], dtype=jnp.float32
|
|
||||||
) # (D,)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Compute Δt matrix – shape (T, D)
|
# Compute Δt matrix – shape (T, D)
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
delta_t = ts_det[None, :] - ts_trk[:, None] # broadcasting, (T, D)
|
# Epoch timestamps are ~1.7 × 10⁹; storing them in float32 wipes out
|
||||||
|
# sub‑second detail (resolution ≈ 200 ms). Keep them in float64 until
|
||||||
|
# after subtraction so we preserve Δt‑on‑the‑order‑of‑milliseconds.
|
||||||
|
# --- timestamps ----------
|
||||||
|
tracking0 = next(iter(trackings))
|
||||||
|
detection0 = next(iter(camera_detections))
|
||||||
|
t0 = min(
|
||||||
|
tracking0.last_active_timestamp, detection0.timestamp
|
||||||
|
).timestamp() # common origin (float)
|
||||||
|
ts_trk = jnp.array(
|
||||||
|
[trk.last_active_timestamp.timestamp() - t0 for trk in trackings],
|
||||||
|
dtype=jnp.float32, # now small, ms-scale fits in fp32
|
||||||
|
)
|
||||||
|
ts_det = jnp.array(
|
||||||
|
[det.timestamp.timestamp() - t0 for det in camera_detections],
|
||||||
|
dtype=jnp.float32,
|
||||||
|
)
|
||||||
|
# Δt in seconds, fp32 throughout
|
||||||
|
delta_t = ts_det[None, :] - ts_trk[:, None] # (T,D)
|
||||||
min_dt_s = float(DELTA_T_MIN.total_seconds())
|
min_dt_s = float(DELTA_T_MIN.total_seconds())
|
||||||
delta_t = jnp.clip(delta_t, a_min=min_dt_s, a_max=None) # ensure ≥ DELTA_T_MIN
|
delta_t = jnp.clip(delta_t, a_min=min_dt_s, a_max=None)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# ---------- 2D affinity -------------------------------------------
|
# ---------- 2D affinity -------------------------------------------
|
||||||
@ -1041,6 +1091,12 @@ def calculate_camera_affinity_matrix_jax(
|
|||||||
diff2d = norm_trk[:, None, :, :] - norm_det[None, :, :, :]
|
diff2d = norm_trk[:, None, :, :] - norm_det[None, :, :, :]
|
||||||
dist2d: Float[Array, "T D J"] = jnp.linalg.norm(diff2d, axis=-1)
|
dist2d: Float[Array, "T D J"] = jnp.linalg.norm(diff2d, axis=-1)
|
||||||
|
|
||||||
|
jax.debug.print(
|
||||||
|
"[JAX] norm_trk[0,0,:6] = {}", norm_trk[0, :, :6] # shape (J,2) 取前6
|
||||||
|
)
|
||||||
|
jax.debug.print("[JAX] norm_det[0,:6] = {}", norm_det[0, :6]) # shape (J,2)
|
||||||
|
jax.debug.print("[JAX] dist2d(T0,D0) first6 = {}", dist2d[0, 0, :6])
|
||||||
|
|
||||||
# Compute per-keypoint 2D affinity
|
# Compute per-keypoint 2D affinity
|
||||||
delta_t_broadcast = delta_t[:, :, None] # (T, D, 1)
|
delta_t_broadcast = delta_t[:, :, None] # (T, D, 1)
|
||||||
affinity_2d = (
|
affinity_2d = (
|
||||||
@ -1097,7 +1153,11 @@ def calculate_camera_affinity_matrix_jax(
|
|||||||
w_3d * (1 - dist3d / alpha_3d) * jnp.exp(-lambda_a * delta_t_broadcast)
|
w_3d * (1 - dist3d / alpha_3d) * jnp.exp(-lambda_a * delta_t_broadcast)
|
||||||
)
|
)
|
||||||
|
|
||||||
jax.debug.print("Shapes: dist2d {} dist3d {}", dist2d.shape, dist3d.shape)
|
jax.debug.print("[JAX] aff3d(T0,D0) first6 = {}", affinity_3d[0, 0, :6])
|
||||||
|
jax.debug.print("[JAX] aff2d(T0,D0) first6 = {}", affinity_2d[0, 0, :6])
|
||||||
|
jax.debug.print(
|
||||||
|
"[JAX] aff2d.shape={}; aff3d.shape={}", affinity_2d.shape, affinity_3d.shape
|
||||||
|
)
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Combine and reduce across keypoints → (T, D)
|
# Combine and reduce across keypoints → (T, D)
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
@ -1174,30 +1234,6 @@ unmatched_detections = shallow_copy(next_group)
|
|||||||
camera_detections = classify_by_camera(unmatched_detections)
|
camera_detections = classify_by_camera(unmatched_detections)
|
||||||
|
|
||||||
camera_detections_next_batch = camera_detections["AE_08"]
|
camera_detections_next_batch = camera_detections["AE_08"]
|
||||||
affinity = calculate_camera_affinity_matrix_jax(
|
|
||||||
trackings,
|
|
||||||
camera_detections_next_batch,
|
|
||||||
w_2d=W_2D,
|
|
||||||
alpha_2d=ALPHA_2D,
|
|
||||||
w_3d=W_3D,
|
|
||||||
alpha_3d=ALPHA_3D,
|
|
||||||
lambda_a=LAMBDA_A,
|
|
||||||
)
|
|
||||||
display(camera_detections_next_batch)
|
|
||||||
display(affinity)
|
|
||||||
|
|
||||||
affinity_naive, _ = calculate_affinity_matrix(
|
|
||||||
trackings,
|
|
||||||
camera_detections_next_batch,
|
|
||||||
w_2d=W_2D,
|
|
||||||
alpha_2d=ALPHA_2D,
|
|
||||||
w_3d=W_3D,
|
|
||||||
alpha_3d=ALPHA_3D,
|
|
||||||
lambda_a=LAMBDA_A,
|
|
||||||
)
|
|
||||||
display(camera_detections_next_batch)
|
|
||||||
display(affinity_naive)
|
|
||||||
|
|
||||||
debug_compare_affinity_matrices(
|
debug_compare_affinity_matrices(
|
||||||
trackings,
|
trackings,
|
||||||
camera_detections_next_batch,
|
camera_detections_next_batch,
|
||||||
@ -1209,11 +1245,3 @@ debug_compare_affinity_matrices(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# Perform Hungarian algorithm for assignment for each camera
|
|
||||||
indices_T, indices_D = linear_sum_assignment(affinity, maximize=True)
|
|
||||||
indices_T = cast(Sequence[int], indices_T)
|
|
||||||
indices_D = cast(Sequence[int], indices_D)
|
|
||||||
display(indices_T)
|
|
||||||
display(indices_D)
|
|
||||||
|
|
||||||
# %%
|
|
||||||
|
|||||||
Reference in New Issue
Block a user