diff --git a/playground.py b/playground.py index d3a1eb8..d2f09e1 100644 --- a/playground.py +++ b/playground.py @@ -620,7 +620,7 @@ def calculate_affinity_2d( @jaxtyped(typechecker=beartype) def perpendicular_distance_point_to_line_two_points( point: Num[Array, "3"], line: tuple[Num[Array, "3"], Num[Array, "3"]] -) -> Float[Array, ""]: +) -> jnp.floating[Any]: """ Calculate the perpendicular distance between a point and a line. @@ -674,9 +674,11 @@ def perpendicular_distance_camera_2d_points_to_tracking_raycasting( # Vectorize over all keypoints vmap_calc_distance = jax.vmap(calc_distance) + distances: Float[Array, "J"] = cast( + Any, vmap_calc_distance(predicted_pose, back_projected_points) + ) - # Calculate and return distances for all keypoints - return vmap_calc_distance(predicted_pose, back_projected_points) + return distances @jaxtyped(typechecker=beartype)