forked from HQU-gxy/CVTH3PE
refactor: Optimize distortion function for clarity and performance
- Added JAX Just-In-Time (JIT) compilation to the `distortion` function for improved performance. - Reorganized variable unpacking and calculations for better readability and efficiency. - Enhanced comments to clarify the steps involved in the distortion process, including normalization and radial/tangential distortion calculations. - Updated the return values to directly reflect pixel coordinates after distortion.
This commit is contained in:
@ -4,14 +4,15 @@ from datetime import datetime
|
|||||||
from typing import Any, TypeAlias, TypedDict, Optional
|
from typing import Any, TypeAlias, TypedDict, Optional
|
||||||
|
|
||||||
from beartype import beartype
|
from beartype import beartype
|
||||||
from jax import Array
|
import jax
|
||||||
from jax import numpy as jnp
|
from jax import numpy as jnp
|
||||||
from jaxtyping import Num, jaxtyped
|
from jaxtyping import Num, jaxtyped, Array
|
||||||
from typing_extensions import NotRequired
|
from typing_extensions import NotRequired
|
||||||
|
|
||||||
CameraID: TypeAlias = str # pylint: disable=invalid-name
|
CameraID: TypeAlias = str # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
@jaxtyped(typechecker=beartype)
|
@jaxtyped(typechecker=beartype)
|
||||||
def distortion(
|
def distortion(
|
||||||
points_2d: Num[Array, "N 2"],
|
points_2d: Num[Array, "N 2"],
|
||||||
@ -37,31 +38,36 @@ def distortion(
|
|||||||
Implementation based on OpenCV's distortion model:
|
Implementation based on OpenCV's distortion model:
|
||||||
https://docs.opencv.org/4.10.0/d9/d0c/group__calib3d.html#ga69f2545a8b62a6b0fc2ee060dc30559d
|
https://docs.opencv.org/4.10.0/d9/d0c/group__calib3d.html#ga69f2545a8b62a6b0fc2ee060dc30559d
|
||||||
"""
|
"""
|
||||||
|
# unpack
|
||||||
|
fx, fy = K[0, 0], K[1, 1]
|
||||||
|
cx, cy = K[0, 2], K[1, 2]
|
||||||
k1, k2, p1, p2, k3 = dist_coeffs
|
k1, k2, p1, p2, k3 = dist_coeffs
|
||||||
|
|
||||||
# Get principal point and focal length
|
# normalize
|
||||||
cx, cy = K[0, 2], K[1, 2]
|
|
||||||
fx, fy = K[0, 0], K[1, 1]
|
|
||||||
|
|
||||||
# Convert to normalized coordinates
|
|
||||||
x = (points_2d[:, 0] - cx) / fx
|
x = (points_2d[:, 0] - cx) / fx
|
||||||
y = (points_2d[:, 1] - cy) / fy
|
y = (points_2d[:, 1] - cy) / fy
|
||||||
|
|
||||||
|
# precompute r^2, r^4, r^6
|
||||||
r2 = x * x + y * y
|
r2 = x * x + y * y
|
||||||
|
r4 = r2 * r2
|
||||||
|
r6 = r4 * r2
|
||||||
|
|
||||||
# Radial distortion
|
# radial term
|
||||||
x_distort = x * (1 + k1 * r2 + k2 * r2 * r2 + k3 * r2 * r2 * r2)
|
radial = 1 + k1 * r2 + k2 * r4 + k3 * r6
|
||||||
y_distort = y * (1 + k1 * r2 + k2 * r2 * r2 + k3 * r2 * r2 * r2)
|
|
||||||
|
|
||||||
# Tangential distortion
|
# tangential term
|
||||||
x_distort = x_distort + 2 * p1 * x * y + p2 * (r2 + 2 * x * x)
|
x_tan = 2 * p1 * x * y + p2 * (r2 + 2 * x * x)
|
||||||
y_distort = y_distort + p1 * (r2 + 2 * y * y) + 2 * p2 * x * y
|
y_tan = p1 * (r2 + 2 * y * y) + 2 * p2 * x * y
|
||||||
|
|
||||||
# Back to absolute coordinates
|
# apply both
|
||||||
x_distort = x_distort * fx + cx
|
x_dist = x * radial + x_tan
|
||||||
y_distort = y_distort * fy + cy
|
y_dist = y * radial + y_tan
|
||||||
|
|
||||||
# Combine distorted coordinates
|
# back to pixels
|
||||||
return jnp.stack([x_distort, y_distort], axis=1)
|
u = x_dist * fx + cx
|
||||||
|
v = y_dist * fy + cy
|
||||||
|
|
||||||
|
return jnp.stack([u, v], axis=1)
|
||||||
|
|
||||||
|
|
||||||
@jaxtyped(typechecker=beartype)
|
@jaxtyped(typechecker=beartype)
|
||||||
|
|||||||
Reference in New Issue
Block a user