forked from HQU-gxy/CVTH3PE
refactor: Update distortion and projection functions for clarity and usability
- Enhanced docstrings for `distortion` and `project` functions to clarify input expectations and output formats, emphasizing pixel coordinates. - Improved variable naming for distortion calculations to enhance readability. - Added checks for valid points in the `project` function, ensuring proper handling of distortion parameters. - Introduced a new method `project_ideal` in the `Camera` class for projecting 3D points without distortion, improving usability.
This commit is contained in:
@ -15,19 +15,27 @@ CameraID: TypeAlias = str # pylint: disable=invalid-name
|
|||||||
@jaxtyped(typechecker=beartype)
|
@jaxtyped(typechecker=beartype)
|
||||||
def distortion(
|
def distortion(
|
||||||
points_2d: Num[Array, "N 2"],
|
points_2d: Num[Array, "N 2"],
|
||||||
K: Num[Array, "3 3"],
|
K: Num[Array, "3 3"], # pylint: disable=invalid-name
|
||||||
dist_coeffs: Num[Array, "5"],
|
dist_coeffs: Num[Array, "5"],
|
||||||
) -> Num[Array, "N 2"]:
|
) -> Num[Array, "N 2"]:
|
||||||
"""
|
"""
|
||||||
Apply distortion to 2D points
|
Apply distortion to 2D points in pixel coordinates
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
points_2d: 2D points in image coordinates
|
points_2d: 2D points in pixel coordinates
|
||||||
K: Camera intrinsic matrix
|
K: Camera intrinsic matrix
|
||||||
dist_coeffs: Distortion coefficients [k1, k2, p1, p2, k3]
|
dist_coeffs: Distortion coefficients [k1, k2, p1, p2, k3]
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Distorted 2D points
|
Distorted 2D points in pixel coordinates
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The function handles the conversion between pixel coordinates and normalized coordinates
|
||||||
|
internally. It expects points_2d to be in pixel coordinates, and returns distorted
|
||||||
|
points in pixel coordinates.
|
||||||
|
|
||||||
|
Implementation based on OpenCV's distortion model:
|
||||||
|
https://docs.opencv.org/4.10.0/d9/d0c/group__calib3d.html#ga69f2545a8b62a6b0fc2ee060dc30559d
|
||||||
"""
|
"""
|
||||||
k1, k2, p1, p2, k3 = dist_coeffs
|
k1, k2, p1, p2, k3 = dist_coeffs
|
||||||
|
|
||||||
@ -41,41 +49,48 @@ def distortion(
|
|||||||
r2 = x * x + y * y
|
r2 = x * x + y * y
|
||||||
|
|
||||||
# Radial distortion
|
# Radial distortion
|
||||||
xdistort = x * (1 + k1 * r2 + k2 * r2 * r2 + k3 * r2 * r2 * r2)
|
x_distort = x * (1 + k1 * r2 + k2 * r2 * r2 + k3 * r2 * r2 * r2)
|
||||||
ydistort = y * (1 + k1 * r2 + k2 * r2 * r2 + k3 * r2 * r2 * r2)
|
y_distort = y * (1 + k1 * r2 + k2 * r2 * r2 + k3 * r2 * r2 * r2)
|
||||||
|
|
||||||
# Tangential distortion
|
# Tangential distortion
|
||||||
xdistort = xdistort + 2 * p1 * x * y + p2 * (r2 + 2 * x * x)
|
x_distort = x_distort + 2 * p1 * x * y + p2 * (r2 + 2 * x * x)
|
||||||
ydistort = ydistort + p1 * (r2 + 2 * y * y) + 2 * p2 * x * y
|
y_distort = y_distort + p1 * (r2 + 2 * y * y) + 2 * p2 * x * y
|
||||||
|
|
||||||
# Back to absolute coordinates
|
# Back to absolute coordinates
|
||||||
xdistort = xdistort * fx + cx
|
x_distort = x_distort * fx + cx
|
||||||
ydistort = ydistort * fy + cy
|
y_distort = y_distort * fy + cy
|
||||||
|
|
||||||
# Combine distorted coordinates
|
# Combine distorted coordinates
|
||||||
return jnp.stack([xdistort, ydistort], axis=1)
|
return jnp.stack([x_distort, y_distort], axis=1)
|
||||||
|
|
||||||
|
|
||||||
@jaxtyped(typechecker=beartype)
|
@jaxtyped(typechecker=beartype)
|
||||||
def project(
|
def project(
|
||||||
points_3d: Num[Array, "N 3"],
|
points_3d: Num[Array, "N 3"],
|
||||||
projection_matrix: Num[Array, "3 4"],
|
projection_matrix: Num[Array, "3 4"],
|
||||||
K: Num[Array, "3 3"],
|
K: Optional[Num[Array, "3 3"]] = None, # pylint: disable=invalid-name
|
||||||
dist_coeffs: Num[Array, "5"],
|
dist_coeffs: Optional[Num[Array, "5"]] = None,
|
||||||
|
image_size: Optional[Num[Array, "2"]] = None,
|
||||||
) -> Num[Array, "N 2"]:
|
) -> Num[Array, "N 2"]:
|
||||||
"""
|
"""
|
||||||
Project 3D points to 2D points
|
Project 3D points to 2D points in pixel coordinates
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
points_3d: 3D points in world coordinates
|
points_3d: 3D points in world coordinates
|
||||||
projection_matrix: pre-computed projection matrix (K @ Rt[:3, :])
|
projection_matrix: pre-computed projection matrix (K @ Rt[:3, :]) that projects to pixel coordinates
|
||||||
K: Camera intrinsic matrix
|
K: (optional) Camera intrinsic matrix, unnormalized
|
||||||
dist_coeffs: Distortion coefficients
|
dist_coeffs: (optional) Distortion coefficients
|
||||||
|
image_size: (optional) Image dimensions [width, height] for valid point check.
|
||||||
|
If not provided, uses (0,1) normalized coordinates.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
K and dist_coeffs must be provided together, or both be None.
|
||||||
|
If K is provided, it assumes that the projection matrix is calculated from the same K.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
2D points in image coordinates
|
2D points in pixel coordinates
|
||||||
"""
|
"""
|
||||||
P = projection_matrix
|
P = projection_matrix # pylint: disable=invalid-name
|
||||||
p3d_homogeneous = jnp.hstack(
|
p3d_homogeneous = jnp.hstack(
|
||||||
(points_3d, jnp.ones((points_3d.shape[0], 1), dtype=points_3d.dtype))
|
(points_3d, jnp.ones((points_3d.shape[0], 1), dtype=points_3d.dtype))
|
||||||
)
|
)
|
||||||
@ -86,17 +101,30 @@ def project(
|
|||||||
# Perspective division
|
# Perspective division
|
||||||
p2d = p2d_homogeneous[:, 0:2] / p2d_homogeneous[:, 2:3]
|
p2d = p2d_homogeneous[:, 0:2] / p2d_homogeneous[:, 2:3]
|
||||||
|
|
||||||
# Apply distortion if needed
|
if dist_coeffs is not None and K is not None:
|
||||||
if dist_coeffs is not None:
|
# Check if valid points (within image boundaries)
|
||||||
# Check if valid points (between 0 and 1)
|
if image_size is not None:
|
||||||
valid = jnp.all(p2d > 0, axis=1) & jnp.all(p2d < 1, axis=1)
|
# Use image dimensions for valid point check in pixel space
|
||||||
|
valid = (
|
||||||
|
jnp.all(p2d >= 0, axis=1)
|
||||||
|
& (p2d[:, 0] < image_size[0])
|
||||||
|
& (p2d[:, 1] < image_size[1])
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Fall back to normalized coordinates if image_size not provided
|
||||||
|
valid = jnp.all(p2d >= 0, axis=1) & jnp.all(p2d < 1, axis=1)
|
||||||
|
|
||||||
# Only apply distortion if there are valid points
|
# only valid points need distortion
|
||||||
if jnp.any(valid):
|
if jnp.any(valid):
|
||||||
# Only distort valid points
|
|
||||||
valid_p2d = p2d[valid]
|
valid_p2d = p2d[valid]
|
||||||
distorted_valid = distortion(valid_p2d, K, dist_coeffs)
|
distorted_valid = distortion(valid_p2d, K, dist_coeffs)
|
||||||
p2d = p2d.at[valid].set(distorted_valid)
|
p2d = p2d.at[valid].set(distorted_valid)
|
||||||
|
elif dist_coeffs is None and K is None:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"dist_coeffs and K must be provided together to compute distortion"
|
||||||
|
)
|
||||||
|
|
||||||
return jnp.squeeze(p2d)
|
return jnp.squeeze(p2d)
|
||||||
|
|
||||||
@ -199,13 +227,30 @@ class Camera:
|
|||||||
points_3d: 3D points in world coordinates
|
points_3d: 3D points in world coordinates
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
2D points in image coordinates
|
2D points in pixel coordinates
|
||||||
"""
|
"""
|
||||||
return project(
|
return project(
|
||||||
points_3d=points_3d,
|
points_3d,
|
||||||
|
projection_matrix=self.params.projection_matrix,
|
||||||
K=self.params.K,
|
K=self.params.K,
|
||||||
dist_coeffs=self.params.dist_coeffs,
|
dist_coeffs=self.params.dist_coeffs,
|
||||||
|
image_size=self.params.image_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def project_ideal(self, points_3d: Num[Array, "N 3"]) -> Num[Array, "N 2"]:
|
||||||
|
"""
|
||||||
|
Project 3D points to 2D points using this camera's parameters, without distortion
|
||||||
|
|
||||||
|
Args:
|
||||||
|
points_3d: 3D points in world coordinates
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
2D points in pixel coordinates
|
||||||
|
"""
|
||||||
|
return project(
|
||||||
|
points_3d,
|
||||||
projection_matrix=self.params.projection_matrix,
|
projection_matrix=self.params.projection_matrix,
|
||||||
|
image_size=self.params.image_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
def distortion(self, points_2d: Num[Array, "N 2"]) -> Num[Array, "N 2"]:
|
def distortion(self, points_2d: Num[Array, "N 2"]) -> Num[Array, "N 2"]:
|
||||||
|
|||||||
Reference in New Issue
Block a user