From cb369b35daa7275e9f7527ce217ea02f504e77d1 Mon Sep 17 00:00:00 2001 From: crosstyan Date: Tue, 22 Apr 2025 11:41:54 +0800 Subject: [PATCH] refactor: Update distortion and projection functions for clarity and usability - Enhanced docstrings for `distortion` and `project` functions to clarify input expectations and output formats, emphasizing pixel coordinates. - Improved variable naming for distortion calculations to enhance readability. - Added checks for valid points in the `project` function, ensuring proper handling of distortion parameters. - Introduced a new method `project_ideal` in the `Camera` class for projecting 3D points without distortion, improving usability. --- app/camera/__init__.py | 99 ++++++++++++++++++++++++++++++------------ 1 file changed, 72 insertions(+), 27 deletions(-) diff --git a/app/camera/__init__.py b/app/camera/__init__.py index a2a66e3..947eca2 100644 --- a/app/camera/__init__.py +++ b/app/camera/__init__.py @@ -15,19 +15,27 @@ CameraID: TypeAlias = str # pylint: disable=invalid-name @jaxtyped(typechecker=beartype) def distortion( points_2d: Num[Array, "N 2"], - K: Num[Array, "3 3"], + K: Num[Array, "3 3"], # pylint: disable=invalid-name dist_coeffs: Num[Array, "5"], ) -> Num[Array, "N 2"]: """ - Apply distortion to 2D points + Apply distortion to 2D points in pixel coordinates Args: - points_2d: 2D points in image coordinates + points_2d: 2D points in pixel coordinates K: Camera intrinsic matrix dist_coeffs: Distortion coefficients [k1, k2, p1, p2, k3] Returns: - Distorted 2D points + Distorted 2D points in pixel coordinates + + Note: + The function handles the conversion between pixel coordinates and normalized coordinates + internally. It expects points_2d to be in pixel coordinates, and returns distorted + points in pixel coordinates. + + Implementation based on OpenCV's distortion model: + https://docs.opencv.org/4.10.0/d9/d0c/group__calib3d.html#ga69f2545a8b62a6b0fc2ee060dc30559d """ k1, k2, p1, p2, k3 = dist_coeffs @@ -41,41 +49,48 @@ def distortion( r2 = x * x + y * y # Radial distortion - xdistort = x * (1 + k1 * r2 + k2 * r2 * r2 + k3 * r2 * r2 * r2) - ydistort = y * (1 + k1 * r2 + k2 * r2 * r2 + k3 * r2 * r2 * r2) + x_distort = x * (1 + k1 * r2 + k2 * r2 * r2 + k3 * r2 * r2 * r2) + y_distort = y * (1 + k1 * r2 + k2 * r2 * r2 + k3 * r2 * r2 * r2) # Tangential distortion - xdistort = xdistort + 2 * p1 * x * y + p2 * (r2 + 2 * x * x) - ydistort = ydistort + p1 * (r2 + 2 * y * y) + 2 * p2 * x * y + x_distort = x_distort + 2 * p1 * x * y + p2 * (r2 + 2 * x * x) + y_distort = y_distort + p1 * (r2 + 2 * y * y) + 2 * p2 * x * y # Back to absolute coordinates - xdistort = xdistort * fx + cx - ydistort = ydistort * fy + cy + x_distort = x_distort * fx + cx + y_distort = y_distort * fy + cy # Combine distorted coordinates - return jnp.stack([xdistort, ydistort], axis=1) + return jnp.stack([x_distort, y_distort], axis=1) @jaxtyped(typechecker=beartype) def project( points_3d: Num[Array, "N 3"], projection_matrix: Num[Array, "3 4"], - K: Num[Array, "3 3"], - dist_coeffs: Num[Array, "5"], + K: Optional[Num[Array, "3 3"]] = None, # pylint: disable=invalid-name + dist_coeffs: Optional[Num[Array, "5"]] = None, + image_size: Optional[Num[Array, "2"]] = None, ) -> Num[Array, "N 2"]: """ - Project 3D points to 2D points + Project 3D points to 2D points in pixel coordinates Args: points_3d: 3D points in world coordinates - projection_matrix: pre-computed projection matrix (K @ Rt[:3, :]) - K: Camera intrinsic matrix - dist_coeffs: Distortion coefficients + projection_matrix: pre-computed projection matrix (K @ Rt[:3, :]) that projects to pixel coordinates + K: (optional) Camera intrinsic matrix, unnormalized + dist_coeffs: (optional) Distortion coefficients + image_size: (optional) Image dimensions [width, height] for valid point check. + If not provided, uses (0,1) normalized coordinates. + + Note: + K and dist_coeffs must be provided together, or both be None. + If K is provided, it assumes that the projection matrix is calculated from the same K. Returns: - 2D points in image coordinates + 2D points in pixel coordinates """ - P = projection_matrix + P = projection_matrix # pylint: disable=invalid-name p3d_homogeneous = jnp.hstack( (points_3d, jnp.ones((points_3d.shape[0], 1), dtype=points_3d.dtype)) ) @@ -86,17 +101,30 @@ def project( # Perspective division p2d = p2d_homogeneous[:, 0:2] / p2d_homogeneous[:, 2:3] - # Apply distortion if needed - if dist_coeffs is not None: - # Check if valid points (between 0 and 1) - valid = jnp.all(p2d > 0, axis=1) & jnp.all(p2d < 1, axis=1) + if dist_coeffs is not None and K is not None: + # Check if valid points (within image boundaries) + if image_size is not None: + # Use image dimensions for valid point check in pixel space + valid = ( + jnp.all(p2d >= 0, axis=1) + & (p2d[:, 0] < image_size[0]) + & (p2d[:, 1] < image_size[1]) + ) + else: + # Fall back to normalized coordinates if image_size not provided + valid = jnp.all(p2d >= 0, axis=1) & jnp.all(p2d < 1, axis=1) - # Only apply distortion if there are valid points + # only valid points need distortion if jnp.any(valid): - # Only distort valid points valid_p2d = p2d[valid] distorted_valid = distortion(valid_p2d, K, dist_coeffs) p2d = p2d.at[valid].set(distorted_valid) + elif dist_coeffs is None and K is None: + pass + else: + raise ValueError( + "dist_coeffs and K must be provided together to compute distortion" + ) return jnp.squeeze(p2d) @@ -199,13 +227,30 @@ class Camera: points_3d: 3D points in world coordinates Returns: - 2D points in image coordinates + 2D points in pixel coordinates """ return project( - points_3d=points_3d, + points_3d, + projection_matrix=self.params.projection_matrix, K=self.params.K, dist_coeffs=self.params.dist_coeffs, + image_size=self.params.image_size, + ) + + def project_ideal(self, points_3d: Num[Array, "N 3"]) -> Num[Array, "N 2"]: + """ + Project 3D points to 2D points using this camera's parameters, without distortion + + Args: + points_3d: 3D points in world coordinates + + Returns: + 2D points in pixel coordinates + """ + return project( + points_3d, projection_matrix=self.params.projection_matrix, + image_size=self.params.image_size, ) def distortion(self, points_2d: Num[Array, "N 2"]) -> Num[Array, "N 2"]: