import click
import json
import numpy as np
from pathlib import Path
from typing import Dict, Any, Optional
from loguru import logger
import sys
from aruco.ground_plane import (
GroundPlaneConfig,
refine_ground_from_depth,
create_ground_diagnostic_plot,
save_diagnostic_plot,
Mat44,
)
from aruco.depth_save import load_depth_data
from aruco.icp_registration import refine_with_icp, ICPConfig
@click.command()
@click.option(
"--input-extrinsics",
"-i",
required=True,
type=click.Path(exists=True, dir_okay=False),
help="Input extrinsics JSON file.",
)
@click.option(
"--input-depth",
"-d",
required=True,
type=click.Path(exists=True, dir_okay=False),
help="Input depth HDF5 file.",
)
@click.option(
"--output-extrinsics",
"-o",
required=True,
type=click.Path(dir_okay=False),
help="Output extrinsics JSON file.",
)
@click.option(
"--plot/--no-plot",
default=True,
help="Generate diagnostic plot.",
)
@click.option(
"--plot-output",
type=click.Path(dir_okay=False),
help="Path for diagnostic plot HTML (defaults to output_extrinsics_base.html).",
)
@click.option(
"--max-rotation-deg",
default=5.0,
help="Maximum allowed rotation correction in degrees.",
)
@click.option(
"--max-translation-m",
default=0.1,
help="Maximum allowed translation correction in meters.",
)
@click.option(
"--ransac-threshold",
default=0.02,
help="RANSAC distance threshold in meters.",
)
@click.option(
"--min-inlier-ratio",
default=0.0,
help="Minimum ratio of inliers to total points (0.0-1.0).",
)
@click.option(
"--height-range",
nargs=2,
type=float,
default=(0.2, 5.0),
help="Min and max height (depth) range in meters.",
)
@click.option(
"--stride",
default=8,
help="Pixel stride for depth sampling.",
)
@click.option(
"--seed",
type=int,
help="Random seed for RANSAC determinism.",
)
@click.option(
"--icp/--no-icp",
default=False,
help="Enable ICP refinement after ground plane alignment.",
)
@click.option(
"--icp-method",
type=click.Choice(["point_to_plane", "gicp"]),
default="point_to_plane",
help="ICP registration method.",
)
@click.option(
"--icp-voxel-size",
type=float,
default=0.02,
help="Voxel size for ICP downsampling (meters).",
)
@click.option(
"--debug/--no-debug",
default=False,
help="Enable debug logging.",
)
def main(
input_extrinsics: str,
input_depth: str,
output_extrinsics: str,
plot: bool,
plot_output: Optional[str],
max_rotation_deg: float,
max_translation_m: float,
ransac_threshold: float,
min_inlier_ratio: float,
height_range: tuple[float, float],
stride: int,
seed: Optional[int],
icp: bool,
icp_method: str,
icp_voxel_size: float,
debug: bool,
):
"""
Refine camera extrinsics by aligning the ground plane detected in depth maps.
Loads existing extrinsics and depth data, detects the floor plane in each camera,
and computes a correction transform to align the floor to Y=0.
"""
# Configure logging
logger.remove()
logger.add(
sys.stderr,
level="DEBUG" if debug else "INFO",
format="{time:HH:mm:ss} | {message}",
)
try:
# 1. Load Extrinsics
logger.info(f"Loading extrinsics from {input_extrinsics}")
with open(input_extrinsics, "r") as f:
extrinsics_data = json.load(f)
# Parse extrinsics into Dict[str, Mat44]
extrinsics: Dict[str, Mat44] = {}
for serial, data in extrinsics_data.items():
if serial == "_meta":
continue
if "pose" in data:
pose_str = data["pose"]
T = np.fromstring(pose_str, sep=" ").reshape(4, 4)
extrinsics[serial] = T
if not extrinsics:
raise click.UsageError("No valid camera poses found in input extrinsics.")
# 2. Load Depth Data
logger.info(f"Loading depth data from {input_depth}")
depth_data = load_depth_data(input_depth)
# Prepare camera data for refinement
camera_data_for_refine: Dict[str, Dict[str, Any]] = {}
for serial, data in depth_data.items():
# Use pooled depth if available, otherwise check raw frames
depth_map = data.get("pooled_depth")
if depth_map is None:
# Fallback to first raw frame if available
raw_frames = data.get("raw_frames", [])
if raw_frames:
depth_map = raw_frames[0].get("depth_map")
if depth_map is not None:
camera_data_for_refine[serial] = {
"depth": depth_map,
"K": data["intrinsics"],
}
if not camera_data_for_refine:
raise click.UsageError("No depth maps found in input depth file.")
# 3. Configure Refinement
config = GroundPlaneConfig(
enabled=True,
target_y=0.0,
stride=stride,
depth_min=height_range[0],
depth_max=height_range[1],
ransac_dist_thresh=ransac_threshold,
max_rotation_deg=max_rotation_deg,
max_translation_m=max_translation_m,
min_inlier_ratio=min_inlier_ratio,
seed=seed,
)
# 4. Run Refinement
logger.info("Running ground plane refinement...")
new_extrinsics, metrics = refine_ground_from_depth(
camera_data_for_refine, extrinsics, config
)
logger.info(f"Refinement result: {metrics.message}")
if metrics.success:
logger.info(f"Max rotation: {metrics.rotation_deg:.2f} deg")
logger.info(f"Max translation: {metrics.translation_m:.3f} m")
# 4.5 Optional ICP Refinement
icp_metrics = None
if icp:
logger.info(f"Running ICP refinement ({icp_method})...")
icp_config = ICPConfig(
method=icp_method,
voxel_size=icp_voxel_size,
)
icp_extrinsics, icp_metrics = refine_with_icp(
camera_data_for_refine,
new_extrinsics,
metrics.camera_planes,
icp_config,
)
if icp_metrics.success:
logger.info(f"ICP refinement successful: {icp_metrics.message}")
new_extrinsics = icp_extrinsics
else:
logger.warning(
f"ICP refinement failed or skipped: {icp_metrics.message}"
)
# 5. Save Output Extrinsics
output_data = extrinsics_data.copy()
per_camera_diagnostics = {}
for serial, T_new in new_extrinsics.items():
if serial in output_data:
pose_str = " ".join(f"{x:.6f}" for x in T_new.flatten())
output_data[serial]["pose"] = pose_str
if serial in metrics.camera_corrections:
T_corr = metrics.camera_corrections[serial]
trace = np.trace(T_corr[:3, :3])
cos_angle = np.clip((trace - 1) / 2, -1.0, 1.0)
rot_deg = float(np.rad2deg(np.arccos(cos_angle)))
trans_m = float(np.linalg.norm(T_corr[:3, 3]))
per_camera_diagnostics[serial] = {
"corrected": True,
"delta_rot_deg": rot_deg,
"delta_trans_m": trans_m,
}
else:
per_camera_diagnostics[serial] = {
"corrected": False,
"reason": "skipped_or_failed",
}
if "_meta" not in output_data:
output_data["_meta"] = {}
output_data["_meta"]["ground_refined"] = {
"timestamp": str(np.datetime64("now")),
"config": {
"max_rotation_deg": max_rotation_deg,
"max_translation_m": max_translation_m,
"ransac_threshold": ransac_threshold,
"height_range": height_range,
},
"metrics": {
"success": metrics.success,
"num_cameras_total": metrics.num_cameras_total,
"num_cameras_valid": metrics.num_cameras_valid,
"correction_applied": metrics.correction_applied,
"max_rotation_deg": metrics.rotation_deg,
"max_translation_m": metrics.translation_m,
},
"per_camera": per_camera_diagnostics,
}
if icp_metrics:
output_data["_meta"]["icp_refined"] = {
"timestamp": str(np.datetime64("now")),
"config": {
"method": icp_method,
"voxel_size": icp_voxel_size,
},
"metrics": {
"success": icp_metrics.success,
"num_pairs_attempted": icp_metrics.num_pairs_attempted,
"num_pairs_converged": icp_metrics.num_pairs_converged,
"num_cameras_optimized": icp_metrics.num_cameras_optimized,
"num_disconnected": icp_metrics.num_disconnected,
"message": icp_metrics.message,
},
}
logger.info(f"Saving refined extrinsics to {output_extrinsics}")
with open(output_extrinsics, "w") as f:
json.dump(output_data, f, indent=4, sort_keys=True)
# 6. Generate Plot (Optional)
if plot:
if not plot_output:
plot_output = str(Path(output_extrinsics).with_suffix(".html"))
logger.info(f"Generating diagnostic plot to {plot_output}")
fig = create_ground_diagnostic_plot(
metrics,
camera_data_for_refine,
extrinsics, # before
new_extrinsics, # after
)
save_diagnostic_plot(fig, plot_output)
except Exception as e:
logger.error(f"Error: {e}")
if debug:
raise
sys.exit(1)
if __name__ == "__main__":
main() # pylint: disable=no-value-for-parameter