forked from HQU-gxy/CVTH3PE
refactor: Update distortion and projection functions for clarity and usability
- Enhanced docstrings for `distortion` and `project` functions to clarify input expectations and output formats, emphasizing pixel coordinates. - Improved variable naming for distortion calculations to enhance readability. - Added checks for valid points in the `project` function, ensuring proper handling of distortion parameters. - Introduced a new method `project_ideal` in the `Camera` class for projecting 3D points without distortion, improving usability.
This commit is contained in:
@ -15,19 +15,27 @@ CameraID: TypeAlias = str # pylint: disable=invalid-name
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def distortion(
|
||||
points_2d: Num[Array, "N 2"],
|
||||
K: Num[Array, "3 3"],
|
||||
K: Num[Array, "3 3"], # pylint: disable=invalid-name
|
||||
dist_coeffs: Num[Array, "5"],
|
||||
) -> Num[Array, "N 2"]:
|
||||
"""
|
||||
Apply distortion to 2D points
|
||||
Apply distortion to 2D points in pixel coordinates
|
||||
|
||||
Args:
|
||||
points_2d: 2D points in image coordinates
|
||||
points_2d: 2D points in pixel coordinates
|
||||
K: Camera intrinsic matrix
|
||||
dist_coeffs: Distortion coefficients [k1, k2, p1, p2, k3]
|
||||
|
||||
Returns:
|
||||
Distorted 2D points
|
||||
Distorted 2D points in pixel coordinates
|
||||
|
||||
Note:
|
||||
The function handles the conversion between pixel coordinates and normalized coordinates
|
||||
internally. It expects points_2d to be in pixel coordinates, and returns distorted
|
||||
points in pixel coordinates.
|
||||
|
||||
Implementation based on OpenCV's distortion model:
|
||||
https://docs.opencv.org/4.10.0/d9/d0c/group__calib3d.html#ga69f2545a8b62a6b0fc2ee060dc30559d
|
||||
"""
|
||||
k1, k2, p1, p2, k3 = dist_coeffs
|
||||
|
||||
@ -41,41 +49,48 @@ def distortion(
|
||||
r2 = x * x + y * y
|
||||
|
||||
# Radial distortion
|
||||
xdistort = x * (1 + k1 * r2 + k2 * r2 * r2 + k3 * r2 * r2 * r2)
|
||||
ydistort = y * (1 + k1 * r2 + k2 * r2 * r2 + k3 * r2 * r2 * r2)
|
||||
x_distort = x * (1 + k1 * r2 + k2 * r2 * r2 + k3 * r2 * r2 * r2)
|
||||
y_distort = y * (1 + k1 * r2 + k2 * r2 * r2 + k3 * r2 * r2 * r2)
|
||||
|
||||
# Tangential distortion
|
||||
xdistort = xdistort + 2 * p1 * x * y + p2 * (r2 + 2 * x * x)
|
||||
ydistort = ydistort + p1 * (r2 + 2 * y * y) + 2 * p2 * x * y
|
||||
x_distort = x_distort + 2 * p1 * x * y + p2 * (r2 + 2 * x * x)
|
||||
y_distort = y_distort + p1 * (r2 + 2 * y * y) + 2 * p2 * x * y
|
||||
|
||||
# Back to absolute coordinates
|
||||
xdistort = xdistort * fx + cx
|
||||
ydistort = ydistort * fy + cy
|
||||
x_distort = x_distort * fx + cx
|
||||
y_distort = y_distort * fy + cy
|
||||
|
||||
# Combine distorted coordinates
|
||||
return jnp.stack([xdistort, ydistort], axis=1)
|
||||
return jnp.stack([x_distort, y_distort], axis=1)
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def project(
|
||||
points_3d: Num[Array, "N 3"],
|
||||
projection_matrix: Num[Array, "3 4"],
|
||||
K: Num[Array, "3 3"],
|
||||
dist_coeffs: Num[Array, "5"],
|
||||
K: Optional[Num[Array, "3 3"]] = None, # pylint: disable=invalid-name
|
||||
dist_coeffs: Optional[Num[Array, "5"]] = None,
|
||||
image_size: Optional[Num[Array, "2"]] = None,
|
||||
) -> Num[Array, "N 2"]:
|
||||
"""
|
||||
Project 3D points to 2D points
|
||||
Project 3D points to 2D points in pixel coordinates
|
||||
|
||||
Args:
|
||||
points_3d: 3D points in world coordinates
|
||||
projection_matrix: pre-computed projection matrix (K @ Rt[:3, :])
|
||||
K: Camera intrinsic matrix
|
||||
dist_coeffs: Distortion coefficients
|
||||
projection_matrix: pre-computed projection matrix (K @ Rt[:3, :]) that projects to pixel coordinates
|
||||
K: (optional) Camera intrinsic matrix, unnormalized
|
||||
dist_coeffs: (optional) Distortion coefficients
|
||||
image_size: (optional) Image dimensions [width, height] for valid point check.
|
||||
If not provided, uses (0,1) normalized coordinates.
|
||||
|
||||
Note:
|
||||
K and dist_coeffs must be provided together, or both be None.
|
||||
If K is provided, it assumes that the projection matrix is calculated from the same K.
|
||||
|
||||
Returns:
|
||||
2D points in image coordinates
|
||||
2D points in pixel coordinates
|
||||
"""
|
||||
P = projection_matrix
|
||||
P = projection_matrix # pylint: disable=invalid-name
|
||||
p3d_homogeneous = jnp.hstack(
|
||||
(points_3d, jnp.ones((points_3d.shape[0], 1), dtype=points_3d.dtype))
|
||||
)
|
||||
@ -86,17 +101,30 @@ def project(
|
||||
# Perspective division
|
||||
p2d = p2d_homogeneous[:, 0:2] / p2d_homogeneous[:, 2:3]
|
||||
|
||||
# Apply distortion if needed
|
||||
if dist_coeffs is not None:
|
||||
# Check if valid points (between 0 and 1)
|
||||
valid = jnp.all(p2d > 0, axis=1) & jnp.all(p2d < 1, axis=1)
|
||||
if dist_coeffs is not None and K is not None:
|
||||
# Check if valid points (within image boundaries)
|
||||
if image_size is not None:
|
||||
# Use image dimensions for valid point check in pixel space
|
||||
valid = (
|
||||
jnp.all(p2d >= 0, axis=1)
|
||||
& (p2d[:, 0] < image_size[0])
|
||||
& (p2d[:, 1] < image_size[1])
|
||||
)
|
||||
else:
|
||||
# Fall back to normalized coordinates if image_size not provided
|
||||
valid = jnp.all(p2d >= 0, axis=1) & jnp.all(p2d < 1, axis=1)
|
||||
|
||||
# Only apply distortion if there are valid points
|
||||
# only valid points need distortion
|
||||
if jnp.any(valid):
|
||||
# Only distort valid points
|
||||
valid_p2d = p2d[valid]
|
||||
distorted_valid = distortion(valid_p2d, K, dist_coeffs)
|
||||
p2d = p2d.at[valid].set(distorted_valid)
|
||||
elif dist_coeffs is None and K is None:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
"dist_coeffs and K must be provided together to compute distortion"
|
||||
)
|
||||
|
||||
return jnp.squeeze(p2d)
|
||||
|
||||
@ -199,13 +227,30 @@ class Camera:
|
||||
points_3d: 3D points in world coordinates
|
||||
|
||||
Returns:
|
||||
2D points in image coordinates
|
||||
2D points in pixel coordinates
|
||||
"""
|
||||
return project(
|
||||
points_3d=points_3d,
|
||||
points_3d,
|
||||
projection_matrix=self.params.projection_matrix,
|
||||
K=self.params.K,
|
||||
dist_coeffs=self.params.dist_coeffs,
|
||||
image_size=self.params.image_size,
|
||||
)
|
||||
|
||||
def project_ideal(self, points_3d: Num[Array, "N 3"]) -> Num[Array, "N 2"]:
|
||||
"""
|
||||
Project 3D points to 2D points using this camera's parameters, without distortion
|
||||
|
||||
Args:
|
||||
points_3d: 3D points in world coordinates
|
||||
|
||||
Returns:
|
||||
2D points in pixel coordinates
|
||||
"""
|
||||
return project(
|
||||
points_3d,
|
||||
projection_matrix=self.params.projection_matrix,
|
||||
image_size=self.params.image_size,
|
||||
)
|
||||
|
||||
def distortion(self, points_2d: Num[Array, "N 2"]) -> Num[Array, "N 2"]:
|
||||
|
||||
Reference in New Issue
Block a user