1
0
forked from HQU-gxy/CVTH3PE

feat: Enhance play notebook and camera module with new projection and distortion functionalities

- Updated play notebook to include new tracking and clustering functionalities.
- Introduced `distortion` and `project` functions for applying distortion to 2D points and projecting 3D points to 2D, respectively.
- Enhanced `CameraParams` and `Camera` classes with methods for distortion and projection, improving usability.
- Cleaned up execution counts in the notebook for better organization.
This commit is contained in:
2025-04-21 18:38:08 +08:00
parent 40f3150417
commit 032eb684ec
2 changed files with 504 additions and 36 deletions

View File

@ -1,7 +1,7 @@
from collections import OrderedDict, defaultdict
from dataclasses import dataclass
from datetime import datetime
from typing import Any, TypeAlias, TypedDict
from typing import Any, TypeAlias, TypedDict, Optional
from beartype import beartype
from jax import Array
@ -12,6 +12,95 @@ from typing_extensions import NotRequired
CameraID: TypeAlias = str # pylint: disable=invalid-name
@jaxtyped(typechecker=beartype)
def distortion(
points_2d: Num[Array, "N 2"],
K: Num[Array, "3 3"],
dist_coeffs: Num[Array, "5"],
) -> Num[Array, "N 2"]:
"""
Apply distortion to 2D points
Args:
points_2d: 2D points in image coordinates
K: Camera intrinsic matrix
dist_coeffs: Distortion coefficients [k1, k2, p1, p2, k3]
Returns:
Distorted 2D points
"""
k1, k2, p1, p2, k3 = dist_coeffs
# Get principal point and focal length
cx, cy = K[0, 2], K[1, 2]
fx, fy = K[0, 0], K[1, 1]
# Convert to normalized coordinates
x = (points_2d[:, 0] - cx) / fx
y = (points_2d[:, 1] - cy) / fy
r2 = x * x + y * y
# Radial distortion
xdistort = x * (1 + k1 * r2 + k2 * r2 * r2 + k3 * r2 * r2 * r2)
ydistort = y * (1 + k1 * r2 + k2 * r2 * r2 + k3 * r2 * r2 * r2)
# Tangential distortion
xdistort = xdistort + 2 * p1 * x * y + p2 * (r2 + 2 * x * x)
ydistort = ydistort + p1 * (r2 + 2 * y * y) + 2 * p2 * x * y
# Back to absolute coordinates
xdistort = xdistort * fx + cx
ydistort = ydistort * fy + cy
# Combine distorted coordinates
return jnp.stack([xdistort, ydistort], axis=1)
@jaxtyped(typechecker=beartype)
def project(
points_3d: Num[Array, "N 3"],
projection_matrix: Num[Array, "3 4"],
K: Num[Array, "3 3"],
dist_coeffs: Num[Array, "5"],
) -> Num[Array, "N 2"]:
"""
Project 3D points to 2D points
Args:
points_3d: 3D points in world coordinates
projection_matrix: pre-computed projection matrix (K @ Rt[:3, :])
K: Camera intrinsic matrix
dist_coeffs: Distortion coefficients
Returns:
2D points in image coordinates
"""
P = projection_matrix
p3d_homogeneous = jnp.hstack(
(points_3d, jnp.ones((points_3d.shape[0], 1), dtype=points_3d.dtype))
)
# Project points
p2d_homogeneous = p3d_homogeneous @ P.T
# Perspective division
p2d = p2d_homogeneous[:, 0:2] / p2d_homogeneous[:, 2:3]
# Apply distortion if needed
if dist_coeffs is not None:
# Check if valid points (between 0 and 1)
valid = jnp.all(p2d > 0, axis=1) & jnp.all(p2d < 1, axis=1)
# Only apply distortion if there are valid points
if jnp.any(valid):
# Only distort valid points
valid_p2d = p2d[valid]
distorted_valid = distortion(valid_p2d, K, dist_coeffs)
p2d = p2d.at[valid].set(distorted_valid)
return jnp.squeeze(p2d)
@jaxtyped(typechecker=beartype)
@dataclass(frozen=True)
class CameraParams:
@ -29,8 +118,18 @@ class CameraParams:
R and t are the rotation and translation that describe the change of
coordinates from world to camera coordinate systems (or camera frame)
Rt is expected to be World-to-Camera (W2C) transformation matrix,
which is the result of `solvePnP` in OpenCV. (but converted to homogeneous coordinates)
World-to-Camera (W2C): Transforms points from world coordinates to camera coordinates
- The world origin is transformed to camera space
- Used for projecting 3D world points onto the camera's image plane
- Required for rendering/projection
"""
dist_coeffs: Num[Array, "N"]
dist_coeffs: Num[Array, "5"]
"""
An array of distortion coefficients of the form
[k1, k2, [p1, p2, [k3]]], where ki is the ith
@ -42,6 +141,25 @@ class CameraParams:
The size of image plane (width, height)
"""
@property
def pose_matrix(self) -> Num[Array, "4 4"]:
"""
The inversion of the extrinsic matrix, which gives Camera-to-World (C2W) transformation matrix.
Camera-to-World (C2W): Transforms points from camera coordinates to world coordinates
- The camera is the origin in camera space
- This transformation tells where the camera is positioned in world space
- Often used for camera positioning/orientation
The result is cached on first access. (lazy evaluation)
"""
t = getattr(self, "_pose", None)
if t is None:
t = jnp.linalg.inv(self.Rt)
object.__setattr__(self, "_pose", t)
return t
@property
def projection_matrix(self) -> Num[Array, "3 4"]:
"""
@ -73,6 +191,39 @@ class Camera:
Camera parameters
"""
def project(self, points_3d: Num[Array, "N 3"]) -> Num[Array, "N 2"]:
"""
Project 3D points to 2D points using this camera's parameters
Args:
points_3d: 3D points in world coordinates
Returns:
2D points in image coordinates
"""
return project(
points_3d=points_3d,
K=self.params.K,
dist_coeffs=self.params.dist_coeffs,
projection_matrix=self.params.projection_matrix,
)
def distortion(self, points_2d: Num[Array, "N 2"]) -> Num[Array, "N 2"]:
"""
Apply distortion to 2D points using this camera's parameters
Args:
points_2d: 2D points in image coordinates
Returns:
Distorted 2D points
"""
return distortion(
points_2d=points_2d,
K=self.params.K,
dist_coeffs=self.params.dist_coeffs,
)
@jaxtyped(typechecker=beartype)
@dataclass(frozen=True)