import click import json import numpy as np from pathlib import Path from typing import Dict, Any, Optional from loguru import logger import sys from aruco.ground_plane import ( GroundPlaneConfig, refine_ground_from_depth, create_ground_diagnostic_plot, save_diagnostic_plot, GroundPlaneMetrics, Mat44, ) from aruco.depth_save import load_depth_data @click.command() @click.option( "--input-extrinsics", "-i", required=True, type=click.Path(exists=True, dir_okay=False), help="Input extrinsics JSON file.", ) @click.option( "--input-depth", "-d", required=True, type=click.Path(exists=True, dir_okay=False), help="Input depth HDF5 file.", ) @click.option( "--output-extrinsics", "-o", required=True, type=click.Path(dir_okay=False), help="Output extrinsics JSON file.", ) @click.option( "--metrics-json", type=click.Path(dir_okay=False), help="Optional path to save metrics JSON.", ) @click.option( "--plot/--no-plot", default=True, help="Generate diagnostic plot.", ) @click.option( "--plot-output", type=click.Path(dir_okay=False), help="Path for diagnostic plot HTML (defaults to output_extrinsics_base.html).", ) @click.option( "--max-rotation-deg", default=5.0, help="Maximum allowed rotation correction in degrees.", ) @click.option( "--max-translation-m", default=0.1, help="Maximum allowed translation correction in meters.", ) @click.option( "--ransac-threshold", default=0.02, help="RANSAC distance threshold in meters.", ) @click.option( "--min-inlier-ratio", default=0.0, help="Minimum ratio of inliers to total points (0.0-1.0).", ) @click.option( "--height-range", nargs=2, type=float, default=(0.2, 5.0), help="Min and max height (depth) range in meters.", ) @click.option( "--stride", default=8, help="Pixel stride for depth sampling.", ) @click.option( "--seed", type=int, help="Random seed for RANSAC determinism.", ) @click.option( "--debug/--no-debug", default=False, help="Enable debug logging.", ) def main( input_extrinsics: str, input_depth: str, output_extrinsics: str, metrics_json: Optional[str], plot: bool, plot_output: Optional[str], max_rotation_deg: float, max_translation_m: float, ransac_threshold: float, min_inlier_ratio: float, height_range: tuple[float, float], stride: int, seed: Optional[int], debug: bool, ): """ Refine camera extrinsics by aligning the ground plane detected in depth maps. Loads existing extrinsics and depth data, detects the floor plane in each camera, and computes a correction transform to align the floor to Y=0. """ # Configure logging logger.remove() logger.add( sys.stderr, level="DEBUG" if debug else "INFO", format="{time:HH:mm:ss} | {message}", ) try: # 1. Load Extrinsics logger.info(f"Loading extrinsics from {input_extrinsics}") with open(input_extrinsics, "r") as f: extrinsics_data = json.load(f) # Parse extrinsics into Dict[str, Mat44] extrinsics: Dict[str, Mat44] = {} for serial, data in extrinsics_data.items(): if serial == "_meta": continue if "pose" in data: pose_str = data["pose"] T = np.fromstring(pose_str, sep=" ").reshape(4, 4) extrinsics[serial] = T if not extrinsics: raise click.UsageError("No valid camera poses found in input extrinsics.") # 2. Load Depth Data logger.info(f"Loading depth data from {input_depth}") depth_data = load_depth_data(input_depth) # Prepare camera data for refinement camera_data_for_refine: Dict[str, Dict[str, Any]] = {} for serial, data in depth_data.items(): # Use pooled depth if available, otherwise check raw frames depth_map = data.get("pooled_depth") if depth_map is None: # Fallback to first raw frame if available raw_frames = data.get("raw_frames", []) if raw_frames: depth_map = raw_frames[0].get("depth_map") if depth_map is not None: camera_data_for_refine[serial] = { "depth": depth_map, "K": data["intrinsics"], } if not camera_data_for_refine: raise click.UsageError("No depth maps found in input depth file.") # 3. Configure Refinement config = GroundPlaneConfig( enabled=True, target_y=0.0, stride=stride, depth_min=height_range[0], depth_max=height_range[1], ransac_dist_thresh=ransac_threshold, max_rotation_deg=max_rotation_deg, max_translation_m=max_translation_m, min_inlier_ratio=min_inlier_ratio, seed=seed, ) # 4. Run Refinement logger.info("Running ground plane refinement...") new_extrinsics, metrics = refine_ground_from_depth( camera_data_for_refine, extrinsics, config ) logger.info(f"Refinement result: {metrics.message}") if metrics.success: logger.info(f"Max rotation: {metrics.rotation_deg:.2f} deg") logger.info(f"Max translation: {metrics.translation_m:.3f} m") # 5. Save Output Extrinsics output_data = extrinsics_data.copy() for serial, T_new in new_extrinsics.items(): if serial in output_data: pose_str = " ".join(f"{x:.6f}" for x in T_new.flatten()) output_data[serial]["pose"] = pose_str if serial in metrics.camera_corrections: T_corr = metrics.camera_corrections[serial] trace = np.trace(T_corr[:3, :3]) cos_angle = np.clip((trace - 1) / 2, -1.0, 1.0) rot_deg = float(np.rad2deg(np.arccos(cos_angle))) trans_m = float(np.linalg.norm(T_corr[:3, 3])) output_data[serial]["ground_refine"] = { "corrected": True, "delta_rot_deg": rot_deg, "delta_trans_m": trans_m, } else: output_data[serial]["ground_refine"] = { "corrected": False, "reason": "skipped_or_failed", } if "_meta" not in output_data: output_data["_meta"] = {} output_data["_meta"]["ground_refined"] = { "timestamp": str(np.datetime64("now")), "config": { "max_rotation_deg": max_rotation_deg, "max_translation_m": max_translation_m, "ransac_threshold": ransac_threshold, "height_range": height_range, }, "metrics": { "success": metrics.success, "num_cameras_total": metrics.num_cameras_total, "num_cameras_valid": metrics.num_cameras_valid, "correction_applied": metrics.correction_applied, "max_rotation_deg": metrics.rotation_deg, "max_translation_m": metrics.translation_m, }, } logger.info(f"Saving refined extrinsics to {output_extrinsics}") with open(output_extrinsics, "w") as f: json.dump(output_data, f, indent=4, sort_keys=True) # 6. Save Metrics JSON (Optional) if metrics_json: metrics_data = { "success": metrics.success, "message": metrics.message, "num_cameras_total": metrics.num_cameras_total, "num_cameras_valid": metrics.num_cameras_valid, "skipped_cameras": metrics.skipped_cameras, "max_rotation_deg": metrics.rotation_deg, "max_translation_m": metrics.translation_m, "camera_corrections": { s: " ".join(f"{x:.6f}" for x in T.flatten()) for s, T in metrics.camera_corrections.items() }, } logger.info(f"Saving metrics to {metrics_json}") with open(metrics_json, "w") as f: json.dump(metrics_data, f, indent=4) # 7. Generate Plot (Optional) if plot: if not plot_output: plot_output = str(Path(output_extrinsics).with_suffix(".html")) logger.info(f"Generating diagnostic plot to {plot_output}") fig = create_ground_diagnostic_plot( metrics, camera_data_for_refine, extrinsics, # before new_extrinsics, # after ) save_diagnostic_plot(fig, plot_output) except Exception as e: logger.error(f"Error: {e}") if debug: raise sys.exit(1) if __name__ == "__main__": main()