From b3ed20296a816e62c622f5cf5365c134ff4c83f6 Mon Sep 17 00:00:00 2001 From: crosstyan Date: Mon, 28 Apr 2025 18:01:24 +0800 Subject: [PATCH] revert --- app/camera/__init__.py | 14 +-- playground.py | 232 +++++++++++++++++------------------------ 2 files changed, 103 insertions(+), 143 deletions(-) diff --git a/app/camera/__init__.py b/app/camera/__init__.py index b583566..f09352b 100644 --- a/app/camera/__init__.py +++ b/app/camera/__init__.py @@ -227,11 +227,13 @@ def project( # Fall back to normalized coordinates if image_size not provided valid = jnp.all(p2d >= 0, axis=1) & jnp.all(p2d < 1, axis=1) - # only valid points need distortion - if jnp.any(valid): - valid_p2d = p2d[valid] - distorted_valid = distortion(valid_p2d, K, dist_coeffs) - p2d = p2d.at[valid].set(distorted_valid) + # Distort *all* points, then blend results using `where` to keep + # numerical traces inside JAX – this avoids Python ``if`` with a traced + # value (which triggers TracerBoolConversionError when the function is + # vmapped/jitted). + distorted_all = distortion(p2d, K, dist_coeffs) + # Broadcast the valid mask over the last (x,y) dimension + p2d = jnp.where(valid[:, None], distorted_all, p2d) elif dist_coeffs is None and K is None: pass else: @@ -239,7 +241,7 @@ def project( "dist_coeffs and K must be provided together to compute distortion" ) - return jnp.squeeze(p2d) + return p2d # type: ignore @jaxtyped(typechecker=beartype) diff --git a/playground.py b/playground.py index 8e9133a..e2f33f1 100644 --- a/playground.py +++ b/playground.py @@ -46,6 +46,7 @@ from jaxtyping import Array, Float, Num, jaxtyped from matplotlib import pyplot as plt from numpy.typing import ArrayLike from scipy.optimize import linear_sum_assignment +from functools import partial, reduce from scipy.spatial.transform import Rotation as R from typing_extensions import deprecated @@ -65,6 +66,7 @@ NDArray: TypeAlias = np.ndarray # %% DATASET_PATH = Path("samples") / "04_02" AK_CAMERA_DATASET: ak.Array = ak.from_parquet(DATASET_PATH / "camera_params.parquet") +DELTA_T_MIN = timedelta(milliseconds=10) display(AK_CAMERA_DATASET) @@ -388,6 +390,16 @@ def flatten_values( return [v for vs in d.values() for v in vs] +def flatten_values_len( + d: Mapping[Any, Sequence[T]], +) -> int: + """ + Flatten a dictionary of sequences into a single list of values. + """ + val = reduce(lambda acc, xs: acc + len(xs), d.values(), 0) + return val + + # %% WIDTH = 2560 HEIGHT = 1440 @@ -676,6 +688,10 @@ def perpendicular_distance_camera_2d_points_to_tracking_raycasting( Array of perpendicular distances for each keypoint """ camera = detection.camera + assert detection.timestamp >= tracking.last_active_timestamp + delta_t_raw = detection.timestamp - tracking.last_active_timestamp + # Clamp delta_t to avoid division-by-zero / exploding affinity. + delta_t = max(delta_t_raw, DELTA_T_MIN) delta_t_s = delta_t.total_seconds() predicted_pose = predict_pose_3d(tracking, delta_t_s) @@ -769,7 +785,9 @@ def calculate_tracking_detection_affinity( Combined affinity score """ camera = detection.camera - delta_t = detection.timestamp - tracking.last_active_timestamp + delta_t_raw = detection.timestamp - tracking.last_active_timestamp + # Clamp delta_t to avoid division-by-zero / exploding affinity. + delta_t = max(delta_t_raw, DELTA_T_MIN) # Calculate 2D affinity tracking_2d_projection = camera.project(tracking.keypoints) @@ -811,7 +829,7 @@ def calculate_tracking_detection_affinity( @beartype def calculate_affinity_matrix( trackings: Sequence[Tracking], - detections: Sequence[Detection], + detections: Sequence[Detection] | OrderedDict[CameraID, list[Detection]], w_2d: float, alpha_2d: float, w_3d: float, @@ -863,8 +881,13 @@ def calculate_affinity_matrix( through the detection_by_camera dictionary, which is returned alongside the matrix to maintain this relationship. """ - affinity = jnp.zeros((len(trackings), len(detections))) - detection_by_camera = classify_by_camera(detections) + if isinstance(detections, OrderedDict): + D = flatten_values_len(detections) + affinity = jnp.zeros((len(trackings), D)) + detection_by_camera = detections + else: + affinity = jnp.zeros((len(trackings), len(detections))) + detection_by_camera = classify_by_camera(detections) for i, tracking in enumerate(trackings): j = 0 @@ -896,152 +919,73 @@ def calculate_camera_affinity_matrix( lambda_a: float, ) -> Float[Array, "T D"]: """ - Vectorized version (with JAX) that computes the affinity matrix between a set - of *trackings* and *detections* coming from **one** camera. + Calculate an affinity matrix between trackings and detections from a single camera. - The whole computation is done with JAX array operations and `vmap` – no - explicit Python ``for``-loops over the (T, D) pairs. This makes the routine - fully parallelisable on CPU/GPU/TPU without any extra `jit` compilation. + This follows the iterative camera-by-camera approach from the paper + "Cross-View Tracking for Multi-Human 3D Pose Estimation at over 100 FPS". + Instead of creating one large matrix for all cameras, this creates + a separate matrix for each camera, which can be processed independently. - Args - ----- - trackings : Sequence[Tracking] - Existing 3-D track states (length = T) - camera_detections : Sequence[Detection] - Detections from *a single* camera (length = D). All detections **must** - share the same ``detection.camera`` instance. - w_2d, alpha_2d, w_3d, alpha_3d, lambda_a : float - Hyper-parameters exactly as defined in the paper (and earlier helper - functions). + Args: + trackings: Sequence of tracking objects + camera_detections: Sequence of detection objects, from the same camera + w_2d: Weight for 2D affinity + alpha_2d: Normalization factor for 2D distance + w_3d: Weight for 3D affinity + alpha_3d: Normalization factor for 3D distance + lambda_a: Decay rate for time difference - Returns - ------- - affinity : jnp.ndarray (T x D) - Affinity matrix between each tracking (row) and detection (column). + Returns: + Affinity matrix of shape (T, D) where: + - T = number of trackings (rows) + - D = number of detections from this specific camera (columns) - Matrix Layout - ------- - The affinity matrix for a single camera has shape (T, D), where: - - T = number of trackings (rows) - - D = number of detections from this camera (columns) + Matrix Layout: + The affinity matrix for a single camera has shape (T, D), where: + - T = number of trackings (rows) + - D = number of detections from this camera (columns) - The matrix is organized as follows: + The matrix is organized as follows: - ``` - | Detections from Camera c | - | d1 d2 d3 ... | - ---------+------------------------+ - Track 1 | a11 a12 a13 ... | - Track 2 | a21 a22 a23 ... | - ... | ... ... ... ... | - Track t | at1 at2 at3 ... | - ``` + ``` + | Detections from Camera c | + | d1 d2 d3 ... | + ---------+------------------------+ + Track 1 | a11 a12 a13 ... | + Track 2 | a21 a22 a23 ... | + ... | ... ... ... ... | + Track t | at1 at2 at3 ... | + ``` - Each cell aij represents the affinity between tracking i and detection j, - computed using both 2D and 3D geometric correspondences. + Each cell aij represents the affinity between tracking i and detection j, + computed using both 2D and 3D geometric correspondences. """ - # ---------- Safety checks & early exits -------------------------------- - if len(trackings) == 0 or len(camera_detections) == 0: - return jnp.zeros((len(trackings), len(camera_detections))) # pragma: no cover + def verify_all_detection_from_same_camera(detections: Sequence[Detection]): + if not detections: + return True + camera_id = next(iter(detections)).camera.id + return all(map(lambda d: d.camera.id == camera_id, detections)) - # Ensure all detections come from the *same* camera - cam_id_ref = camera_detections[0].camera.id - if any(det.camera.id != cam_id_ref for det in camera_detections): - raise ValueError( - "All detections given to calculate_camera_affinity_matrix must come from the same camera." - ) + if not verify_all_detection_from_same_camera(camera_detections): + raise ValueError("All detections must be from the same camera") - camera = camera_detections[0].camera # shared camera object - cam_w, cam_h = map(int, camera.params.image_size) - cam_center = camera.params.location # (3,) + affinity = jnp.zeros((len(trackings), len(camera_detections))) - # ---------- Pack tracking data into JAX arrays ------------------------- - # (T, J, 3) - track_kps_3d = jnp.stack([trk.keypoints for trk in trackings]) - - # (T, 3) velocity – zero if None - velocities = jnp.stack( - [ - ( - trk.velocity - if trk.velocity is not None - else jnp.zeros(3, dtype=jnp.float32) + for i, tracking in enumerate(trackings): + for j, det in enumerate(camera_detections): + affinity_value = calculate_tracking_detection_affinity( + tracking, + det, + w_2d=w_2d, + alpha_2d=alpha_2d, + w_3d=w_3d, + alpha_3d=alpha_3d, + lambda_a=lambda_a, ) - for trk in trackings - ] - ) + affinity = affinity.at[i, j].set(affinity_value) - # (T,) last update timestamps (float seconds) - track_last_ts = jnp.array( - [trk.last_active_timestamp.timestamp() for trk in trackings] - ) - - # Pre-project 3-D tracking points into 2-D for *this* camera – (T, J, 2) - track_proj_2d = jax.vmap(camera.project)(track_kps_3d) - - # ---------- Pack detection data ---------------------------------------- - # (D, J, 2) - det_kps_2d = jnp.stack([det.keypoints for det in camera_detections]) - - # (D,) detection timestamps (float seconds) - det_ts = jnp.array([det.timestamp.timestamp() for det in camera_detections]) - - # Back-project detection 2-D points to the z=0 plane in world coords – (D, J, 3) - det_backproj_3d = camera.unproject_points_to_z_plane(det_kps_2d, z=0.0) - - # ---------- Broadcast / compute pair-wise quantities -------------------- - # Time differences Δt (T, D) – always non-negative because detections are newer - delta_t = jnp.maximum(det_ts[None, :] - track_last_ts[:, None], 0.0) - - # ---------- 2-D affinity -------------------------------------------------- - # Normalise 2-D points by image size (already handled in helper but easier here) - track_proj_norm = track_proj_2d / jnp.array([cam_w, cam_h]) # (T, J, 2) - det_kps_norm = det_kps_2d / jnp.array([cam_w, cam_h]) # (D, J, 2) - - # (T, D, J) Euclidean distances in normalised image space - dist_2d = jnp.linalg.norm( - track_proj_norm[:, None, :, :] - det_kps_norm[None, :, :, :], - axis=-1, - ) - - # (T, D, 1) for broadcasting with J dimension - delta_t_exp = delta_t[:, :, None] - - affinity_2d_per_kp = ( - w_2d - * (1.0 - dist_2d / (alpha_2d * jnp.clip(delta_t_exp, a_min=1e-6))) - * jnp.exp(-lambda_a * delta_t_exp) - ) - affinity_2d = jnp.sum(affinity_2d_per_kp, axis=-1) # (T, D) - - # ---------- 3-D affinity -------------------------------------------------- - # Predict 3-D pose at detection time for each (T, D) pair – (T, D, J, 3) - predicted_pose = ( - track_kps_3d[:, None, :, :] - + velocities[:, None, None, :] * delta_t_exp[..., None] - ) - - # Camera ray for each detection/keypoint – (1, D, J, 3) - line_vec = det_backproj_3d[None, :, :, :] - cam_center # broadcast (T, D, J, 3) - - # Vector from camera centre to predicted point – (T, D, J, 3) - vec_cam_to_pred = cam_center - predicted_pose - - # Cross-product norm and distance - cross_prod = jnp.cross(line_vec, vec_cam_to_pred) - numer = jnp.linalg.norm(cross_prod, axis=-1) # (T, D, J) - denom = jnp.linalg.norm(line_vec, axis=-1) # (1, D, J) broadcast automatically - dist_3d = numer / jnp.clip(denom, a_min=1e-6) - - affinity_3d_per_kp = ( - w_3d * (1.0 - dist_3d / alpha_3d) * jnp.exp(-lambda_a * delta_t_exp) - ) - affinity_3d = jnp.sum(affinity_3d_per_kp, axis=-1) # (T, D) - - # ---------- Final affinity ---------------------------------------------- - affinity_total = affinity_2d + affinity_3d # (T, D) - return affinity_total + return affinity # %% @@ -1056,17 +1000,31 @@ trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id) unmatched_detections = shallow_copy(next_group) camera_detections = classify_by_camera(unmatched_detections) +camera_detections_next_batch = camera_detections["AE_08"] affinity = calculate_camera_affinity_matrix( trackings, - next(iter(camera_detections.values())), + camera_detections_next_batch, w_2d=W_2D, alpha_2d=ALPHA_2D, w_3d=W_3D, alpha_3d=ALPHA_3D, lambda_a=LAMBDA_A, ) +display(camera_detections_next_batch) display(affinity) +affinity_naive, _ = calculate_affinity_matrix( + trackings, + camera_detections, + w_2d=W_2D, + alpha_2d=ALPHA_2D, + w_3d=W_3D, + alpha_3d=ALPHA_3D, + lambda_a=LAMBDA_A, +) +display(camera_detections) +display(affinity_naive) + # %% # Perform Hungarian algorithm for assignment for each camera