From 3ccf790bac9b0a7e69d180d62110ebc90cf8a7f7 Mon Sep 17 00:00:00 2001 From: crosstyan Date: Tue, 22 Apr 2025 11:43:55 +0800 Subject: [PATCH] refactor: Optimize distortion function for clarity and performance - Added JAX Just-In-Time (JIT) compilation to the `distortion` function for improved performance. - Reorganized variable unpacking and calculations for better readability and efficiency. - Enhanced comments to clarify the steps involved in the distortion process, including normalization and radial/tangential distortion calculations. - Updated the return values to directly reflect pixel coordinates after distortion. --- app/camera/__init__.py | 42 ++++++++++++++++++++++++------------------ 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/app/camera/__init__.py b/app/camera/__init__.py index 947eca2..d0d0f96 100644 --- a/app/camera/__init__.py +++ b/app/camera/__init__.py @@ -4,14 +4,15 @@ from datetime import datetime from typing import Any, TypeAlias, TypedDict, Optional from beartype import beartype -from jax import Array +import jax from jax import numpy as jnp -from jaxtyping import Num, jaxtyped +from jaxtyping import Num, jaxtyped, Array from typing_extensions import NotRequired CameraID: TypeAlias = str # pylint: disable=invalid-name +@jax.jit @jaxtyped(typechecker=beartype) def distortion( points_2d: Num[Array, "N 2"], @@ -37,31 +38,36 @@ def distortion( Implementation based on OpenCV's distortion model: https://docs.opencv.org/4.10.0/d9/d0c/group__calib3d.html#ga69f2545a8b62a6b0fc2ee060dc30559d """ + # unpack + fx, fy = K[0, 0], K[1, 1] + cx, cy = K[0, 2], K[1, 2] k1, k2, p1, p2, k3 = dist_coeffs - # Get principal point and focal length - cx, cy = K[0, 2], K[1, 2] - fx, fy = K[0, 0], K[1, 1] - - # Convert to normalized coordinates + # normalize x = (points_2d[:, 0] - cx) / fx y = (points_2d[:, 1] - cy) / fy + + # precompute r^2, r^4, r^6 r2 = x * x + y * y + r4 = r2 * r2 + r6 = r4 * r2 - # Radial distortion - x_distort = x * (1 + k1 * r2 + k2 * r2 * r2 + k3 * r2 * r2 * r2) - y_distort = y * (1 + k1 * r2 + k2 * r2 * r2 + k3 * r2 * r2 * r2) + # radial term + radial = 1 + k1 * r2 + k2 * r4 + k3 * r6 - # Tangential distortion - x_distort = x_distort + 2 * p1 * x * y + p2 * (r2 + 2 * x * x) - y_distort = y_distort + p1 * (r2 + 2 * y * y) + 2 * p2 * x * y + # tangential term + x_tan = 2 * p1 * x * y + p2 * (r2 + 2 * x * x) + y_tan = p1 * (r2 + 2 * y * y) + 2 * p2 * x * y - # Back to absolute coordinates - x_distort = x_distort * fx + cx - y_distort = y_distort * fy + cy + # apply both + x_dist = x * radial + x_tan + y_dist = y * radial + y_tan - # Combine distorted coordinates - return jnp.stack([x_distort, y_distort], axis=1) + # back to pixels + u = x_dist * fx + cx + v = y_dist * fy + cy + + return jnp.stack([u, v], axis=1) @jaxtyped(typechecker=beartype)