feat: Enhance playground.py with new 3D tracking and affinity calculations

- Added functions for calculating perpendicular distances between predicted 3D tracking points and camera rays, improving 3D tracking accuracy.
- Introduced a new function for calculating 3D affinity scores based on distances and time differences, enhancing the integration of 3D tracking with existing systems.
- Updated existing functions to support new data types and improved documentation for clarity on parameters and return values.
- Refactored affinity calculation logic to utilize JAX for performance optimization in distance computations.
This commit is contained in:
2025-04-27 16:56:49 +08:00
parent 5b5ccbd92b
commit a4cc34f599
2 changed files with 154 additions and 17 deletions

View File

@ -568,7 +568,7 @@ def calculate_distance_2d(
left: Num[Array, "J 2"],
right: Num[Array, "J 2"],
image_size: tuple[int, int] = (1, 1),
):
) -> Float[Array, "J"]:
"""
Calculate the *normalized* distance between two sets of keypoints.
@ -576,6 +576,9 @@ def calculate_distance_2d(
left: The left keypoints
right: The right keypoints
image_size: The size of the image
Returns:
Array of normalized Euclidean distances between corresponding keypoints
"""
w, h = image_size
if w == 1 and h == 1:
@ -590,25 +593,41 @@ def calculate_distance_2d(
@jaxtyped(typechecker=beartype)
def calculate_affinity_2d(
distance_2d: float, w_2d: float, alpha_2d: float, lambda_a: float, delta_t: float
distance_2d: Float[Array, "J"],
delta_t: timedelta,
w_2d: float,
alpha_2d: float,
lambda_a: float,
) -> float:
"""
Calculate the affinity between two detections based on the distance between their keypoints.
Calculate the affinity between two detections based on the distances between their keypoints.
The affinity score is calculated by summing individual keypoint affinities:
A_2D = sum(w_2D * (1 - distance_2D / (alpha_2D*delta_t)) * np.exp(-lambda_a * delta_t)) for each keypoint
Args:
distance_2d: The normalized distance between the two keypoints (see `calculate_distance_2d`)
w_2d: The weight of the distance (parameter)
alpha_2d: The alpha value for the distance calculation (parameter)
lambda_a: The lambda value for the distance calculation (parameter)
distance_2d: The normalized distances between keypoints (array with one value per keypoint)
w_2d: The weight for 2D affinity
alpha_2d: The normalization factor for distance
lambda_a: The decay rate for time difference
delta_t: The time delta between the two detections, in seconds
Returns:
Sum of affinity scores across all keypoints
"""
return w_2d * (1 - distance_2d / (alpha_2d * delta_t)) * np.exp(-lambda_a * delta_t)
delta_t_s = delta_t.total_seconds()
affinity_per_keypoint = (
w_2d
* (1 - distance_2d / (alpha_2d * delta_t_s))
* jnp.exp(-lambda_a * delta_t_s)
)
return jnp.sum(affinity_per_keypoint).item()
@jaxtyped(typechecker=beartype)
def perpendicular_distance_point_to_line_two_points(
point: Num[Array, "2"], line: tuple[Num[Array, "2"], Num[Array, "2"]]
):
point: Num[Array, "3"], line: tuple[Num[Array, "3"], Num[Array, "3"]]
) -> Float[Array, ""]:
"""
Calculate the perpendicular distance between a point and a line.
@ -621,20 +640,106 @@ def perpendicular_distance_point_to_line_two_points(
return distance
@jaxtyped(typechecker=beartype)
def perpendicular_distance_camera_2d_points_to_tracking_raycasting(
detection: Detection,
tracking: Tracking,
delta_t: timedelta,
) -> Float[Array, "J"]:
"""
Calculate the perpendicular distances between predicted 3D tracking points
and the rays cast from camera center through the 2D image points.
Args:
detection: The detection object containing 2D keypoints and camera parameters
tracking: The tracking object containing 3D keypoints
delta_t: Time delta between the tracking's last update and current observation
Returns:
Array of perpendicular distances for each keypoint
"""
camera = detection.camera
# Convert timedelta to seconds for prediction
delta_t_s = delta_t.total_seconds()
# Predict the 3D pose based on tracking and delta_t
predicted_pose = predict_pose_3d(tracking, delta_t_s)
# Back-project the 2D points to 3D space (assuming z=0 plane)
back_projected_points = detection.camera.unproject_points_to_z_plane(
detection.keypoints, z=0.0
)
# Get camera center from the camera parameters
camera_center = camera.params.location
# Define function to calculate distance between a predicted point and its corresponding ray
def calc_distance(predicted_point, back_projected_point):
return perpendicular_distance_point_to_line_two_points(
predicted_point, (camera_center, back_projected_point)
)
# Vectorize over all keypoints
vmap_calc_distance = jax.vmap(calc_distance)
# Calculate and return distances for all keypoints
return vmap_calc_distance(predicted_pose, back_projected_points)
@jaxtyped(typechecker=beartype)
def calculate_affinity_3d(
distances: Float[Array, "J"],
delta_t: timedelta,
w_3d: float,
alpha_3d: float,
lambda_a: float,
) -> float:
"""
Calculate 3D affinity score between a tracking and detection.
The affinity score is calculated by summing individual keypoint affinities:
A_3D = sum(w_3D * (1 - dl / alpha_3D) * np.exp(-lambda_a * delta_t)) for each keypoint
Args:
distances: Array of perpendicular distances for each keypoint
delta_t: Time difference between tracking and detection
w_3d: Weight for 3D affinity
alpha_3d: Normalization factor for distance
lambda_a: Decay rate for time difference
Returns:
Sum of affinity scores across all keypoints
"""
delta_t_s = delta_t.total_seconds()
affinity_per_keypoint = (
w_3d * (1 - distances / alpha_3d) * jnp.exp(-lambda_a * delta_t_s)
)
# Sum affinities across all keypoints
return jnp.sum(affinity_per_keypoint).item()
def predict_pose_3d(
tracking: Tracking,
delta_t: float,
delta_t_s: float,
) -> Float[Array, "J 3"]:
"""
Predict the 3D pose of a tracking based on its velocity.
"""
if tracking.velocity is None:
return tracking.keypoints
return tracking.keypoints + tracking.velocity * delta_t
return tracking.keypoints + tracking.velocity * delta_t_s
# %%
# let's do cross-view association
W_2D = 1.0
ALPHA_2D = 1.0
LAMBDA_A = 0.1
W_3D = 1.0
ALPHA_3D = 1.0
LAMBDA_A = 0.1
trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id)
unmatched_detections = shallow_copy(next_group)
# cross-view association matrix with shape (T, D), where T is the number of
@ -647,12 +752,35 @@ unmatched_detections = shallow_copy(next_group)
#
# where T <- [t1..tt]; D <- join(c1..cc), where `cn` is a collection of
# detections from camera `n`
affinity = np.zeros((len(trackings), len(unmatched_detections)))
affinity = jnp.zeros((len(trackings), len(unmatched_detections)))
detection_by_camera = classify_by_camera(unmatched_detections)
for i, tracking in enumerate(trackings):
j = 0
for c, detections in detection_by_camera.items():
camera = next(iter(detections)).camera
# pixel space, unnormalized
tracking_2d_projection = camera.project(tracking.keypoints)
for det in detections:
...
delta_t = det.timestamp - tracking.last_active_timestamp
distance_2d = calculate_distance_2d(tracking_2d_projection, det.keypoints)
affinity_2d = calculate_affinity_2d(
distance_2d,
delta_t,
w_2d=W_2D,
alpha_2d=ALPHA_2D,
lambda_a=LAMBDA_A,
)
distances = perpendicular_distance_camera_2d_points_to_tracking_raycasting(
det, tracking, delta_t
)
affinity_3d = calculate_affinity_3d(
distances,
delta_t,
w_3d=W_3D,
alpha_3d=ALPHA_3D,
lambda_a=LAMBDA_A,
)
affinity_sum = affinity_2d + affinity_3d
affinity = affinity.at[i, j].set(affinity_sum)
j += 1
display(affinity)