1
0
forked from HQU-gxy/CVTH3PE

feat: Enhance play notebook and camera module with new unprojection functionalities

- Updated the play notebook to include new methods for unprojecting 2D points onto a 3D plane.
- Introduced `unproject_points_onto_plane` and `unproject_points_to_z_plane` functions in the camera module for improved point handling.
- Enhanced the `Camera` class with a method for unprojecting points to a specified z-plane.
- Cleaned up execution counts in the notebook for better organization and clarity.
This commit is contained in:
2025-04-24 18:55:24 +08:00
parent 00481a0d6f
commit c3c93f6ca6
3 changed files with 160 additions and 37 deletions

View File

@ -6,7 +6,7 @@ from typing import Any, TypeAlias, TypedDict, Optional
from beartype import beartype
import jax
from jax import numpy as jnp
from jaxtyping import Num, jaxtyped, Array
from jaxtyping import Num, jaxtyped, Array, Float
from cv2 import undistortPoints
import numpy as np
@ -87,6 +87,81 @@ def distortion(
return jnp.stack([u, v], axis=1)
@jaxtyped(typechecker=beartype)
def unproject_points_onto_plane(
points_2d: Float[Array, "N 2"],
plane_normal: Float[Array, "3"],
plane_point: Float[Array, "3"],
K: Float[Array, "3 3"], # pylint: disable=invalid-name
dist_coeffs: Float[Array, "5"],
pose_matrix: Float[Array, "4 4"],
) -> Float[Array, "N 3"]:
"""
Un-project 2D image points onto an arbitrary 3D plane.
This function computes the ray-plane intersections, since every `points_2d`
could be treated as a ray.
(i.e. back-project points onto a plane)
Args:
points_2d: [..., 2] image pixel coordinates
plane_normal: (3,) normal vector of the plane in world coords
plane_point: (3,) a known point on the plane in world coords
K: Camera intrinsic matrix
dist_coeffs: Distortion coefficients
pose_matrix: Camera-to-World (C2W) transformation matrix
Note:
`pose_matrix` is NOT the same as camera extrinsic (World-to-Camera, W2C),
but the inverse of it.
Returns:
[..., 3] world-space intersection points
"""
# Step 1: undistort (no-op here)
pts = undistort_points(
np.asarray(points_2d), np.asarray(K), np.asarray(dist_coeffs)
)
# Step 2: normalize image coordinates into camera rays
fx, fy = K[0, 0], K[1, 1]
cx, cy = K[0, 2], K[1, 2]
dirs_cam = jnp.stack(
[(pts[:, 0] - cx) / fx, (pts[:, 1] - cy) / fy, jnp.ones_like(pts[:, 0])],
axis=-1,
) # (N, 3)
# Step 3: transform rays into world space
c2w = pose_matrix
ray_orig = c2w[:3, 3] # (3,)
R_world = c2w[:3, :3] # (3,3)
ray_dirs = (R_world @ dirs_cam.T).T # (N, 3)
# Step 4: plane intersection
n = plane_normal / jnp.linalg.norm(plane_normal)
p0 = plane_point
denom = jnp.dot(ray_dirs, n) # (N,)
numer = jnp.dot((p0 - ray_orig), n) # scalar
t = numer / denom # (N,)
points_world = ray_orig + ray_dirs * t[:, None]
return points_world
@jaxtyped(typechecker=beartype)
def unproject_points_to_z_plane(
points_2d: Float[Array, "N 2"],
K: Float[Array, "3 3"],
dist_coeffs: Float[Array, "5"],
pose_matrix: Float[Array, "4 4"],
z: float = 0.0,
) -> Float[Array, "N 3"]:
plane_normal = jnp.array([0.0, 0.0, 1.0])
plane_point = jnp.array([0.0, 0.0, z])
return unproject_points_onto_plane(
points_2d, plane_normal, plane_point, K, dist_coeffs, pose_matrix
)
@jaxtyped(typechecker=beartype)
def project(
points_3d: Num[Array, "N 3"],
@ -242,6 +317,9 @@ class Camera:
Camera parameters
"""
def __repr__(self) -> str:
return f"<Camera id={self.id}>"
def project(self, points_3d: Num[Array, "N 3"]) -> Num[Array, "N 2"]:
"""
Project 3D points to 2D points using this camera's parameters
@ -292,6 +370,20 @@ class Camera:
dist_coeffs=self.params.dist_coeffs,
)
def unproject_points_to_z_plane(
self, points_2d: Num[Array, "N 2"], z: float = 0.0
) -> Num[Array, "N 3"]:
"""
Unproject 2D points to 3D points on a plane at z = constant.
"""
return unproject_points_to_z_plane(
points_2d,
self.params.K,
self.params.dist_coeffs,
self.params.pose_matrix,
z,
)
@jaxtyped(typechecker=beartype)
@dataclass(frozen=True)
@ -337,6 +429,9 @@ class Detection:
object.__setattr__(self, "_kp_undistorted", kpu)
return kpu
def __repr__(self) -> str:
return f"Detection({self.camera}, {self.timestamp})"
def classify_by_camera(
detections: list[Detection],