refactor: Enhance type hints and documentation for distance calculation functions in playground.py
- Updated the return type of `perpendicular_distance_point_to_line_two_points` to `Float[Array, ""]` for improved type clarity. - Added detailed docstrings to `perpendicular_distance_point_to_line_two_points`, specifying arguments and return values for better understanding. - Streamlined the distance calculation in `perpendicular_distance_camera_2d_points_to_tracking_raycasting` by removing unnecessary type casting, enhancing code readability.
This commit is contained in:
@ -620,11 +620,18 @@ def calculate_affinity_2d(
|
|||||||
@jaxtyped(typechecker=beartype)
|
@jaxtyped(typechecker=beartype)
|
||||||
def perpendicular_distance_point_to_line_two_points(
|
def perpendicular_distance_point_to_line_two_points(
|
||||||
point: Num[Array, "3"], line: tuple[Num[Array, "3"], Num[Array, "3"]]
|
point: Num[Array, "3"], line: tuple[Num[Array, "3"], Num[Array, "3"]]
|
||||||
) -> jnp.floating[Any]:
|
) -> Float[Array, ""]:
|
||||||
"""
|
"""
|
||||||
Calculate the perpendicular distance between a point and a line.
|
Calculate the perpendicular distance between a point and a line.
|
||||||
|
|
||||||
where `line` is represented by two points: `(line_start, line_end)`
|
where `line` is represented by two points: `(line_start, line_end)`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
point: The point to calculate the distance to
|
||||||
|
line: The line to calculate the distance to, represented by two points
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The perpendicular distance between the point and the line
|
||||||
|
(should be a scalar in `float`)
|
||||||
"""
|
"""
|
||||||
line_start, line_end = line
|
line_start, line_end = line
|
||||||
distance = jnp.linalg.norm(
|
distance = jnp.linalg.norm(
|
||||||
@ -652,21 +659,16 @@ def perpendicular_distance_camera_2d_points_to_tracking_raycasting(
|
|||||||
Array of perpendicular distances for each keypoint
|
Array of perpendicular distances for each keypoint
|
||||||
"""
|
"""
|
||||||
camera = detection.camera
|
camera = detection.camera
|
||||||
# Convert timedelta to seconds for prediction
|
|
||||||
delta_t_s = delta_t.total_seconds()
|
delta_t_s = delta_t.total_seconds()
|
||||||
|
|
||||||
# Predict the 3D pose based on tracking and delta_t
|
|
||||||
predicted_pose = predict_pose_3d(tracking, delta_t_s)
|
predicted_pose = predict_pose_3d(tracking, delta_t_s)
|
||||||
|
|
||||||
# Back-project the 2D points to 3D space (assuming z=0 plane)
|
# Back-project the 2D points to 3D space
|
||||||
|
# intersection with z=0 plane
|
||||||
back_projected_points = detection.camera.unproject_points_to_z_plane(
|
back_projected_points = detection.camera.unproject_points_to_z_plane(
|
||||||
detection.keypoints, z=0.0
|
detection.keypoints, z=0.0
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get camera center from the camera parameters
|
|
||||||
camera_center = camera.params.location
|
camera_center = camera.params.location
|
||||||
|
|
||||||
# Define function to calculate distance between a predicted point and its corresponding ray
|
|
||||||
def calc_distance(predicted_point, back_projected_point):
|
def calc_distance(predicted_point, back_projected_point):
|
||||||
return perpendicular_distance_point_to_line_two_points(
|
return perpendicular_distance_point_to_line_two_points(
|
||||||
predicted_point, (camera_center, back_projected_point)
|
predicted_point, (camera_center, back_projected_point)
|
||||||
@ -674,8 +676,8 @@ def perpendicular_distance_camera_2d_points_to_tracking_raycasting(
|
|||||||
|
|
||||||
# Vectorize over all keypoints
|
# Vectorize over all keypoints
|
||||||
vmap_calc_distance = jax.vmap(calc_distance)
|
vmap_calc_distance = jax.vmap(calc_distance)
|
||||||
distances: Float[Array, "J"] = cast(
|
distances: Float[Array, "J"] = vmap_calc_distance(
|
||||||
Any, vmap_calc_distance(predicted_pose, back_projected_points)
|
predicted_pose, back_projected_points
|
||||||
)
|
)
|
||||||
|
|
||||||
return distances
|
return distances
|
||||||
|
|||||||
Reference in New Issue
Block a user