import numpy as np from typing import Optional, Tuple, List, Dict, Any from jaxtyping import Float from typing import TYPE_CHECKING import open3d as o3d from dataclasses import dataclass, field if TYPE_CHECKING: Vec3 = Float[np.ndarray, "3"] Mat44 = Float[np.ndarray, "4 4"] PointsNC = Float[np.ndarray, "N 3"] else: Vec3 = np.ndarray Mat44 = np.ndarray PointsNC = np.ndarray @dataclass class FloorPlane: normal: Vec3 d: float num_inliers: int = 0 @dataclass class FloorCorrection: transform: Mat44 valid: bool reason: str = "" @dataclass class GroundPlaneConfig: enabled: bool = True target_y: float = 0.0 stride: int = 8 depth_min: float = 0.2 depth_max: float = 5.0 ransac_dist_thresh: float = 0.02 ransac_n: int = 3 ransac_iters: int = 1000 max_rotation_deg: float = 5.0 max_translation_m: float = 0.1 min_inliers: int = 500 min_valid_cameras: int = 2 @dataclass class GroundPlaneMetrics: success: bool = False correction_applied: bool = False num_cameras_total: int = 0 num_cameras_valid: int = 0 correction_transform: Mat44 = field(default_factory=lambda: np.eye(4)) rotation_deg: float = 0.0 translation_m: float = 0.0 camera_planes: Dict[str, FloorPlane] = field(default_factory=dict) consensus_plane: Optional[FloorPlane] = None message: str = "" def unproject_depth_to_points( depth_map: np.ndarray, K: np.ndarray, stride: int = 1, depth_min: float = 0.1, depth_max: float = 10.0, ) -> PointsNC: """ Unproject a depth map to a point cloud. """ h, w = depth_map.shape fx = K[0, 0] fy = K[1, 1] cx = K[0, 2] cy = K[1, 2] # Create meshgrid of pixel coordinates # Use stride to reduce number of points u_coords = np.arange(0, w, stride) v_coords = np.arange(0, h, stride) u, v = np.meshgrid(u_coords, v_coords) # Sample depth map z = depth_map[0:h:stride, 0:w:stride] # Filter by depth bounds valid_mask = (z > depth_min) & (z < depth_max) & np.isfinite(z) # Apply mask z_valid = z[valid_mask] u_valid = u[valid_mask] v_valid = v[valid_mask] # Unproject x_valid = (u_valid - cx) * z_valid / fx y_valid = (v_valid - cy) * z_valid / fy # Stack into (N, 3) array points = np.stack((x_valid, y_valid, z_valid), axis=-1) return points.astype(np.float64) def detect_floor_plane( points: PointsNC, distance_threshold: float = 0.02, ransac_n: int = 3, num_iterations: int = 1000, seed: Optional[int] = None, ) -> Optional[FloorPlane]: """ Detect the floor plane from a point cloud using RANSAC. Returns FloorPlane or None if detection fails. """ if points.shape[0] < ransac_n: return None # Convert to Open3D PointCloud pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(points) # Set seed for determinism if provided if seed is not None: o3d.utility.random.seed(seed) # Segment plane plane_model, inliers = pcd.segment_plane( distance_threshold=distance_threshold, ransac_n=ransac_n, num_iterations=num_iterations, ) # Check if we found enough inliers if len(inliers) < ransac_n: return None # plane_model is [a, b, c, d] a, b, c, d = plane_model normal = np.array([a, b, c], dtype=np.float64) # Normalize normal (Open3D usually returns normalized, but be safe) norm = np.linalg.norm(normal) if norm > 1e-6: normal /= norm d /= norm return FloorPlane(normal=normal, d=d, num_inliers=len(inliers)) def compute_consensus_plane( planes: List[FloorPlane], weights: Optional[List[float]] = None, ) -> FloorPlane: """ Compute a consensus plane from multiple plane detections. """ if not planes: raise ValueError("No planes provided for consensus.") n_planes = len(planes) if weights is None: weights = [1.0] * n_planes if len(weights) != n_planes: raise ValueError( f"Weights length {len(weights)} must match planes length {n_planes}" ) # Use the first plane as reference for orientation ref_normal = planes[0].normal accum_normal = np.zeros(3, dtype=np.float64) accum_d = 0.0 total_weight = 0.0 for i, plane in enumerate(planes): w = weights[i] normal = plane.normal d = plane.d # Check orientation against reference if np.dot(normal, ref_normal) < 0: # Flip normal and d to align with reference normal = -normal d = -d accum_normal += normal * w accum_d += d * w total_weight += w if total_weight <= 0: raise ValueError("Total weight must be positive.") avg_normal = accum_normal / total_weight avg_d = accum_d / total_weight # Re-normalize normal norm = np.linalg.norm(avg_normal) if norm > 1e-6: avg_normal /= norm # Scale d by 1/norm to maintain plane equation consistency avg_d /= norm else: # Fallback (should be rare if inputs are valid) avg_normal = np.array([0.0, 1.0, 0.0]) avg_d = 0.0 return FloorPlane(normal=avg_normal, d=float(avg_d)) from .alignment import rotation_align_vectors def compute_floor_correction( current_floor_plane: FloorPlane, target_floor_y: float = 0.0, max_rotation_deg: float = 5.0, max_translation_m: float = 0.1, ) -> FloorCorrection: """ Compute the correction transform to align the current floor plane to the target floor height. Constrains correction to pitch/roll and vertical translation only. """ current_normal = current_floor_plane.normal current_d = current_floor_plane.d # Target normal is always [0, 1, 0] (Y-up) target_normal = np.array([0.0, 1.0, 0.0]) # 1. Compute rotation to align normals try: R_align = rotation_align_vectors(current_normal, target_normal) except ValueError as e: return FloorCorrection( transform=np.eye(4), valid=False, reason=f"Rotation alignment failed: {e}" ) # Check rotation magnitude # Angle of rotation is acos((trace(R) - 1) / 2) trace = np.trace(R_align) # Clip to avoid numerical errors outside [-1, 1] cos_angle = np.clip((trace - 1) / 2, -1.0, 1.0) angle_rad = np.arccos(cos_angle) angle_deg = np.rad2deg(angle_rad) if angle_deg > max_rotation_deg: return FloorCorrection( transform=np.eye(4), valid=False, reason=f"Rotation {angle_deg:.1f} deg exceeds limit {max_rotation_deg:.1f} deg", ) # 2. Compute translation # We want to move points such that the floor is at y = target_floor_y # Plane equation: n . p + d = 0 # Current floor at y = -current_d (if n=[0,1,0]) # We want new y = target_floor_y # So shift = target_floor_y - (-current_d) = target_floor_y + current_d t_y = target_floor_y + current_d # Check translation magnitude if abs(t_y) > max_translation_m: return FloorCorrection( transform=np.eye(4), valid=False, reason=f"Translation {t_y:.3f} m exceeds limit {max_translation_m:.3f} m", ) # Construct T T = np.eye(4) T[:3, :3] = R_align # Translation is applied in the rotated frame (aligned to target normal) T[:3, 3] = target_normal * t_y return FloorCorrection(transform=T.astype(np.float64), valid=True) def refine_ground_from_depth( camera_data: Dict[str, Dict[str, Any]], extrinsics: Dict[str, Mat44], config: GroundPlaneConfig = GroundPlaneConfig(), ) -> Tuple[Dict[str, Mat44], GroundPlaneMetrics]: """ Orchestrate ground plane refinement across multiple cameras. Args: camera_data: Dict mapping serial -> {'depth': np.ndarray, 'K': np.ndarray} extrinsics: Dict mapping serial -> world_from_cam matrix (4x4) config: Configuration parameters Returns: Tuple of (new_extrinsics, metrics) """ metrics = GroundPlaneMetrics() metrics.num_cameras_total = len(camera_data) if not config.enabled: metrics.message = "Ground plane refinement disabled in config" return extrinsics, metrics valid_planes: List[FloorPlane] = [] valid_serials: List[str] = [] # 1. Detect planes in each camera for serial, data in camera_data.items(): if serial not in extrinsics: continue depth_map = data.get("depth") K = data.get("K") if depth_map is None or K is None: continue # Unproject to camera frame points_cam = unproject_depth_to_points( depth_map, K, stride=config.stride, depth_min=config.depth_min, depth_max=config.depth_max, ) if len(points_cam) < config.min_inliers: continue # Transform to world frame T_world_cam = extrinsics[serial] # points_cam is (N, 3) # Apply rotation and translation R = T_world_cam[:3, :3] t = T_world_cam[:3, 3] points_world = (points_cam @ R.T) + t # Detect plane plane = detect_floor_plane( points_world, distance_threshold=config.ransac_dist_thresh, ransac_n=config.ransac_n, num_iterations=config.ransac_iters, ) if plane is not None and plane.num_inliers >= config.min_inliers: metrics.camera_planes[serial] = plane valid_planes.append(plane) valid_serials.append(serial) metrics.num_cameras_valid = len(valid_planes) # 2. Check minimum requirements if len(valid_planes) < config.min_valid_cameras: metrics.message = f"Found {len(valid_planes)} valid planes, required {config.min_valid_cameras}" return extrinsics, metrics # 3. Compute consensus try: consensus = compute_consensus_plane(valid_planes) metrics.consensus_plane = consensus except ValueError as e: metrics.message = f"Consensus computation failed: {e}" return extrinsics, metrics # 4. Compute correction correction = compute_floor_correction( consensus, target_floor_y=config.target_y, max_rotation_deg=config.max_rotation_deg, max_translation_m=config.max_translation_m, ) metrics.correction_transform = correction.transform if not correction.valid: metrics.message = f"Correction invalid: {correction.reason}" return extrinsics, metrics # 5. Apply correction # T_corr is the transform that moves the world frame. # New world points P' = T_corr * P # We want new extrinsics T'_world_cam such that P' = T'_world_cam * P_cam # T'_world_cam * P_cam = T_corr * (T_world_cam * P_cam) # So T'_world_cam = T_corr * T_world_cam new_extrinsics = {} T_corr = correction.transform for serial, T_old in extrinsics.items(): new_extrinsics[serial] = T_corr @ T_old # Calculate metrics # Rotation angle of T_corr trace = np.trace(T_corr[:3, :3]) cos_angle = np.clip((trace - 1) / 2, -1.0, 1.0) metrics.rotation_deg = float(np.rad2deg(np.arccos(cos_angle))) metrics.translation_m = float(np.linalg.norm(T_corr[:3, 3])) metrics.success = True metrics.correction_applied = True metrics.message = "Success" return new_extrinsics, metrics