feat: Enhance playground.py with new 3D tracking and affinity calculations

- Added functions for calculating perpendicular distances between predicted 3D tracking points and camera rays, improving 3D tracking accuracy.
- Introduced a new function for calculating 3D affinity scores based on distances and time differences, enhancing the integration of 3D tracking with existing systems.
- Updated existing functions to support new data types and improved documentation for clarity on parameters and return values.
- Refactored affinity calculation logic to utilize JAX for performance optimization in distance computations.
This commit is contained in:
2025-04-27 16:56:49 +08:00
parent 5b5ccbd92b
commit a4cc34f599
2 changed files with 154 additions and 17 deletions

View File

@ -103,8 +103,10 @@ def unproject_points_onto_plane(
(i.e. back-project points onto a plane) (i.e. back-project points onto a plane)
`intersect_image_rays_with_plane`/`compute_ray_plane_intersections`
Args: Args:
points_2d: [..., 2] image pixel coordinates points_2d: [..., 2] image pixel coordinates (with camera distortion)
plane_normal: (3,) normal vector of the plane in world coords plane_normal: (3,) normal vector of the plane in world coords
plane_point: (3,) a known point on the plane in world coords plane_point: (3,) a known point on the plane in world coords
K: Camera intrinsic matrix K: Camera intrinsic matrix
@ -118,7 +120,7 @@ def unproject_points_onto_plane(
Returns: Returns:
[..., 3] world-space intersection points [..., 3] world-space intersection points
""" """
# Step 1: undistort (no-op here) # Step 1: undistort
pts = undistort_points( pts = undistort_points(
np.asarray(points_2d), np.asarray(K), np.asarray(dist_coeffs) np.asarray(points_2d), np.asarray(K), np.asarray(dist_coeffs)
) )
@ -313,6 +315,13 @@ class CameraParams:
object.__setattr__(self, "_proj", pm) object.__setattr__(self, "_proj", pm)
return pm return pm
@property
def location(self) -> Num[Array, "3"]:
"""
The 3D location of camera (relative to world coordinate system)
"""
return self.pose_matrix[:3, -1].reshape((3,))
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
@dataclass(frozen=True) @dataclass(frozen=True)
@ -390,7 +399,7 @@ class Camera:
Un-project 2D points to 3D points on a plane at z = constant. Un-project 2D points to 3D points on a plane at z = constant.
Args: Args:
points_2d: 2D points in pixel coordinates points_2d: 2D points in pixel coordinates (with camera distortion)
z: z-coordinate of the plane (default: 0.0, i.e. ground/horizon/floor plane) z: z-coordinate of the plane (default: 0.0, i.e. ground/horizon/floor plane)
Returns: Returns:

View File

@ -568,7 +568,7 @@ def calculate_distance_2d(
left: Num[Array, "J 2"], left: Num[Array, "J 2"],
right: Num[Array, "J 2"], right: Num[Array, "J 2"],
image_size: tuple[int, int] = (1, 1), image_size: tuple[int, int] = (1, 1),
): ) -> Float[Array, "J"]:
""" """
Calculate the *normalized* distance between two sets of keypoints. Calculate the *normalized* distance between two sets of keypoints.
@ -576,6 +576,9 @@ def calculate_distance_2d(
left: The left keypoints left: The left keypoints
right: The right keypoints right: The right keypoints
image_size: The size of the image image_size: The size of the image
Returns:
Array of normalized Euclidean distances between corresponding keypoints
""" """
w, h = image_size w, h = image_size
if w == 1 and h == 1: if w == 1 and h == 1:
@ -590,25 +593,41 @@ def calculate_distance_2d(
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
def calculate_affinity_2d( def calculate_affinity_2d(
distance_2d: float, w_2d: float, alpha_2d: float, lambda_a: float, delta_t: float distance_2d: Float[Array, "J"],
delta_t: timedelta,
w_2d: float,
alpha_2d: float,
lambda_a: float,
) -> float: ) -> float:
""" """
Calculate the affinity between two detections based on the distance between their keypoints. Calculate the affinity between two detections based on the distances between their keypoints.
The affinity score is calculated by summing individual keypoint affinities:
A_2D = sum(w_2D * (1 - distance_2D / (alpha_2D*delta_t)) * np.exp(-lambda_a * delta_t)) for each keypoint
Args: Args:
distance_2d: The normalized distance between the two keypoints (see `calculate_distance_2d`) distance_2d: The normalized distances between keypoints (array with one value per keypoint)
w_2d: The weight of the distance (parameter) w_2d: The weight for 2D affinity
alpha_2d: The alpha value for the distance calculation (parameter) alpha_2d: The normalization factor for distance
lambda_a: The lambda value for the distance calculation (parameter) lambda_a: The decay rate for time difference
delta_t: The time delta between the two detections, in seconds delta_t: The time delta between the two detections, in seconds
Returns:
Sum of affinity scores across all keypoints
""" """
return w_2d * (1 - distance_2d / (alpha_2d * delta_t)) * np.exp(-lambda_a * delta_t) delta_t_s = delta_t.total_seconds()
affinity_per_keypoint = (
w_2d
* (1 - distance_2d / (alpha_2d * delta_t_s))
* jnp.exp(-lambda_a * delta_t_s)
)
return jnp.sum(affinity_per_keypoint).item()
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
def perpendicular_distance_point_to_line_two_points( def perpendicular_distance_point_to_line_two_points(
point: Num[Array, "2"], line: tuple[Num[Array, "2"], Num[Array, "2"]] point: Num[Array, "3"], line: tuple[Num[Array, "3"], Num[Array, "3"]]
): ) -> Float[Array, ""]:
""" """
Calculate the perpendicular distance between a point and a line. Calculate the perpendicular distance between a point and a line.
@ -621,20 +640,106 @@ def perpendicular_distance_point_to_line_two_points(
return distance return distance
@jaxtyped(typechecker=beartype)
def perpendicular_distance_camera_2d_points_to_tracking_raycasting(
detection: Detection,
tracking: Tracking,
delta_t: timedelta,
) -> Float[Array, "J"]:
"""
Calculate the perpendicular distances between predicted 3D tracking points
and the rays cast from camera center through the 2D image points.
Args:
detection: The detection object containing 2D keypoints and camera parameters
tracking: The tracking object containing 3D keypoints
delta_t: Time delta between the tracking's last update and current observation
Returns:
Array of perpendicular distances for each keypoint
"""
camera = detection.camera
# Convert timedelta to seconds for prediction
delta_t_s = delta_t.total_seconds()
# Predict the 3D pose based on tracking and delta_t
predicted_pose = predict_pose_3d(tracking, delta_t_s)
# Back-project the 2D points to 3D space (assuming z=0 plane)
back_projected_points = detection.camera.unproject_points_to_z_plane(
detection.keypoints, z=0.0
)
# Get camera center from the camera parameters
camera_center = camera.params.location
# Define function to calculate distance between a predicted point and its corresponding ray
def calc_distance(predicted_point, back_projected_point):
return perpendicular_distance_point_to_line_two_points(
predicted_point, (camera_center, back_projected_point)
)
# Vectorize over all keypoints
vmap_calc_distance = jax.vmap(calc_distance)
# Calculate and return distances for all keypoints
return vmap_calc_distance(predicted_pose, back_projected_points)
@jaxtyped(typechecker=beartype)
def calculate_affinity_3d(
distances: Float[Array, "J"],
delta_t: timedelta,
w_3d: float,
alpha_3d: float,
lambda_a: float,
) -> float:
"""
Calculate 3D affinity score between a tracking and detection.
The affinity score is calculated by summing individual keypoint affinities:
A_3D = sum(w_3D * (1 - dl / alpha_3D) * np.exp(-lambda_a * delta_t)) for each keypoint
Args:
distances: Array of perpendicular distances for each keypoint
delta_t: Time difference between tracking and detection
w_3d: Weight for 3D affinity
alpha_3d: Normalization factor for distance
lambda_a: Decay rate for time difference
Returns:
Sum of affinity scores across all keypoints
"""
delta_t_s = delta_t.total_seconds()
affinity_per_keypoint = (
w_3d * (1 - distances / alpha_3d) * jnp.exp(-lambda_a * delta_t_s)
)
# Sum affinities across all keypoints
return jnp.sum(affinity_per_keypoint).item()
def predict_pose_3d( def predict_pose_3d(
tracking: Tracking, tracking: Tracking,
delta_t: float, delta_t_s: float,
) -> Float[Array, "J 3"]: ) -> Float[Array, "J 3"]:
""" """
Predict the 3D pose of a tracking based on its velocity. Predict the 3D pose of a tracking based on its velocity.
""" """
if tracking.velocity is None: if tracking.velocity is None:
return tracking.keypoints return tracking.keypoints
return tracking.keypoints + tracking.velocity * delta_t return tracking.keypoints + tracking.velocity * delta_t_s
# %% # %%
# let's do cross-view association # let's do cross-view association
W_2D = 1.0
ALPHA_2D = 1.0
LAMBDA_A = 0.1
W_3D = 1.0
ALPHA_3D = 1.0
LAMBDA_A = 0.1
trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id) trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id)
unmatched_detections = shallow_copy(next_group) unmatched_detections = shallow_copy(next_group)
# cross-view association matrix with shape (T, D), where T is the number of # cross-view association matrix with shape (T, D), where T is the number of
@ -647,12 +752,35 @@ unmatched_detections = shallow_copy(next_group)
# #
# where T <- [t1..tt]; D <- join(c1..cc), where `cn` is a collection of # where T <- [t1..tt]; D <- join(c1..cc), where `cn` is a collection of
# detections from camera `n` # detections from camera `n`
affinity = np.zeros((len(trackings), len(unmatched_detections))) affinity = jnp.zeros((len(trackings), len(unmatched_detections)))
detection_by_camera = classify_by_camera(unmatched_detections) detection_by_camera = classify_by_camera(unmatched_detections)
for i, tracking in enumerate(trackings): for i, tracking in enumerate(trackings):
j = 0
for c, detections in detection_by_camera.items(): for c, detections in detection_by_camera.items():
camera = next(iter(detections)).camera camera = next(iter(detections)).camera
# pixel space, unnormalized # pixel space, unnormalized
tracking_2d_projection = camera.project(tracking.keypoints) tracking_2d_projection = camera.project(tracking.keypoints)
for det in detections: for det in detections:
... delta_t = det.timestamp - tracking.last_active_timestamp
distance_2d = calculate_distance_2d(tracking_2d_projection, det.keypoints)
affinity_2d = calculate_affinity_2d(
distance_2d,
delta_t,
w_2d=W_2D,
alpha_2d=ALPHA_2D,
lambda_a=LAMBDA_A,
)
distances = perpendicular_distance_camera_2d_points_to_tracking_raycasting(
det, tracking, delta_t
)
affinity_3d = calculate_affinity_3d(
distances,
delta_t,
w_3d=W_3D,
alpha_3d=ALPHA_3D,
lambda_a=LAMBDA_A,
)
affinity_sum = affinity_2d + affinity_3d
affinity = affinity.at[i, j].set(affinity_sum)
j += 1
display(affinity)