feat: Add beartype for runtime type checking and update dependencies

- Add beartype dependency to pyproject.toml and uv.lock
- Replace typeguard with beartype in type checking
- Create camera module with type-safe camera parameter definitions
- Migrate utility function to use beartype and JAX numpy
This commit is contained in:
2025-03-06 18:35:43 +08:00
parent 245c4b502d
commit 3ce5b564bf
5 changed files with 77 additions and 12 deletions

View File

54
app/camera/__init__.py Normal file
View File

@ -0,0 +1,54 @@
from typing import TypedDict, TypeAlias, Any
from typing_extensions import NotRequired
from jaxtyping import Num, jaxtyped
from beartype import beartype
from jax import numpy as jnp, Array
CameraID: TypeAlias = str
@jaxtyped(typechecker=beartype)
class CameraParams(TypedDict):
"""
Camera parameters: intrinsic matrix, extrinsic matrix, and distortion coefficients
"""
K: Num[Array, "3 3"]
"""
intrinsic matrix
"""
Rt: Num[Array, "4 4"]
"""
[R|t] extrinsic matrix
R and t are the rotation and translation that describe the change of
coordinates from world to camera coordinate systems (or camera frame)
"""
dist_coeffs: Num[Array, "N"]
"""
An array of distortion coefficients of the form
[k1, k2, [p1, p2, [k3]]], where ki is the ith
radial distortion coefficient and pi is the ith
tangential distortion coeff.
"""
@jaxtyped(typechecker=beartype)
class Camera(TypedDict):
"""
a description of a camera
"""
id: CameraID
"""
Camera ID
"""
params: CameraParams
"""
Camera parameters
"""
size: tuple[int, int]
"""
Image size
"""

View File

@ -1,23 +1,21 @@
from typing import Any from typing import Any
import numpy as np from jaxtyping import Num, jaxtyped
from jaxtyping import Float, Num, jaxtyped from beartype import beartype
from typeguard import typechecked from jax import numpy as jnp, Array
from app._typing import NDArray
@jaxtyped(typechecker=typechecked) @jaxtyped(typechecker=beartype)
def calculate_perpendicular_distance( def calculate_perpendicular_distance(
point: Num[NDArray, "2"], point: Num[Array, "2"],
line: Num[NDArray, "2 2"], line: Num[Array, "2 2"],
) -> np.floating[Any]: ) -> jnp.floating[Any]:
""" """
Calculate the perpendicular distance between a point and a line. Calculate the perpendicular distance between a point and a line.
""" """
line_start, line_end = line line_start, line_end = line
distance = np.linalg.norm( distance = jnp.linalg.norm(
np.cross(line_end - line_start, line_start - point) jnp.cross(line_end - line_start, line_start - point)
) / np.linalg.norm(line_end - line_start) ) / jnp.linalg.norm(line_end - line_start)
return distance return distance

View File

@ -7,6 +7,7 @@ requires-python = ">=3.10"
dependencies = [ dependencies = [
"anyio>=4.8.0", "anyio>=4.8.0",
"awkward>=2.7.4", "awkward>=2.7.4",
"beartype>=0.20.0",
"cvxopt>=1.3.2", "cvxopt>=1.3.2",
"jax[cuda12]>=0.5.1", "jax[cuda12]>=0.5.1",
"jaxtyping>=0.2.38", "jaxtyping>=0.2.38",

12
uv.lock generated
View File

@ -1,4 +1,5 @@
version = 1 version = 1
revision = 1
requires-python = ">=3.10" requires-python = ">=3.10"
resolution-markers = [ resolution-markers = [
"python_full_version >= '3.13' and sys_platform == 'darwin'", "python_full_version >= '3.13' and sys_platform == 'darwin'",
@ -187,6 +188,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/b7/b8/3fe70c75fe32afc4bb507f75563d39bc5642255d1d94f1f23604725780bf/babel-2.17.0-py3-none-any.whl", hash = "sha256:4d0b53093fdfb4b21c92b5213dba5a1b23885afa8383709427046b21c366e5f2", size = 10182537 }, { url = "https://files.pythonhosted.org/packages/b7/b8/3fe70c75fe32afc4bb507f75563d39bc5642255d1d94f1f23604725780bf/babel-2.17.0-py3-none-any.whl", hash = "sha256:4d0b53093fdfb4b21c92b5213dba5a1b23885afa8383709427046b21c366e5f2", size = 10182537 },
] ]
[[package]]
name = "beartype"
version = "0.20.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/7d/48/fe1177f4272a1bc344f3371414aa5b76e19c30d7280d711ce90c5335a6f5/beartype-0.20.0.tar.gz", hash = "sha256:599ecc86b88549bcb6d1af626f44d85ffbb9151ace5d7f9f3b493dce2ffee529", size = 1390635 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/61/ce/9cf60b72e3fbd36a6dd16e31ecbe9ae137a8ead77cb24a5f75302adfc9a9/beartype-0.20.0-py3-none-any.whl", hash = "sha256:090d10e3540b3fca209a0ab5f1c15f9652a075da0a7249c2e6713011e9e5f6ef", size = 1139097 },
]
[[package]] [[package]]
name = "beautifulsoup4" name = "beautifulsoup4"
version = "4.13.3" version = "4.13.3"
@ -372,6 +382,7 @@ source = { virtual = "." }
dependencies = [ dependencies = [
{ name = "anyio" }, { name = "anyio" },
{ name = "awkward" }, { name = "awkward" },
{ name = "beartype" },
{ name = "cvxopt" }, { name = "cvxopt" },
{ name = "jax", extra = ["cuda12"] }, { name = "jax", extra = ["cuda12"] },
{ name = "jaxtyping" }, { name = "jaxtyping" },
@ -391,6 +402,7 @@ dev = [
requires-dist = [ requires-dist = [
{ name = "anyio", specifier = ">=4.8.0" }, { name = "anyio", specifier = ">=4.8.0" },
{ name = "awkward", specifier = ">=2.7.4" }, { name = "awkward", specifier = ">=2.7.4" },
{ name = "beartype", specifier = ">=0.20.0" },
{ name = "cvxopt", specifier = ">=1.3.2" }, { name = "cvxopt", specifier = ">=1.3.2" },
{ name = "jax", extras = ["cuda12"], specifier = ">=0.5.1" }, { name = "jax", extras = ["cuda12"], specifier = ">=0.5.1" },
{ name = "jaxtyping", specifier = ">=0.2.38" }, { name = "jaxtyping", specifier = ">=0.2.38" },