feat: Add beartype for runtime type checking and update dependencies

- Add beartype dependency to pyproject.toml and uv.lock
- Replace typeguard with beartype in type checking
- Create camera module with type-safe camera parameter definitions
- Migrate utility function to use beartype and JAX numpy
This commit is contained in:
2025-03-06 18:35:43 +08:00
parent 245c4b502d
commit 3ce5b564bf
5 changed files with 77 additions and 12 deletions

View File

54
app/camera/__init__.py Normal file
View File

@ -0,0 +1,54 @@
from typing import TypedDict, TypeAlias, Any
from typing_extensions import NotRequired
from jaxtyping import Num, jaxtyped
from beartype import beartype
from jax import numpy as jnp, Array
CameraID: TypeAlias = str
@jaxtyped(typechecker=beartype)
class CameraParams(TypedDict):
"""
Camera parameters: intrinsic matrix, extrinsic matrix, and distortion coefficients
"""
K: Num[Array, "3 3"]
"""
intrinsic matrix
"""
Rt: Num[Array, "4 4"]
"""
[R|t] extrinsic matrix
R and t are the rotation and translation that describe the change of
coordinates from world to camera coordinate systems (or camera frame)
"""
dist_coeffs: Num[Array, "N"]
"""
An array of distortion coefficients of the form
[k1, k2, [p1, p2, [k3]]], where ki is the ith
radial distortion coefficient and pi is the ith
tangential distortion coeff.
"""
@jaxtyped(typechecker=beartype)
class Camera(TypedDict):
"""
a description of a camera
"""
id: CameraID
"""
Camera ID
"""
params: CameraParams
"""
Camera parameters
"""
size: tuple[int, int]
"""
Image size
"""

View File

@ -1,23 +1,21 @@
from typing import Any
import numpy as np
from jaxtyping import Float, Num, jaxtyped
from typeguard import typechecked
from app._typing import NDArray
from jaxtyping import Num, jaxtyped
from beartype import beartype
from jax import numpy as jnp, Array
@jaxtyped(typechecker=typechecked)
@jaxtyped(typechecker=beartype)
def calculate_perpendicular_distance(
point: Num[NDArray, "2"],
line: Num[NDArray, "2 2"],
) -> np.floating[Any]:
point: Num[Array, "2"],
line: Num[Array, "2 2"],
) -> jnp.floating[Any]:
"""
Calculate the perpendicular distance between a point and a line.
"""
line_start, line_end = line
distance = np.linalg.norm(
np.cross(line_end - line_start, line_start - point)
) / np.linalg.norm(line_end - line_start)
distance = jnp.linalg.norm(
jnp.cross(line_end - line_start, line_start - point)
) / jnp.linalg.norm(line_end - line_start)
return distance

View File

@ -7,6 +7,7 @@ requires-python = ">=3.10"
dependencies = [
"anyio>=4.8.0",
"awkward>=2.7.4",
"beartype>=0.20.0",
"cvxopt>=1.3.2",
"jax[cuda12]>=0.5.1",
"jaxtyping>=0.2.38",

12
uv.lock generated
View File

@ -1,4 +1,5 @@
version = 1
revision = 1
requires-python = ">=3.10"
resolution-markers = [
"python_full_version >= '3.13' and sys_platform == 'darwin'",
@ -187,6 +188,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/b7/b8/3fe70c75fe32afc4bb507f75563d39bc5642255d1d94f1f23604725780bf/babel-2.17.0-py3-none-any.whl", hash = "sha256:4d0b53093fdfb4b21c92b5213dba5a1b23885afa8383709427046b21c366e5f2", size = 10182537 },
]
[[package]]
name = "beartype"
version = "0.20.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/7d/48/fe1177f4272a1bc344f3371414aa5b76e19c30d7280d711ce90c5335a6f5/beartype-0.20.0.tar.gz", hash = "sha256:599ecc86b88549bcb6d1af626f44d85ffbb9151ace5d7f9f3b493dce2ffee529", size = 1390635 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/61/ce/9cf60b72e3fbd36a6dd16e31ecbe9ae137a8ead77cb24a5f75302adfc9a9/beartype-0.20.0-py3-none-any.whl", hash = "sha256:090d10e3540b3fca209a0ab5f1c15f9652a075da0a7249c2e6713011e9e5f6ef", size = 1139097 },
]
[[package]]
name = "beautifulsoup4"
version = "4.13.3"
@ -372,6 +382,7 @@ source = { virtual = "." }
dependencies = [
{ name = "anyio" },
{ name = "awkward" },
{ name = "beartype" },
{ name = "cvxopt" },
{ name = "jax", extra = ["cuda12"] },
{ name = "jaxtyping" },
@ -391,6 +402,7 @@ dev = [
requires-dist = [
{ name = "anyio", specifier = ">=4.8.0" },
{ name = "awkward", specifier = ">=2.7.4" },
{ name = "beartype", specifier = ">=0.20.0" },
{ name = "cvxopt", specifier = ">=1.3.2" },
{ name = "jax", extras = ["cuda12"], specifier = ">=0.5.1" },
{ name = "jaxtyping", specifier = ">=0.2.38" },