refactor: Update camera module for improved type handling and utility functions

- Reorganized imports for better clarity and consistency.
- Renamed variables in distance calculation functions for improved readability.
- Enhanced `compute_affinity_epipolar_constraint_with_pairs` function with detailed docstring explaining its purpose and parameters.
- Updated function signature to accept both list and dictionary formats for detections, improving flexibility.
- Adjusted affinity calculation logic to ensure consistent naming conventions for parameters.
This commit is contained in:
2025-04-15 11:06:34 +08:00
parent 92477b18d2
commit 3f32333de4

View File

@ -1,14 +1,15 @@
from typing import TypedDict, TypeAlias, Any from collections import OrderedDict, defaultdict
from dataclasses import dataclass
from datetime import datetime
from typing import Any, TypeAlias, TypedDict
from beartype import beartype
from jax import Array
from jax import numpy as jnp
from jaxtyping import Num, jaxtyped
from typing_extensions import NotRequired from typing_extensions import NotRequired
from jaxtyping import Num, jaxtyped CameraID: TypeAlias = str # pylint: disable=invalid-name
from beartype import beartype
from jax import numpy as jnp, Array
from dataclasses import dataclass
from collections import defaultdict, OrderedDict
from datetime import datetime
CameraID: TypeAlias = str
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
@ -113,7 +114,7 @@ def to_homogeneous(points: Num[Array, "N 2"] | Num[Array, "N 3"]) -> Num[Array,
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
def point_line_distance( def point_line_distance(
point: Num[Array, "N 3"] | Num[Array, "N 2"], points: Num[Array, "N 3"] | Num[Array, "N 2"],
line: Num[Array, "N 3"], line: Num[Array, "N 3"],
eps: float = 1e-9, eps: float = 1e-9,
): ):
@ -131,7 +132,7 @@ def point_line_distance(
See also: See also:
https://en.wikipedia.org/wiki/Distance_from_a_point_to_a_line https://en.wikipedia.org/wiki/Distance_from_a_point_to_a_line
""" """
numerator = abs(line[:, 0] * point[:, 0] + line[:, 1] * point[:, 1] + line[:, 2]) numerator = abs(line[:, 0] * points[:, 0] + line[:, 1] * points[:, 1] + line[:, 2])
denominator = jnp.sqrt(line[:, 0] * line[:, 0] + line[:, 1] * line[:, 1]) denominator = jnp.sqrt(line[:, 0] * line[:, 0] + line[:, 1] * line[:, 1])
return numerator / (denominator + eps) return numerator / (denominator + eps)
@ -203,16 +204,16 @@ def distance_between_epipolar_lines(
) )
if x1.shape[-1] == 2: if x1.shape[-1] == 2:
point1 = to_homogeneous(x1) points1 = to_homogeneous(x1)
elif x1.shape[-1] == 3: elif x1.shape[-1] == 3:
point1 = x1 points1 = x1
else: else:
raise ValueError(f"Invalid shape for correspondence1: {x1.shape}") raise ValueError(f"Invalid shape for correspondence1: {x1.shape}")
if x2.shape[-1] == 2: if x2.shape[-1] == 2:
point2 = to_homogeneous(x2) points2 = to_homogeneous(x2)
elif x2.shape[-1] == 3: elif x2.shape[-1] == 3:
point2 = x2 points2 = x2
else: else:
raise ValueError(f"Invalid shape for correspondence2: {x2.shape}") raise ValueError(f"Invalid shape for correspondence2: {x2.shape}")
@ -223,10 +224,10 @@ def distance_between_epipolar_lines(
# points 1 and 2 are unnormalized points # points 1 and 2 are unnormalized points
dist_1 = jnp.mean( dist_1 = jnp.mean(
right_to_left_epipolar_distance(point1, point2, fundamental_matrix) right_to_left_epipolar_distance(points1, points2, fundamental_matrix)
) )
dist_2 = jnp.mean( dist_2 = jnp.mean(
left_to_right_epipolar_distance(point1, point2, fundamental_matrix) left_to_right_epipolar_distance(points1, points2, fundamental_matrix)
) )
distance = dist_1 + dist_2 distance = dist_1 + dist_2
return distance return distance
@ -279,19 +280,74 @@ def calculate_fundamental_matrix(
def compute_affinity_epipolar_constraint_with_pairs( def compute_affinity_epipolar_constraint_with_pairs(
left: Detection, right: Detection, alpha_2D: float left: Detection, right: Detection, alpha_2d: float
): ):
"""
Compute the affinity between two groups of detections by epipolar constraint,
where camera parameters are included in the detections.
Note:
Originally, alpha_2d comes from the paper as a scaling factor for epipolar error affinity.
Its role is mainly to normalize error into [0,1] range, but it could lead to negative affinity.
An alternative approach is using normalized epipolar error relative to image size, with soft cutoff,
like exp(-error / threshold), for better interpretability and stability.
"""
fundamental_matrix = calculate_fundamental_matrix(left.camera, right.camera) fundamental_matrix = calculate_fundamental_matrix(left.camera, right.camera)
d = distance_between_epipolar_lines( d = distance_between_epipolar_lines(
left.keypoints, right.keypoints, fundamental_matrix left.keypoints, right.keypoints, fundamental_matrix
) )
return 1 - (d / alpha_2D) return 1 - (d / alpha_2d)
def get_affinity_matrix_epipolar_constraint( def calculate_affinity_matrix_by_epipolar_constraint(
detections: list[Detection], detections: list[Detection] | dict[CameraID, list[Detection]],
alpha_2D: float, alpha_2d: float,
) -> Num[Array, "N N"]: ) -> tuple[list[Detection], Num[Array, "N N"]]:
"""
Calculate the affinity matrix by epipolar constraint
This function evaluates the geometric consistency of every pair of detections
across different cameras using the fundamental matrix. It assumes that
detections from the same camera are not comparable and should have zero affinity.
The affinity is computed by:
1. Calculating the fundamental matrix between the two cameras.
2. Measuring the average point-to-epipolar-line distance for all keypoints.
3. Mapping the distance to affinity with the formula: 1 - (distance / alpha_2d).
Args:
detections: Either a flat list of Detection or a dict grouping Detection by CameraID.
alpha_2d: Image resolution-dependent threshold controlling affinity scaling.
Typically relates to expected pixel displacement rate.
Returns:
sorted_detections: Flattened list of detections sorted by camera order.
affinity_matrix: Array of shape (N, N), where N is the number of detections.
Notes:
- Detections from the same camera always have affinity = 0.
- Affinity decays linearly with epipolar error until 0 (or potentially negative).
- Consider switching to exp(-error / scale) style for non-negative affinity.
- alpha_2d should be adjusted based on image resolution or empirical observation.
Affinity Matrix layout:
assuming we have 3 cameras
C0 has 3 detections: D0_C0, D1_C0, D2_C0
C1 has 2 detections: D0_C1, D1_C1
C2 has 2 detections: D0_C2, D1_C2
D0_C0(0), D1_C0(1), D2_C0(2), D0_C1(3), D1_C1(4), D0_C2(5), D1_C2(6)
D0_C0(0) 0 0 0 a_03 a_04 a_05 a_06
D1_C0(1) 0 0 0 a_13 a_14 a_15 a_16
...
D0_C1(3) a_30 a_31 a_32 0 0 a_35 a_36
...
D1_C2(6) a_60 a_61 a_62 a_63 a_64 0 0
"""
if isinstance(detections, dict):
camera_wise_split = detections
else:
camera_wise_split = classify_by_camera(detections) camera_wise_split = classify_by_camera(detections)
num_entries = sum(len(entries) for entries in camera_wise_split.values()) num_entries = sum(len(entries) for entries in camera_wise_split.values())
affinity_matrix = jnp.zeros((num_entries, num_entries), dtype=jnp.float32) affinity_matrix = jnp.zeros((num_entries, num_entries), dtype=jnp.float32)
@ -312,33 +368,20 @@ def get_affinity_matrix_epipolar_constraint(
total_indices - camera_id_index_map[camera_id] total_indices - camera_id_index_map[camera_id]
) )
# assuming we have 3 cameras
# C0 has 3 detections: D0_C0, D1_C0, D2_C0
# C1 has 2 detections: D0_C1, D1_C1
# C2 has 2 detections: D0_C2, D1_C2
#
# D0_C0(0), D1_C0(1), D2_C0(2), D0_C1(3), D1_C1(4), D0_C2(5), D1_C2(6)
# D0_C0(0) 0 0 0 a_03 a_04 a_05 a_06
# D1_C0(1) 0 0 0 a_13 a_14 a_15 a_16
# ...
# D0_C1(3) a_30 a_31 a_32 0 0 a_35 a_36
# ...
# D1_C2(6) a_60 a_61 a_62 a_63 a_64 0 0
# ignore self-affinity # ignore self-affinity
# ignore same-camera affinity # ignore same-camera affinity
# assuming commutative property of epipolar constraint # assuming commutative
for i, det in enumerate(sorted_detections): for i, det in enumerate(sorted_detections):
other_indices = camera_id_index_map_inverse[det.camera.id] other_indices = camera_id_index_map_inverse[det.camera.id]
for j in other_indices: for j in other_indices:
if affinity_matrix_mask[i, j] or affinity_matrix_mask[j, i]: if affinity_matrix_mask[i, j] or affinity_matrix_mask[j, i]:
continue continue
a = compute_affinity_epipolar_constraint_with_pairs( a = compute_affinity_epipolar_constraint_with_pairs(
det, sorted_detections[j], alpha_2D det, sorted_detections[j], alpha_2d
) )
affinity_matrix = affinity_matrix.at[i, j].set(a) affinity_matrix = affinity_matrix.at[i, j].set(a)
affinity_matrix = affinity_matrix.at[j, i].set(a) affinity_matrix = affinity_matrix.at[j, i].set(a)
affinity_matrix_mask = affinity_matrix_mask.at[i, j].set(True) affinity_matrix_mask = affinity_matrix_mask.at[i, j].set(True)
affinity_matrix_mask = affinity_matrix_mask.at[j, i].set(True) affinity_matrix_mask = affinity_matrix_mask.at[j, i].set(True)
return affinity_matrix return sorted_detections, affinity_matrix