feat: implement refine_ground_plane.py CLI

This commit is contained in:
2026-02-09 07:50:16 +00:00
parent 0f7d7a9a63
commit 9d9e95de81
4 changed files with 489 additions and 1 deletions
+289
View File
@@ -0,0 +1,289 @@
import click
import json
import numpy as np
from pathlib import Path
from typing import Dict, Any, Optional
from loguru import logger
import sys
from aruco.ground_plane import (
GroundPlaneConfig,
refine_ground_from_depth,
create_ground_diagnostic_plot,
save_diagnostic_plot,
GroundPlaneMetrics,
Mat44,
)
from aruco.depth_save import load_depth_data
@click.command()
@click.option(
"--input-extrinsics",
"-i",
required=True,
type=click.Path(exists=True, dir_okay=False),
help="Input extrinsics JSON file.",
)
@click.option(
"--input-depth",
"-d",
required=True,
type=click.Path(exists=True, dir_okay=False),
help="Input depth HDF5 file.",
)
@click.option(
"--output-extrinsics",
"-o",
required=True,
type=click.Path(dir_okay=False),
help="Output extrinsics JSON file.",
)
@click.option(
"--metrics-json",
type=click.Path(dir_okay=False),
help="Optional path to save metrics JSON.",
)
@click.option(
"--plot/--no-plot",
default=True,
help="Generate diagnostic plot.",
)
@click.option(
"--plot-output",
type=click.Path(dir_okay=False),
help="Path for diagnostic plot HTML (defaults to output_extrinsics_base.html).",
)
@click.option(
"--max-rotation-deg",
default=5.0,
help="Maximum allowed rotation correction in degrees.",
)
@click.option(
"--max-translation-m",
default=0.1,
help="Maximum allowed translation correction in meters.",
)
@click.option(
"--ransac-threshold",
default=0.02,
help="RANSAC distance threshold in meters.",
)
@click.option(
"--min-inlier-ratio",
default=0.0,
help="Minimum ratio of inliers to total points (0.0-1.0).",
)
@click.option(
"--height-range",
nargs=2,
type=float,
default=(0.2, 5.0),
help="Min and max height (depth) range in meters.",
)
@click.option(
"--stride",
default=8,
help="Pixel stride for depth sampling.",
)
@click.option(
"--seed",
type=int,
help="Random seed for RANSAC determinism.",
)
@click.option(
"--debug/--no-debug",
default=False,
help="Enable debug logging.",
)
def main(
input_extrinsics: str,
input_depth: str,
output_extrinsics: str,
metrics_json: Optional[str],
plot: bool,
plot_output: Optional[str],
max_rotation_deg: float,
max_translation_m: float,
ransac_threshold: float,
min_inlier_ratio: float,
height_range: tuple[float, float],
stride: int,
seed: Optional[int],
debug: bool,
):
"""
Refine camera extrinsics by aligning the ground plane detected in depth maps.
Loads existing extrinsics and depth data, detects the floor plane in each camera,
and computes a correction transform to align the floor to Y=0.
"""
# Configure logging
logger.remove()
logger.add(
sys.stderr,
level="DEBUG" if debug else "INFO",
format="<green>{time:HH:mm:ss}</green> | <level>{message}</level>",
)
try:
# 1. Load Extrinsics
logger.info(f"Loading extrinsics from {input_extrinsics}")
with open(input_extrinsics, "r") as f:
extrinsics_data = json.load(f)
# Parse extrinsics into Dict[str, Mat44]
extrinsics: Dict[str, Mat44] = {}
for serial, data in extrinsics_data.items():
if serial == "_meta":
continue
if "pose" in data:
pose_str = data["pose"]
T = np.fromstring(pose_str, sep=" ").reshape(4, 4)
extrinsics[serial] = T
if not extrinsics:
raise click.UsageError("No valid camera poses found in input extrinsics.")
# 2. Load Depth Data
logger.info(f"Loading depth data from {input_depth}")
depth_data = load_depth_data(input_depth)
# Prepare camera data for refinement
camera_data_for_refine: Dict[str, Dict[str, Any]] = {}
for serial, data in depth_data.items():
# Use pooled depth if available, otherwise check raw frames
depth_map = data.get("pooled_depth")
if depth_map is None:
# Fallback to first raw frame if available
raw_frames = data.get("raw_frames", [])
if raw_frames:
depth_map = raw_frames[0].get("depth_map")
if depth_map is not None:
camera_data_for_refine[serial] = {
"depth": depth_map,
"K": data["intrinsics"],
}
if not camera_data_for_refine:
raise click.UsageError("No depth maps found in input depth file.")
# 3. Configure Refinement
config = GroundPlaneConfig(
enabled=True,
target_y=0.0,
stride=stride,
depth_min=height_range[0],
depth_max=height_range[1],
ransac_dist_thresh=ransac_threshold,
max_rotation_deg=max_rotation_deg,
max_translation_m=max_translation_m,
min_inlier_ratio=min_inlier_ratio,
seed=seed,
)
# 4. Run Refinement
logger.info("Running ground plane refinement...")
new_extrinsics, metrics = refine_ground_from_depth(
camera_data_for_refine, extrinsics, config
)
logger.info(f"Refinement result: {metrics.message}")
if metrics.success:
logger.info(f"Max rotation: {metrics.rotation_deg:.2f} deg")
logger.info(f"Max translation: {metrics.translation_m:.3f} m")
# 5. Save Output Extrinsics
output_data = extrinsics_data.copy()
for serial, T_new in new_extrinsics.items():
if serial in output_data:
pose_str = " ".join(f"{x:.6f}" for x in T_new.flatten())
output_data[serial]["pose"] = pose_str
if serial in metrics.camera_corrections:
T_corr = metrics.camera_corrections[serial]
trace = np.trace(T_corr[:3, :3])
cos_angle = np.clip((trace - 1) / 2, -1.0, 1.0)
rot_deg = float(np.rad2deg(np.arccos(cos_angle)))
trans_m = float(np.linalg.norm(T_corr[:3, 3]))
output_data[serial]["ground_refine"] = {
"corrected": True,
"delta_rot_deg": rot_deg,
"delta_trans_m": trans_m,
}
else:
output_data[serial]["ground_refine"] = {
"corrected": False,
"reason": "skipped_or_failed",
}
if "_meta" not in output_data:
output_data["_meta"] = {}
output_data["_meta"]["ground_refined"] = {
"timestamp": str(np.datetime64("now")),
"config": {
"max_rotation_deg": max_rotation_deg,
"max_translation_m": max_translation_m,
"ransac_threshold": ransac_threshold,
"height_range": height_range,
},
"metrics": {
"success": metrics.success,
"num_cameras_total": metrics.num_cameras_total,
"num_cameras_valid": metrics.num_cameras_valid,
"correction_applied": metrics.correction_applied,
"max_rotation_deg": metrics.rotation_deg,
"max_translation_m": metrics.translation_m,
},
}
logger.info(f"Saving refined extrinsics to {output_extrinsics}")
with open(output_extrinsics, "w") as f:
json.dump(output_data, f, indent=4, sort_keys=True)
# 6. Save Metrics JSON (Optional)
if metrics_json:
metrics_data = {
"success": metrics.success,
"message": metrics.message,
"num_cameras_total": metrics.num_cameras_total,
"num_cameras_valid": metrics.num_cameras_valid,
"skipped_cameras": metrics.skipped_cameras,
"max_rotation_deg": metrics.rotation_deg,
"max_translation_m": metrics.translation_m,
"camera_corrections": {
s: " ".join(f"{x:.6f}" for x in T.flatten())
for s, T in metrics.camera_corrections.items()
},
}
logger.info(f"Saving metrics to {metrics_json}")
with open(metrics_json, "w") as f:
json.dump(metrics_data, f, indent=4)
# 7. Generate Plot (Optional)
if plot:
if not plot_output:
plot_output = str(Path(output_extrinsics).with_suffix(".html"))
logger.info(f"Generating diagnostic plot to {plot_output}")
fig = create_ground_diagnostic_plot(
metrics,
camera_data_for_refine,
extrinsics, # before
new_extrinsics, # after
)
save_diagnostic_plot(fig, plot_output)
except Exception as e:
logger.error(f"Error: {e}")
if debug:
raise
sys.exit(1)
if __name__ == "__main__":
main()