diff --git a/playground.py b/playground.py index d2f09e1..e1a3519 100644 --- a/playground.py +++ b/playground.py @@ -620,11 +620,18 @@ def calculate_affinity_2d( @jaxtyped(typechecker=beartype) def perpendicular_distance_point_to_line_two_points( point: Num[Array, "3"], line: tuple[Num[Array, "3"], Num[Array, "3"]] -) -> jnp.floating[Any]: +) -> Float[Array, ""]: """ Calculate the perpendicular distance between a point and a line. - where `line` is represented by two points: `(line_start, line_end)` + + Args: + point: The point to calculate the distance to + line: The line to calculate the distance to, represented by two points + + Returns: + The perpendicular distance between the point and the line + (should be a scalar in `float`) """ line_start, line_end = line distance = jnp.linalg.norm( @@ -652,21 +659,16 @@ def perpendicular_distance_camera_2d_points_to_tracking_raycasting( Array of perpendicular distances for each keypoint """ camera = detection.camera - # Convert timedelta to seconds for prediction delta_t_s = delta_t.total_seconds() - - # Predict the 3D pose based on tracking and delta_t predicted_pose = predict_pose_3d(tracking, delta_t_s) - # Back-project the 2D points to 3D space (assuming z=0 plane) + # Back-project the 2D points to 3D space + # intersection with z=0 plane back_projected_points = detection.camera.unproject_points_to_z_plane( detection.keypoints, z=0.0 ) - - # Get camera center from the camera parameters camera_center = camera.params.location - # Define function to calculate distance between a predicted point and its corresponding ray def calc_distance(predicted_point, back_projected_point): return perpendicular_distance_point_to_line_two_points( predicted_point, (camera_center, back_projected_point) @@ -674,8 +676,8 @@ def perpendicular_distance_camera_2d_points_to_tracking_raycasting( # Vectorize over all keypoints vmap_calc_distance = jax.vmap(calc_distance) - distances: Float[Array, "J"] = cast( - Any, vmap_calc_distance(predicted_pose, back_projected_points) + distances: Float[Array, "J"] = vmap_calc_distance( + predicted_pose, back_projected_points ) return distances