feat: Enhance playground.py with new 3D tracking and affinity calculations
- Added functions for calculating perpendicular distances between predicted 3D tracking points and camera rays, improving 3D tracking accuracy. - Introduced a new function for calculating 3D affinity scores based on distances and time differences, enhancing the integration of 3D tracking with existing systems. - Updated existing functions to support new data types and improved documentation for clarity on parameters and return values. - Refactored affinity calculation logic to utilize JAX for performance optimization in distance computations.
This commit is contained in:
@ -103,8 +103,10 @@ def unproject_points_onto_plane(
|
|||||||
|
|
||||||
(i.e. back-project points onto a plane)
|
(i.e. back-project points onto a plane)
|
||||||
|
|
||||||
|
`intersect_image_rays_with_plane`/`compute_ray_plane_intersections`
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
points_2d: [..., 2] image pixel coordinates
|
points_2d: [..., 2] image pixel coordinates (with camera distortion)
|
||||||
plane_normal: (3,) normal vector of the plane in world coords
|
plane_normal: (3,) normal vector of the plane in world coords
|
||||||
plane_point: (3,) a known point on the plane in world coords
|
plane_point: (3,) a known point on the plane in world coords
|
||||||
K: Camera intrinsic matrix
|
K: Camera intrinsic matrix
|
||||||
@ -118,7 +120,7 @@ def unproject_points_onto_plane(
|
|||||||
Returns:
|
Returns:
|
||||||
[..., 3] world-space intersection points
|
[..., 3] world-space intersection points
|
||||||
"""
|
"""
|
||||||
# Step 1: undistort (no-op here)
|
# Step 1: undistort
|
||||||
pts = undistort_points(
|
pts = undistort_points(
|
||||||
np.asarray(points_2d), np.asarray(K), np.asarray(dist_coeffs)
|
np.asarray(points_2d), np.asarray(K), np.asarray(dist_coeffs)
|
||||||
)
|
)
|
||||||
@ -313,6 +315,13 @@ class CameraParams:
|
|||||||
object.__setattr__(self, "_proj", pm)
|
object.__setattr__(self, "_proj", pm)
|
||||||
return pm
|
return pm
|
||||||
|
|
||||||
|
@property
|
||||||
|
def location(self) -> Num[Array, "3"]:
|
||||||
|
"""
|
||||||
|
The 3D location of camera (relative to world coordinate system)
|
||||||
|
"""
|
||||||
|
return self.pose_matrix[:3, -1].reshape((3,))
|
||||||
|
|
||||||
|
|
||||||
@jaxtyped(typechecker=beartype)
|
@jaxtyped(typechecker=beartype)
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@ -390,7 +399,7 @@ class Camera:
|
|||||||
Un-project 2D points to 3D points on a plane at z = constant.
|
Un-project 2D points to 3D points on a plane at z = constant.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
points_2d: 2D points in pixel coordinates
|
points_2d: 2D points in pixel coordinates (with camera distortion)
|
||||||
z: z-coordinate of the plane (default: 0.0, i.e. ground/horizon/floor plane)
|
z: z-coordinate of the plane (default: 0.0, i.e. ground/horizon/floor plane)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|||||||
156
playground.py
156
playground.py
@ -568,7 +568,7 @@ def calculate_distance_2d(
|
|||||||
left: Num[Array, "J 2"],
|
left: Num[Array, "J 2"],
|
||||||
right: Num[Array, "J 2"],
|
right: Num[Array, "J 2"],
|
||||||
image_size: tuple[int, int] = (1, 1),
|
image_size: tuple[int, int] = (1, 1),
|
||||||
):
|
) -> Float[Array, "J"]:
|
||||||
"""
|
"""
|
||||||
Calculate the *normalized* distance between two sets of keypoints.
|
Calculate the *normalized* distance between two sets of keypoints.
|
||||||
|
|
||||||
@ -576,6 +576,9 @@ def calculate_distance_2d(
|
|||||||
left: The left keypoints
|
left: The left keypoints
|
||||||
right: The right keypoints
|
right: The right keypoints
|
||||||
image_size: The size of the image
|
image_size: The size of the image
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Array of normalized Euclidean distances between corresponding keypoints
|
||||||
"""
|
"""
|
||||||
w, h = image_size
|
w, h = image_size
|
||||||
if w == 1 and h == 1:
|
if w == 1 and h == 1:
|
||||||
@ -590,25 +593,41 @@ def calculate_distance_2d(
|
|||||||
|
|
||||||
@jaxtyped(typechecker=beartype)
|
@jaxtyped(typechecker=beartype)
|
||||||
def calculate_affinity_2d(
|
def calculate_affinity_2d(
|
||||||
distance_2d: float, w_2d: float, alpha_2d: float, lambda_a: float, delta_t: float
|
distance_2d: Float[Array, "J"],
|
||||||
|
delta_t: timedelta,
|
||||||
|
w_2d: float,
|
||||||
|
alpha_2d: float,
|
||||||
|
lambda_a: float,
|
||||||
) -> float:
|
) -> float:
|
||||||
"""
|
"""
|
||||||
Calculate the affinity between two detections based on the distance between their keypoints.
|
Calculate the affinity between two detections based on the distances between their keypoints.
|
||||||
|
|
||||||
|
The affinity score is calculated by summing individual keypoint affinities:
|
||||||
|
A_2D = sum(w_2D * (1 - distance_2D / (alpha_2D*delta_t)) * np.exp(-lambda_a * delta_t)) for each keypoint
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
distance_2d: The normalized distance between the two keypoints (see `calculate_distance_2d`)
|
distance_2d: The normalized distances between keypoints (array with one value per keypoint)
|
||||||
w_2d: The weight of the distance (parameter)
|
w_2d: The weight for 2D affinity
|
||||||
alpha_2d: The alpha value for the distance calculation (parameter)
|
alpha_2d: The normalization factor for distance
|
||||||
lambda_a: The lambda value for the distance calculation (parameter)
|
lambda_a: The decay rate for time difference
|
||||||
delta_t: The time delta between the two detections, in seconds
|
delta_t: The time delta between the two detections, in seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sum of affinity scores across all keypoints
|
||||||
"""
|
"""
|
||||||
return w_2d * (1 - distance_2d / (alpha_2d * delta_t)) * np.exp(-lambda_a * delta_t)
|
delta_t_s = delta_t.total_seconds()
|
||||||
|
affinity_per_keypoint = (
|
||||||
|
w_2d
|
||||||
|
* (1 - distance_2d / (alpha_2d * delta_t_s))
|
||||||
|
* jnp.exp(-lambda_a * delta_t_s)
|
||||||
|
)
|
||||||
|
return jnp.sum(affinity_per_keypoint).item()
|
||||||
|
|
||||||
|
|
||||||
@jaxtyped(typechecker=beartype)
|
@jaxtyped(typechecker=beartype)
|
||||||
def perpendicular_distance_point_to_line_two_points(
|
def perpendicular_distance_point_to_line_two_points(
|
||||||
point: Num[Array, "2"], line: tuple[Num[Array, "2"], Num[Array, "2"]]
|
point: Num[Array, "3"], line: tuple[Num[Array, "3"], Num[Array, "3"]]
|
||||||
):
|
) -> Float[Array, ""]:
|
||||||
"""
|
"""
|
||||||
Calculate the perpendicular distance between a point and a line.
|
Calculate the perpendicular distance between a point and a line.
|
||||||
|
|
||||||
@ -621,20 +640,106 @@ def perpendicular_distance_point_to_line_two_points(
|
|||||||
return distance
|
return distance
|
||||||
|
|
||||||
|
|
||||||
|
@jaxtyped(typechecker=beartype)
|
||||||
|
def perpendicular_distance_camera_2d_points_to_tracking_raycasting(
|
||||||
|
detection: Detection,
|
||||||
|
tracking: Tracking,
|
||||||
|
delta_t: timedelta,
|
||||||
|
) -> Float[Array, "J"]:
|
||||||
|
"""
|
||||||
|
Calculate the perpendicular distances between predicted 3D tracking points
|
||||||
|
and the rays cast from camera center through the 2D image points.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
detection: The detection object containing 2D keypoints and camera parameters
|
||||||
|
tracking: The tracking object containing 3D keypoints
|
||||||
|
delta_t: Time delta between the tracking's last update and current observation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Array of perpendicular distances for each keypoint
|
||||||
|
"""
|
||||||
|
camera = detection.camera
|
||||||
|
# Convert timedelta to seconds for prediction
|
||||||
|
delta_t_s = delta_t.total_seconds()
|
||||||
|
|
||||||
|
# Predict the 3D pose based on tracking and delta_t
|
||||||
|
predicted_pose = predict_pose_3d(tracking, delta_t_s)
|
||||||
|
|
||||||
|
# Back-project the 2D points to 3D space (assuming z=0 plane)
|
||||||
|
back_projected_points = detection.camera.unproject_points_to_z_plane(
|
||||||
|
detection.keypoints, z=0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get camera center from the camera parameters
|
||||||
|
camera_center = camera.params.location
|
||||||
|
|
||||||
|
# Define function to calculate distance between a predicted point and its corresponding ray
|
||||||
|
def calc_distance(predicted_point, back_projected_point):
|
||||||
|
return perpendicular_distance_point_to_line_two_points(
|
||||||
|
predicted_point, (camera_center, back_projected_point)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Vectorize over all keypoints
|
||||||
|
vmap_calc_distance = jax.vmap(calc_distance)
|
||||||
|
|
||||||
|
# Calculate and return distances for all keypoints
|
||||||
|
return vmap_calc_distance(predicted_pose, back_projected_points)
|
||||||
|
|
||||||
|
|
||||||
|
@jaxtyped(typechecker=beartype)
|
||||||
|
def calculate_affinity_3d(
|
||||||
|
distances: Float[Array, "J"],
|
||||||
|
delta_t: timedelta,
|
||||||
|
w_3d: float,
|
||||||
|
alpha_3d: float,
|
||||||
|
lambda_a: float,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Calculate 3D affinity score between a tracking and detection.
|
||||||
|
|
||||||
|
The affinity score is calculated by summing individual keypoint affinities:
|
||||||
|
A_3D = sum(w_3D * (1 - dl / alpha_3D) * np.exp(-lambda_a * delta_t)) for each keypoint
|
||||||
|
|
||||||
|
Args:
|
||||||
|
distances: Array of perpendicular distances for each keypoint
|
||||||
|
delta_t: Time difference between tracking and detection
|
||||||
|
w_3d: Weight for 3D affinity
|
||||||
|
alpha_3d: Normalization factor for distance
|
||||||
|
lambda_a: Decay rate for time difference
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sum of affinity scores across all keypoints
|
||||||
|
"""
|
||||||
|
delta_t_s = delta_t.total_seconds()
|
||||||
|
affinity_per_keypoint = (
|
||||||
|
w_3d * (1 - distances / alpha_3d) * jnp.exp(-lambda_a * delta_t_s)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sum affinities across all keypoints
|
||||||
|
return jnp.sum(affinity_per_keypoint).item()
|
||||||
|
|
||||||
|
|
||||||
def predict_pose_3d(
|
def predict_pose_3d(
|
||||||
tracking: Tracking,
|
tracking: Tracking,
|
||||||
delta_t: float,
|
delta_t_s: float,
|
||||||
) -> Float[Array, "J 3"]:
|
) -> Float[Array, "J 3"]:
|
||||||
"""
|
"""
|
||||||
Predict the 3D pose of a tracking based on its velocity.
|
Predict the 3D pose of a tracking based on its velocity.
|
||||||
"""
|
"""
|
||||||
if tracking.velocity is None:
|
if tracking.velocity is None:
|
||||||
return tracking.keypoints
|
return tracking.keypoints
|
||||||
return tracking.keypoints + tracking.velocity * delta_t
|
return tracking.keypoints + tracking.velocity * delta_t_s
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# let's do cross-view association
|
# let's do cross-view association
|
||||||
|
W_2D = 1.0
|
||||||
|
ALPHA_2D = 1.0
|
||||||
|
LAMBDA_A = 0.1
|
||||||
|
W_3D = 1.0
|
||||||
|
ALPHA_3D = 1.0
|
||||||
|
LAMBDA_A = 0.1
|
||||||
|
|
||||||
trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id)
|
trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id)
|
||||||
unmatched_detections = shallow_copy(next_group)
|
unmatched_detections = shallow_copy(next_group)
|
||||||
# cross-view association matrix with shape (T, D), where T is the number of
|
# cross-view association matrix with shape (T, D), where T is the number of
|
||||||
@ -647,12 +752,35 @@ unmatched_detections = shallow_copy(next_group)
|
|||||||
#
|
#
|
||||||
# where T <- [t1..tt]; D <- join(c1..cc), where `cn` is a collection of
|
# where T <- [t1..tt]; D <- join(c1..cc), where `cn` is a collection of
|
||||||
# detections from camera `n`
|
# detections from camera `n`
|
||||||
affinity = np.zeros((len(trackings), len(unmatched_detections)))
|
affinity = jnp.zeros((len(trackings), len(unmatched_detections)))
|
||||||
detection_by_camera = classify_by_camera(unmatched_detections)
|
detection_by_camera = classify_by_camera(unmatched_detections)
|
||||||
for i, tracking in enumerate(trackings):
|
for i, tracking in enumerate(trackings):
|
||||||
|
j = 0
|
||||||
for c, detections in detection_by_camera.items():
|
for c, detections in detection_by_camera.items():
|
||||||
camera = next(iter(detections)).camera
|
camera = next(iter(detections)).camera
|
||||||
# pixel space, unnormalized
|
# pixel space, unnormalized
|
||||||
tracking_2d_projection = camera.project(tracking.keypoints)
|
tracking_2d_projection = camera.project(tracking.keypoints)
|
||||||
for det in detections:
|
for det in detections:
|
||||||
...
|
delta_t = det.timestamp - tracking.last_active_timestamp
|
||||||
|
distance_2d = calculate_distance_2d(tracking_2d_projection, det.keypoints)
|
||||||
|
affinity_2d = calculate_affinity_2d(
|
||||||
|
distance_2d,
|
||||||
|
delta_t,
|
||||||
|
w_2d=W_2D,
|
||||||
|
alpha_2d=ALPHA_2D,
|
||||||
|
lambda_a=LAMBDA_A,
|
||||||
|
)
|
||||||
|
distances = perpendicular_distance_camera_2d_points_to_tracking_raycasting(
|
||||||
|
det, tracking, delta_t
|
||||||
|
)
|
||||||
|
affinity_3d = calculate_affinity_3d(
|
||||||
|
distances,
|
||||||
|
delta_t,
|
||||||
|
w_3d=W_3D,
|
||||||
|
alpha_3d=ALPHA_3D,
|
||||||
|
lambda_a=LAMBDA_A,
|
||||||
|
)
|
||||||
|
affinity_sum = affinity_2d + affinity_3d
|
||||||
|
affinity = affinity.at[i, j].set(affinity_sum)
|
||||||
|
j += 1
|
||||||
|
display(affinity)
|
||||||
|
|||||||
Reference in New Issue
Block a user