diff --git a/app/camera/__init__.py b/app/camera/__init__.py index f09352b..68819c7 100644 --- a/app/camera/__init__.py +++ b/app/camera/__init__.py @@ -29,7 +29,6 @@ def undistort_points( return res.reshape(-1, 2) -@jax.jit @jaxtyped(typechecker=beartype) def distortion( points_2d: Num[Array, "N 2"], @@ -169,6 +168,9 @@ def unproject_points_to_z_plane( Returns: [..., 3] world-space intersection points + + Note: + This function is not JAX-friendly, since it use `cv2.undistortPoints` internally. """ plane_normal = jnp.array([0.0, 0.0, 1.0]) plane_point = jnp.array([0.0, 0.0, z]) @@ -406,6 +408,9 @@ class Camera: Returns: [..., 3] world-space intersection points + + Note: + This function is not JAX-friendly, since it use `cv2.undistortPoints` internally. """ return unproject_points_to_z_plane( points_2d, @@ -460,6 +465,32 @@ class Detection: object.__setattr__(self, "_kp_undistorted", kpu) return kpu + @property + def normalized_keypoints_undistorted(self) -> Num[Array, "N 2"]: + """ + Returns normalized keypoints, without distortion. + + The result is cached on first access. (lazy evaluation) + """ + norm_kps = getattr(self, "_norm_kps_undistorted", None) + if norm_kps is None: + norm_kps = self.keypoints_undistorted / self.camera.params.image_size + object.__setattr__(self, "_norm_kps_undistorted", norm_kps) + return norm_kps + + @property + def normalized_keypoints(self) -> Num[Array, "N 2"]: + """ + Returns normalized keypoints. + + The result is cached on first access. (lazy evaluation) + """ + norm_kps = getattr(self, "_norm_kps", None) + if norm_kps is None: + norm_kps = self.keypoints / self.camera.params.image_size + object.__setattr__(self, "_norm_kps", norm_kps) + return norm_kps + def __repr__(self) -> str: return f"Detection({self.camera}, {self.timestamp})" diff --git a/playground.py b/playground.py index dfdce25..6135f7a 100644 --- a/playground.py +++ b/playground.py @@ -746,16 +746,34 @@ def calculate_affinity_3d( return affinity_per_keypoint +@jaxtyped(typechecker=beartype) def predict_pose_3d( tracking: Tracking, delta_t_s: float, ) -> Float[Array, "J 3"]: """ Predict the 3D pose of a tracking based on its velocity. + JAX-friendly implementation that avoids Python control flow. + + Args: + tracking: The tracking object containing keypoints and optional velocity + delta_t_s: Time delta in seconds + + Returns: + Predicted 3D pose keypoints """ + # ------------------------------------------------------------------ + # Step 1 – decide velocity on the Python side + # ------------------------------------------------------------------ if tracking.velocity is None: - return tracking.keypoints - return tracking.keypoints + tracking.velocity * delta_t_s + velocity = jnp.zeros_like(tracking.keypoints) # (J, 3) + else: + velocity = tracking.velocity # (J, 3) + + # ------------------------------------------------------------------ + # Step 2 – pure JAX math + # ------------------------------------------------------------------ + return tracking.keypoints + velocity * delta_t_s @beartype @@ -1017,17 +1035,17 @@ def calculate_camera_affinity_matrix_jax( # Return an empty affinity matrix with appropriate shape. return jnp.zeros((len(trackings), len(camera_detections))) # type: ignore[return-value] + cam = next(iter(camera_detections)).camera # Ensure every detection truly belongs to the same camera (guard clause) - cam_id = camera_detections[0].camera.id + cam_id = cam.id if any(det.camera.id != cam_id for det in camera_detections): raise ValueError( "All detections passed to `calculate_camera_affinity_matrix` must come from one camera." ) # We will rely on a single `Camera` instance (all detections share it) - cam = camera_detections[0].camera - w_img, h_img = cam.params.image_size - w_img, h_img = float(w_img), float(h_img) + w_img_, h_img_ = cam.params.image_size + w_img, h_img = float(w_img_), float(h_img_) # ------------------------------------------------------------------ # Gather data into ndarray / DeviceArray batches so that we can compute @@ -1038,6 +1056,7 @@ def calculate_camera_affinity_matrix_jax( kps3d_trk: Float[Array, "T J 3"] = jnp.stack( [trk.keypoints for trk in trackings] ) # (T, J, 3) + J = kps3d_trk.shape[1] ts_trk = jnp.array( [trk.last_active_timestamp.timestamp() for trk in trackings], dtype=jnp.float32 ) # (T,) @@ -1094,22 +1113,34 @@ def calculate_camera_affinity_matrix_jax( ] # each (J,3) backproj: Float[Array, "D J 3"] = jnp.stack(backproj_points_list) # (D, J, 3) - # Predicted 3D pose for each tracking (no velocity yet ⇒ same as stored kps) - # shape (T, J, 3) - predicted_pose: Float[Array, "T J 3"] = kps3d_trk # velocity handled outside + zero_velocity = jnp.zeros((J, 3)) + trk_velocities = jnp.stack( + [ + trk.velocity if trk.velocity is not None else zero_velocity + for trk in trackings + ] + ) + + predicted_pose: Float[Array, "T D J 3"] = ( + kps3d_trk[:, None, :, :] # (T,1,J,3) + + trk_velocities[:, None, :, :] * delta_t[:, :, None, None] # (T,D,1,1) + ) # Camera center – shape (3,) -> will broadcast cam_center = cam.params.location # (3,) - # Compute perpendicular distance using vectorised formula - # distance = || (p2-p1) × (p1 - P) || / ||p2 - p1|| - # p1 == cam_center, p2 == backproj, P == predicted_pose - - v1 = backproj[None, :, :, :] - cam_center # (1, D, J, 3) - v2 = cam_center - predicted_pose[:, None, :, :] # (T, 1, J, 3) + # Compute perpendicular distance using vectorized formula + # distance = || (P - p1) × (p2 - p1) || / ||p2 - p1|| + # p1 = cam_center, p2 = backproj, P = predicted_pose + p1 = cam_center + p2 = backproj + P = predicted_pose + v1 = P - p1 + v2 = p2[None, :, :, :] - p1 # (1, D, J, 3) + # jax.debug.print cross = jnp.cross(v1, v2) # (T, D, J, 3) num = jnp.linalg.norm(cross, axis=-1) # (T, D, J) - den = jnp.linalg.norm(v1, axis=-1) # (1, D, J) + den = jnp.linalg.norm(v2, axis=-1) # (1, D, J) dist3d: Float[Array, "T D J"] = num / den affinity_3d = ( @@ -1136,7 +1167,7 @@ unmatched_detections = shallow_copy(next_group) camera_detections = classify_by_camera(unmatched_detections) camera_detections_next_batch = camera_detections["AE_08"] -affinity = calculate_camera_affinity_matrix( +affinity = calculate_camera_affinity_matrix_jax( trackings, camera_detections_next_batch, w_2d=W_2D,