forked from HQU-gxy/CVTH3PE
feat: Enhance play notebook and camera module with new unprojection functionalities
- Updated the play notebook to include new methods for unprojecting 2D points onto a 3D plane. - Introduced `unproject_points_onto_plane` and `unproject_points_to_z_plane` functions in the camera module for improved point handling. - Enhanced the `Camera` class with a method for unprojecting points to a specified z-plane. - Cleaned up execution counts in the notebook for better organization and clarity.
This commit is contained in:
@ -6,7 +6,7 @@ from typing import Any, TypeAlias, TypedDict, Optional
|
||||
from beartype import beartype
|
||||
import jax
|
||||
from jax import numpy as jnp
|
||||
from jaxtyping import Num, jaxtyped, Array
|
||||
from jaxtyping import Num, jaxtyped, Array, Float
|
||||
from cv2 import undistortPoints
|
||||
import numpy as np
|
||||
|
||||
@ -87,6 +87,81 @@ def distortion(
|
||||
return jnp.stack([u, v], axis=1)
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def unproject_points_onto_plane(
|
||||
points_2d: Float[Array, "N 2"],
|
||||
plane_normal: Float[Array, "3"],
|
||||
plane_point: Float[Array, "3"],
|
||||
K: Float[Array, "3 3"], # pylint: disable=invalid-name
|
||||
dist_coeffs: Float[Array, "5"],
|
||||
pose_matrix: Float[Array, "4 4"],
|
||||
) -> Float[Array, "N 3"]:
|
||||
"""
|
||||
Un-project 2D image points onto an arbitrary 3D plane.
|
||||
This function computes the ray-plane intersections, since every `points_2d`
|
||||
could be treated as a ray.
|
||||
|
||||
(i.e. back-project points onto a plane)
|
||||
|
||||
Args:
|
||||
points_2d: [..., 2] image pixel coordinates
|
||||
plane_normal: (3,) normal vector of the plane in world coords
|
||||
plane_point: (3,) a known point on the plane in world coords
|
||||
K: Camera intrinsic matrix
|
||||
dist_coeffs: Distortion coefficients
|
||||
pose_matrix: Camera-to-World (C2W) transformation matrix
|
||||
|
||||
Note:
|
||||
`pose_matrix` is NOT the same as camera extrinsic (World-to-Camera, W2C),
|
||||
but the inverse of it.
|
||||
|
||||
Returns:
|
||||
[..., 3] world-space intersection points
|
||||
"""
|
||||
# Step 1: undistort (no-op here)
|
||||
pts = undistort_points(
|
||||
np.asarray(points_2d), np.asarray(K), np.asarray(dist_coeffs)
|
||||
)
|
||||
|
||||
# Step 2: normalize image coordinates into camera rays
|
||||
fx, fy = K[0, 0], K[1, 1]
|
||||
cx, cy = K[0, 2], K[1, 2]
|
||||
dirs_cam = jnp.stack(
|
||||
[(pts[:, 0] - cx) / fx, (pts[:, 1] - cy) / fy, jnp.ones_like(pts[:, 0])],
|
||||
axis=-1,
|
||||
) # (N, 3)
|
||||
|
||||
# Step 3: transform rays into world space
|
||||
c2w = pose_matrix
|
||||
ray_orig = c2w[:3, 3] # (3,)
|
||||
R_world = c2w[:3, :3] # (3,3)
|
||||
ray_dirs = (R_world @ dirs_cam.T).T # (N, 3)
|
||||
|
||||
# Step 4: plane intersection
|
||||
n = plane_normal / jnp.linalg.norm(plane_normal)
|
||||
p0 = plane_point
|
||||
denom = jnp.dot(ray_dirs, n) # (N,)
|
||||
numer = jnp.dot((p0 - ray_orig), n) # scalar
|
||||
t = numer / denom # (N,)
|
||||
points_world = ray_orig + ray_dirs * t[:, None]
|
||||
return points_world
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def unproject_points_to_z_plane(
|
||||
points_2d: Float[Array, "N 2"],
|
||||
K: Float[Array, "3 3"],
|
||||
dist_coeffs: Float[Array, "5"],
|
||||
pose_matrix: Float[Array, "4 4"],
|
||||
z: float = 0.0,
|
||||
) -> Float[Array, "N 3"]:
|
||||
plane_normal = jnp.array([0.0, 0.0, 1.0])
|
||||
plane_point = jnp.array([0.0, 0.0, z])
|
||||
return unproject_points_onto_plane(
|
||||
points_2d, plane_normal, plane_point, K, dist_coeffs, pose_matrix
|
||||
)
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def project(
|
||||
points_3d: Num[Array, "N 3"],
|
||||
@ -242,6 +317,9 @@ class Camera:
|
||||
Camera parameters
|
||||
"""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Camera id={self.id}>"
|
||||
|
||||
def project(self, points_3d: Num[Array, "N 3"]) -> Num[Array, "N 2"]:
|
||||
"""
|
||||
Project 3D points to 2D points using this camera's parameters
|
||||
@ -292,6 +370,20 @@ class Camera:
|
||||
dist_coeffs=self.params.dist_coeffs,
|
||||
)
|
||||
|
||||
def unproject_points_to_z_plane(
|
||||
self, points_2d: Num[Array, "N 2"], z: float = 0.0
|
||||
) -> Num[Array, "N 3"]:
|
||||
"""
|
||||
Unproject 2D points to 3D points on a plane at z = constant.
|
||||
"""
|
||||
return unproject_points_to_z_plane(
|
||||
points_2d,
|
||||
self.params.K,
|
||||
self.params.dist_coeffs,
|
||||
self.params.pose_matrix,
|
||||
z,
|
||||
)
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
@dataclass(frozen=True)
|
||||
@ -337,6 +429,9 @@ class Detection:
|
||||
object.__setattr__(self, "_kp_undistorted", kpu)
|
||||
return kpu
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Detection({self.camera}, {self.timestamp})"
|
||||
|
||||
|
||||
def classify_by_camera(
|
||||
detections: list[Detection],
|
||||
|
||||
Reference in New Issue
Block a user