refactor: Enhance type hints and documentation for distance calculation functions in playground.py

- Updated the return type of `perpendicular_distance_point_to_line_two_points` to `Float[Array, ""]` for improved type clarity.
- Added detailed docstrings to `perpendicular_distance_point_to_line_two_points`, specifying arguments and return values for better understanding.
- Streamlined the distance calculation in `perpendicular_distance_camera_2d_points_to_tracking_raycasting` by removing unnecessary type casting, enhancing code readability.
This commit is contained in:
2025-04-27 18:04:58 +08:00
parent 9639bcb794
commit d9aaa96d0a

View File

@ -620,11 +620,18 @@ def calculate_affinity_2d(
@jaxtyped(typechecker=beartype)
def perpendicular_distance_point_to_line_two_points(
point: Num[Array, "3"], line: tuple[Num[Array, "3"], Num[Array, "3"]]
) -> jnp.floating[Any]:
) -> Float[Array, ""]:
"""
Calculate the perpendicular distance between a point and a line.
where `line` is represented by two points: `(line_start, line_end)`
Args:
point: The point to calculate the distance to
line: The line to calculate the distance to, represented by two points
Returns:
The perpendicular distance between the point and the line
(should be a scalar in `float`)
"""
line_start, line_end = line
distance = jnp.linalg.norm(
@ -652,21 +659,16 @@ def perpendicular_distance_camera_2d_points_to_tracking_raycasting(
Array of perpendicular distances for each keypoint
"""
camera = detection.camera
# Convert timedelta to seconds for prediction
delta_t_s = delta_t.total_seconds()
# Predict the 3D pose based on tracking and delta_t
predicted_pose = predict_pose_3d(tracking, delta_t_s)
# Back-project the 2D points to 3D space (assuming z=0 plane)
# Back-project the 2D points to 3D space
# intersection with z=0 plane
back_projected_points = detection.camera.unproject_points_to_z_plane(
detection.keypoints, z=0.0
)
# Get camera center from the camera parameters
camera_center = camera.params.location
# Define function to calculate distance between a predicted point and its corresponding ray
def calc_distance(predicted_point, back_projected_point):
return perpendicular_distance_point_to_line_two_points(
predicted_point, (camera_center, back_projected_point)
@ -674,8 +676,8 @@ def perpendicular_distance_camera_2d_points_to_tracking_raycasting(
# Vectorize over all keypoints
vmap_calc_distance = jax.vmap(calc_distance)
distances: Float[Array, "J"] = cast(
Any, vmap_calc_distance(predicted_pose, back_projected_points)
distances: Float[Array, "J"] = vmap_calc_distance(
predicted_pose, back_projected_points
)
return distances