refactor: Optimize distortion function for clarity and performance

- Added JAX Just-In-Time (JIT) compilation to the `distortion` function for improved performance.
- Reorganized variable unpacking and calculations for better readability and efficiency.
- Enhanced comments to clarify the steps involved in the distortion process, including normalization and radial/tangential distortion calculations.
- Updated the return values to directly reflect pixel coordinates after distortion.
This commit is contained in:
2025-04-22 11:43:55 +08:00
parent cb369b35da
commit 3ccf790bac

View File

@ -4,14 +4,15 @@ from datetime import datetime
from typing import Any, TypeAlias, TypedDict, Optional from typing import Any, TypeAlias, TypedDict, Optional
from beartype import beartype from beartype import beartype
from jax import Array import jax
from jax import numpy as jnp from jax import numpy as jnp
from jaxtyping import Num, jaxtyped from jaxtyping import Num, jaxtyped, Array
from typing_extensions import NotRequired from typing_extensions import NotRequired
CameraID: TypeAlias = str # pylint: disable=invalid-name CameraID: TypeAlias = str # pylint: disable=invalid-name
@jax.jit
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
def distortion( def distortion(
points_2d: Num[Array, "N 2"], points_2d: Num[Array, "N 2"],
@ -37,31 +38,36 @@ def distortion(
Implementation based on OpenCV's distortion model: Implementation based on OpenCV's distortion model:
https://docs.opencv.org/4.10.0/d9/d0c/group__calib3d.html#ga69f2545a8b62a6b0fc2ee060dc30559d https://docs.opencv.org/4.10.0/d9/d0c/group__calib3d.html#ga69f2545a8b62a6b0fc2ee060dc30559d
""" """
# unpack
fx, fy = K[0, 0], K[1, 1]
cx, cy = K[0, 2], K[1, 2]
k1, k2, p1, p2, k3 = dist_coeffs k1, k2, p1, p2, k3 = dist_coeffs
# Get principal point and focal length # normalize
cx, cy = K[0, 2], K[1, 2]
fx, fy = K[0, 0], K[1, 1]
# Convert to normalized coordinates
x = (points_2d[:, 0] - cx) / fx x = (points_2d[:, 0] - cx) / fx
y = (points_2d[:, 1] - cy) / fy y = (points_2d[:, 1] - cy) / fy
# precompute r^2, r^4, r^6
r2 = x * x + y * y r2 = x * x + y * y
r4 = r2 * r2
r6 = r4 * r2
# Radial distortion # radial term
x_distort = x * (1 + k1 * r2 + k2 * r2 * r2 + k3 * r2 * r2 * r2) radial = 1 + k1 * r2 + k2 * r4 + k3 * r6
y_distort = y * (1 + k1 * r2 + k2 * r2 * r2 + k3 * r2 * r2 * r2)
# Tangential distortion # tangential term
x_distort = x_distort + 2 * p1 * x * y + p2 * (r2 + 2 * x * x) x_tan = 2 * p1 * x * y + p2 * (r2 + 2 * x * x)
y_distort = y_distort + p1 * (r2 + 2 * y * y) + 2 * p2 * x * y y_tan = p1 * (r2 + 2 * y * y) + 2 * p2 * x * y
# Back to absolute coordinates # apply both
x_distort = x_distort * fx + cx x_dist = x * radial + x_tan
y_distort = y_distort * fy + cy y_dist = y * radial + y_tan
# Combine distorted coordinates # back to pixels
return jnp.stack([x_distort, y_distort], axis=1) u = x_dist * fx + cx
v = y_dist * fy + cy
return jnp.stack([u, v], axis=1)
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)