1
0
forked from HQU-gxy/CVTH3PE
Files
CVTH3PE/app/camera/__init__.py
crosstyan 92477b18d2 feat: Enhance camera module with new data structures and utility functions
- Introduced dataclass structures for CameraParams and Camera to improve type safety and clarity.
- Added Detection dataclass to encapsulate detection data, including keypoints and timestamps.
- Implemented classify_by_camera function to organize detections by camera.
- Added utility functions for converting points to homogeneous coordinates and calculating distances to lines.
- Updated dependencies in pyproject.toml to include new libraries for enhanced functionality.
2025-04-15 10:20:49 +08:00

345 lines
9.8 KiB
Python

from typing import TypedDict, TypeAlias, Any
from typing_extensions import NotRequired
from jaxtyping import Num, jaxtyped
from beartype import beartype
from jax import numpy as jnp, Array
from dataclasses import dataclass
from collections import defaultdict, OrderedDict
from datetime import datetime
CameraID: TypeAlias = str
@jaxtyped(typechecker=beartype)
@dataclass
class CameraParams:
"""
Camera parameters: intrinsic matrix, extrinsic matrix, and distortion coefficients
"""
K: Num[Array, "3 3"]
"""
intrinsic matrix
"""
Rt: Num[Array, "4 4"]
"""
[R|t] extrinsic matrix
R and t are the rotation and translation that describe the change of
coordinates from world to camera coordinate systems (or camera frame)
"""
dist_coeffs: Num[Array, "N"]
"""
An array of distortion coefficients of the form
[k1, k2, [p1, p2, [k3]]], where ki is the ith
radial distortion coefficient and pi is the ith
tangential distortion coeff.
"""
image_size: Num[Array, "2"]
"""
The size of image plane (width, height)
"""
@jaxtyped(typechecker=beartype)
@dataclass
class Camera:
"""
a description of a camera
"""
id: CameraID
"""
Camera ID
"""
params: CameraParams
"""
Camera parameters
"""
@jaxtyped(typechecker=beartype)
@dataclass
class Detection:
"""
One detection from a camera
"""
keypoints: Num[Array, "N 2"]
"""
Keypoints
"""
confidences: Num[Array, "N"]
"""
Confidences
"""
camera: Camera
"""
Camera
"""
timestamp: datetime
"""
Timestamp of the detection
"""
def classify_by_camera(
detections: list[Detection],
) -> OrderedDict[CameraID, list[Detection]]:
"""
Classify detections by camera
"""
# or use setdefault
camera_wise_split: dict[CameraID, list[Detection]] = defaultdict(list)
for entry in detections:
camera_id = entry.camera.id
camera_wise_split[camera_id].append(entry)
return OrderedDict(camera_wise_split)
@jaxtyped(typechecker=beartype)
def to_homogeneous(points: Num[Array, "N 2"] | Num[Array, "N 3"]) -> Num[Array, "N 3"]:
"""
Convert points to homogeneous coordinates
"""
if points.shape[-1] == 2:
return jnp.hstack((points, jnp.ones((points.shape[0], 1))))
elif points.shape[-1] == 3:
return points
else:
raise ValueError(f"Invalid shape for points: {points.shape}")
@jaxtyped(typechecker=beartype)
def point_line_distance(
point: Num[Array, "N 3"] | Num[Array, "N 2"],
line: Num[Array, "N 3"],
eps: float = 1e-9,
):
"""
Calculate the distance from a point to a line
Args:
point: (possibly homogeneous) points :math:`(N, 2 or 3)`.
line: lines coefficients :math:`(a, b, c)` with shape :math:`(N, 3)`, where :math:`ax + by + c = 0`.
eps: Small constant for safe sqrt.
Returns:
the computed distance with shape :math:`(N)`.
See also:
https://en.wikipedia.org/wiki/Distance_from_a_point_to_a_line
"""
numerator = abs(line[:, 0] * point[:, 0] + line[:, 1] * point[:, 1] + line[:, 2])
denominator = jnp.sqrt(line[:, 0] * line[:, 0] + line[:, 1] * line[:, 1])
return numerator / (denominator + eps)
@jaxtyped(typechecker=beartype)
def left_to_right_epipolar_distance(
left: Num[Array, "N 3"],
right: Num[Array, "N 3"],
fundamental_matrix: Num[Array, "3 3"],
):
"""
Return one-sided epipolar distance for correspondences given the fundamental matrix.
Args:
left: points in the left image (homogeneous) :math:`(N, 3)`
right: points in the right image (homogeneous) :math:`(N, 3)`
fundamental_matrix: fundamental matrix :math:`(3, 3)`
Returns:
the computed distance with shape :math:`(N)`.
See also:
https://en.wikipedia.org/wiki/Fundamental_matrix_%28computer_vision%29
$$x^{\\prime T}Fx = 0$$
"""
F_t = fundamental_matrix.transpose()
line1_in_2 = jnp.matmul(left, F_t)
return point_line_distance(right, line1_in_2)
@jaxtyped(typechecker=beartype)
def right_to_left_epipolar_distance(
left: Num[Array, "N 3"],
right: Num[Array, "N 3"],
fundamental_matrix: Num[Array, "3 3"],
):
"""
Return one-sided epipolar distance for correspondences given the fundamental matrix.
Args:
left: points in the left image (homogeneous) :math:`(N, 3)`
right: points in the right image (homogeneous) :math:`(N, 3)`
fundamental_matrix: fundamental matrix :math:`(3, 3)`
Returns:
the computed distance with shape :math:`(N)`.
See also:
https://en.wikipedia.org/wiki/Fundamental_matrix_%28computer_vision%29
$$x^{\\prime T}Fx = 0$$
"""
line2_in_1 = jnp.matmul(right, fundamental_matrix)
return point_line_distance(left, line2_in_1)
def distance_between_epipolar_lines(
x1: Num[Array, "N 2"] | Num[Array, "N 3"],
x2: Num[Array, "N 2"] | Num[Array, "N 3"],
fundamental_matrix: Num[Array, "3 3"],
):
"""
Calculate the total epipolar line distance between x1 and x2.
"""
if x1.shape[0] != x2.shape[0]:
raise ValueError(
f"x1 and x2 must have the same number of points: {x1.shape[0]} != {x2.shape[0]}"
)
if x1.shape[-1] == 2:
point1 = to_homogeneous(x1)
elif x1.shape[-1] == 3:
point1 = x1
else:
raise ValueError(f"Invalid shape for correspondence1: {x1.shape}")
if x2.shape[-1] == 2:
point2 = to_homogeneous(x2)
elif x2.shape[-1] == 3:
point2 = x2
else:
raise ValueError(f"Invalid shape for correspondence2: {x2.shape}")
if fundamental_matrix.shape != (3, 3):
raise ValueError(
f"Invalid shape for fundamental_matrix: {fundamental_matrix.shape}"
)
# points 1 and 2 are unnormalized points
dist_1 = jnp.mean(
right_to_left_epipolar_distance(point1, point2, fundamental_matrix)
)
dist_2 = jnp.mean(
left_to_right_epipolar_distance(point1, point2, fundamental_matrix)
)
distance = dist_1 + dist_2
return distance
@jaxtyped(typechecker=beartype)
def calculate_fundamental_matrix(
camera_left: Camera, camera_right: Camera
) -> Num[Array, "3 3"]:
"""
Calculate the fundamental matrix for the given cameras.
"""
# Intrinsics
K1 = camera_left.params.K
K2 = camera_right.params.K
# Extrinsics (World to Camera transforms)
Rt1 = camera_left.params.Rt
Rt2 = camera_right.params.Rt
# Convert to Camera to World (Inverse)
T2: Array = jnp.linalg.inv(Rt2)
# Relative transform from Left to Right
T_rel = T2 @ Rt1
R = T_rel[:3, :3]
t = T_rel[:3, 3:]
# Skew-symmetric matrix for cross product
def skew(v: Num[Array, "3"]) -> Num[Array, "3 3"]:
return jnp.array(
[
[0, -v[2], v[1]],
[v[2], 0, -v[0]],
[-v[1], v[0], 0],
]
)
t_skew = skew(t.reshape(-1))
# Essential Matrix
E = t_skew @ R
# Fundamental Matrix
F = jnp.linalg.inv(K2).T @ E @ jnp.linalg.inv(K1)
return F
def compute_affinity_epipolar_constraint_with_pairs(
left: Detection, right: Detection, alpha_2D: float
):
fundamental_matrix = calculate_fundamental_matrix(left.camera, right.camera)
d = distance_between_epipolar_lines(
left.keypoints, right.keypoints, fundamental_matrix
)
return 1 - (d / alpha_2D)
def get_affinity_matrix_epipolar_constraint(
detections: list[Detection],
alpha_2D: float,
) -> Num[Array, "N N"]:
camera_wise_split = classify_by_camera(detections)
num_entries = sum(len(entries) for entries in camera_wise_split.values())
affinity_matrix = jnp.zeros((num_entries, num_entries), dtype=jnp.float32)
affinity_matrix_mask = jnp.zeros((num_entries, num_entries), dtype=jnp.bool_)
acc = 0
total_indices = set(range(num_entries))
camera_id_index_map: dict[CameraID, set[int]] = defaultdict(set)
camera_id_index_map_inverse: dict[CameraID, set[int]] = defaultdict(set)
# sorted by [D0_C0, D1_C0, D2_C0, D0_C1, D1_C1, D0_C2, D1_C2...]
sorted_detections: list[Detection] = []
for camera_id, entries in camera_wise_split.items():
for i, _ in enumerate(entries):
camera_id_index_map[camera_id].add(acc + i)
sorted_detections.append(entries[i])
acc += 1
camera_id_index_map_inverse[camera_id] = (
total_indices - camera_id_index_map[camera_id]
)
# assuming we have 3 cameras
# C0 has 3 detections: D0_C0, D1_C0, D2_C0
# C1 has 2 detections: D0_C1, D1_C1
# C2 has 2 detections: D0_C2, D1_C2
#
# D0_C0(0), D1_C0(1), D2_C0(2), D0_C1(3), D1_C1(4), D0_C2(5), D1_C2(6)
# D0_C0(0) 0 0 0 a_03 a_04 a_05 a_06
# D1_C0(1) 0 0 0 a_13 a_14 a_15 a_16
# ...
# D0_C1(3) a_30 a_31 a_32 0 0 a_35 a_36
# ...
# D1_C2(6) a_60 a_61 a_62 a_63 a_64 0 0
# ignore self-affinity
# ignore same-camera affinity
# assuming commutative property of epipolar constraint
for i, det in enumerate(sorted_detections):
other_indices = camera_id_index_map_inverse[det.camera.id]
for j in other_indices:
if affinity_matrix_mask[i, j] or affinity_matrix_mask[j, i]:
continue
a = compute_affinity_epipolar_constraint_with_pairs(
det, sorted_detections[j], alpha_2D
)
affinity_matrix = affinity_matrix.at[i, j].set(a)
affinity_matrix = affinity_matrix.at[j, i].set(a)
affinity_matrix_mask = affinity_matrix_mask.at[i, j].set(True)
affinity_matrix_mask = affinity_matrix_mask.at[j, i].set(True)
return affinity_matrix