feat: Enhance 3D pose prediction and camera affinity calculations

- Introduced a JAX-friendly implementation of `predict_pose_3d` to predict 3D poses based on tracking velocities, improving performance by avoiding Python control flow.
- Updated `calculate_camera_affinity_matrix_jax` to streamline the affinity matrix calculation, incorporating velocity handling for trackings and enhancing clarity in the distance computation.
- Added properties for normalized keypoints in the `Detection` class, allowing for lazy evaluation and improved usability.
- Enhanced documentation throughout to clarify function purposes and parameters.
This commit is contained in:
2025-04-29 12:06:07 +08:00
parent 6194f083cb
commit 86fcc5f283
2 changed files with 80 additions and 18 deletions

View File

@ -29,7 +29,6 @@ def undistort_points(
return res.reshape(-1, 2) return res.reshape(-1, 2)
@jax.jit
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
def distortion( def distortion(
points_2d: Num[Array, "N 2"], points_2d: Num[Array, "N 2"],
@ -169,6 +168,9 @@ def unproject_points_to_z_plane(
Returns: Returns:
[..., 3] world-space intersection points [..., 3] world-space intersection points
Note:
This function is not JAX-friendly, since it use `cv2.undistortPoints` internally.
""" """
plane_normal = jnp.array([0.0, 0.0, 1.0]) plane_normal = jnp.array([0.0, 0.0, 1.0])
plane_point = jnp.array([0.0, 0.0, z]) plane_point = jnp.array([0.0, 0.0, z])
@ -406,6 +408,9 @@ class Camera:
Returns: Returns:
[..., 3] world-space intersection points [..., 3] world-space intersection points
Note:
This function is not JAX-friendly, since it use `cv2.undistortPoints` internally.
""" """
return unproject_points_to_z_plane( return unproject_points_to_z_plane(
points_2d, points_2d,
@ -460,6 +465,32 @@ class Detection:
object.__setattr__(self, "_kp_undistorted", kpu) object.__setattr__(self, "_kp_undistorted", kpu)
return kpu return kpu
@property
def normalized_keypoints_undistorted(self) -> Num[Array, "N 2"]:
"""
Returns normalized keypoints, without distortion.
The result is cached on first access. (lazy evaluation)
"""
norm_kps = getattr(self, "_norm_kps_undistorted", None)
if norm_kps is None:
norm_kps = self.keypoints_undistorted / self.camera.params.image_size
object.__setattr__(self, "_norm_kps_undistorted", norm_kps)
return norm_kps
@property
def normalized_keypoints(self) -> Num[Array, "N 2"]:
"""
Returns normalized keypoints.
The result is cached on first access. (lazy evaluation)
"""
norm_kps = getattr(self, "_norm_kps", None)
if norm_kps is None:
norm_kps = self.keypoints / self.camera.params.image_size
object.__setattr__(self, "_norm_kps", norm_kps)
return norm_kps
def __repr__(self) -> str: def __repr__(self) -> str:
return f"Detection({self.camera}, {self.timestamp})" return f"Detection({self.camera}, {self.timestamp})"

View File

@ -746,16 +746,34 @@ def calculate_affinity_3d(
return affinity_per_keypoint return affinity_per_keypoint
@jaxtyped(typechecker=beartype)
def predict_pose_3d( def predict_pose_3d(
tracking: Tracking, tracking: Tracking,
delta_t_s: float, delta_t_s: float,
) -> Float[Array, "J 3"]: ) -> Float[Array, "J 3"]:
""" """
Predict the 3D pose of a tracking based on its velocity. Predict the 3D pose of a tracking based on its velocity.
JAX-friendly implementation that avoids Python control flow.
Args:
tracking: The tracking object containing keypoints and optional velocity
delta_t_s: Time delta in seconds
Returns:
Predicted 3D pose keypoints
""" """
# ------------------------------------------------------------------
# Step 1 decide velocity on the Python side
# ------------------------------------------------------------------
if tracking.velocity is None: if tracking.velocity is None:
return tracking.keypoints velocity = jnp.zeros_like(tracking.keypoints) # (J, 3)
return tracking.keypoints + tracking.velocity * delta_t_s else:
velocity = tracking.velocity # (J, 3)
# ------------------------------------------------------------------
# Step 2 pure JAX math
# ------------------------------------------------------------------
return tracking.keypoints + velocity * delta_t_s
@beartype @beartype
@ -1017,17 +1035,17 @@ def calculate_camera_affinity_matrix_jax(
# Return an empty affinity matrix with appropriate shape. # Return an empty affinity matrix with appropriate shape.
return jnp.zeros((len(trackings), len(camera_detections))) # type: ignore[return-value] return jnp.zeros((len(trackings), len(camera_detections))) # type: ignore[return-value]
cam = next(iter(camera_detections)).camera
# Ensure every detection truly belongs to the same camera (guard clause) # Ensure every detection truly belongs to the same camera (guard clause)
cam_id = camera_detections[0].camera.id cam_id = cam.id
if any(det.camera.id != cam_id for det in camera_detections): if any(det.camera.id != cam_id for det in camera_detections):
raise ValueError( raise ValueError(
"All detections passed to `calculate_camera_affinity_matrix` must come from one camera." "All detections passed to `calculate_camera_affinity_matrix` must come from one camera."
) )
# We will rely on a single `Camera` instance (all detections share it) # We will rely on a single `Camera` instance (all detections share it)
cam = camera_detections[0].camera w_img_, h_img_ = cam.params.image_size
w_img, h_img = cam.params.image_size w_img, h_img = float(w_img_), float(h_img_)
w_img, h_img = float(w_img), float(h_img)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Gather data into ndarray / DeviceArray batches so that we can compute # Gather data into ndarray / DeviceArray batches so that we can compute
@ -1038,6 +1056,7 @@ def calculate_camera_affinity_matrix_jax(
kps3d_trk: Float[Array, "T J 3"] = jnp.stack( kps3d_trk: Float[Array, "T J 3"] = jnp.stack(
[trk.keypoints for trk in trackings] [trk.keypoints for trk in trackings]
) # (T, J, 3) ) # (T, J, 3)
J = kps3d_trk.shape[1]
ts_trk = jnp.array( ts_trk = jnp.array(
[trk.last_active_timestamp.timestamp() for trk in trackings], dtype=jnp.float32 [trk.last_active_timestamp.timestamp() for trk in trackings], dtype=jnp.float32
) # (T,) ) # (T,)
@ -1094,22 +1113,34 @@ def calculate_camera_affinity_matrix_jax(
] # each (J,3) ] # each (J,3)
backproj: Float[Array, "D J 3"] = jnp.stack(backproj_points_list) # (D, J, 3) backproj: Float[Array, "D J 3"] = jnp.stack(backproj_points_list) # (D, J, 3)
# Predicted 3D pose for each tracking (no velocity yet ⇒ same as stored kps) zero_velocity = jnp.zeros((J, 3))
# shape (T, J, 3) trk_velocities = jnp.stack(
predicted_pose: Float[Array, "T J 3"] = kps3d_trk # velocity handled outside [
trk.velocity if trk.velocity is not None else zero_velocity
for trk in trackings
]
)
predicted_pose: Float[Array, "T D J 3"] = (
kps3d_trk[:, None, :, :] # (T,1,J,3)
+ trk_velocities[:, None, :, :] * delta_t[:, :, None, None] # (T,D,1,1)
)
# Camera center shape (3,) -> will broadcast # Camera center shape (3,) -> will broadcast
cam_center = cam.params.location # (3,) cam_center = cam.params.location # (3,)
# Compute perpendicular distance using vectorised formula # Compute perpendicular distance using vectorized formula
# distance = || (p2-p1) × (p1 - P) || / ||p2 - p1|| # distance = || (P - p1) × (p2 - p1) || / ||p2 - p1||
# p1 == cam_center, p2 == backproj, P == predicted_pose # p1 = cam_center, p2 = backproj, P = predicted_pose
p1 = cam_center
v1 = backproj[None, :, :, :] - cam_center # (1, D, J, 3) p2 = backproj
v2 = cam_center - predicted_pose[:, None, :, :] # (T, 1, J, 3) P = predicted_pose
v1 = P - p1
v2 = p2[None, :, :, :] - p1 # (1, D, J, 3)
# jax.debug.print
cross = jnp.cross(v1, v2) # (T, D, J, 3) cross = jnp.cross(v1, v2) # (T, D, J, 3)
num = jnp.linalg.norm(cross, axis=-1) # (T, D, J) num = jnp.linalg.norm(cross, axis=-1) # (T, D, J)
den = jnp.linalg.norm(v1, axis=-1) # (1, D, J) den = jnp.linalg.norm(v2, axis=-1) # (1, D, J)
dist3d: Float[Array, "T D J"] = num / den dist3d: Float[Array, "T D J"] = num / den
affinity_3d = ( affinity_3d = (
@ -1136,7 +1167,7 @@ unmatched_detections = shallow_copy(next_group)
camera_detections = classify_by_camera(unmatched_detections) camera_detections = classify_by_camera(unmatched_detections)
camera_detections_next_batch = camera_detections["AE_08"] camera_detections_next_batch = camera_detections["AE_08"]
affinity = calculate_camera_affinity_matrix( affinity = calculate_camera_affinity_matrix_jax(
trackings, trackings,
camera_detections_next_batch, camera_detections_next_batch,
w_2d=W_2D, w_2d=W_2D,