from typing import Any from jaxtyping import Num, jaxtyped from beartype import beartype from jax import numpy as jnp, Array @jaxtyped(typechecker=beartype) def calculate_perpendicular_distance( point: Num[Array, "2"], line: Num[Array, "2 2"], ) -> jnp.floating[Any]: """ Calculate the perpendicular distance between a point and a line. """ line_start, line_end = line distance = jnp.linalg.norm( jnp.cross(line_end - line_start, line_start - point) ) / jnp.linalg.norm(line_end - line_start) return distance