From 3f32333de45c370c186d8a62eb7f789e3615a6df Mon Sep 17 00:00:00 2001 From: crosstyan Date: Tue, 15 Apr 2025 11:06:34 +0800 Subject: [PATCH] refactor: Update camera module for improved type handling and utility functions - Reorganized imports for better clarity and consistency. - Renamed variables in distance calculation functions for improved readability. - Enhanced `compute_affinity_epipolar_constraint_with_pairs` function with detailed docstring explaining its purpose and parameters. - Updated function signature to accept both list and dictionary formats for detections, improving flexibility. - Adjusted affinity calculation logic to ensure consistent naming conventions for parameters. --- app/camera/__init__.py | 123 +++++++++++++++++++++++++++-------------- 1 file changed, 83 insertions(+), 40 deletions(-) diff --git a/app/camera/__init__.py b/app/camera/__init__.py index 4b10eb9..66717fa 100644 --- a/app/camera/__init__.py +++ b/app/camera/__init__.py @@ -1,14 +1,15 @@ -from typing import TypedDict, TypeAlias, Any +from collections import OrderedDict, defaultdict +from dataclasses import dataclass +from datetime import datetime +from typing import Any, TypeAlias, TypedDict + +from beartype import beartype +from jax import Array +from jax import numpy as jnp +from jaxtyping import Num, jaxtyped from typing_extensions import NotRequired -from jaxtyping import Num, jaxtyped -from beartype import beartype -from jax import numpy as jnp, Array -from dataclasses import dataclass -from collections import defaultdict, OrderedDict -from datetime import datetime - -CameraID: TypeAlias = str +CameraID: TypeAlias = str # pylint: disable=invalid-name @jaxtyped(typechecker=beartype) @@ -113,7 +114,7 @@ def to_homogeneous(points: Num[Array, "N 2"] | Num[Array, "N 3"]) -> Num[Array, @jaxtyped(typechecker=beartype) def point_line_distance( - point: Num[Array, "N 3"] | Num[Array, "N 2"], + points: Num[Array, "N 3"] | Num[Array, "N 2"], line: Num[Array, "N 3"], eps: float = 1e-9, ): @@ -131,7 +132,7 @@ def point_line_distance( See also: https://en.wikipedia.org/wiki/Distance_from_a_point_to_a_line """ - numerator = abs(line[:, 0] * point[:, 0] + line[:, 1] * point[:, 1] + line[:, 2]) + numerator = abs(line[:, 0] * points[:, 0] + line[:, 1] * points[:, 1] + line[:, 2]) denominator = jnp.sqrt(line[:, 0] * line[:, 0] + line[:, 1] * line[:, 1]) return numerator / (denominator + eps) @@ -203,16 +204,16 @@ def distance_between_epipolar_lines( ) if x1.shape[-1] == 2: - point1 = to_homogeneous(x1) + points1 = to_homogeneous(x1) elif x1.shape[-1] == 3: - point1 = x1 + points1 = x1 else: raise ValueError(f"Invalid shape for correspondence1: {x1.shape}") if x2.shape[-1] == 2: - point2 = to_homogeneous(x2) + points2 = to_homogeneous(x2) elif x2.shape[-1] == 3: - point2 = x2 + points2 = x2 else: raise ValueError(f"Invalid shape for correspondence2: {x2.shape}") @@ -223,10 +224,10 @@ def distance_between_epipolar_lines( # points 1 and 2 are unnormalized points dist_1 = jnp.mean( - right_to_left_epipolar_distance(point1, point2, fundamental_matrix) + right_to_left_epipolar_distance(points1, points2, fundamental_matrix) ) dist_2 = jnp.mean( - left_to_right_epipolar_distance(point1, point2, fundamental_matrix) + left_to_right_epipolar_distance(points1, points2, fundamental_matrix) ) distance = dist_1 + dist_2 return distance @@ -279,20 +280,75 @@ def calculate_fundamental_matrix( def compute_affinity_epipolar_constraint_with_pairs( - left: Detection, right: Detection, alpha_2D: float + left: Detection, right: Detection, alpha_2d: float ): + """ + Compute the affinity between two groups of detections by epipolar constraint, + where camera parameters are included in the detections. + + Note: + Originally, alpha_2d comes from the paper as a scaling factor for epipolar error affinity. + Its role is mainly to normalize error into [0,1] range, but it could lead to negative affinity. + An alternative approach is using normalized epipolar error relative to image size, with soft cutoff, + like exp(-error / threshold), for better interpretability and stability. + """ fundamental_matrix = calculate_fundamental_matrix(left.camera, right.camera) d = distance_between_epipolar_lines( left.keypoints, right.keypoints, fundamental_matrix ) - return 1 - (d / alpha_2D) + return 1 - (d / alpha_2d) -def get_affinity_matrix_epipolar_constraint( - detections: list[Detection], - alpha_2D: float, -) -> Num[Array, "N N"]: - camera_wise_split = classify_by_camera(detections) +def calculate_affinity_matrix_by_epipolar_constraint( + detections: list[Detection] | dict[CameraID, list[Detection]], + alpha_2d: float, +) -> tuple[list[Detection], Num[Array, "N N"]]: + """ + Calculate the affinity matrix by epipolar constraint + + This function evaluates the geometric consistency of every pair of detections + across different cameras using the fundamental matrix. It assumes that + detections from the same camera are not comparable and should have zero affinity. + + The affinity is computed by: + 1. Calculating the fundamental matrix between the two cameras. + 2. Measuring the average point-to-epipolar-line distance for all keypoints. + 3. Mapping the distance to affinity with the formula: 1 - (distance / alpha_2d). + + Args: + detections: Either a flat list of Detection or a dict grouping Detection by CameraID. + alpha_2d: Image resolution-dependent threshold controlling affinity scaling. + Typically relates to expected pixel displacement rate. + + Returns: + sorted_detections: Flattened list of detections sorted by camera order. + affinity_matrix: Array of shape (N, N), where N is the number of detections. + + Notes: + - Detections from the same camera always have affinity = 0. + - Affinity decays linearly with epipolar error until 0 (or potentially negative). + - Consider switching to exp(-error / scale) style for non-negative affinity. + - alpha_2d should be adjusted based on image resolution or empirical observation. + + + Affinity Matrix layout: + assuming we have 3 cameras + C0 has 3 detections: D0_C0, D1_C0, D2_C0 + C1 has 2 detections: D0_C1, D1_C1 + C2 has 2 detections: D0_C2, D1_C2 + + D0_C0(0), D1_C0(1), D2_C0(2), D0_C1(3), D1_C1(4), D0_C2(5), D1_C2(6) + D0_C0(0) 0 0 0 a_03 a_04 a_05 a_06 + D1_C0(1) 0 0 0 a_13 a_14 a_15 a_16 + ... + D0_C1(3) a_30 a_31 a_32 0 0 a_35 a_36 + ... + D1_C2(6) a_60 a_61 a_62 a_63 a_64 0 0 + """ + if isinstance(detections, dict): + camera_wise_split = detections + else: + camera_wise_split = classify_by_camera(detections) num_entries = sum(len(entries) for entries in camera_wise_split.values()) affinity_matrix = jnp.zeros((num_entries, num_entries), dtype=jnp.float32) affinity_matrix_mask = jnp.zeros((num_entries, num_entries), dtype=jnp.bool_) @@ -312,33 +368,20 @@ def get_affinity_matrix_epipolar_constraint( total_indices - camera_id_index_map[camera_id] ) - # assuming we have 3 cameras - # C0 has 3 detections: D0_C0, D1_C0, D2_C0 - # C1 has 2 detections: D0_C1, D1_C1 - # C2 has 2 detections: D0_C2, D1_C2 - # - # D0_C0(0), D1_C0(1), D2_C0(2), D0_C1(3), D1_C1(4), D0_C2(5), D1_C2(6) - # D0_C0(0) 0 0 0 a_03 a_04 a_05 a_06 - # D1_C0(1) 0 0 0 a_13 a_14 a_15 a_16 - # ... - # D0_C1(3) a_30 a_31 a_32 0 0 a_35 a_36 - # ... - # D1_C2(6) a_60 a_61 a_62 a_63 a_64 0 0 - # ignore self-affinity # ignore same-camera affinity - # assuming commutative property of epipolar constraint + # assuming commutative for i, det in enumerate(sorted_detections): other_indices = camera_id_index_map_inverse[det.camera.id] for j in other_indices: if affinity_matrix_mask[i, j] or affinity_matrix_mask[j, i]: continue a = compute_affinity_epipolar_constraint_with_pairs( - det, sorted_detections[j], alpha_2D + det, sorted_detections[j], alpha_2d ) affinity_matrix = affinity_matrix.at[i, j].set(a) affinity_matrix = affinity_matrix.at[j, i].set(a) affinity_matrix_mask = affinity_matrix_mask.at[i, j].set(True) affinity_matrix_mask = affinity_matrix_mask.at[j, i].set(True) - return affinity_matrix + return sorted_detections, affinity_matrix