From 29c8ef399091ec1a5f4a6d9913952813bcd3e54b Mon Sep 17 00:00:00 2001 From: crosstyan Date: Tue, 29 Apr 2025 12:56:58 +0800 Subject: [PATCH] fix: fix the timestamp precision error cause the jax version not giving the correct result MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Epoch timestamps are ~1.7 × 10⁹; storing them in float32 wipes out sub‑second detail (resolution ≈ 200 ms). Keep them in float64 until after subtraction so we preserve Δt‑on‑the‑order‑of‑milliseconds. - Introduced a `_DEBUG_CURRENT_TRACKING` variable to track the current indices of tracking and detection during calculations. - Added a `_global_current_tracking_str` function to format the current tracking state for debugging purposes. - Enhanced `calculate_distance_2d` and `calculate_tracking_detection_affinity` functions with debug print statements to log intermediate values, improving traceability of calculations. - Updated `perpendicular_distance_camera_2d_points_to_tracking_raycasting` to accept `delta_t` from the caller while ensuring it adheres to a minimum threshold. - Refactored timestamp handling in `calculate_camera_affinity_matrix_jax` to maintain precision during calculations. --- playground.py | 122 +++++++++++++++++++++++++++++++------------------- 1 file changed, 75 insertions(+), 47 deletions(-) diff --git a/playground.py b/playground.py index 2531df0..9ff3b69 100644 --- a/playground.py +++ b/playground.py @@ -68,6 +68,12 @@ AK_CAMERA_DATASET: ak.Array = ak.from_parquet(DATASET_PATH / "camera_params.parq DELTA_T_MIN = timedelta(milliseconds=10) display(AK_CAMERA_DATASET) +_DEBUG_CURRENT_TRACKING: tuple[int, int] = (0, 0) + + +def _global_current_tracking_str(): + return str(_DEBUG_CURRENT_TRACKING) + # %% class Resolution(TypedDict): @@ -586,7 +592,25 @@ def calculate_distance_2d( else: left_normalized = left / jnp.array([w, h]) right_normalized = right / jnp.array([w, h]) - return jnp.linalg.norm(left_normalized - right_normalized, axis=-1) + dist = jnp.linalg.norm(left_normalized - right_normalized, axis=-1) + lt = left_normalized[:6] + rt = right_normalized[:6] + jax.debug.print( + "[REF]{} norm_trk first6 = {}", + _global_current_tracking_str(), + lt, + ) + jax.debug.print( + "[REF]{} norm_det first6 = {}", + _global_current_tracking_str(), + rt, + ) + jax.debug.print( + "[REF]{} dist2d first6 = {}", + _global_current_tracking_str(), + dist[:6], + ) + return dist @jaxtyped(typechecker=beartype) @@ -652,6 +676,7 @@ def perpendicular_distance_camera_2d_points_to_tracking_raycasting( delta_t: timedelta, ) -> Float[Array, "J"]: """ + NOTE: `delta_t` is now taken from the caller and NOT recomputed internally. Calculate the perpendicular distances between predicted 3D tracking points and the rays cast from camera center through the 2D image points. @@ -664,10 +689,9 @@ def perpendicular_distance_camera_2d_points_to_tracking_raycasting( Array of perpendicular distances for each keypoint """ camera = detection.camera - assert detection.timestamp >= tracking.last_active_timestamp - delta_t_raw = detection.timestamp - tracking.last_active_timestamp - # Clamp delta_t to avoid division-by-zero / exploding affinity. - delta_t = max(delta_t_raw, DELTA_T_MIN) + # Use the delta_t supplied by the caller, but clamp to DELTA_T_MIN to + # avoid division-by-zero / exploding affinities. + delta_t = max(delta_t, DELTA_T_MIN) delta_t_s = delta_t.total_seconds() predicted_pose = tracking.predict(delta_t_s) @@ -781,6 +805,19 @@ def calculate_tracking_detection_affinity( lambda_a=lambda_a, ) + jax.debug.print( + "[REF] aff2d{} first6 = {}", + _global_current_tracking_str(), + affinity_2d[:6], + ) + jax.debug.print( + "[REF] aff3d{} first6 = {}", _global_current_tracking_str(), affinity_3d[:6] + ) + jax.debug.print( + "[REF] aff2d.shape={}; aff3d.shape={}", + affinity_2d.shape, + affinity_3d.shape, + ) # Combine affinities total_affinity = affinity_2d + affinity_3d return jnp.sum(total_affinity).item() @@ -938,6 +975,8 @@ def calculate_camera_affinity_matrix( for i, tracking in enumerate(trackings): for j, det in enumerate(camera_detections): + global _DEBUG_CURRENT_TRACKING + _DEBUG_CURRENT_TRACKING = (i, j) affinity_value = calculate_tracking_detection_affinity( tracking, det, @@ -1004,24 +1043,35 @@ def calculate_camera_affinity_matrix_jax( [trk.keypoints for trk in trackings] ) # (T, J, 3) J = kps3d_trk.shape[1] - ts_trk = jnp.array( - [trk.last_active_timestamp.timestamp() for trk in trackings], dtype=jnp.float32 - ) # (T,) - # === Detection-side tensors === kps2d_det: Float[Array, "D J 2"] = jnp.stack( [det.keypoints for det in camera_detections] ) # (D, J, 2) - ts_det = jnp.array( - [det.timestamp.timestamp() for det in camera_detections], dtype=jnp.float32 - ) # (D,) # ------------------------------------------------------------------ # Compute Δt matrix – shape (T, D) # ------------------------------------------------------------------ - delta_t = ts_det[None, :] - ts_trk[:, None] # broadcasting, (T, D) + # Epoch timestamps are ~1.7 × 10⁹; storing them in float32 wipes out + # sub‑second detail (resolution ≈ 200 ms). Keep them in float64 until + # after subtraction so we preserve Δt‑on‑the‑order‑of‑milliseconds. + # --- timestamps ---------- + tracking0 = next(iter(trackings)) + detection0 = next(iter(camera_detections)) + t0 = min( + tracking0.last_active_timestamp, detection0.timestamp + ).timestamp() # common origin (float) + ts_trk = jnp.array( + [trk.last_active_timestamp.timestamp() - t0 for trk in trackings], + dtype=jnp.float32, # now small, ms-scale fits in fp32 + ) + ts_det = jnp.array( + [det.timestamp.timestamp() - t0 for det in camera_detections], + dtype=jnp.float32, + ) + # Δt in seconds, fp32 throughout + delta_t = ts_det[None, :] - ts_trk[:, None] # (T,D) min_dt_s = float(DELTA_T_MIN.total_seconds()) - delta_t = jnp.clip(delta_t, a_min=min_dt_s, a_max=None) # ensure ≥ DELTA_T_MIN + delta_t = jnp.clip(delta_t, a_min=min_dt_s, a_max=None) # ------------------------------------------------------------------ # ---------- 2D affinity ------------------------------------------- @@ -1041,6 +1091,12 @@ def calculate_camera_affinity_matrix_jax( diff2d = norm_trk[:, None, :, :] - norm_det[None, :, :, :] dist2d: Float[Array, "T D J"] = jnp.linalg.norm(diff2d, axis=-1) + jax.debug.print( + "[JAX] norm_trk[0,0,:6] = {}", norm_trk[0, :, :6] # shape (J,2) 取前6 + ) + jax.debug.print("[JAX] norm_det[0,:6] = {}", norm_det[0, :6]) # shape (J,2) + jax.debug.print("[JAX] dist2d(T0,D0) first6 = {}", dist2d[0, 0, :6]) + # Compute per-keypoint 2D affinity delta_t_broadcast = delta_t[:, :, None] # (T, D, 1) affinity_2d = ( @@ -1097,7 +1153,11 @@ def calculate_camera_affinity_matrix_jax( w_3d * (1 - dist3d / alpha_3d) * jnp.exp(-lambda_a * delta_t_broadcast) ) - jax.debug.print("Shapes: dist2d {} dist3d {}", dist2d.shape, dist3d.shape) + jax.debug.print("[JAX] aff3d(T0,D0) first6 = {}", affinity_3d[0, 0, :6]) + jax.debug.print("[JAX] aff2d(T0,D0) first6 = {}", affinity_2d[0, 0, :6]) + jax.debug.print( + "[JAX] aff2d.shape={}; aff3d.shape={}", affinity_2d.shape, affinity_3d.shape + ) # ------------------------------------------------------------------ # Combine and reduce across keypoints → (T, D) # ------------------------------------------------------------------ @@ -1174,30 +1234,6 @@ unmatched_detections = shallow_copy(next_group) camera_detections = classify_by_camera(unmatched_detections) camera_detections_next_batch = camera_detections["AE_08"] -affinity = calculate_camera_affinity_matrix_jax( - trackings, - camera_detections_next_batch, - w_2d=W_2D, - alpha_2d=ALPHA_2D, - w_3d=W_3D, - alpha_3d=ALPHA_3D, - lambda_a=LAMBDA_A, -) -display(camera_detections_next_batch) -display(affinity) - -affinity_naive, _ = calculate_affinity_matrix( - trackings, - camera_detections_next_batch, - w_2d=W_2D, - alpha_2d=ALPHA_2D, - w_3d=W_3D, - alpha_3d=ALPHA_3D, - lambda_a=LAMBDA_A, -) -display(camera_detections_next_batch) -display(affinity_naive) - debug_compare_affinity_matrices( trackings, camera_detections_next_batch, @@ -1209,11 +1245,3 @@ debug_compare_affinity_matrices( ) # %% -# Perform Hungarian algorithm for assignment for each camera -indices_T, indices_D = linear_sum_assignment(affinity, maximize=True) -indices_T = cast(Sequence[int], indices_T) -indices_D = cast(Sequence[int], indices_D) -display(indices_T) -display(indices_D) - -# %%