forked from HQU-gxy/CVTH3PE
feat: Enhance camera module with new data structures and utility functions
- Introduced dataclass structures for CameraParams and Camera to improve type safety and clarity. - Added Detection dataclass to encapsulate detection data, including keypoints and timestamps. - Implemented classify_by_camera function to organize detections by camera. - Added utility functions for converting points to homogeneous coordinates and calculating distances to lines. - Updated dependencies in pyproject.toml to include new libraries for enhanced functionality.
This commit is contained in:
@ -4,12 +4,16 @@ from typing_extensions import NotRequired
|
||||
from jaxtyping import Num, jaxtyped
|
||||
from beartype import beartype
|
||||
from jax import numpy as jnp, Array
|
||||
from dataclasses import dataclass
|
||||
from collections import defaultdict, OrderedDict
|
||||
from datetime import datetime
|
||||
|
||||
CameraID: TypeAlias = str
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
class CameraParams(TypedDict):
|
||||
@dataclass
|
||||
class CameraParams:
|
||||
"""
|
||||
Camera parameters: intrinsic matrix, extrinsic matrix, and distortion coefficients
|
||||
"""
|
||||
@ -32,10 +36,15 @@ class CameraParams(TypedDict):
|
||||
radial distortion coefficient and pi is the ith
|
||||
tangential distortion coeff.
|
||||
"""
|
||||
image_size: Num[Array, "2"]
|
||||
"""
|
||||
The size of image plane (width, height)
|
||||
"""
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
class Camera(TypedDict):
|
||||
@dataclass
|
||||
class Camera:
|
||||
"""
|
||||
a description of a camera
|
||||
"""
|
||||
@ -48,7 +57,288 @@ class Camera(TypedDict):
|
||||
"""
|
||||
Camera parameters
|
||||
"""
|
||||
size: tuple[int, int]
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
@dataclass
|
||||
class Detection:
|
||||
"""
|
||||
Image size
|
||||
One detection from a camera
|
||||
"""
|
||||
|
||||
keypoints: Num[Array, "N 2"]
|
||||
"""
|
||||
Keypoints
|
||||
"""
|
||||
confidences: Num[Array, "N"]
|
||||
"""
|
||||
Confidences
|
||||
"""
|
||||
camera: Camera
|
||||
"""
|
||||
Camera
|
||||
"""
|
||||
timestamp: datetime
|
||||
"""
|
||||
Timestamp of the detection
|
||||
"""
|
||||
|
||||
|
||||
def classify_by_camera(
|
||||
detections: list[Detection],
|
||||
) -> OrderedDict[CameraID, list[Detection]]:
|
||||
"""
|
||||
Classify detections by camera
|
||||
"""
|
||||
# or use setdefault
|
||||
camera_wise_split: dict[CameraID, list[Detection]] = defaultdict(list)
|
||||
for entry in detections:
|
||||
camera_id = entry.camera.id
|
||||
camera_wise_split[camera_id].append(entry)
|
||||
return OrderedDict(camera_wise_split)
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def to_homogeneous(points: Num[Array, "N 2"] | Num[Array, "N 3"]) -> Num[Array, "N 3"]:
|
||||
"""
|
||||
Convert points to homogeneous coordinates
|
||||
"""
|
||||
if points.shape[-1] == 2:
|
||||
return jnp.hstack((points, jnp.ones((points.shape[0], 1))))
|
||||
elif points.shape[-1] == 3:
|
||||
return points
|
||||
else:
|
||||
raise ValueError(f"Invalid shape for points: {points.shape}")
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def point_line_distance(
|
||||
point: Num[Array, "N 3"] | Num[Array, "N 2"],
|
||||
line: Num[Array, "N 3"],
|
||||
eps: float = 1e-9,
|
||||
):
|
||||
"""
|
||||
Calculate the distance from a point to a line
|
||||
|
||||
Args:
|
||||
point: (possibly homogeneous) points :math:`(N, 2 or 3)`.
|
||||
line: lines coefficients :math:`(a, b, c)` with shape :math:`(N, 3)`, where :math:`ax + by + c = 0`.
|
||||
eps: Small constant for safe sqrt.
|
||||
|
||||
Returns:
|
||||
the computed distance with shape :math:`(N)`.
|
||||
|
||||
See also:
|
||||
https://en.wikipedia.org/wiki/Distance_from_a_point_to_a_line
|
||||
"""
|
||||
numerator = abs(line[:, 0] * point[:, 0] + line[:, 1] * point[:, 1] + line[:, 2])
|
||||
denominator = jnp.sqrt(line[:, 0] * line[:, 0] + line[:, 1] * line[:, 1])
|
||||
return numerator / (denominator + eps)
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def left_to_right_epipolar_distance(
|
||||
left: Num[Array, "N 3"],
|
||||
right: Num[Array, "N 3"],
|
||||
fundamental_matrix: Num[Array, "3 3"],
|
||||
):
|
||||
"""
|
||||
Return one-sided epipolar distance for correspondences given the fundamental matrix.
|
||||
|
||||
Args:
|
||||
left: points in the left image (homogeneous) :math:`(N, 3)`
|
||||
right: points in the right image (homogeneous) :math:`(N, 3)`
|
||||
fundamental_matrix: fundamental matrix :math:`(3, 3)`
|
||||
|
||||
Returns:
|
||||
the computed distance with shape :math:`(N)`.
|
||||
|
||||
See also:
|
||||
https://en.wikipedia.org/wiki/Fundamental_matrix_%28computer_vision%29
|
||||
|
||||
$$x^{\\prime T}Fx = 0$$
|
||||
"""
|
||||
F_t = fundamental_matrix.transpose()
|
||||
line1_in_2 = jnp.matmul(left, F_t)
|
||||
return point_line_distance(right, line1_in_2)
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def right_to_left_epipolar_distance(
|
||||
left: Num[Array, "N 3"],
|
||||
right: Num[Array, "N 3"],
|
||||
fundamental_matrix: Num[Array, "3 3"],
|
||||
):
|
||||
"""
|
||||
Return one-sided epipolar distance for correspondences given the fundamental matrix.
|
||||
|
||||
Args:
|
||||
left: points in the left image (homogeneous) :math:`(N, 3)`
|
||||
right: points in the right image (homogeneous) :math:`(N, 3)`
|
||||
fundamental_matrix: fundamental matrix :math:`(3, 3)`
|
||||
|
||||
Returns:
|
||||
the computed distance with shape :math:`(N)`.
|
||||
|
||||
See also:
|
||||
https://en.wikipedia.org/wiki/Fundamental_matrix_%28computer_vision%29
|
||||
|
||||
$$x^{\\prime T}Fx = 0$$
|
||||
"""
|
||||
line2_in_1 = jnp.matmul(right, fundamental_matrix)
|
||||
return point_line_distance(left, line2_in_1)
|
||||
|
||||
|
||||
def distance_between_epipolar_lines(
|
||||
x1: Num[Array, "N 2"] | Num[Array, "N 3"],
|
||||
x2: Num[Array, "N 2"] | Num[Array, "N 3"],
|
||||
fundamental_matrix: Num[Array, "3 3"],
|
||||
):
|
||||
"""
|
||||
Calculate the total epipolar line distance between x1 and x2.
|
||||
"""
|
||||
if x1.shape[0] != x2.shape[0]:
|
||||
raise ValueError(
|
||||
f"x1 and x2 must have the same number of points: {x1.shape[0]} != {x2.shape[0]}"
|
||||
)
|
||||
|
||||
if x1.shape[-1] == 2:
|
||||
point1 = to_homogeneous(x1)
|
||||
elif x1.shape[-1] == 3:
|
||||
point1 = x1
|
||||
else:
|
||||
raise ValueError(f"Invalid shape for correspondence1: {x1.shape}")
|
||||
|
||||
if x2.shape[-1] == 2:
|
||||
point2 = to_homogeneous(x2)
|
||||
elif x2.shape[-1] == 3:
|
||||
point2 = x2
|
||||
else:
|
||||
raise ValueError(f"Invalid shape for correspondence2: {x2.shape}")
|
||||
|
||||
if fundamental_matrix.shape != (3, 3):
|
||||
raise ValueError(
|
||||
f"Invalid shape for fundamental_matrix: {fundamental_matrix.shape}"
|
||||
)
|
||||
|
||||
# points 1 and 2 are unnormalized points
|
||||
dist_1 = jnp.mean(
|
||||
right_to_left_epipolar_distance(point1, point2, fundamental_matrix)
|
||||
)
|
||||
dist_2 = jnp.mean(
|
||||
left_to_right_epipolar_distance(point1, point2, fundamental_matrix)
|
||||
)
|
||||
distance = dist_1 + dist_2
|
||||
return distance
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def calculate_fundamental_matrix(
|
||||
camera_left: Camera, camera_right: Camera
|
||||
) -> Num[Array, "3 3"]:
|
||||
"""
|
||||
Calculate the fundamental matrix for the given cameras.
|
||||
"""
|
||||
|
||||
# Intrinsics
|
||||
K1 = camera_left.params.K
|
||||
K2 = camera_right.params.K
|
||||
|
||||
# Extrinsics (World to Camera transforms)
|
||||
Rt1 = camera_left.params.Rt
|
||||
Rt2 = camera_right.params.Rt
|
||||
|
||||
# Convert to Camera to World (Inverse)
|
||||
T2: Array = jnp.linalg.inv(Rt2)
|
||||
|
||||
# Relative transform from Left to Right
|
||||
T_rel = T2 @ Rt1
|
||||
|
||||
R = T_rel[:3, :3]
|
||||
t = T_rel[:3, 3:]
|
||||
|
||||
# Skew-symmetric matrix for cross product
|
||||
def skew(v: Num[Array, "3"]) -> Num[Array, "3 3"]:
|
||||
return jnp.array(
|
||||
[
|
||||
[0, -v[2], v[1]],
|
||||
[v[2], 0, -v[0]],
|
||||
[-v[1], v[0], 0],
|
||||
]
|
||||
)
|
||||
|
||||
t_skew = skew(t.reshape(-1))
|
||||
|
||||
# Essential Matrix
|
||||
E = t_skew @ R
|
||||
|
||||
# Fundamental Matrix
|
||||
F = jnp.linalg.inv(K2).T @ E @ jnp.linalg.inv(K1)
|
||||
|
||||
return F
|
||||
|
||||
|
||||
def compute_affinity_epipolar_constraint_with_pairs(
|
||||
left: Detection, right: Detection, alpha_2D: float
|
||||
):
|
||||
fundamental_matrix = calculate_fundamental_matrix(left.camera, right.camera)
|
||||
d = distance_between_epipolar_lines(
|
||||
left.keypoints, right.keypoints, fundamental_matrix
|
||||
)
|
||||
return 1 - (d / alpha_2D)
|
||||
|
||||
|
||||
def get_affinity_matrix_epipolar_constraint(
|
||||
detections: list[Detection],
|
||||
alpha_2D: float,
|
||||
) -> Num[Array, "N N"]:
|
||||
camera_wise_split = classify_by_camera(detections)
|
||||
num_entries = sum(len(entries) for entries in camera_wise_split.values())
|
||||
affinity_matrix = jnp.zeros((num_entries, num_entries), dtype=jnp.float32)
|
||||
affinity_matrix_mask = jnp.zeros((num_entries, num_entries), dtype=jnp.bool_)
|
||||
|
||||
acc = 0
|
||||
total_indices = set(range(num_entries))
|
||||
camera_id_index_map: dict[CameraID, set[int]] = defaultdict(set)
|
||||
camera_id_index_map_inverse: dict[CameraID, set[int]] = defaultdict(set)
|
||||
# sorted by [D0_C0, D1_C0, D2_C0, D0_C1, D1_C1, D0_C2, D1_C2...]
|
||||
sorted_detections: list[Detection] = []
|
||||
for camera_id, entries in camera_wise_split.items():
|
||||
for i, _ in enumerate(entries):
|
||||
camera_id_index_map[camera_id].add(acc + i)
|
||||
sorted_detections.append(entries[i])
|
||||
acc += 1
|
||||
camera_id_index_map_inverse[camera_id] = (
|
||||
total_indices - camera_id_index_map[camera_id]
|
||||
)
|
||||
|
||||
# assuming we have 3 cameras
|
||||
# C0 has 3 detections: D0_C0, D1_C0, D2_C0
|
||||
# C1 has 2 detections: D0_C1, D1_C1
|
||||
# C2 has 2 detections: D0_C2, D1_C2
|
||||
#
|
||||
# D0_C0(0), D1_C0(1), D2_C0(2), D0_C1(3), D1_C1(4), D0_C2(5), D1_C2(6)
|
||||
# D0_C0(0) 0 0 0 a_03 a_04 a_05 a_06
|
||||
# D1_C0(1) 0 0 0 a_13 a_14 a_15 a_16
|
||||
# ...
|
||||
# D0_C1(3) a_30 a_31 a_32 0 0 a_35 a_36
|
||||
# ...
|
||||
# D1_C2(6) a_60 a_61 a_62 a_63 a_64 0 0
|
||||
|
||||
# ignore self-affinity
|
||||
# ignore same-camera affinity
|
||||
# assuming commutative property of epipolar constraint
|
||||
for i, det in enumerate(sorted_detections):
|
||||
other_indices = camera_id_index_map_inverse[det.camera.id]
|
||||
for j in other_indices:
|
||||
if affinity_matrix_mask[i, j] or affinity_matrix_mask[j, i]:
|
||||
continue
|
||||
a = compute_affinity_epipolar_constraint_with_pairs(
|
||||
det, sorted_detections[j], alpha_2D
|
||||
)
|
||||
affinity_matrix = affinity_matrix.at[i, j].set(a)
|
||||
affinity_matrix = affinity_matrix.at[j, i].set(a)
|
||||
affinity_matrix_mask = affinity_matrix_mask.at[i, j].set(True)
|
||||
affinity_matrix_mask = affinity_matrix_mask.at[j, i].set(True)
|
||||
|
||||
return affinity_matrix
|
||||
|
||||
Reference in New Issue
Block a user