From 9639bcb79454befa8ac71d2b1cbe2fe5dc132fd8 Mon Sep 17 00:00:00 2001 From: crosstyan Date: Sun, 27 Apr 2025 18:00:56 +0800 Subject: [PATCH] refactor: Update type hints and streamline distance calculations in playground.py - Changed the return type of `perpendicular_distance_point_to_line_two_points` to `jnp.floating[Any]` for improved type accuracy. - Refactored the `perpendicular_distance_camera_2d_points_to_tracking_raycasting` function to store distances in a typed variable, enhancing clarity and type safety. - Improved the overall readability of the distance calculation logic by returning the computed distances directly. --- playground.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/playground.py b/playground.py index d3a1eb8..d2f09e1 100644 --- a/playground.py +++ b/playground.py @@ -620,7 +620,7 @@ def calculate_affinity_2d( @jaxtyped(typechecker=beartype) def perpendicular_distance_point_to_line_two_points( point: Num[Array, "3"], line: tuple[Num[Array, "3"], Num[Array, "3"]] -) -> Float[Array, ""]: +) -> jnp.floating[Any]: """ Calculate the perpendicular distance between a point and a line. @@ -674,9 +674,11 @@ def perpendicular_distance_camera_2d_points_to_tracking_raycasting( # Vectorize over all keypoints vmap_calc_distance = jax.vmap(calc_distance) + distances: Float[Array, "J"] = cast( + Any, vmap_calc_distance(predicted_pose, back_projected_points) + ) - # Calculate and return distances for all keypoints - return vmap_calc_distance(predicted_pose, back_projected_points) + return distances @jaxtyped(typechecker=beartype)