from collections import OrderedDict, defaultdict from dataclasses import dataclass from datetime import datetime from typing import Any, TypeAlias, TypedDict, Optional from beartype import beartype import jax from jax import numpy as jnp from jaxtyping import Num, jaxtyped, Array from cv2 import undistortPoints import numpy as np NDArray: TypeAlias = np.ndarray CameraID: TypeAlias = str # pylint: disable=invalid-name @jaxtyped(typechecker=beartype) def undistort_points( points: Num[NDArray, "M 2"], camera_matrix: Num[NDArray, "3 3"], dist_coeffs: Num[NDArray, "N"], ) -> Num[NDArray, "M 2"]: """ a thin wrapper of cv2.undistortPoints """ K = camera_matrix dist = dist_coeffs res = undistortPoints(points, K, dist, P=K) # type: ignore return res.reshape(-1, 2) @jax.jit @jaxtyped(typechecker=beartype) def distortion( points_2d: Num[Array, "N 2"], K: Num[Array, "3 3"], # pylint: disable=invalid-name dist_coeffs: Num[Array, "5"], ) -> Num[Array, "N 2"]: """ Apply distortion to 2D points in pixel coordinates Args: points_2d: 2D points in pixel coordinates K: Camera intrinsic matrix dist_coeffs: Distortion coefficients [k1, k2, p1, p2, k3] Returns: Distorted 2D points in pixel coordinates Note: The function handles the conversion between pixel coordinates and normalized coordinates internally. It expects points_2d to be in pixel coordinates, and returns distorted points in pixel coordinates. Implementation based on OpenCV's distortion model: https://docs.opencv.org/4.10.0/d9/d0c/group__calib3d.html#ga69f2545a8b62a6b0fc2ee060dc30559d """ # unpack fx, fy = K[0, 0], K[1, 1] cx, cy = K[0, 2], K[1, 2] k1, k2, p1, p2, k3 = dist_coeffs # normalize x = (points_2d[:, 0] - cx) / fx y = (points_2d[:, 1] - cy) / fy # precompute r^2, r^4, r^6 r2 = x * x + y * y r4 = r2 * r2 r6 = r4 * r2 # radial term radial = 1 + k1 * r2 + k2 * r4 + k3 * r6 # tangential term x_tan = 2 * p1 * x * y + p2 * (r2 + 2 * x * x) y_tan = p1 * (r2 + 2 * y * y) + 2 * p2 * x * y # apply both x_dist = x * radial + x_tan y_dist = y * radial + y_tan # back to pixels u = x_dist * fx + cx v = y_dist * fy + cy return jnp.stack([u, v], axis=1) @jaxtyped(typechecker=beartype) def project( points_3d: Num[Array, "N 3"], projection_matrix: Num[Array, "3 4"], K: Optional[Num[Array, "3 3"]] = None, # pylint: disable=invalid-name dist_coeffs: Optional[Num[Array, "5"]] = None, image_size: Optional[Num[Array, "2"]] = None, ) -> Num[Array, "N 2"]: """ Project 3D points to 2D points in pixel coordinates Args: points_3d: 3D points in world coordinates projection_matrix: pre-computed projection matrix (K @ Rt[:3, :]) that projects to pixel coordinates K: (optional) Camera intrinsic matrix, unnormalized dist_coeffs: (optional) Distortion coefficients image_size: (optional) Image dimensions [width, height] for valid point check. If not provided, uses (0,1) normalized coordinates. Note: K and dist_coeffs must be provided together, or both be None. If K is provided, it assumes that the projection matrix is calculated from the same K. Returns: 2D points in pixel coordinates """ P = projection_matrix # pylint: disable=invalid-name p3d_homogeneous = jnp.hstack( (points_3d, jnp.ones((points_3d.shape[0], 1), dtype=points_3d.dtype)) ) # Project points p2d_homogeneous = p3d_homogeneous @ P.T # Perspective division p2d = p2d_homogeneous[:, 0:2] / p2d_homogeneous[:, 2:3] if dist_coeffs is not None and K is not None: # Check if valid points (within image boundaries) if image_size is not None: # Use image dimensions for valid point check in pixel space valid = ( jnp.all(p2d >= 0, axis=1) & (p2d[:, 0] < image_size[0]) & (p2d[:, 1] < image_size[1]) ) else: # Fall back to normalized coordinates if image_size not provided valid = jnp.all(p2d >= 0, axis=1) & jnp.all(p2d < 1, axis=1) # only valid points need distortion if jnp.any(valid): valid_p2d = p2d[valid] distorted_valid = distortion(valid_p2d, K, dist_coeffs) p2d = p2d.at[valid].set(distorted_valid) elif dist_coeffs is None and K is None: pass else: raise ValueError( "dist_coeffs and K must be provided together to compute distortion" ) return jnp.squeeze(p2d) @jaxtyped(typechecker=beartype) @dataclass(frozen=True) class CameraParams: """ Camera parameters: intrinsic matrix, extrinsic matrix, and distortion coefficients """ K: Num[Array, "3 3"] """ intrinsic matrix """ Rt: Num[Array, "4 4"] """ [R|t] extrinsic matrix R and t are the rotation and translation that describe the change of coordinates from world to camera coordinate systems (or camera frame) Rt is expected to be World-to-Camera (W2C) transformation matrix, which is the result of `solvePnP` in OpenCV. (but converted to homogeneous coordinates) World-to-Camera (W2C): Transforms points from world coordinates to camera coordinates - The world origin is transformed to camera space - Used for projecting 3D world points onto the camera's image plane - Required for rendering/projection """ dist_coeffs: Num[Array, "5"] """ An array of distortion coefficients of the form [k1, k2, [p1, p2, [k3]]], where ki is the ith radial distortion coefficient and pi is the ith tangential distortion coeff. """ image_size: Num[Array, "2"] """ The size of image plane (width, height) """ @property def pose_matrix(self) -> Num[Array, "4 4"]: """ The inversion of the extrinsic matrix, which gives Camera-to-World (C2W) transformation matrix. Camera-to-World (C2W): Transforms points from camera coordinates to world coordinates - The camera is the origin in camera space - This transformation tells where the camera is positioned in world space - Often used for camera positioning/orientation The result is cached on first access. (lazy evaluation) """ t = getattr(self, "_pose", None) if t is None: t = jnp.linalg.inv(self.Rt) object.__setattr__(self, "_pose", t) return t @property def projection_matrix(self) -> Num[Array, "3 4"]: """ Returns the 3x4 projection matrix K @ [R|t]. The result is cached on first access. (lazy evaluation) """ pm = getattr(self, "_proj", None) if pm is None: pm = self.K @ self.Rt[:3, :] # object.__setattr__ bypasses the frozen‐dataclass blocker object.__setattr__(self, "_proj", pm) return pm @jaxtyped(typechecker=beartype) @dataclass(frozen=True) class Camera: """ a description of a camera """ id: CameraID """ Camera ID """ params: CameraParams """ Camera parameters """ def project(self, points_3d: Num[Array, "N 3"]) -> Num[Array, "N 2"]: """ Project 3D points to 2D points using this camera's parameters Args: points_3d: 3D points in world coordinates Returns: 2D points in pixel coordinates """ return project( points_3d, projection_matrix=self.params.projection_matrix, K=self.params.K, dist_coeffs=self.params.dist_coeffs, image_size=self.params.image_size, ) def project_ideal(self, points_3d: Num[Array, "N 3"]) -> Num[Array, "N 2"]: """ Project 3D points to 2D points using this camera's parameters, without distortion Args: points_3d: 3D points in world coordinates Returns: 2D points in pixel coordinates """ return project( points_3d, projection_matrix=self.params.projection_matrix, image_size=self.params.image_size, ) def distortion(self, points_2d: Num[Array, "N 2"]) -> Num[Array, "N 2"]: """ Apply distortion to 2D points using this camera's parameters Args: points_2d: 2D points in image coordinates Returns: Distorted 2D points """ return distortion( points_2d=points_2d, K=self.params.K, dist_coeffs=self.params.dist_coeffs, ) @jaxtyped(typechecker=beartype) @dataclass(frozen=True) class Detection: """ One detection from a camera """ keypoints: Num[Array, "N 2"] """ Keypoints in pixel coordinates. (with camera distortion) Use `keypoints_undistorted` to get undistorted keypoints. """ confidences: Num[Array, "N"] """ Confidences """ camera: Camera """ Camera """ timestamp: datetime """ Timestamp of the detection """ @property def keypoints_undistorted(self) -> Num[Array, "N 2"]: """ Returns undistorted keypoints. The result is cached on first access. (lazy evaluation) """ kpu = getattr(self, "_kp_undistorted", None) if kpu is None: kpu_np = undistort_points( np.asarray(self.keypoints), np.asarray(self.camera.params.K), np.asarray(self.camera.params.dist_coeffs), ) kpu = jnp.asarray(kpu_np) object.__setattr__(self, "_kp_undistorted", kpu) return kpu def classify_by_camera( detections: list[Detection], ) -> OrderedDict[CameraID, list[Detection]]: """ Classify detections by camera """ # or use setdefault camera_wise_split: dict[CameraID, list[Detection]] = defaultdict(list) for entry in detections: camera_id = entry.camera.id camera_wise_split[camera_id].append(entry) return OrderedDict(camera_wise_split) @jaxtyped(typechecker=beartype) def to_homogeneous(points: Num[Array, "N 2"] | Num[Array, "N 3"]) -> Num[Array, "N 3"]: """ Convert points to homogeneous coordinates """ if points.shape[-1] == 2: return jnp.hstack((points, jnp.ones((points.shape[0], 1)))) elif points.shape[-1] == 3: return points else: raise ValueError(f"Invalid shape for points: {points.shape}") @jaxtyped(typechecker=beartype) def point_line_distance( points: Num[Array, "N 3"] | Num[Array, "N 2"], line: Num[Array, "N 3"], eps: float = 1e-9, ): """ Calculate the distance from a point to a line Args: point: (possibly homogeneous) points :math:`(N, 2 or 3)`. line: lines coefficients :math:`(a, b, c)` with shape :math:`(N, 3)`, where :math:`ax + by + c = 0`. eps: Small constant for safe sqrt. Returns: the computed distance with shape :math:`(N)`. See also: https://en.wikipedia.org/wiki/Distance_from_a_point_to_a_line """ numerator = abs(line[:, 0] * points[:, 0] + line[:, 1] * points[:, 1] + line[:, 2]) denominator = jnp.sqrt(line[:, 0] * line[:, 0] + line[:, 1] * line[:, 1]) return numerator / (denominator + eps) @jaxtyped(typechecker=beartype) def left_to_right_epipolar_distance( left: Num[Array, "N 3"], right: Num[Array, "N 3"], fundamental_matrix: Num[Array, "3 3"], ): """ Return one-sided epipolar distance for correspondences given the fundamental matrix. Args: left: points in the left image (homogeneous) :math:`(N, 3)` right: points in the right image (homogeneous) :math:`(N, 3)` fundamental_matrix: fundamental matrix :math:`(3, 3)` Returns: the computed distance with shape :math:`(N)`. See also: https://en.wikipedia.org/wiki/Fundamental_matrix_%28computer_vision%29 $$x^{\\prime T}Fx = 0$$ """ F_t = fundamental_matrix.transpose() line1_in_2 = jnp.matmul(left, F_t) return point_line_distance(right, line1_in_2) @jaxtyped(typechecker=beartype) def right_to_left_epipolar_distance( left: Num[Array, "N 3"], right: Num[Array, "N 3"], fundamental_matrix: Num[Array, "3 3"], ): """ Return one-sided epipolar distance for correspondences given the fundamental matrix. Args: left: points in the left image (homogeneous) :math:`(N, 3)` right: points in the right image (homogeneous) :math:`(N, 3)` fundamental_matrix: fundamental matrix :math:`(3, 3)` Returns: the computed distance with shape :math:`(N)`. See also: https://en.wikipedia.org/wiki/Fundamental_matrix_%28computer_vision%29 $$x^{\\prime T}Fx = 0$$ """ line2_in_1 = jnp.matmul(right, fundamental_matrix) return point_line_distance(left, line2_in_1) def distance_between_epipolar_lines( x1: Num[Array, "N 2"] | Num[Array, "N 3"], x2: Num[Array, "N 2"] | Num[Array, "N 3"], fundamental_matrix: Num[Array, "3 3"], ): """ Calculate the total epipolar line distance between x1 and x2. """ if x1.shape[0] != x2.shape[0]: raise ValueError( f"x1 and x2 must have the same number of points: {x1.shape[0]} != {x2.shape[0]}" ) if x1.shape[-1] == 2: points1 = to_homogeneous(x1) elif x1.shape[-1] == 3: points1 = x1 else: raise ValueError(f"Invalid shape for correspondence1: {x1.shape}") if x2.shape[-1] == 2: points2 = to_homogeneous(x2) elif x2.shape[-1] == 3: points2 = x2 else: raise ValueError(f"Invalid shape for correspondence2: {x2.shape}") if fundamental_matrix.shape != (3, 3): raise ValueError( f"Invalid shape for fundamental_matrix: {fundamental_matrix.shape}" ) # points 1 and 2 are unnormalized points dist_1 = jnp.mean( right_to_left_epipolar_distance(points1, points2, fundamental_matrix) ) dist_2 = jnp.mean( left_to_right_epipolar_distance(points1, points2, fundamental_matrix) ) distance = dist_1 + dist_2 return distance @jaxtyped(typechecker=beartype) def calculate_fundamental_matrix( camera_left: Camera, camera_right: Camera ) -> Num[Array, "3 3"]: """ Calculate the fundamental matrix for the given cameras. """ # Intrinsics K1 = camera_left.params.K K2 = camera_right.params.K # Extrinsics (World to Camera transforms) Rt1 = camera_left.params.Rt Rt2 = camera_right.params.Rt # Convert to Camera to World (Inverse) T2: Array = jnp.linalg.inv(Rt2) # Relative transform from Left to Right T_rel = T2 @ Rt1 R = T_rel[:3, :3] t = T_rel[:3, 3:] # Skew-symmetric matrix for cross product def skew(v: Num[Array, "3"]) -> Num[Array, "3 3"]: return jnp.array( [ [0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0], ] ) t_skew = skew(t.reshape(-1)) # Essential Matrix E = t_skew @ R # Fundamental Matrix F = jnp.linalg.inv(K2).T @ E @ jnp.linalg.inv(K1) return F def compute_affinity_epipolar_constraint_with_pairs( left: Detection, right: Detection, alpha_2d: float ): """ Compute the affinity between two groups of detections by epipolar constraint, where camera parameters are included in the detections. Note: Originally, alpha_2d comes from the paper as a scaling factor for epipolar error affinity. Its role is mainly to normalize error into [0,1] range, but it could lead to negative affinity. An alternative approach is using normalized epipolar error relative to image size, with soft cutoff, like exp(-error / threshold), for better interpretability and stability. """ fundamental_matrix = calculate_fundamental_matrix(left.camera, right.camera) d = distance_between_epipolar_lines( left.keypoints, right.keypoints, fundamental_matrix ) return 1 - (d / alpha_2d) def calculate_affinity_matrix_by_epipolar_constraint( detections: list[Detection] | dict[CameraID, list[Detection]], alpha_2d: float, ) -> tuple[list[Detection], Num[Array, "N N"]]: """ Calculate the affinity matrix by epipolar constraint This function evaluates the geometric consistency of every pair of detections across different cameras using the fundamental matrix. It assumes that detections from the same camera are not comparable and should have zero affinity. The affinity is computed by: 1. Calculating the fundamental matrix between the two cameras. 2. Measuring the average point-to-epipolar-line distance for all keypoints. 3. Mapping the distance to affinity with the formula: 1 - (distance / alpha_2d). Args: detections: Either a flat list of Detection or a dict grouping Detection by CameraID. alpha_2d: Image resolution-dependent threshold controlling affinity scaling. Typically relates to expected pixel displacement rate. Returns: sorted_detections: Flattened list of detections sorted by camera order. affinity_matrix: Array of shape (N, N), where N is the number of detections. Notes: - Detections from the same camera always have affinity = 0. - Affinity decays linearly with epipolar error until 0 (or potentially negative). - Consider switching to exp(-error / scale) style for non-negative affinity. - alpha_2d should be adjusted based on image resolution or empirical observation. Affinity Matrix layout: assuming we have 3 cameras C0 has 3 detections: D0_C0, D1_C0, D2_C0 C1 has 2 detections: D0_C1, D1_C1 C2 has 2 detections: D0_C2, D1_C2 D0_C0(0), D1_C0(1), D2_C0(2), D0_C1(3), D1_C1(4), D0_C2(5), D1_C2(6) D0_C0(0) 0 0 0 a_03 a_04 a_05 a_06 D1_C0(1) 0 0 0 a_13 a_14 a_15 a_16 ... D0_C1(3) a_30 a_31 a_32 0 0 a_35 a_36 ... D1_C2(6) a_60 a_61 a_62 a_63 a_64 0 0 """ if isinstance(detections, dict): camera_wise_split = detections else: camera_wise_split = classify_by_camera(detections) num_entries = sum(len(entries) for entries in camera_wise_split.values()) affinity_matrix = jnp.ones((num_entries, num_entries), dtype=jnp.float32) * -jnp.inf affinity_matrix_mask = jnp.zeros((num_entries, num_entries), dtype=jnp.bool_) acc = 0 total_indices = set(range(num_entries)) camera_id_index_map: dict[CameraID, set[int]] = defaultdict(set) camera_id_index_map_inverse: dict[CameraID, set[int]] = defaultdict(set) # sorted by [D0_C0, D1_C0, D2_C0, D0_C1, D1_C1, D0_C2, D1_C2...] sorted_detections: list[Detection] = [] for camera_id, entries in camera_wise_split.items(): for i, _ in enumerate(entries): camera_id_index_map[camera_id].add(acc) sorted_detections.append(entries[i]) acc += 1 camera_id_index_map_inverse[camera_id] = ( total_indices - camera_id_index_map[camera_id] ) # ignore self-affinity # ignore same-camera affinity # assuming commutative for i, det in enumerate(sorted_detections): other_indices = camera_id_index_map_inverse[det.camera.id] for j in other_indices: if i == j: continue if affinity_matrix_mask[i, j] or affinity_matrix_mask[j, i]: continue a = compute_affinity_epipolar_constraint_with_pairs( det, sorted_detections[j], alpha_2d ) affinity_matrix = affinity_matrix.at[i, j].set(a) affinity_matrix = affinity_matrix.at[j, i].set(a) affinity_matrix_mask = affinity_matrix_mask.at[i, j].set(True) affinity_matrix_mask = affinity_matrix_mask.at[j, i].set(True) return sorted_detections, affinity_matrix