refactor: Optimize distortion function for clarity and performance

- Added JAX Just-In-Time (JIT) compilation to the `distortion` function for improved performance.
- Reorganized variable unpacking and calculations for better readability and efficiency.
- Enhanced comments to clarify the steps involved in the distortion process, including normalization and radial/tangential distortion calculations.
- Updated the return values to directly reflect pixel coordinates after distortion.
This commit is contained in:
2025-04-22 11:43:55 +08:00
parent cb369b35da
commit 3ccf790bac

View File

@ -4,14 +4,15 @@ from datetime import datetime
from typing import Any, TypeAlias, TypedDict, Optional
from beartype import beartype
from jax import Array
import jax
from jax import numpy as jnp
from jaxtyping import Num, jaxtyped
from jaxtyping import Num, jaxtyped, Array
from typing_extensions import NotRequired
CameraID: TypeAlias = str # pylint: disable=invalid-name
@jax.jit
@jaxtyped(typechecker=beartype)
def distortion(
points_2d: Num[Array, "N 2"],
@ -37,31 +38,36 @@ def distortion(
Implementation based on OpenCV's distortion model:
https://docs.opencv.org/4.10.0/d9/d0c/group__calib3d.html#ga69f2545a8b62a6b0fc2ee060dc30559d
"""
# unpack
fx, fy = K[0, 0], K[1, 1]
cx, cy = K[0, 2], K[1, 2]
k1, k2, p1, p2, k3 = dist_coeffs
# Get principal point and focal length
cx, cy = K[0, 2], K[1, 2]
fx, fy = K[0, 0], K[1, 1]
# Convert to normalized coordinates
# normalize
x = (points_2d[:, 0] - cx) / fx
y = (points_2d[:, 1] - cy) / fy
# precompute r^2, r^4, r^6
r2 = x * x + y * y
r4 = r2 * r2
r6 = r4 * r2
# Radial distortion
x_distort = x * (1 + k1 * r2 + k2 * r2 * r2 + k3 * r2 * r2 * r2)
y_distort = y * (1 + k1 * r2 + k2 * r2 * r2 + k3 * r2 * r2 * r2)
# radial term
radial = 1 + k1 * r2 + k2 * r4 + k3 * r6
# Tangential distortion
x_distort = x_distort + 2 * p1 * x * y + p2 * (r2 + 2 * x * x)
y_distort = y_distort + p1 * (r2 + 2 * y * y) + 2 * p2 * x * y
# tangential term
x_tan = 2 * p1 * x * y + p2 * (r2 + 2 * x * x)
y_tan = p1 * (r2 + 2 * y * y) + 2 * p2 * x * y
# Back to absolute coordinates
x_distort = x_distort * fx + cx
y_distort = y_distort * fy + cy
# apply both
x_dist = x * radial + x_tan
y_dist = y * radial + y_tan
# Combine distorted coordinates
return jnp.stack([x_distort, y_distort], axis=1)
# back to pixels
u = x_dist * fx + cx
v = y_dist * fy + cy
return jnp.stack([u, v], axis=1)
@jaxtyped(typechecker=beartype)