forked from HQU-gxy/CVTH3PE
feat: Enhance 3D pose prediction and camera affinity calculations
- Introduced a JAX-friendly implementation of `predict_pose_3d` to predict 3D poses based on tracking velocities, improving performance by avoiding Python control flow. - Updated `calculate_camera_affinity_matrix_jax` to streamline the affinity matrix calculation, incorporating velocity handling for trackings and enhancing clarity in the distance computation. - Added properties for normalized keypoints in the `Detection` class, allowing for lazy evaluation and improved usability. - Enhanced documentation throughout to clarify function purposes and parameters.
This commit is contained in:
@ -746,16 +746,34 @@ def calculate_affinity_3d(
|
||||
return affinity_per_keypoint
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def predict_pose_3d(
|
||||
tracking: Tracking,
|
||||
delta_t_s: float,
|
||||
) -> Float[Array, "J 3"]:
|
||||
"""
|
||||
Predict the 3D pose of a tracking based on its velocity.
|
||||
JAX-friendly implementation that avoids Python control flow.
|
||||
|
||||
Args:
|
||||
tracking: The tracking object containing keypoints and optional velocity
|
||||
delta_t_s: Time delta in seconds
|
||||
|
||||
Returns:
|
||||
Predicted 3D pose keypoints
|
||||
"""
|
||||
# ------------------------------------------------------------------
|
||||
# Step 1 – decide velocity on the Python side
|
||||
# ------------------------------------------------------------------
|
||||
if tracking.velocity is None:
|
||||
return tracking.keypoints
|
||||
return tracking.keypoints + tracking.velocity * delta_t_s
|
||||
velocity = jnp.zeros_like(tracking.keypoints) # (J, 3)
|
||||
else:
|
||||
velocity = tracking.velocity # (J, 3)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 2 – pure JAX math
|
||||
# ------------------------------------------------------------------
|
||||
return tracking.keypoints + velocity * delta_t_s
|
||||
|
||||
|
||||
@beartype
|
||||
@ -1017,17 +1035,17 @@ def calculate_camera_affinity_matrix_jax(
|
||||
# Return an empty affinity matrix with appropriate shape.
|
||||
return jnp.zeros((len(trackings), len(camera_detections))) # type: ignore[return-value]
|
||||
|
||||
cam = next(iter(camera_detections)).camera
|
||||
# Ensure every detection truly belongs to the same camera (guard clause)
|
||||
cam_id = camera_detections[0].camera.id
|
||||
cam_id = cam.id
|
||||
if any(det.camera.id != cam_id for det in camera_detections):
|
||||
raise ValueError(
|
||||
"All detections passed to `calculate_camera_affinity_matrix` must come from one camera."
|
||||
)
|
||||
|
||||
# We will rely on a single `Camera` instance (all detections share it)
|
||||
cam = camera_detections[0].camera
|
||||
w_img, h_img = cam.params.image_size
|
||||
w_img, h_img = float(w_img), float(h_img)
|
||||
w_img_, h_img_ = cam.params.image_size
|
||||
w_img, h_img = float(w_img_), float(h_img_)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Gather data into ndarray / DeviceArray batches so that we can compute
|
||||
@ -1038,6 +1056,7 @@ def calculate_camera_affinity_matrix_jax(
|
||||
kps3d_trk: Float[Array, "T J 3"] = jnp.stack(
|
||||
[trk.keypoints for trk in trackings]
|
||||
) # (T, J, 3)
|
||||
J = kps3d_trk.shape[1]
|
||||
ts_trk = jnp.array(
|
||||
[trk.last_active_timestamp.timestamp() for trk in trackings], dtype=jnp.float32
|
||||
) # (T,)
|
||||
@ -1094,22 +1113,34 @@ def calculate_camera_affinity_matrix_jax(
|
||||
] # each (J,3)
|
||||
backproj: Float[Array, "D J 3"] = jnp.stack(backproj_points_list) # (D, J, 3)
|
||||
|
||||
# Predicted 3D pose for each tracking (no velocity yet ⇒ same as stored kps)
|
||||
# shape (T, J, 3)
|
||||
predicted_pose: Float[Array, "T J 3"] = kps3d_trk # velocity handled outside
|
||||
zero_velocity = jnp.zeros((J, 3))
|
||||
trk_velocities = jnp.stack(
|
||||
[
|
||||
trk.velocity if trk.velocity is not None else zero_velocity
|
||||
for trk in trackings
|
||||
]
|
||||
)
|
||||
|
||||
predicted_pose: Float[Array, "T D J 3"] = (
|
||||
kps3d_trk[:, None, :, :] # (T,1,J,3)
|
||||
+ trk_velocities[:, None, :, :] * delta_t[:, :, None, None] # (T,D,1,1)
|
||||
)
|
||||
|
||||
# Camera center – shape (3,) -> will broadcast
|
||||
cam_center = cam.params.location # (3,)
|
||||
|
||||
# Compute perpendicular distance using vectorised formula
|
||||
# distance = || (p2-p1) × (p1 - P) || / ||p2 - p1||
|
||||
# p1 == cam_center, p2 == backproj, P == predicted_pose
|
||||
|
||||
v1 = backproj[None, :, :, :] - cam_center # (1, D, J, 3)
|
||||
v2 = cam_center - predicted_pose[:, None, :, :] # (T, 1, J, 3)
|
||||
# Compute perpendicular distance using vectorized formula
|
||||
# distance = || (P - p1) × (p2 - p1) || / ||p2 - p1||
|
||||
# p1 = cam_center, p2 = backproj, P = predicted_pose
|
||||
p1 = cam_center
|
||||
p2 = backproj
|
||||
P = predicted_pose
|
||||
v1 = P - p1
|
||||
v2 = p2[None, :, :, :] - p1 # (1, D, J, 3)
|
||||
# jax.debug.print
|
||||
cross = jnp.cross(v1, v2) # (T, D, J, 3)
|
||||
num = jnp.linalg.norm(cross, axis=-1) # (T, D, J)
|
||||
den = jnp.linalg.norm(v1, axis=-1) # (1, D, J)
|
||||
den = jnp.linalg.norm(v2, axis=-1) # (1, D, J)
|
||||
dist3d: Float[Array, "T D J"] = num / den
|
||||
|
||||
affinity_3d = (
|
||||
@ -1136,7 +1167,7 @@ unmatched_detections = shallow_copy(next_group)
|
||||
camera_detections = classify_by_camera(unmatched_detections)
|
||||
|
||||
camera_detections_next_batch = camera_detections["AE_08"]
|
||||
affinity = calculate_camera_affinity_matrix(
|
||||
affinity = calculate_camera_affinity_matrix_jax(
|
||||
trackings,
|
||||
camera_detections_next_batch,
|
||||
w_2d=W_2D,
|
||||
|
||||
Reference in New Issue
Block a user