forked from HQU-gxy/CVTH3PE
- Enhanced docstrings for `unproject_points_to_z_plane` and the corresponding method in the `Camera` class to provide detailed descriptions of arguments and return values. - Clarified the purpose and usage of the unprojection functionalities, improving overall code readability and usability.
758 lines
24 KiB
Python
758 lines
24 KiB
Python
from collections import OrderedDict, defaultdict
|
||
from dataclasses import dataclass
|
||
from datetime import datetime
|
||
from typing import Any, TypeAlias, TypedDict, Optional
|
||
|
||
from beartype import beartype
|
||
import jax
|
||
from jax import numpy as jnp
|
||
from jaxtyping import Num, jaxtyped, Array, Float
|
||
from cv2 import undistortPoints
|
||
import numpy as np
|
||
|
||
NDArray: TypeAlias = np.ndarray
|
||
CameraID: TypeAlias = str # pylint: disable=invalid-name
|
||
|
||
|
||
@jaxtyped(typechecker=beartype)
|
||
def undistort_points(
|
||
points: Num[NDArray, "M 2"],
|
||
camera_matrix: Num[NDArray, "3 3"],
|
||
dist_coeffs: Num[NDArray, "N"],
|
||
) -> Num[NDArray, "M 2"]:
|
||
"""
|
||
a thin wrapper of cv2.undistortPoints
|
||
"""
|
||
K = camera_matrix
|
||
dist = dist_coeffs
|
||
res = undistortPoints(points, K, dist, P=K) # type: ignore
|
||
return res.reshape(-1, 2)
|
||
|
||
|
||
@jax.jit
|
||
@jaxtyped(typechecker=beartype)
|
||
def distortion(
|
||
points_2d: Num[Array, "N 2"],
|
||
K: Num[Array, "3 3"], # pylint: disable=invalid-name
|
||
dist_coeffs: Num[Array, "5"],
|
||
) -> Num[Array, "N 2"]:
|
||
"""
|
||
Apply distortion to 2D points in pixel coordinates
|
||
|
||
Args:
|
||
points_2d: 2D points in pixel coordinates
|
||
K: Camera intrinsic matrix
|
||
dist_coeffs: Distortion coefficients [k1, k2, p1, p2, k3]
|
||
|
||
Returns:
|
||
Distorted 2D points in pixel coordinates
|
||
|
||
Note:
|
||
The function handles the conversion between pixel coordinates and normalized coordinates
|
||
internally. It expects points_2d to be in pixel coordinates, and returns distorted
|
||
points in pixel coordinates.
|
||
|
||
Implementation based on OpenCV's distortion model:
|
||
https://docs.opencv.org/4.10.0/d9/d0c/group__calib3d.html#ga69f2545a8b62a6b0fc2ee060dc30559d
|
||
"""
|
||
# unpack
|
||
fx, fy = K[0, 0], K[1, 1]
|
||
cx, cy = K[0, 2], K[1, 2]
|
||
k1, k2, p1, p2, k3 = dist_coeffs
|
||
|
||
# normalize
|
||
x = (points_2d[:, 0] - cx) / fx
|
||
y = (points_2d[:, 1] - cy) / fy
|
||
|
||
# precompute r^2, r^4, r^6
|
||
r2 = x * x + y * y
|
||
r4 = r2 * r2
|
||
r6 = r4 * r2
|
||
|
||
# radial term
|
||
radial = 1 + k1 * r2 + k2 * r4 + k3 * r6
|
||
|
||
# tangential term
|
||
x_tan = 2 * p1 * x * y + p2 * (r2 + 2 * x * x)
|
||
y_tan = p1 * (r2 + 2 * y * y) + 2 * p2 * x * y
|
||
|
||
# apply both
|
||
x_dist = x * radial + x_tan
|
||
y_dist = y * radial + y_tan
|
||
|
||
# back to pixels
|
||
u = x_dist * fx + cx
|
||
v = y_dist * fy + cy
|
||
|
||
return jnp.stack([u, v], axis=1)
|
||
|
||
|
||
@jaxtyped(typechecker=beartype)
|
||
def unproject_points_onto_plane(
|
||
points_2d: Float[Array, "N 2"],
|
||
plane_normal: Float[Array, "3"],
|
||
plane_point: Float[Array, "3"],
|
||
K: Float[Array, "3 3"], # pylint: disable=invalid-name
|
||
dist_coeffs: Float[Array, "5"],
|
||
pose_matrix: Float[Array, "4 4"],
|
||
) -> Float[Array, "N 3"]:
|
||
"""
|
||
Un-project 2D image points onto an arbitrary 3D plane.
|
||
This function computes the ray-plane intersections, since every `points_2d`
|
||
could be treated as a ray.
|
||
|
||
(i.e. back-project points onto a plane)
|
||
|
||
Args:
|
||
points_2d: [..., 2] image pixel coordinates
|
||
plane_normal: (3,) normal vector of the plane in world coords
|
||
plane_point: (3,) a known point on the plane in world coords
|
||
K: Camera intrinsic matrix
|
||
dist_coeffs: Distortion coefficients
|
||
pose_matrix: Camera-to-World (C2W) transformation matrix
|
||
|
||
Note:
|
||
`pose_matrix` is NOT the same as camera extrinsic (World-to-Camera, W2C),
|
||
but the inverse of it.
|
||
|
||
Returns:
|
||
[..., 3] world-space intersection points
|
||
"""
|
||
# Step 1: undistort (no-op here)
|
||
pts = undistort_points(
|
||
np.asarray(points_2d), np.asarray(K), np.asarray(dist_coeffs)
|
||
)
|
||
|
||
# Step 2: normalize image coordinates into camera rays
|
||
fx, fy = K[0, 0], K[1, 1]
|
||
cx, cy = K[0, 2], K[1, 2]
|
||
dirs_cam = jnp.stack(
|
||
[(pts[:, 0] - cx) / fx, (pts[:, 1] - cy) / fy, jnp.ones_like(pts[:, 0])],
|
||
axis=-1,
|
||
) # (N, 3)
|
||
|
||
# Step 3: transform rays into world space
|
||
c2w = pose_matrix
|
||
ray_orig = c2w[:3, 3] # (3,)
|
||
R_world = c2w[:3, :3] # (3,3)
|
||
ray_dirs = (R_world @ dirs_cam.T).T # (N, 3)
|
||
|
||
# Step 4: plane intersection
|
||
n = plane_normal / jnp.linalg.norm(plane_normal)
|
||
p0 = plane_point
|
||
denom = jnp.dot(ray_dirs, n) # (N,)
|
||
numer = jnp.dot((p0 - ray_orig), n) # scalar
|
||
t = numer / denom # (N,)
|
||
points_world = ray_orig + ray_dirs * t[:, None]
|
||
return points_world
|
||
|
||
|
||
@jaxtyped(typechecker=beartype)
|
||
def unproject_points_to_z_plane(
|
||
points_2d: Float[Array, "N 2"],
|
||
K: Float[Array, "3 3"],
|
||
dist_coeffs: Float[Array, "5"],
|
||
pose_matrix: Float[Array, "4 4"],
|
||
z: float = 0.0,
|
||
) -> Float[Array, "N 3"]:
|
||
"""
|
||
Un-project 2D points to 3D points on a plane at z = constant.
|
||
|
||
Args:
|
||
points_2d: 2D points in pixel coordinates
|
||
K: Camera intrinsic matrix
|
||
dist_coeffs: Distortion coefficients
|
||
pose_matrix: Camera-to-World (C2W) transformation matrix
|
||
z: z-coordinate of the plane (default: 0.0, i.e. ground/horizon/floor plane)
|
||
|
||
Returns:
|
||
[..., 3] world-space intersection points
|
||
"""
|
||
plane_normal = jnp.array([0.0, 0.0, 1.0])
|
||
plane_point = jnp.array([0.0, 0.0, z])
|
||
return unproject_points_onto_plane(
|
||
points_2d, plane_normal, plane_point, K, dist_coeffs, pose_matrix
|
||
)
|
||
|
||
|
||
@jaxtyped(typechecker=beartype)
|
||
def project(
|
||
points_3d: Num[Array, "N 3"],
|
||
projection_matrix: Num[Array, "3 4"],
|
||
K: Optional[Num[Array, "3 3"]] = None, # pylint: disable=invalid-name
|
||
dist_coeffs: Optional[Num[Array, "5"]] = None,
|
||
image_size: Optional[Num[Array, "2"]] = None,
|
||
) -> Num[Array, "N 2"]:
|
||
"""
|
||
Project 3D points to 2D points in pixel coordinates
|
||
|
||
Args:
|
||
points_3d: 3D points in world coordinates
|
||
projection_matrix: pre-computed projection matrix (K @ Rt[:3, :]) that projects to pixel coordinates
|
||
K: (optional) Camera intrinsic matrix, unnormalized
|
||
dist_coeffs: (optional) Distortion coefficients
|
||
image_size: (optional) Image dimensions [width, height] for valid point check.
|
||
If not provided, uses (0,1) normalized coordinates.
|
||
|
||
Note:
|
||
K and dist_coeffs must be provided together, or both be None.
|
||
If K is provided, it assumes that the projection matrix is calculated from the same K.
|
||
|
||
Returns:
|
||
2D points in pixel coordinates
|
||
"""
|
||
P = projection_matrix # pylint: disable=invalid-name
|
||
p3d_homogeneous = jnp.hstack(
|
||
(points_3d, jnp.ones((points_3d.shape[0], 1), dtype=points_3d.dtype))
|
||
)
|
||
|
||
# Project points
|
||
p2d_homogeneous = p3d_homogeneous @ P.T
|
||
|
||
# Perspective division
|
||
p2d = p2d_homogeneous[:, 0:2] / p2d_homogeneous[:, 2:3]
|
||
|
||
if dist_coeffs is not None and K is not None:
|
||
# Check if valid points (within image boundaries)
|
||
if image_size is not None:
|
||
# Use image dimensions for valid point check in pixel space
|
||
valid = (
|
||
jnp.all(p2d >= 0, axis=1)
|
||
& (p2d[:, 0] < image_size[0])
|
||
& (p2d[:, 1] < image_size[1])
|
||
)
|
||
else:
|
||
# Fall back to normalized coordinates if image_size not provided
|
||
valid = jnp.all(p2d >= 0, axis=1) & jnp.all(p2d < 1, axis=1)
|
||
|
||
# only valid points need distortion
|
||
if jnp.any(valid):
|
||
valid_p2d = p2d[valid]
|
||
distorted_valid = distortion(valid_p2d, K, dist_coeffs)
|
||
p2d = p2d.at[valid].set(distorted_valid)
|
||
elif dist_coeffs is None and K is None:
|
||
pass
|
||
else:
|
||
raise ValueError(
|
||
"dist_coeffs and K must be provided together to compute distortion"
|
||
)
|
||
|
||
return jnp.squeeze(p2d)
|
||
|
||
|
||
@jaxtyped(typechecker=beartype)
|
||
@dataclass(frozen=True)
|
||
class CameraParams:
|
||
"""
|
||
Camera parameters: intrinsic matrix, extrinsic matrix, and distortion coefficients
|
||
"""
|
||
|
||
K: Num[Array, "3 3"]
|
||
"""
|
||
intrinsic matrix
|
||
"""
|
||
Rt: Num[Array, "4 4"]
|
||
"""
|
||
[R|t] extrinsic matrix
|
||
|
||
R and t are the rotation and translation that describe the change of
|
||
coordinates from world to camera coordinate systems (or camera frame)
|
||
|
||
Rt is expected to be World-to-Camera (W2C) transformation matrix,
|
||
which is the result of `solvePnP` in OpenCV. (but converted to homogeneous coordinates)
|
||
|
||
|
||
World-to-Camera (W2C): Transforms points from world coordinates to camera coordinates
|
||
|
||
- The world origin is transformed to camera space
|
||
- Used for projecting 3D world points onto the camera's image plane
|
||
- Required for rendering/projection
|
||
"""
|
||
dist_coeffs: Num[Array, "5"]
|
||
"""
|
||
An array of distortion coefficients of the form
|
||
[k1, k2, [p1, p2, [k3]]], where ki is the ith
|
||
radial distortion coefficient and pi is the ith
|
||
tangential distortion coeff.
|
||
"""
|
||
image_size: Num[Array, "2"]
|
||
"""
|
||
The size of image plane (width, height)
|
||
"""
|
||
|
||
@property
|
||
def pose_matrix(self) -> Num[Array, "4 4"]:
|
||
"""
|
||
The inversion of the extrinsic matrix, which gives Camera-to-World (C2W) transformation matrix.
|
||
|
||
Camera-to-World (C2W): Transforms points from camera coordinates to world coordinates
|
||
|
||
- The camera is the origin in camera space
|
||
- This transformation tells where the camera is positioned in world space
|
||
- Often used for camera positioning/orientation
|
||
|
||
The result is cached on first access. (lazy evaluation)
|
||
"""
|
||
t = getattr(self, "_pose", None)
|
||
if t is None:
|
||
t = jnp.linalg.inv(self.Rt)
|
||
object.__setattr__(self, "_pose", t)
|
||
return t
|
||
|
||
@property
|
||
def projection_matrix(self) -> Num[Array, "3 4"]:
|
||
"""
|
||
Returns the 3x4 projection matrix K @ [R|t].
|
||
|
||
The result is cached on first access. (lazy evaluation)
|
||
"""
|
||
pm = getattr(self, "_proj", None)
|
||
if pm is None:
|
||
pm = self.K @ self.Rt[:3, :]
|
||
# object.__setattr__ bypasses the frozen‐dataclass blocker
|
||
object.__setattr__(self, "_proj", pm)
|
||
return pm
|
||
|
||
|
||
@jaxtyped(typechecker=beartype)
|
||
@dataclass(frozen=True)
|
||
class Camera:
|
||
"""
|
||
a description of a camera
|
||
"""
|
||
|
||
id: CameraID
|
||
"""
|
||
Camera ID
|
||
"""
|
||
params: CameraParams
|
||
"""
|
||
Camera parameters
|
||
"""
|
||
|
||
def __repr__(self) -> str:
|
||
return f"<Camera id={self.id}>"
|
||
|
||
def project(self, points_3d: Num[Array, "N 3"]) -> Num[Array, "N 2"]:
|
||
"""
|
||
Project 3D points to 2D points using this camera's parameters
|
||
|
||
Args:
|
||
points_3d: 3D points in world coordinates
|
||
|
||
Returns:
|
||
2D points in pixel coordinates
|
||
"""
|
||
return project(
|
||
points_3d,
|
||
projection_matrix=self.params.projection_matrix,
|
||
K=self.params.K,
|
||
dist_coeffs=self.params.dist_coeffs,
|
||
image_size=self.params.image_size,
|
||
)
|
||
|
||
def project_ideal(self, points_3d: Num[Array, "N 3"]) -> Num[Array, "N 2"]:
|
||
"""
|
||
Project 3D points to 2D points using this camera's parameters, without distortion
|
||
|
||
Args:
|
||
points_3d: 3D points in world coordinates
|
||
|
||
Returns:
|
||
2D points in pixel coordinates
|
||
"""
|
||
return project(
|
||
points_3d,
|
||
projection_matrix=self.params.projection_matrix,
|
||
image_size=self.params.image_size,
|
||
)
|
||
|
||
def distortion(self, points_2d: Num[Array, "N 2"]) -> Num[Array, "N 2"]:
|
||
"""
|
||
Apply distortion to 2D points using this camera's parameters
|
||
|
||
Args:
|
||
points_2d: 2D points in image coordinates
|
||
|
||
Returns:
|
||
Distorted 2D points
|
||
"""
|
||
return distortion(
|
||
points_2d=points_2d,
|
||
K=self.params.K,
|
||
dist_coeffs=self.params.dist_coeffs,
|
||
)
|
||
|
||
def unproject_points_to_z_plane(
|
||
self, points_2d: Num[Array, "N 2"], z: float = 0.0
|
||
) -> Num[Array, "N 3"]:
|
||
"""
|
||
Un-project 2D points to 3D points on a plane at z = constant.
|
||
|
||
Args:
|
||
points_2d: 2D points in pixel coordinates
|
||
z: z-coordinate of the plane (default: 0.0, i.e. ground/horizon/floor plane)
|
||
|
||
Returns:
|
||
[..., 3] world-space intersection points
|
||
"""
|
||
return unproject_points_to_z_plane(
|
||
points_2d,
|
||
self.params.K,
|
||
self.params.dist_coeffs,
|
||
self.params.pose_matrix,
|
||
z,
|
||
)
|
||
|
||
|
||
@jaxtyped(typechecker=beartype)
|
||
@dataclass(frozen=True)
|
||
class Detection:
|
||
"""
|
||
One detection from a camera
|
||
"""
|
||
|
||
keypoints: Num[Array, "N 2"]
|
||
"""
|
||
Keypoints in pixel coordinates. (with camera distortion)
|
||
|
||
Use `keypoints_undistorted` to get undistorted keypoints.
|
||
"""
|
||
confidences: Num[Array, "N"]
|
||
"""
|
||
Confidences
|
||
"""
|
||
camera: Camera
|
||
"""
|
||
Camera
|
||
"""
|
||
timestamp: datetime
|
||
"""
|
||
Timestamp of the detection
|
||
"""
|
||
|
||
@property
|
||
def keypoints_undistorted(self) -> Num[Array, "N 2"]:
|
||
"""
|
||
Returns undistorted keypoints.
|
||
|
||
The result is cached on first access. (lazy evaluation)
|
||
"""
|
||
kpu = getattr(self, "_kp_undistorted", None)
|
||
if kpu is None:
|
||
kpu_np = undistort_points(
|
||
np.asarray(self.keypoints),
|
||
np.asarray(self.camera.params.K),
|
||
np.asarray(self.camera.params.dist_coeffs),
|
||
)
|
||
kpu = jnp.asarray(kpu_np)
|
||
object.__setattr__(self, "_kp_undistorted", kpu)
|
||
return kpu
|
||
|
||
def __repr__(self) -> str:
|
||
return f"Detection({self.camera}, {self.timestamp})"
|
||
|
||
|
||
def classify_by_camera(
|
||
detections: list[Detection],
|
||
) -> OrderedDict[CameraID, list[Detection]]:
|
||
"""
|
||
Classify detections by camera
|
||
"""
|
||
# or use setdefault
|
||
camera_wise_split: dict[CameraID, list[Detection]] = defaultdict(list)
|
||
for entry in detections:
|
||
camera_id = entry.camera.id
|
||
camera_wise_split[camera_id].append(entry)
|
||
return OrderedDict(camera_wise_split)
|
||
|
||
|
||
@jaxtyped(typechecker=beartype)
|
||
def to_homogeneous(points: Num[Array, "N 2"] | Num[Array, "N 3"]) -> Num[Array, "N 3"]:
|
||
"""
|
||
Convert points to homogeneous coordinates
|
||
"""
|
||
if points.shape[-1] == 2:
|
||
return jnp.hstack((points, jnp.ones((points.shape[0], 1))))
|
||
elif points.shape[-1] == 3:
|
||
return points
|
||
else:
|
||
raise ValueError(f"Invalid shape for points: {points.shape}")
|
||
|
||
|
||
@jaxtyped(typechecker=beartype)
|
||
def point_line_distance(
|
||
points: Num[Array, "N 3"] | Num[Array, "N 2"],
|
||
line: Num[Array, "N 3"],
|
||
eps: float = 1e-9,
|
||
):
|
||
"""
|
||
Calculate the distance from a point to a line
|
||
|
||
Args:
|
||
point: (possibly homogeneous) points :math:`(N, 2 or 3)`.
|
||
line: lines coefficients :math:`(a, b, c)` with shape :math:`(N, 3)`, where :math:`ax + by + c = 0`.
|
||
eps: Small constant for safe sqrt.
|
||
|
||
Returns:
|
||
the computed distance with shape :math:`(N)`.
|
||
|
||
See also:
|
||
https://en.wikipedia.org/wiki/Distance_from_a_point_to_a_line
|
||
"""
|
||
numerator = abs(line[:, 0] * points[:, 0] + line[:, 1] * points[:, 1] + line[:, 2])
|
||
denominator = jnp.sqrt(line[:, 0] * line[:, 0] + line[:, 1] * line[:, 1])
|
||
return numerator / (denominator + eps)
|
||
|
||
|
||
@jaxtyped(typechecker=beartype)
|
||
def left_to_right_epipolar_distance(
|
||
left: Num[Array, "N 3"],
|
||
right: Num[Array, "N 3"],
|
||
fundamental_matrix: Num[Array, "3 3"],
|
||
):
|
||
"""
|
||
Return one-sided epipolar distance for correspondences given the fundamental matrix.
|
||
|
||
Args:
|
||
left: points in the left image (homogeneous) :math:`(N, 3)`
|
||
right: points in the right image (homogeneous) :math:`(N, 3)`
|
||
fundamental_matrix: fundamental matrix :math:`(3, 3)`
|
||
|
||
Returns:
|
||
the computed distance with shape :math:`(N)`.
|
||
|
||
See also:
|
||
https://en.wikipedia.org/wiki/Fundamental_matrix_%28computer_vision%29
|
||
|
||
$$x^{\\prime T}Fx = 0$$
|
||
"""
|
||
F_t = fundamental_matrix.transpose()
|
||
line1_in_2 = jnp.matmul(left, F_t)
|
||
return point_line_distance(right, line1_in_2)
|
||
|
||
|
||
@jaxtyped(typechecker=beartype)
|
||
def right_to_left_epipolar_distance(
|
||
left: Num[Array, "N 3"],
|
||
right: Num[Array, "N 3"],
|
||
fundamental_matrix: Num[Array, "3 3"],
|
||
):
|
||
"""
|
||
Return one-sided epipolar distance for correspondences given the fundamental matrix.
|
||
|
||
Args:
|
||
left: points in the left image (homogeneous) :math:`(N, 3)`
|
||
right: points in the right image (homogeneous) :math:`(N, 3)`
|
||
fundamental_matrix: fundamental matrix :math:`(3, 3)`
|
||
|
||
Returns:
|
||
the computed distance with shape :math:`(N)`.
|
||
|
||
See also:
|
||
https://en.wikipedia.org/wiki/Fundamental_matrix_%28computer_vision%29
|
||
|
||
$$x^{\\prime T}Fx = 0$$
|
||
"""
|
||
line2_in_1 = jnp.matmul(right, fundamental_matrix)
|
||
return point_line_distance(left, line2_in_1)
|
||
|
||
|
||
def distance_between_epipolar_lines(
|
||
x1: Num[Array, "N 2"] | Num[Array, "N 3"],
|
||
x2: Num[Array, "N 2"] | Num[Array, "N 3"],
|
||
fundamental_matrix: Num[Array, "3 3"],
|
||
):
|
||
"""
|
||
Calculate the total epipolar line distance between x1 and x2.
|
||
"""
|
||
if x1.shape[0] != x2.shape[0]:
|
||
raise ValueError(
|
||
f"x1 and x2 must have the same number of points: {x1.shape[0]} != {x2.shape[0]}"
|
||
)
|
||
|
||
if x1.shape[-1] == 2:
|
||
points1 = to_homogeneous(x1)
|
||
elif x1.shape[-1] == 3:
|
||
points1 = x1
|
||
else:
|
||
raise ValueError(f"Invalid shape for correspondence1: {x1.shape}")
|
||
|
||
if x2.shape[-1] == 2:
|
||
points2 = to_homogeneous(x2)
|
||
elif x2.shape[-1] == 3:
|
||
points2 = x2
|
||
else:
|
||
raise ValueError(f"Invalid shape for correspondence2: {x2.shape}")
|
||
|
||
if fundamental_matrix.shape != (3, 3):
|
||
raise ValueError(
|
||
f"Invalid shape for fundamental_matrix: {fundamental_matrix.shape}"
|
||
)
|
||
|
||
# points 1 and 2 are unnormalized points
|
||
dist_1 = jnp.mean(
|
||
right_to_left_epipolar_distance(points1, points2, fundamental_matrix)
|
||
)
|
||
dist_2 = jnp.mean(
|
||
left_to_right_epipolar_distance(points1, points2, fundamental_matrix)
|
||
)
|
||
distance = dist_1 + dist_2
|
||
return distance
|
||
|
||
|
||
@jaxtyped(typechecker=beartype)
|
||
def calculate_fundamental_matrix(
|
||
camera_left: Camera, camera_right: Camera
|
||
) -> Num[Array, "3 3"]:
|
||
"""
|
||
Calculate the fundamental matrix for the given cameras.
|
||
"""
|
||
|
||
# Intrinsics
|
||
K1 = camera_left.params.K
|
||
K2 = camera_right.params.K
|
||
|
||
# Extrinsics (World to Camera transforms)
|
||
Rt1 = camera_left.params.Rt
|
||
Rt2 = camera_right.params.Rt
|
||
|
||
# Convert to Camera to World (Inverse)
|
||
T2: Array = jnp.linalg.inv(Rt2)
|
||
|
||
# Relative transform from Left to Right
|
||
T_rel = T2 @ Rt1
|
||
|
||
R = T_rel[:3, :3]
|
||
t = T_rel[:3, 3:]
|
||
|
||
# Skew-symmetric matrix for cross product
|
||
def skew(v: Num[Array, "3"]) -> Num[Array, "3 3"]:
|
||
return jnp.array(
|
||
[
|
||
[0, -v[2], v[1]],
|
||
[v[2], 0, -v[0]],
|
||
[-v[1], v[0], 0],
|
||
]
|
||
)
|
||
|
||
t_skew = skew(t.reshape(-1))
|
||
|
||
# Essential Matrix
|
||
E = t_skew @ R
|
||
|
||
# Fundamental Matrix
|
||
F = jnp.linalg.inv(K2).T @ E @ jnp.linalg.inv(K1)
|
||
|
||
return F
|
||
|
||
|
||
def compute_affinity_epipolar_constraint_with_pairs(
|
||
left: Detection, right: Detection, alpha_2d: float
|
||
):
|
||
"""
|
||
Compute the affinity between two groups of detections by epipolar constraint,
|
||
where camera parameters are included in the detections.
|
||
|
||
Note:
|
||
Originally, alpha_2d comes from the paper as a scaling factor for epipolar error affinity.
|
||
Its role is mainly to normalize error into [0,1] range, but it could lead to negative affinity.
|
||
An alternative approach is using normalized epipolar error relative to image size, with soft cutoff,
|
||
like exp(-error / threshold), for better interpretability and stability.
|
||
"""
|
||
fundamental_matrix = calculate_fundamental_matrix(left.camera, right.camera)
|
||
d = distance_between_epipolar_lines(
|
||
left.keypoints, right.keypoints, fundamental_matrix
|
||
)
|
||
return 1 - (d / alpha_2d)
|
||
|
||
|
||
def calculate_affinity_matrix_by_epipolar_constraint(
|
||
detections: list[Detection] | dict[CameraID, list[Detection]],
|
||
alpha_2d: float,
|
||
) -> tuple[list[Detection], Num[Array, "N N"]]:
|
||
"""
|
||
Calculate the affinity matrix by epipolar constraint
|
||
|
||
This function evaluates the geometric consistency of every pair of detections
|
||
across different cameras using the fundamental matrix. It assumes that
|
||
detections from the same camera are not comparable and should have zero affinity.
|
||
|
||
The affinity is computed by:
|
||
1. Calculating the fundamental matrix between the two cameras.
|
||
2. Measuring the average point-to-epipolar-line distance for all keypoints.
|
||
3. Mapping the distance to affinity with the formula: 1 - (distance / alpha_2d).
|
||
|
||
Args:
|
||
detections: Either a flat list of Detection or a dict grouping Detection by CameraID.
|
||
alpha_2d: Image resolution-dependent threshold controlling affinity scaling.
|
||
Typically relates to expected pixel displacement rate.
|
||
|
||
Returns:
|
||
sorted_detections: Flattened list of detections sorted by camera order.
|
||
affinity_matrix: Array of shape (N, N), where N is the number of detections.
|
||
|
||
Notes:
|
||
- Detections from the same camera always have affinity = 0.
|
||
- Affinity decays linearly with epipolar error until 0 (or potentially negative).
|
||
- Consider switching to exp(-error / scale) style for non-negative affinity.
|
||
- alpha_2d should be adjusted based on image resolution or empirical observation.
|
||
|
||
|
||
Affinity Matrix layout:
|
||
assuming we have 3 cameras
|
||
C0 has 3 detections: D0_C0, D1_C0, D2_C0
|
||
C1 has 2 detections: D0_C1, D1_C1
|
||
C2 has 2 detections: D0_C2, D1_C2
|
||
|
||
D0_C0(0), D1_C0(1), D2_C0(2), D0_C1(3), D1_C1(4), D0_C2(5), D1_C2(6)
|
||
D0_C0(0) 0 0 0 a_03 a_04 a_05 a_06
|
||
D1_C0(1) 0 0 0 a_13 a_14 a_15 a_16
|
||
...
|
||
D0_C1(3) a_30 a_31 a_32 0 0 a_35 a_36
|
||
...
|
||
D1_C2(6) a_60 a_61 a_62 a_63 a_64 0 0
|
||
"""
|
||
if isinstance(detections, dict):
|
||
camera_wise_split = detections
|
||
else:
|
||
camera_wise_split = classify_by_camera(detections)
|
||
num_entries = sum(len(entries) for entries in camera_wise_split.values())
|
||
affinity_matrix = jnp.ones((num_entries, num_entries), dtype=jnp.float32) * -jnp.inf
|
||
affinity_matrix_mask = jnp.zeros((num_entries, num_entries), dtype=jnp.bool_)
|
||
|
||
acc = 0
|
||
total_indices = set(range(num_entries))
|
||
camera_id_index_map: dict[CameraID, set[int]] = defaultdict(set)
|
||
camera_id_index_map_inverse: dict[CameraID, set[int]] = defaultdict(set)
|
||
# sorted by [D0_C0, D1_C0, D2_C0, D0_C1, D1_C1, D0_C2, D1_C2...]
|
||
sorted_detections: list[Detection] = []
|
||
for camera_id, entries in camera_wise_split.items():
|
||
for i, _ in enumerate(entries):
|
||
camera_id_index_map[camera_id].add(acc)
|
||
sorted_detections.append(entries[i])
|
||
acc += 1
|
||
camera_id_index_map_inverse[camera_id] = (
|
||
total_indices - camera_id_index_map[camera_id]
|
||
)
|
||
|
||
# ignore self-affinity
|
||
# ignore same-camera affinity
|
||
# assuming commutative
|
||
for i, det in enumerate(sorted_detections):
|
||
other_indices = camera_id_index_map_inverse[det.camera.id]
|
||
for j in other_indices:
|
||
if i == j:
|
||
continue
|
||
if affinity_matrix_mask[i, j] or affinity_matrix_mask[j, i]:
|
||
continue
|
||
a = compute_affinity_epipolar_constraint_with_pairs(
|
||
det, sorted_detections[j], alpha_2d
|
||
)
|
||
affinity_matrix = affinity_matrix.at[i, j].set(a)
|
||
affinity_matrix = affinity_matrix.at[j, i].set(a)
|
||
affinity_matrix_mask = affinity_matrix_mask.at[i, j].set(True)
|
||
affinity_matrix_mask = affinity_matrix_mask.at[j, i].set(True)
|
||
|
||
return sorted_detections, affinity_matrix
|