forked from HQU-gxy/CVTH3PE
refactor: Optimize distortion function for clarity and performance
- Added JAX Just-In-Time (JIT) compilation to the `distortion` function for improved performance. - Reorganized variable unpacking and calculations for better readability and efficiency. - Enhanced comments to clarify the steps involved in the distortion process, including normalization and radial/tangential distortion calculations. - Updated the return values to directly reflect pixel coordinates after distortion.
This commit is contained in:
@ -4,14 +4,15 @@ from datetime import datetime
|
||||
from typing import Any, TypeAlias, TypedDict, Optional
|
||||
|
||||
from beartype import beartype
|
||||
from jax import Array
|
||||
import jax
|
||||
from jax import numpy as jnp
|
||||
from jaxtyping import Num, jaxtyped
|
||||
from jaxtyping import Num, jaxtyped, Array
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
CameraID: TypeAlias = str # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@jax.jit
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def distortion(
|
||||
points_2d: Num[Array, "N 2"],
|
||||
@ -37,31 +38,36 @@ def distortion(
|
||||
Implementation based on OpenCV's distortion model:
|
||||
https://docs.opencv.org/4.10.0/d9/d0c/group__calib3d.html#ga69f2545a8b62a6b0fc2ee060dc30559d
|
||||
"""
|
||||
# unpack
|
||||
fx, fy = K[0, 0], K[1, 1]
|
||||
cx, cy = K[0, 2], K[1, 2]
|
||||
k1, k2, p1, p2, k3 = dist_coeffs
|
||||
|
||||
# Get principal point and focal length
|
||||
cx, cy = K[0, 2], K[1, 2]
|
||||
fx, fy = K[0, 0], K[1, 1]
|
||||
|
||||
# Convert to normalized coordinates
|
||||
# normalize
|
||||
x = (points_2d[:, 0] - cx) / fx
|
||||
y = (points_2d[:, 1] - cy) / fy
|
||||
|
||||
# precompute r^2, r^4, r^6
|
||||
r2 = x * x + y * y
|
||||
r4 = r2 * r2
|
||||
r6 = r4 * r2
|
||||
|
||||
# Radial distortion
|
||||
x_distort = x * (1 + k1 * r2 + k2 * r2 * r2 + k3 * r2 * r2 * r2)
|
||||
y_distort = y * (1 + k1 * r2 + k2 * r2 * r2 + k3 * r2 * r2 * r2)
|
||||
# radial term
|
||||
radial = 1 + k1 * r2 + k2 * r4 + k3 * r6
|
||||
|
||||
# Tangential distortion
|
||||
x_distort = x_distort + 2 * p1 * x * y + p2 * (r2 + 2 * x * x)
|
||||
y_distort = y_distort + p1 * (r2 + 2 * y * y) + 2 * p2 * x * y
|
||||
# tangential term
|
||||
x_tan = 2 * p1 * x * y + p2 * (r2 + 2 * x * x)
|
||||
y_tan = p1 * (r2 + 2 * y * y) + 2 * p2 * x * y
|
||||
|
||||
# Back to absolute coordinates
|
||||
x_distort = x_distort * fx + cx
|
||||
y_distort = y_distort * fy + cy
|
||||
# apply both
|
||||
x_dist = x * radial + x_tan
|
||||
y_dist = y * radial + y_tan
|
||||
|
||||
# Combine distorted coordinates
|
||||
return jnp.stack([x_distort, y_distort], axis=1)
|
||||
# back to pixels
|
||||
u = x_dist * fx + cx
|
||||
v = y_dist * fy + cy
|
||||
|
||||
return jnp.stack([u, v], axis=1)
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
|
||||
Reference in New Issue
Block a user