import numpy as np from typing import Optional, Tuple, List, Dict, Any from jaxtyping import Float from typing import TYPE_CHECKING import open3d as o3d from dataclasses import dataclass, field import plotly.graph_objects as go if TYPE_CHECKING: Vec3 = Float[np.ndarray, "3"] Mat44 = Float[np.ndarray, "4 4"] PointsNC = Float[np.ndarray, "N 3"] else: Vec3 = np.ndarray Mat44 = np.ndarray PointsNC = np.ndarray @dataclass class FloorPlane: normal: Vec3 d: float num_inliers: int = 0 @dataclass class FloorCorrection: transform: Mat44 valid: bool reason: str = "" @dataclass class GroundPlaneConfig: enabled: bool = True target_y: float = 0.0 stride: int = 8 depth_min: float = 0.2 depth_max: float = 5.0 ransac_dist_thresh: float = 0.02 ransac_n: int = 3 ransac_iters: int = 1000 max_rotation_deg: float = 5.0 max_translation_m: float = 0.1 min_inliers: int = 500 min_inlier_ratio: float = 0.15 min_valid_cameras: int = 2 normal_vertical_thresh: float = 0.9 max_consensus_deviation_deg: float = 10.0 max_consensus_deviation_m: float = 0.5 seed: Optional[int] = None @dataclass class GroundPlaneMetrics: success: bool = False correction_applied: bool = False num_cameras_total: int = 0 num_cameras_valid: int = 0 # Per-camera corrections camera_corrections: Dict[str, Mat44] = field(default_factory=dict) skipped_cameras: List[str] = field(default_factory=list) # Summary stats (optional, maybe average of corrections?) rotation_deg: float = 0.0 translation_m: float = 0.0 camera_planes: Dict[str, FloorPlane] = field(default_factory=dict) consensus_plane: Optional[FloorPlane] = None message: str = "" def unproject_depth_to_points( depth_map: np.ndarray, K: np.ndarray, stride: int = 1, depth_min: float = 0.1, depth_max: float = 10.0, ) -> PointsNC: """ Unproject a depth map to a point cloud. """ h, w = depth_map.shape fx = K[0, 0] fy = K[1, 1] cx = K[0, 2] cy = K[1, 2] # Create meshgrid of pixel coordinates # Use stride to reduce number of points u_coords = np.arange(0, w, stride) v_coords = np.arange(0, h, stride) u, v = np.meshgrid(u_coords, v_coords) # Sample depth map z = depth_map[0:h:stride, 0:w:stride] # Filter by depth bounds valid_mask = (z > depth_min) & (z < depth_max) & np.isfinite(z) # Apply mask z_valid = z[valid_mask] u_valid = u[valid_mask] v_valid = v[valid_mask] # Unproject x_valid = (u_valid - cx) * z_valid / fx y_valid = (v_valid - cy) * z_valid / fy # Stack into (N, 3) array points = np.stack((x_valid, y_valid, z_valid), axis=-1) return points.astype(np.float64) def detect_floor_plane( points: PointsNC, distance_threshold: float = 0.02, ransac_n: int = 3, num_iterations: int = 1000, seed: Optional[int] = None, ) -> Optional[FloorPlane]: """ Detect the floor plane from a point cloud using RANSAC. Returns FloorPlane or None if detection fails. """ if points.shape[0] < ransac_n: return None # Convert to Open3D PointCloud pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(points) # Set seed for determinism if provided if seed is not None: o3d.utility.random.seed(seed) # Segment plane plane_model, inliers = pcd.segment_plane( distance_threshold=distance_threshold, ransac_n=ransac_n, num_iterations=num_iterations, ) # Check if we found enough inliers if len(inliers) < ransac_n: return None # plane_model is [a, b, c, d] a, b, c, d = plane_model normal = np.array([a, b, c], dtype=np.float64) # Normalize normal (Open3D usually returns normalized, but be safe) norm = np.linalg.norm(normal) if norm > 1e-6: normal /= norm d /= norm return FloorPlane(normal=normal, d=d, num_inliers=len(inliers)) def compute_consensus_plane( planes: List[FloorPlane], weights: Optional[List[float]] = None, ) -> FloorPlane: """ Compute a consensus plane from multiple plane detections. Uses a robust median-like approach to reject outliers. """ if not planes: raise ValueError("No planes provided for consensus.") n_planes = len(planes) if weights is None: weights = [1.0] * n_planes if len(weights) != n_planes: raise ValueError( f"Weights length {len(weights)} must match planes length {n_planes}" ) # 1. Align all normals to be in the upper hemisphere (y > 0) # This simplifies averaging aligned_planes = [] for p in planes: normal = p.normal.copy() d = p.d if normal[1] < 0: normal = -normal d = -d aligned_planes.append(FloorPlane(normal=normal, d=d, num_inliers=p.num_inliers)) # 2. Compute median normal and d to be robust against outliers normals = np.array([p.normal for p in aligned_planes]) ds = np.array([p.d for p in aligned_planes]) # Median of each component for normal (approximate robust mean) median_normal = np.median(normals, axis=0) norm = np.linalg.norm(median_normal) if norm > 1e-6: median_normal /= norm else: median_normal = np.array([0.0, 1.0, 0.0]) median_d = float(np.median(ds)) # 3. Filter outliers based on deviation from median # Angle deviation valid_indices = [] for i, p in enumerate(aligned_planes): # Angle between normal and median normal dot = np.clip(np.dot(p.normal, median_normal), -1.0, 1.0) angle_deg = np.rad2deg(np.arccos(dot)) # Distance deviation dist_diff = abs(p.d - median_d) # Thresholds for outlier rejection (hardcoded for now, could be config) if angle_deg < 15.0 and dist_diff < 0.5: valid_indices.append(i) if not valid_indices: # Fallback to all if everything is rejected (should be rare) valid_indices = list(range(n_planes)) # 4. Weighted average of valid planes accum_normal = np.zeros(3, dtype=np.float64) accum_d = 0.0 total_weight = 0.0 for i in valid_indices: w = weights[i] p = aligned_planes[i] accum_normal += p.normal * w accum_d += p.d * w total_weight += w if total_weight <= 0: # Should not happen given checks above return FloorPlane(normal=median_normal, d=median_d) avg_normal = accum_normal / total_weight avg_d = accum_d / total_weight # Re-normalize normal norm = np.linalg.norm(avg_normal) if norm > 1e-6: avg_normal /= norm avg_d /= norm else: avg_normal = np.array([0.0, 1.0, 0.0]) avg_d = 0.0 return FloorPlane(normal=avg_normal, d=float(avg_d)) from .alignment import rotation_align_vectors def compute_floor_correction( current_floor_plane: FloorPlane, target_floor_y: float = 0.0, max_rotation_deg: float = 5.0, max_translation_m: float = 0.1, target_plane: Optional[FloorPlane] = None, ) -> FloorCorrection: """ Compute the correction transform to align the current floor plane to the target floor height. Constrains correction to pitch/roll and vertical translation only. If target_plane is provided, aligns current plane to target_plane (relative correction). Otherwise, aligns to absolute Y=target_floor_y (absolute correction). """ current_normal = current_floor_plane.normal current_d = current_floor_plane.d # Target normal is always [0, 1, 0] (Y-up) target_normal = np.array([0.0, 1.0, 0.0]) if target_plane is not None: # Use target_plane.normal as the target normal align_target_normal = target_plane.normal # Ensure it points roughly up if align_target_normal[1] < 0: align_target_normal = -align_target_normal else: align_target_normal = target_normal # 1. Compute rotation to align normals try: R_align = rotation_align_vectors(current_normal, align_target_normal) except ValueError as e: return FloorCorrection( transform=np.eye(4), valid=False, reason=f"Rotation alignment failed: {e}" ) # Check rotation magnitude # Angle of rotation is acos((trace(R) - 1) / 2) trace = np.trace(R_align) # Clip to avoid numerical errors outside [-1, 1] cos_angle = np.clip((trace - 1) / 2, -1.0, 1.0) angle_rad = np.arccos(cos_angle) angle_deg = np.rad2deg(angle_rad) if angle_deg > max_rotation_deg: return FloorCorrection( transform=np.eye(4), valid=False, reason=f"Rotation {angle_deg:.1f} deg exceeds limit {max_rotation_deg:.1f} deg", ) # 2. Compute translation if target_plane is not None: # Relative correction: align d to target_plane.d # Shift = current_d - target_plane.d (assuming normals aligned) # We use absolute values of d to handle potential sign flips in plane detection # But wait, d sign matters for plane side. # If normals are aligned (which we ensured with R_align and align_target_normal), # then d should be comparable directly. # However, target_plane.d might be negative if normal was flipped. # Let's use the d corresponding to align_target_normal. target_d = target_plane.d if np.dot(target_plane.normal, align_target_normal) < 0: target_d = -target_d # Current d needs to be relative to current normal? # No, current_d is relative to current_normal. # After rotation R_align, current_normal becomes align_target_normal. # So current_d is preserved (distance to origin doesn't change with rotation around origin). # So we just compare d values. t_mag = current_d - target_d trans_dir = align_target_normal else: # Absolute correction to target_y # We want new y = target_floor_y # So shift = target_floor_y + current_d t_mag = target_floor_y + current_d trans_dir = target_normal # Check translation magnitude if abs(t_mag) > max_translation_m: return FloorCorrection( transform=np.eye(4), valid=False, reason=f"Translation {t_mag:.3f} m exceeds limit {max_translation_m:.3f} m", ) # Construct T T = np.eye(4) T[:3, :3] = R_align # Translation is applied in the rotated frame (aligned to target normal) T[:3, 3] = trans_dir * t_mag return FloorCorrection(transform=T.astype(np.float64), valid=True) def refine_ground_from_depth( camera_data: Dict[str, Dict[str, Any]], extrinsics: Dict[str, Mat44], config: GroundPlaneConfig = GroundPlaneConfig(), ) -> Tuple[Dict[str, Mat44], GroundPlaneMetrics]: """ Orchestrate ground plane refinement across multiple cameras. Args: camera_data: Dict mapping serial -> {'depth': np.ndarray, 'K': np.ndarray} extrinsics: Dict mapping serial -> world_from_cam matrix (4x4) config: Configuration parameters Returns: Tuple of (new_extrinsics, metrics) """ metrics = GroundPlaneMetrics() metrics.num_cameras_total = len(camera_data) if not config.enabled: metrics.message = "Ground plane refinement disabled in config" return extrinsics, metrics valid_planes: List[FloorPlane] = [] valid_serials: List[str] = [] # 1. Detect planes in each camera for serial, data in camera_data.items(): if serial not in extrinsics: continue depth_map = data.get("depth") K = data.get("K") if depth_map is None or K is None: continue # Unproject to camera frame points_cam = unproject_depth_to_points( depth_map, K, stride=config.stride, depth_min=config.depth_min, depth_max=config.depth_max, ) if len(points_cam) < config.min_inliers: continue # Transform to world frame T_world_cam = extrinsics[serial] # points_cam is (N, 3) # Apply rotation and translation R = T_world_cam[:3, :3] t = T_world_cam[:3, 3] points_world = (points_cam @ R.T) + t # Detect plane plane = detect_floor_plane( points_world, distance_threshold=config.ransac_dist_thresh, ransac_n=config.ransac_n, num_iterations=config.ransac_iters, seed=config.seed, ) if plane is not None: # Check inlier count if plane.num_inliers < config.min_inliers: continue # Check inlier ratio if configured if config.min_inlier_ratio > 0: ratio = plane.num_inliers / len(points_world) if ratio < config.min_inlier_ratio: continue # Check normal orientation (must be roughly vertical) # We expect floor normal to be roughly [0, 1, 0] or [0, -1, 0] if abs(plane.normal[1]) < config.normal_vertical_thresh: continue metrics.camera_planes[serial] = plane valid_planes.append(plane) valid_serials.append(serial) metrics.num_cameras_valid = len(valid_planes) # 2. Check minimum requirements if len(valid_planes) < config.min_valid_cameras: metrics.message = f"Found {len(valid_planes)} valid planes, required {config.min_valid_cameras}" return extrinsics, metrics # 3. Compute consensus (for reporting/metrics only) try: consensus = compute_consensus_plane(valid_planes) metrics.consensus_plane = consensus except ValueError as e: metrics.message = f"Consensus computation failed: {e}" return extrinsics, metrics # 4. Compute and apply per-camera correction new_extrinsics = extrinsics.copy() # Track max rotation/translation for summary metrics max_rot = 0.0 max_trans = 0.0 corrections_count = 0 for serial, T_old in extrinsics.items(): # If we didn't find a plane for this camera, skip it if serial not in metrics.camera_planes: metrics.skipped_cameras.append(serial) continue plane = metrics.camera_planes[serial] correction = compute_floor_correction( plane, target_floor_y=config.target_y, max_rotation_deg=config.max_rotation_deg, max_translation_m=config.max_translation_m, target_plane=metrics.consensus_plane, ) if not correction.valid: metrics.skipped_cameras.append(serial) continue # Validate against consensus if available if metrics.consensus_plane: # Check if this camera's plane is too far from consensus # This prevents a single bad camera from getting a huge correction # even if it passed individual checks (e.g. it found a wall instead of floor) # Angle check dot = np.clip( np.dot(plane.normal, metrics.consensus_plane.normal), -1.0, 1.0 ) # Handle flipped normals if dot < 0: dot = -dot angle_deg = np.rad2deg(np.arccos(dot)) if angle_deg > config.max_consensus_deviation_deg: metrics.skipped_cameras.append(serial) continue # Distance check (project consensus origin onto this plane) # Consensus plane: n_c . p + d_c = 0 # This plane: n . p + d = 0 # Compare d values (assuming normals aligned) d_diff = abs(abs(plane.d) - abs(metrics.consensus_plane.d)) if d_diff > config.max_consensus_deviation_m: metrics.skipped_cameras.append(serial) continue T_corr = correction.transform metrics.camera_corrections[serial] = T_corr # Apply correction: T_new = T_corr @ T_old new_extrinsics[serial] = T_corr @ T_old # Update summary metrics trace = np.trace(T_corr[:3, :3]) cos_angle = np.clip((trace - 1) / 2, -1.0, 1.0) rot_deg = float(np.rad2deg(np.arccos(cos_angle))) trans_m = float(np.linalg.norm(T_corr[:3, 3])) max_rot = max(max_rot, rot_deg) max_trans = max(max_trans, trans_m) corrections_count += 1 metrics.rotation_deg = max_rot metrics.translation_m = max_trans metrics.success = True metrics.correction_applied = corrections_count > 0 metrics.message = ( f"Corrected {corrections_count} cameras, skipped {len(metrics.skipped_cameras)}" ) return new_extrinsics, metrics def create_ground_diagnostic_plot( metrics: GroundPlaneMetrics, camera_data: Dict[str, Dict[str, Any]], extrinsics_before: Dict[str, Mat44], extrinsics_after: Dict[str, Mat44], ) -> go.Figure: """ Create a Plotly diagnostic visualization for ground plane refinement. """ fig = go.Figure() # 1. Add World Origin Axes axis_scale = 0.5 for axis, color, name in zip( [np.array([1, 0, 0]), np.array([0, 1, 0]), np.array([0, 0, 1])], ["red", "green", "blue"], ["X", "Y", "Z"], ): fig.add_trace( go.Scatter3d( x=[0, axis[0] * axis_scale], y=[0, axis[1] * axis_scale], z=[0, axis[2] * axis_scale], mode="lines", line=dict(color=color, width=4), name=f"World {name}", showlegend=True, ) ) # 2. Add Consensus Plane (if available) if metrics.consensus_plane: plane = metrics.consensus_plane # Create a surface for the plane size = 5.0 x = np.linspace(-size, size, 2) z = np.linspace(-size, size, 2) xx, zz = np.meshgrid(x, z) # n.p + d = 0 => n0*x + n1*y + n2*z + d = 0 => y = -(n0*x + n2*z + d) / n1 if abs(plane.normal[1]) > 1e-6: yy = ( -(plane.normal[0] * xx + plane.normal[2] * zz + plane.d) / plane.normal[1] ) fig.add_trace( go.Surface( x=xx, y=yy, z=zz, showscale=False, opacity=0.3, colorscale=[[0, "lightgray"], [1, "lightgray"]], name="Consensus Plane", ) ) # 3. Add Floor Points per camera for serial, data in camera_data.items(): if serial not in extrinsics_before: continue depth_map = data.get("depth") K = data.get("K") if depth_map is None or K is None: continue # Use a larger stride for visualization to keep it responsive viz_stride = 8 points_cam = unproject_depth_to_points(depth_map, K, stride=viz_stride) if len(points_cam) == 0: continue # Transform to world frame (before) T_before = extrinsics_before[serial] R_b = T_before[:3, :3] t_b = T_before[:3, 3] points_world = (points_cam @ R_b.T) + t_b fig.add_trace( go.Scatter3d( x=points_world[:, 0], y=points_world[:, 1], z=points_world[:, 2], mode="markers", marker=dict(size=2, opacity=0.5), name=f"Points {serial}", ) ) # 4. Add Camera Positions Before/After for serial in extrinsics_before: T_b = extrinsics_before[serial] pos_b = T_b[:3, 3] fig.add_trace( go.Scatter3d( x=[pos_b[0]], y=[pos_b[1]], z=[pos_b[2]], mode="markers+text", marker=dict(size=5, color="red"), text=[f"{serial} (before)"], name=f"Cam {serial} (before)", ) ) if serial in extrinsics_after: T_a = extrinsics_after[serial] pos_a = T_a[:3, 3] fig.add_trace( go.Scatter3d( x=[pos_a[0]], y=[pos_a[1]], z=[pos_a[2]], mode="markers+text", marker=dict(size=5, color="green"), text=[f"{serial} (after)"], name=f"Cam {serial} (after)", ) ) fig.update_layout( title="Ground Plane Refinement Diagnostics", scene=dict( xaxis_title="X", yaxis_title="Y", zaxis_title="Z", aspectmode="data", camera=dict( up=dict(x=0, y=-1, z=0), # Y-down convention for visualization eye=dict(x=1.5, y=-1.5, z=1.5), ), ), margin=dict(l=0, r=0, b=0, t=40), ) return fig def save_diagnostic_plot(fig: go.Figure, path: str) -> None: """ Save the diagnostic plot to an HTML file. """ import os os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) fig.write_html(path)