1
0
forked from HQU-gxy/CVTH3PE

refactor: Update distortion and projection functions for clarity and usability

- Enhanced docstrings for `distortion` and `project` functions to clarify input expectations and output formats, emphasizing pixel coordinates.
- Improved variable naming for distortion calculations to enhance readability.
- Added checks for valid points in the `project` function, ensuring proper handling of distortion parameters.
- Introduced a new method `project_ideal` in the `Camera` class for projecting 3D points without distortion, improving usability.
This commit is contained in:
2025-04-22 11:41:54 +08:00
parent 032eb684ec
commit cb369b35da

View File

@ -15,19 +15,27 @@ CameraID: TypeAlias = str # pylint: disable=invalid-name
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
def distortion( def distortion(
points_2d: Num[Array, "N 2"], points_2d: Num[Array, "N 2"],
K: Num[Array, "3 3"], K: Num[Array, "3 3"], # pylint: disable=invalid-name
dist_coeffs: Num[Array, "5"], dist_coeffs: Num[Array, "5"],
) -> Num[Array, "N 2"]: ) -> Num[Array, "N 2"]:
""" """
Apply distortion to 2D points Apply distortion to 2D points in pixel coordinates
Args: Args:
points_2d: 2D points in image coordinates points_2d: 2D points in pixel coordinates
K: Camera intrinsic matrix K: Camera intrinsic matrix
dist_coeffs: Distortion coefficients [k1, k2, p1, p2, k3] dist_coeffs: Distortion coefficients [k1, k2, p1, p2, k3]
Returns: Returns:
Distorted 2D points Distorted 2D points in pixel coordinates
Note:
The function handles the conversion between pixel coordinates and normalized coordinates
internally. It expects points_2d to be in pixel coordinates, and returns distorted
points in pixel coordinates.
Implementation based on OpenCV's distortion model:
https://docs.opencv.org/4.10.0/d9/d0c/group__calib3d.html#ga69f2545a8b62a6b0fc2ee060dc30559d
""" """
k1, k2, p1, p2, k3 = dist_coeffs k1, k2, p1, p2, k3 = dist_coeffs
@ -41,41 +49,48 @@ def distortion(
r2 = x * x + y * y r2 = x * x + y * y
# Radial distortion # Radial distortion
xdistort = x * (1 + k1 * r2 + k2 * r2 * r2 + k3 * r2 * r2 * r2) x_distort = x * (1 + k1 * r2 + k2 * r2 * r2 + k3 * r2 * r2 * r2)
ydistort = y * (1 + k1 * r2 + k2 * r2 * r2 + k3 * r2 * r2 * r2) y_distort = y * (1 + k1 * r2 + k2 * r2 * r2 + k3 * r2 * r2 * r2)
# Tangential distortion # Tangential distortion
xdistort = xdistort + 2 * p1 * x * y + p2 * (r2 + 2 * x * x) x_distort = x_distort + 2 * p1 * x * y + p2 * (r2 + 2 * x * x)
ydistort = ydistort + p1 * (r2 + 2 * y * y) + 2 * p2 * x * y y_distort = y_distort + p1 * (r2 + 2 * y * y) + 2 * p2 * x * y
# Back to absolute coordinates # Back to absolute coordinates
xdistort = xdistort * fx + cx x_distort = x_distort * fx + cx
ydistort = ydistort * fy + cy y_distort = y_distort * fy + cy
# Combine distorted coordinates # Combine distorted coordinates
return jnp.stack([xdistort, ydistort], axis=1) return jnp.stack([x_distort, y_distort], axis=1)
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
def project( def project(
points_3d: Num[Array, "N 3"], points_3d: Num[Array, "N 3"],
projection_matrix: Num[Array, "3 4"], projection_matrix: Num[Array, "3 4"],
K: Num[Array, "3 3"], K: Optional[Num[Array, "3 3"]] = None, # pylint: disable=invalid-name
dist_coeffs: Num[Array, "5"], dist_coeffs: Optional[Num[Array, "5"]] = None,
image_size: Optional[Num[Array, "2"]] = None,
) -> Num[Array, "N 2"]: ) -> Num[Array, "N 2"]:
""" """
Project 3D points to 2D points Project 3D points to 2D points in pixel coordinates
Args: Args:
points_3d: 3D points in world coordinates points_3d: 3D points in world coordinates
projection_matrix: pre-computed projection matrix (K @ Rt[:3, :]) projection_matrix: pre-computed projection matrix (K @ Rt[:3, :]) that projects to pixel coordinates
K: Camera intrinsic matrix K: (optional) Camera intrinsic matrix, unnormalized
dist_coeffs: Distortion coefficients dist_coeffs: (optional) Distortion coefficients
image_size: (optional) Image dimensions [width, height] for valid point check.
If not provided, uses (0,1) normalized coordinates.
Note:
K and dist_coeffs must be provided together, or both be None.
If K is provided, it assumes that the projection matrix is calculated from the same K.
Returns: Returns:
2D points in image coordinates 2D points in pixel coordinates
""" """
P = projection_matrix P = projection_matrix # pylint: disable=invalid-name
p3d_homogeneous = jnp.hstack( p3d_homogeneous = jnp.hstack(
(points_3d, jnp.ones((points_3d.shape[0], 1), dtype=points_3d.dtype)) (points_3d, jnp.ones((points_3d.shape[0], 1), dtype=points_3d.dtype))
) )
@ -86,17 +101,30 @@ def project(
# Perspective division # Perspective division
p2d = p2d_homogeneous[:, 0:2] / p2d_homogeneous[:, 2:3] p2d = p2d_homogeneous[:, 0:2] / p2d_homogeneous[:, 2:3]
# Apply distortion if needed if dist_coeffs is not None and K is not None:
if dist_coeffs is not None: # Check if valid points (within image boundaries)
# Check if valid points (between 0 and 1) if image_size is not None:
valid = jnp.all(p2d > 0, axis=1) & jnp.all(p2d < 1, axis=1) # Use image dimensions for valid point check in pixel space
valid = (
jnp.all(p2d >= 0, axis=1)
& (p2d[:, 0] < image_size[0])
& (p2d[:, 1] < image_size[1])
)
else:
# Fall back to normalized coordinates if image_size not provided
valid = jnp.all(p2d >= 0, axis=1) & jnp.all(p2d < 1, axis=1)
# Only apply distortion if there are valid points # only valid points need distortion
if jnp.any(valid): if jnp.any(valid):
# Only distort valid points
valid_p2d = p2d[valid] valid_p2d = p2d[valid]
distorted_valid = distortion(valid_p2d, K, dist_coeffs) distorted_valid = distortion(valid_p2d, K, dist_coeffs)
p2d = p2d.at[valid].set(distorted_valid) p2d = p2d.at[valid].set(distorted_valid)
elif dist_coeffs is None and K is None:
pass
else:
raise ValueError(
"dist_coeffs and K must be provided together to compute distortion"
)
return jnp.squeeze(p2d) return jnp.squeeze(p2d)
@ -199,13 +227,30 @@ class Camera:
points_3d: 3D points in world coordinates points_3d: 3D points in world coordinates
Returns: Returns:
2D points in image coordinates 2D points in pixel coordinates
""" """
return project( return project(
points_3d=points_3d, points_3d,
projection_matrix=self.params.projection_matrix,
K=self.params.K, K=self.params.K,
dist_coeffs=self.params.dist_coeffs, dist_coeffs=self.params.dist_coeffs,
image_size=self.params.image_size,
)
def project_ideal(self, points_3d: Num[Array, "N 3"]) -> Num[Array, "N 2"]:
"""
Project 3D points to 2D points using this camera's parameters, without distortion
Args:
points_3d: 3D points in world coordinates
Returns:
2D points in pixel coordinates
"""
return project(
points_3d,
projection_matrix=self.params.projection_matrix, projection_matrix=self.params.projection_matrix,
image_size=self.params.image_size,
) )
def distortion(self, points_2d: Num[Array, "N 2"]) -> Num[Array, "N 2"]: def distortion(self, points_2d: Num[Array, "N 2"]) -> Num[Array, "N 2"]: