import numpy as np import open3d as o3d from typing import Dict, List, Optional, Tuple, Any, TYPE_CHECKING from dataclasses import dataclass, field from jaxtyping import Float from scipy.spatial.transform import Rotation from loguru import logger from .pose_math import invert_transform if TYPE_CHECKING: Vec3 = Float[np.ndarray, "3"] Mat44 = Float[np.ndarray, "4 4"] PointsNC = Float[np.ndarray, "N 3"] else: Vec3 = np.ndarray Mat44 = np.ndarray PointsNC = np.ndarray @dataclass class ICPConfig: """Configuration for ICP registration.""" voxel_size: float = 0.02 # Base voxel size in meters max_iterations: list[int] = field(default_factory=lambda: [50, 30, 14]) method: str = "point_to_plane" # "point_to_plane" or "gicp" band_height: float = 0.3 # Near-floor band height in meters min_fitness: float = 0.15 # Min ICP fitness to accept pair min_overlap_area: float = 0.5 # Min XZ overlap area in m^2 overlap_margin: float = 0.5 # Inflate bboxes by this margin (m) overlap_mode: str = "xz" # 'xz' or '3d' gravity_penalty_weight: float = 2.0 # Soft constraint on pitch/roll max_correspondence_distance_factor: float = 2.5 max_rotation_deg: float = 10.0 # Safety bound on ICP delta max_translation_m: float = 0.3 # Safety bound on ICP delta max_pair_rotation_deg: float = 5.0 # Plausibility gate for pairwise updates max_pair_translation_m: float = 0.5 # Plausibility gate for pairwise updates max_final_rotation_deg: float = ( 15.0 # Stricter post-opt gate for final camera update ) max_final_translation_m: float = ( 1.0 # Stricter post-opt gate for final camera update ) region: str = "floor" # "floor", "hybrid", or "full" robust_kernel: str = "none" # "none" or "tukey" robust_kernel_k: float = 0.1 global_init: bool = False # Enable FPFH+RANSAC global pre-alignment @dataclass class ICPResult: """Result of a pairwise ICP registration.""" transformation: Mat44 # 4x4 fitness: float inlier_rmse: float information_matrix: np.ndarray # 6x6 converged: bool @dataclass class ICPMetrics: """Metrics for the global ICP refinement process.""" success: bool = False num_pairs_attempted: int = 0 num_pairs_converged: int = 0 num_cameras_optimized: int = 0 num_disconnected: int = 0 per_pair_results: dict[tuple[str, str], ICPResult] = field(default_factory=dict) reference_camera: str = "" message: str = "" def preprocess_point_cloud( pcd: o3d.geometry.PointCloud, voxel_size: float, ) -> o3d.geometry.PointCloud: """ Preprocess point cloud: downsample and remove outliers. """ pcd_down = pcd.voxel_down_sample(voxel_size) # SOR: nb_neighbors=20, std_ratio=2.0 pcd_clean, _ = pcd_down.remove_statistical_outlier(nb_neighbors=20, std_ratio=2.0) return pcd_clean def extract_scene_points( points_world: PointsNC, floor_y: float, floor_normal: Vec3, mode: str = "floor", band_height: float = 0.3, ) -> PointsNC: """ Extract points based on mode: - 'floor': points within band_height of floor - 'hybrid': floor points + vertical structures (walls/pillars) - 'full': all points """ if len(points_world) == 0: return points_world if mode == "full": return points_world if mode == "floor": return extract_near_floor_band(points_world, floor_y, band_height, floor_normal) if mode == "hybrid": # 1. Get floor points floor_pts = extract_near_floor_band( points_world, floor_y, band_height, floor_normal ) # 2. Get vertical points # Need normals for this. Create temp PCD pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(points_world) # Estimate normals if not present (using hybrid KD-tree) pcd.estimate_normals( search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.1, max_nn=30) ) normals = np.asarray(pcd.normals) if len(normals) != len(points_world): logger.warning( "Normal estimation failed, falling back to floor points only" ) return floor_pts # Dot product with floor normal # Vertical surfaces have normals perpendicular to floor normal -> dot product near 0 dots = np.abs(normals @ floor_normal) # Keep points where normal is roughly perpendicular to floor normal (vertical surface) # Threshold 0.3 allows for some noise/slope (approx 72-108 degrees from floor normal) vertical_mask = dots < 0.3 vertical_pts = points_world[vertical_mask] if len(vertical_pts) == 0: logger.warning( "No vertical structure found in hybrid mode, falling back to floor points" ) return floor_pts # Combine unique points (though sets are disjoint by definition of mask vs band? # No, band is spatial, vertical is orientation. They might overlap.) # Simply concatenating might duplicate. # Let's use a boolean mask for union. # Re-compute floor mask to combine projections = points_world @ floor_normal floor_mask = (projections >= floor_y) & (projections <= floor_y + band_height) combined_mask = floor_mask | vertical_mask return points_world[combined_mask] # Default fallback return extract_near_floor_band(points_world, floor_y, band_height, floor_normal) def extract_near_floor_band( points_world: PointsNC, floor_y: float, band_height: float, floor_normal: Vec3, ) -> PointsNC: """ Extract points within a vertical band relative to the floor. Points are in world frame. """ if len(points_world) == 0: return points_world # Project points onto floor normal # Distance from origin along normal: p . n # We want points where floor_y <= p.n <= floor_y + band_height projections = points_world @ floor_normal mask = (projections >= floor_y) & (projections <= floor_y + band_height) return points_world[mask] def compute_fpfh_features( pcd_down: o3d.geometry.PointCloud, voxel_size: float, ) -> o3d.pipelines.registration.Feature: """ Compute FPFH features for a downsampled point cloud. """ radius_normal = voxel_size * 2 pcd_down.estimate_normals( o3d.geometry.KDTreeSearchParamHybrid(radius=radius_normal, max_nn=30) ) radius_feature = voxel_size * 5 pcd_fpfh = o3d.pipelines.registration.compute_fpfh_feature( pcd_down, o3d.geometry.KDTreeSearchParamHybrid(radius=radius_feature, max_nn=100), ) return pcd_fpfh def global_registration( source_down: o3d.geometry.PointCloud, target_down: o3d.geometry.PointCloud, source_fpfh: o3d.pipelines.registration.Feature, target_fpfh: o3d.pipelines.registration.Feature, voxel_size: float, ) -> o3d.pipelines.registration.RegistrationResult: """ Perform RANSAC-based global registration. """ distance_threshold = voxel_size * 1.5 result = o3d.pipelines.registration.registration_ransac_based_on_feature_matching( source_down, target_down, source_fpfh, target_fpfh, True, # mutual_filter distance_threshold, o3d.pipelines.registration.TransformationEstimationPointToPoint(False), 3, # ransac_n [ o3d.pipelines.registration.CorrespondenceCheckerBasedOnEdgeLength(0.9), o3d.pipelines.registration.CorrespondenceCheckerBasedOnDistance( distance_threshold ), ], o3d.pipelines.registration.RANSACConvergenceCriteria(4000000, 500), ) return result def compute_overlap_xz( points_a: PointsNC, points_b: PointsNC, margin: float = 0.0, ) -> float: """ Compute intersection area of XZ bounding boxes. """ if len(points_a) == 0 or len(points_b) == 0: return 0.0 min_a = np.min(points_a[:, [0, 2]], axis=0) - margin max_a = np.max(points_a[:, [0, 2]], axis=0) + margin min_b = np.min(points_b[:, [0, 2]], axis=0) - margin max_b = np.max(points_b[:, [0, 2]], axis=0) + margin inter_min = np.maximum(min_a, min_b) inter_max = np.minimum(max_a, max_b) dims = np.maximum(0, inter_max - inter_min) return float(dims[0] * dims[1]) def compute_overlap_3d( points_a: PointsNC, points_b: PointsNC, margin: float = 0.0, ) -> float: """ Compute intersection volume of 3D AABBs. """ if len(points_a) == 0 or len(points_b) == 0: return 0.0 min_a = np.min(points_a, axis=0) - margin max_a = np.max(points_a, axis=0) + margin min_b = np.min(points_b, axis=0) - margin max_b = np.max(points_b, axis=0) + margin inter_min = np.maximum(min_a, min_b) inter_max = np.minimum(max_a, max_b) dims = np.maximum(0, inter_max - inter_min) return float(dims[0] * dims[1] * dims[2]) def apply_gravity_constraint( T_icp: Mat44, T_original: Mat44, penalty_weight: float = 10.0, ) -> Mat44: """ Preserve RANSAC gravity alignment while allowing yaw + XZ + height refinement. """ R_icp = T_icp[:3, :3] R_orig = T_original[:3, :3] rot_icp = Rotation.from_matrix(R_icp) rot_orig = Rotation.from_matrix(R_orig) euler_icp = rot_icp.as_euler("xyz") euler_orig = rot_orig.as_euler("xyz") # Blend pitch (x) and roll (z) # blended = original + (icp - original) / (1 + penalty_weight) # Handle angular wrap-around for robustness diff = euler_icp - euler_orig diff = (diff + np.pi) % (2 * np.pi) - np.pi blended_euler = euler_orig + diff / (1 + penalty_weight) # Keep ICP yaw (y) blended_euler[1] = euler_icp[1] R_constrained = Rotation.from_euler("xyz", blended_euler).as_matrix() T_constrained = T_icp.copy() T_constrained[:3, :3] = R_constrained return T_constrained def pairwise_icp( source_pcd: o3d.geometry.PointCloud, target_pcd: o3d.geometry.PointCloud, config: ICPConfig, init_transform: Mat44, ) -> ICPResult: """ Multi-scale ICP registration. """ current_transform = init_transform voxel_scales = [4, 2, 1] # Initialize reg_result to handle empty scales or other issues # but we expect at least one scale. reg_result = None # Robust kernel setup loss = None if config.robust_kernel == "tukey": loss = o3d.pipelines.registration.TukeyLoss(k=config.robust_kernel_k) for i, scale in enumerate(voxel_scales): voxel_size = config.voxel_size * scale max_iter = config.max_iterations[i] source_down = source_pcd.voxel_down_sample(voxel_size) target_down = target_pcd.voxel_down_sample(voxel_size) source_down.estimate_normals( o3d.geometry.KDTreeSearchParamHybrid(radius=voxel_size * 2, max_nn=30) ) target_down.estimate_normals( o3d.geometry.KDTreeSearchParamHybrid(radius=voxel_size * 2, max_nn=30) ) dist_thresh = voxel_size * config.max_correspondence_distance_factor criteria = o3d.pipelines.registration.ICPConvergenceCriteria( max_iteration=max_iter ) if config.method == "point_to_plane": estimation = ( o3d.pipelines.registration.TransformationEstimationPointToPlane(loss) if loss else o3d.pipelines.registration.TransformationEstimationPointToPlane() ) reg_result = o3d.pipelines.registration.registration_icp( source_down, target_down, dist_thresh, current_transform, estimation, criteria, ) elif config.method == "gicp": estimation = ( o3d.pipelines.registration.TransformationEstimationForGeneralizedICP( loss ) if loss else o3d.pipelines.registration.TransformationEstimationForGeneralizedICP() ) reg_result = o3d.pipelines.registration.registration_generalized_icp( source_down, target_down, dist_thresh, current_transform, estimation, criteria, ) else: # Fallback estimation = ( o3d.pipelines.registration.TransformationEstimationPointToPoint() ) reg_result = o3d.pipelines.registration.registration_icp( source_down, target_down, dist_thresh, current_transform, estimation, criteria, ) current_transform = reg_result.transformation if reg_result is None: return ICPResult( transformation=init_transform, fitness=0.0, inlier_rmse=0.0, information_matrix=np.eye(6), converged=False, ) # Final information matrix info_matrix = o3d.pipelines.registration.get_information_matrix_from_point_clouds( source_pcd, target_pcd, config.voxel_size * config.max_correspondence_distance_factor, current_transform, ) return ICPResult( transformation=current_transform, fitness=reg_result.fitness, inlier_rmse=reg_result.inlier_rmse, information_matrix=info_matrix, converged=reg_result.fitness > config.min_fitness, ) def build_pose_graph( serials: List[str], extrinsics: Dict[str, Mat44], pair_results: Dict[Tuple[str, str], ICPResult], reference_serial: str, ) -> o3d.pipelines.registration.PoseGraph: """ Build a PoseGraph from pairwise results. Only includes cameras reachable from the reference camera. """ # 1. Detect connected component from reference connected = {reference_serial} queue = [reference_serial] while queue: curr = queue.pop(0) for (s1, s2), result in pair_results.items(): if not result.converged: continue if s1 == curr and s2 not in connected: connected.add(s2) queue.append(s2) elif s2 == curr and s1 not in connected: connected.add(s1) queue.append(s1) # 2. Filter serials to only include connected ones # Keep reference_serial at index 0 optimized_serials = [reference_serial] + sorted( list(connected - {reference_serial}) ) serial_to_idx = {s: i for i, s in enumerate(optimized_serials)} # Log disconnected cameras disconnected = set(serials) - connected if disconnected: logger.warning( f"Cameras disconnected from reference {reference_serial}: {disconnected}" ) pose_graph = o3d.pipelines.registration.PoseGraph() for serial in optimized_serials: T_wc = extrinsics[serial] pose_graph.nodes.append(o3d.pipelines.registration.PoseGraphNode(T_wc)) for (s1, s2), result in pair_results.items(): if not result.converged: continue if s1 not in serial_to_idx or s2 not in serial_to_idx: continue idx1 = serial_to_idx[s1] idx2 = serial_to_idx[s2] # Edge from idx1 to idx2 # Open3D PoseGraphEdge(s, t, T) enforces T = Pose(t)^-1 * Pose(s) # i.e., T is the pose of s in t's frame (T_t_s). # result.transformation is T_c2_c1 (aligns pcd1 to pcd2), which is Pose of c1 in c2. # So s=c1 (idx1), t=c2 (idx2). edge = o3d.pipelines.registration.PoseGraphEdge( idx1, idx2, result.transformation, result.information_matrix, uncertain=True ) pose_graph.edges.append(edge) return pose_graph def optimize_pose_graph( pose_graph: o3d.pipelines.registration.PoseGraph, ) -> o3d.pipelines.registration.PoseGraph: """ Run global optimization. """ option = o3d.pipelines.registration.GlobalOptimizationOption( max_correspondence_distance=0.1, edge_prune_threshold=0.25, reference_node=0, ) o3d.pipelines.registration.global_optimization( pose_graph, o3d.pipelines.registration.GlobalOptimizationLevenbergMarquardt(), o3d.pipelines.registration.GlobalOptimizationConvergenceCriteria(), option, ) return pose_graph def refine_with_icp( camera_data: Dict[str, Dict[str, Any]], extrinsics: Dict[str, Mat44], floor_planes: Dict[str, Any], # Dict[str, FloorPlane] config: ICPConfig, ) -> Tuple[Dict[str, Mat44], ICPMetrics]: """ Main orchestrator for ICP refinement. """ from .ground_plane import unproject_depth_to_points metrics = ICPMetrics() serials = sorted(list(camera_data.keys())) if not serials: return extrinsics, metrics metrics.reference_camera = serials[0] # 1. Extract scene points camera_pcds: Dict[str, o3d.geometry.PointCloud] = {} camera_points: Dict[str, PointsNC] = {} # Auto-set overlap mode based on region if config.region == "floor": config.overlap_mode = "xz" elif config.region in ["hybrid", "full"]: config.overlap_mode = "3d" for serial in serials: if serial not in floor_planes or serial not in extrinsics: continue data = camera_data[serial] plane = floor_planes[serial] points_cam = unproject_depth_to_points(data["depth"], data["K"], stride=4) T_wc = extrinsics[serial] points_world = (points_cam @ T_wc[:3, :3].T) + T_wc[:3, 3] # floor_y = -plane.d (distance to origin along normal) floor_y = -plane.d scene_points = extract_scene_points( points_world, floor_y, plane.normal, mode=config.region, band_height=config.band_height, ) if len(scene_points) < 100: continue pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(scene_points) # Apply SOR preprocessing pcd = preprocess_point_cloud(pcd, config.voxel_size) if len(pcd.points) < 100: continue camera_pcds[serial] = pcd camera_points[serial] = np.asarray(pcd.points) # 2. Pairwise ICP valid_serials = sorted(list(camera_pcds.keys())) pair_results: Dict[Tuple[str, str], ICPResult] = {} for i, s1 in enumerate(valid_serials): for j in range(i + 1, len(valid_serials)): s2 = valid_serials[j] # Dispatch overlap check if config.overlap_mode == "3d": area = compute_overlap_3d( camera_points[s1], camera_points[s2], config.overlap_margin ) else: area = compute_overlap_xz( camera_points[s1], camera_points[s2], config.overlap_margin ) if area < config.min_overlap_area: unit = "m^3" if config.overlap_mode == "3d" else "m^2" logger.debug( f"Skipping pair ({s1}, {s2}) due to insufficient overlap: {area:.2f} {unit} < {config.min_overlap_area:.2f} {unit}" ) continue metrics.num_pairs_attempted += 1 # Initial relative transform from current extrinsics # T_21 = T_w2^-1 * T_w1 T_w1 = extrinsics[s1] T_w2 = extrinsics[s2] init_T = invert_transform(T_w2) @ T_w1 # Prepare point clouds for ICP # Always transform back to camera frame to refine relative transform directly # This avoids world-frame identity traps and ensures consistent behavior across regions pcd1 = o3d.geometry.PointCloud() pcd1.points = o3d.utility.Vector3dVector( (np.asarray(camera_pcds[s1].points) - T_w1[:3, 3]) @ T_w1[:3, :3] ) pcd2 = o3d.geometry.PointCloud() pcd2.points = o3d.utility.Vector3dVector( (np.asarray(camera_pcds[s2].points) - T_w2[:3, 3]) @ T_w2[:3, :3] ) current_init = init_T if config.global_init: # Downsample for global registration voxel_size = config.voxel_size source_down = pcd1.voxel_down_sample(voxel_size) target_down = pcd2.voxel_down_sample(voxel_size) source_fpfh = compute_fpfh_features(source_down, voxel_size) target_fpfh = compute_fpfh_features(target_down, voxel_size) global_result = global_registration( source_down, target_down, source_fpfh, target_fpfh, voxel_size ) if global_result.fitness > 0.1: # Validate against safety bounds relative to extrinsic init T_global = global_result.transformation # Compare T_global against current_init T_diff = T_global @ invert_transform(current_init) rot_diff = Rotation.from_matrix(T_diff[:3, :3]).as_euler( "xyz", degrees=True ) rot_mag = np.linalg.norm(rot_diff) trans_mag = np.linalg.norm(T_diff[:3, 3]) if ( rot_mag <= config.max_rotation_deg and trans_mag <= config.max_translation_m ): logger.info( f"Global registration accepted for ({s1}, {s2}): fitness={global_result.fitness:.3f}" ) current_init = T_global else: logger.warning( f"Global registration rejected for ({s1}, {s2}): exceeds bounds (rot={rot_mag:.1f}, trans={trans_mag:.3f})" ) else: logger.warning( f"Global registration failed for ({s1}, {s2}): low fitness {global_result.fitness:.3f}" ) result = pairwise_icp(pcd1, pcd2, config, current_init) # Apply gravity constraint to the result relative to original transform # T_original is the initial T_21 (from extrinsics) T_original_21 = invert_transform(T_w2) @ T_w1 result.transformation = apply_gravity_constraint( result.transformation, T_original_21, config.gravity_penalty_weight ) # Debug logging for transform deltas T_delta = result.transformation @ invert_transform(init_T) rot_delta = Rotation.from_matrix(T_delta[:3, :3]).as_euler( "xyz", degrees=True ) trans_delta = np.linalg.norm(T_delta[:3, 3]) logger.debug( f"Pair ({s1}, {s2}) delta: rot={np.linalg.norm(rot_delta):.2f}deg, trans={trans_delta:.3f}m" ) metrics.per_pair_results[(s1, s2)] = result logger.info( f"Pair ({s1}, {s2}) ICP result: fitness={result.fitness:.3f}, rmse={result.inlier_rmse:.4f}, converged={result.converged}" ) if result.converged: # Pair-level plausibility gate rot_mag = np.linalg.norm(rot_delta) if ( rot_mag > config.max_pair_rotation_deg or trans_delta > config.max_pair_translation_m ): logger.warning( f"Pair ({s1}, {s2}) converged but rejected by pair-level gate: " f"rot={rot_mag:.2f}deg (max={config.max_pair_rotation_deg}), " f"trans={trans_delta:.3f}m (max={config.max_pair_translation_m})" ) else: metrics.num_pairs_converged += 1 pair_results[(s1, s2)] = result if not pair_results: metrics.message = "No converged ICP pairs" # 3. Pose Graph pose_graph = build_pose_graph( valid_serials, extrinsics, pair_results, metrics.reference_camera ) # 4. Optimize optimize_pose_graph(pose_graph) # 5. Extract and Validate new_extrinsics = extrinsics.copy() # Re-derive optimized_serials to match build_pose_graph logic for node-to-serial mapping connected = {metrics.reference_camera} queue = [metrics.reference_camera] while queue: curr = queue.pop(0) for (s1, s2), result in pair_results.items(): if not result.converged: continue if s1 == curr and s2 not in connected: connected.add(s2) queue.append(s2) elif s2 == curr and s1 not in connected: connected.add(s1) queue.append(s1) optimized_serials = [metrics.reference_camera] + sorted( list(connected - {metrics.reference_camera}) ) logger.info(f"Optimized connected component: {optimized_serials}") metrics.num_disconnected = len(valid_serials) - len(optimized_serials) metrics.num_cameras_optimized = 0 for i, serial in enumerate(optimized_serials): T_optimized = pose_graph.nodes[i].pose T_old = extrinsics[serial] # Validate delta T_delta = T_optimized @ invert_transform(T_old) rot_delta = Rotation.from_matrix(T_delta[:3, :3]).as_euler("xyz", degrees=True) rot_mag = np.linalg.norm(rot_delta) trans_mag = np.linalg.norm(T_delta[:3, 3]) logger.debug( f"Camera {serial} optimization delta: rot={rot_mag:.2f}deg, trans={trans_mag:.3f}m" ) # Validate delta against both existing and final safety bounds if rot_mag > config.max_rotation_deg: logger.warning( f"Camera {serial} ICP correction exceeds max_rotation_deg: {rot_mag:.2f}deg > {config.max_rotation_deg:.2f}deg. Rejecting." ) continue if rot_mag > config.max_final_rotation_deg: logger.warning( f"Camera {serial} ICP correction exceeds max_final_rotation_deg: {rot_mag:.2f}deg > {config.max_final_rotation_deg:.2f}deg. Rejecting." ) continue if trans_mag > config.max_translation_m: logger.warning( f"Camera {serial} ICP correction exceeds max_translation_m: {trans_mag:.3f}m > {config.max_translation_m:.3f}m. Rejecting." ) continue if trans_mag > config.max_final_translation_m: logger.warning( f"Camera {serial} ICP correction exceeds max_final_translation_m: {trans_mag:.3f}m > {config.max_final_translation_m:.3f}m. Rejecting." ) continue new_extrinsics[serial] = T_optimized metrics.num_cameras_optimized += 1 metrics.success = metrics.num_cameras_optimized > 0 metrics.message = f"Optimized {metrics.num_cameras_optimized} cameras" return new_extrinsics, metrics