refactor: Enhance type hints and documentation for distance calculation functions in playground.py
- Updated the return type of `perpendicular_distance_point_to_line_two_points` to `Float[Array, ""]` for improved type clarity. - Added detailed docstrings to `perpendicular_distance_point_to_line_two_points`, specifying arguments and return values for better understanding. - Streamlined the distance calculation in `perpendicular_distance_camera_2d_points_to_tracking_raycasting` by removing unnecessary type casting, enhancing code readability.
This commit is contained in:
@ -620,11 +620,18 @@ def calculate_affinity_2d(
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def perpendicular_distance_point_to_line_two_points(
|
||||
point: Num[Array, "3"], line: tuple[Num[Array, "3"], Num[Array, "3"]]
|
||||
) -> jnp.floating[Any]:
|
||||
) -> Float[Array, ""]:
|
||||
"""
|
||||
Calculate the perpendicular distance between a point and a line.
|
||||
|
||||
where `line` is represented by two points: `(line_start, line_end)`
|
||||
|
||||
Args:
|
||||
point: The point to calculate the distance to
|
||||
line: The line to calculate the distance to, represented by two points
|
||||
|
||||
Returns:
|
||||
The perpendicular distance between the point and the line
|
||||
(should be a scalar in `float`)
|
||||
"""
|
||||
line_start, line_end = line
|
||||
distance = jnp.linalg.norm(
|
||||
@ -652,21 +659,16 @@ def perpendicular_distance_camera_2d_points_to_tracking_raycasting(
|
||||
Array of perpendicular distances for each keypoint
|
||||
"""
|
||||
camera = detection.camera
|
||||
# Convert timedelta to seconds for prediction
|
||||
delta_t_s = delta_t.total_seconds()
|
||||
|
||||
# Predict the 3D pose based on tracking and delta_t
|
||||
predicted_pose = predict_pose_3d(tracking, delta_t_s)
|
||||
|
||||
# Back-project the 2D points to 3D space (assuming z=0 plane)
|
||||
# Back-project the 2D points to 3D space
|
||||
# intersection with z=0 plane
|
||||
back_projected_points = detection.camera.unproject_points_to_z_plane(
|
||||
detection.keypoints, z=0.0
|
||||
)
|
||||
|
||||
# Get camera center from the camera parameters
|
||||
camera_center = camera.params.location
|
||||
|
||||
# Define function to calculate distance between a predicted point and its corresponding ray
|
||||
def calc_distance(predicted_point, back_projected_point):
|
||||
return perpendicular_distance_point_to_line_two_points(
|
||||
predicted_point, (camera_center, back_projected_point)
|
||||
@ -674,8 +676,8 @@ def perpendicular_distance_camera_2d_points_to_tracking_raycasting(
|
||||
|
||||
# Vectorize over all keypoints
|
||||
vmap_calc_distance = jax.vmap(calc_distance)
|
||||
distances: Float[Array, "J"] = cast(
|
||||
Any, vmap_calc_distance(predicted_pose, back_projected_points)
|
||||
distances: Float[Array, "J"] = vmap_calc_distance(
|
||||
predicted_pose, back_projected_points
|
||||
)
|
||||
|
||||
return distances
|
||||
|
||||
Reference in New Issue
Block a user