1
0
forked from HQU-gxy/CVTH3PE

refactor: Update distortion and projection functions for clarity and usability

- Enhanced docstrings for `distortion` and `project` functions to clarify input expectations and output formats, emphasizing pixel coordinates.
- Improved variable naming for distortion calculations to enhance readability.
- Added checks for valid points in the `project` function, ensuring proper handling of distortion parameters.
- Introduced a new method `project_ideal` in the `Camera` class for projecting 3D points without distortion, improving usability.
This commit is contained in:
2025-04-22 11:41:54 +08:00
parent 032eb684ec
commit cb369b35da

View File

@ -15,19 +15,27 @@ CameraID: TypeAlias = str # pylint: disable=invalid-name
@jaxtyped(typechecker=beartype)
def distortion(
points_2d: Num[Array, "N 2"],
K: Num[Array, "3 3"],
K: Num[Array, "3 3"], # pylint: disable=invalid-name
dist_coeffs: Num[Array, "5"],
) -> Num[Array, "N 2"]:
"""
Apply distortion to 2D points
Apply distortion to 2D points in pixel coordinates
Args:
points_2d: 2D points in image coordinates
points_2d: 2D points in pixel coordinates
K: Camera intrinsic matrix
dist_coeffs: Distortion coefficients [k1, k2, p1, p2, k3]
Returns:
Distorted 2D points
Distorted 2D points in pixel coordinates
Note:
The function handles the conversion between pixel coordinates and normalized coordinates
internally. It expects points_2d to be in pixel coordinates, and returns distorted
points in pixel coordinates.
Implementation based on OpenCV's distortion model:
https://docs.opencv.org/4.10.0/d9/d0c/group__calib3d.html#ga69f2545a8b62a6b0fc2ee060dc30559d
"""
k1, k2, p1, p2, k3 = dist_coeffs
@ -41,41 +49,48 @@ def distortion(
r2 = x * x + y * y
# Radial distortion
xdistort = x * (1 + k1 * r2 + k2 * r2 * r2 + k3 * r2 * r2 * r2)
ydistort = y * (1 + k1 * r2 + k2 * r2 * r2 + k3 * r2 * r2 * r2)
x_distort = x * (1 + k1 * r2 + k2 * r2 * r2 + k3 * r2 * r2 * r2)
y_distort = y * (1 + k1 * r2 + k2 * r2 * r2 + k3 * r2 * r2 * r2)
# Tangential distortion
xdistort = xdistort + 2 * p1 * x * y + p2 * (r2 + 2 * x * x)
ydistort = ydistort + p1 * (r2 + 2 * y * y) + 2 * p2 * x * y
x_distort = x_distort + 2 * p1 * x * y + p2 * (r2 + 2 * x * x)
y_distort = y_distort + p1 * (r2 + 2 * y * y) + 2 * p2 * x * y
# Back to absolute coordinates
xdistort = xdistort * fx + cx
ydistort = ydistort * fy + cy
x_distort = x_distort * fx + cx
y_distort = y_distort * fy + cy
# Combine distorted coordinates
return jnp.stack([xdistort, ydistort], axis=1)
return jnp.stack([x_distort, y_distort], axis=1)
@jaxtyped(typechecker=beartype)
def project(
points_3d: Num[Array, "N 3"],
projection_matrix: Num[Array, "3 4"],
K: Num[Array, "3 3"],
dist_coeffs: Num[Array, "5"],
K: Optional[Num[Array, "3 3"]] = None, # pylint: disable=invalid-name
dist_coeffs: Optional[Num[Array, "5"]] = None,
image_size: Optional[Num[Array, "2"]] = None,
) -> Num[Array, "N 2"]:
"""
Project 3D points to 2D points
Project 3D points to 2D points in pixel coordinates
Args:
points_3d: 3D points in world coordinates
projection_matrix: pre-computed projection matrix (K @ Rt[:3, :])
K: Camera intrinsic matrix
dist_coeffs: Distortion coefficients
projection_matrix: pre-computed projection matrix (K @ Rt[:3, :]) that projects to pixel coordinates
K: (optional) Camera intrinsic matrix, unnormalized
dist_coeffs: (optional) Distortion coefficients
image_size: (optional) Image dimensions [width, height] for valid point check.
If not provided, uses (0,1) normalized coordinates.
Note:
K and dist_coeffs must be provided together, or both be None.
If K is provided, it assumes that the projection matrix is calculated from the same K.
Returns:
2D points in image coordinates
2D points in pixel coordinates
"""
P = projection_matrix
P = projection_matrix # pylint: disable=invalid-name
p3d_homogeneous = jnp.hstack(
(points_3d, jnp.ones((points_3d.shape[0], 1), dtype=points_3d.dtype))
)
@ -86,17 +101,30 @@ def project(
# Perspective division
p2d = p2d_homogeneous[:, 0:2] / p2d_homogeneous[:, 2:3]
# Apply distortion if needed
if dist_coeffs is not None:
# Check if valid points (between 0 and 1)
valid = jnp.all(p2d > 0, axis=1) & jnp.all(p2d < 1, axis=1)
if dist_coeffs is not None and K is not None:
# Check if valid points (within image boundaries)
if image_size is not None:
# Use image dimensions for valid point check in pixel space
valid = (
jnp.all(p2d >= 0, axis=1)
& (p2d[:, 0] < image_size[0])
& (p2d[:, 1] < image_size[1])
)
else:
# Fall back to normalized coordinates if image_size not provided
valid = jnp.all(p2d >= 0, axis=1) & jnp.all(p2d < 1, axis=1)
# Only apply distortion if there are valid points
# only valid points need distortion
if jnp.any(valid):
# Only distort valid points
valid_p2d = p2d[valid]
distorted_valid = distortion(valid_p2d, K, dist_coeffs)
p2d = p2d.at[valid].set(distorted_valid)
elif dist_coeffs is None and K is None:
pass
else:
raise ValueError(
"dist_coeffs and K must be provided together to compute distortion"
)
return jnp.squeeze(p2d)
@ -199,13 +227,30 @@ class Camera:
points_3d: 3D points in world coordinates
Returns:
2D points in image coordinates
2D points in pixel coordinates
"""
return project(
points_3d=points_3d,
points_3d,
projection_matrix=self.params.projection_matrix,
K=self.params.K,
dist_coeffs=self.params.dist_coeffs,
image_size=self.params.image_size,
)
def project_ideal(self, points_3d: Num[Array, "N 3"]) -> Num[Array, "N 2"]:
"""
Project 3D points to 2D points using this camera's parameters, without distortion
Args:
points_3d: 3D points in world coordinates
Returns:
2D points in pixel coordinates
"""
return project(
points_3d,
projection_matrix=self.params.projection_matrix,
image_size=self.params.image_size,
)
def distortion(self, points_2d: Num[Array, "N 2"]) -> Num[Array, "N 2"]: