feat: implement refine_ground_plane.py CLI
This commit is contained in:
@@ -0,0 +1,289 @@
|
||||
import click
|
||||
import json
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
from loguru import logger
|
||||
import sys
|
||||
|
||||
from aruco.ground_plane import (
|
||||
GroundPlaneConfig,
|
||||
refine_ground_from_depth,
|
||||
create_ground_diagnostic_plot,
|
||||
save_diagnostic_plot,
|
||||
GroundPlaneMetrics,
|
||||
Mat44,
|
||||
)
|
||||
from aruco.depth_save import load_depth_data
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option(
|
||||
"--input-extrinsics",
|
||||
"-i",
|
||||
required=True,
|
||||
type=click.Path(exists=True, dir_okay=False),
|
||||
help="Input extrinsics JSON file.",
|
||||
)
|
||||
@click.option(
|
||||
"--input-depth",
|
||||
"-d",
|
||||
required=True,
|
||||
type=click.Path(exists=True, dir_okay=False),
|
||||
help="Input depth HDF5 file.",
|
||||
)
|
||||
@click.option(
|
||||
"--output-extrinsics",
|
||||
"-o",
|
||||
required=True,
|
||||
type=click.Path(dir_okay=False),
|
||||
help="Output extrinsics JSON file.",
|
||||
)
|
||||
@click.option(
|
||||
"--metrics-json",
|
||||
type=click.Path(dir_okay=False),
|
||||
help="Optional path to save metrics JSON.",
|
||||
)
|
||||
@click.option(
|
||||
"--plot/--no-plot",
|
||||
default=True,
|
||||
help="Generate diagnostic plot.",
|
||||
)
|
||||
@click.option(
|
||||
"--plot-output",
|
||||
type=click.Path(dir_okay=False),
|
||||
help="Path for diagnostic plot HTML (defaults to output_extrinsics_base.html).",
|
||||
)
|
||||
@click.option(
|
||||
"--max-rotation-deg",
|
||||
default=5.0,
|
||||
help="Maximum allowed rotation correction in degrees.",
|
||||
)
|
||||
@click.option(
|
||||
"--max-translation-m",
|
||||
default=0.1,
|
||||
help="Maximum allowed translation correction in meters.",
|
||||
)
|
||||
@click.option(
|
||||
"--ransac-threshold",
|
||||
default=0.02,
|
||||
help="RANSAC distance threshold in meters.",
|
||||
)
|
||||
@click.option(
|
||||
"--min-inlier-ratio",
|
||||
default=0.0,
|
||||
help="Minimum ratio of inliers to total points (0.0-1.0).",
|
||||
)
|
||||
@click.option(
|
||||
"--height-range",
|
||||
nargs=2,
|
||||
type=float,
|
||||
default=(0.2, 5.0),
|
||||
help="Min and max height (depth) range in meters.",
|
||||
)
|
||||
@click.option(
|
||||
"--stride",
|
||||
default=8,
|
||||
help="Pixel stride for depth sampling.",
|
||||
)
|
||||
@click.option(
|
||||
"--seed",
|
||||
type=int,
|
||||
help="Random seed for RANSAC determinism.",
|
||||
)
|
||||
@click.option(
|
||||
"--debug/--no-debug",
|
||||
default=False,
|
||||
help="Enable debug logging.",
|
||||
)
|
||||
def main(
|
||||
input_extrinsics: str,
|
||||
input_depth: str,
|
||||
output_extrinsics: str,
|
||||
metrics_json: Optional[str],
|
||||
plot: bool,
|
||||
plot_output: Optional[str],
|
||||
max_rotation_deg: float,
|
||||
max_translation_m: float,
|
||||
ransac_threshold: float,
|
||||
min_inlier_ratio: float,
|
||||
height_range: tuple[float, float],
|
||||
stride: int,
|
||||
seed: Optional[int],
|
||||
debug: bool,
|
||||
):
|
||||
"""
|
||||
Refine camera extrinsics by aligning the ground plane detected in depth maps.
|
||||
|
||||
Loads existing extrinsics and depth data, detects the floor plane in each camera,
|
||||
and computes a correction transform to align the floor to Y=0.
|
||||
"""
|
||||
# Configure logging
|
||||
logger.remove()
|
||||
logger.add(
|
||||
sys.stderr,
|
||||
level="DEBUG" if debug else "INFO",
|
||||
format="<green>{time:HH:mm:ss}</green> | <level>{message}</level>",
|
||||
)
|
||||
|
||||
try:
|
||||
# 1. Load Extrinsics
|
||||
logger.info(f"Loading extrinsics from {input_extrinsics}")
|
||||
with open(input_extrinsics, "r") as f:
|
||||
extrinsics_data = json.load(f)
|
||||
|
||||
# Parse extrinsics into Dict[str, Mat44]
|
||||
extrinsics: Dict[str, Mat44] = {}
|
||||
for serial, data in extrinsics_data.items():
|
||||
if serial == "_meta":
|
||||
continue
|
||||
if "pose" in data:
|
||||
pose_str = data["pose"]
|
||||
T = np.fromstring(pose_str, sep=" ").reshape(4, 4)
|
||||
extrinsics[serial] = T
|
||||
|
||||
if not extrinsics:
|
||||
raise click.UsageError("No valid camera poses found in input extrinsics.")
|
||||
|
||||
# 2. Load Depth Data
|
||||
logger.info(f"Loading depth data from {input_depth}")
|
||||
depth_data = load_depth_data(input_depth)
|
||||
|
||||
# Prepare camera data for refinement
|
||||
camera_data_for_refine: Dict[str, Dict[str, Any]] = {}
|
||||
for serial, data in depth_data.items():
|
||||
# Use pooled depth if available, otherwise check raw frames
|
||||
depth_map = data.get("pooled_depth")
|
||||
if depth_map is None:
|
||||
# Fallback to first raw frame if available
|
||||
raw_frames = data.get("raw_frames", [])
|
||||
if raw_frames:
|
||||
depth_map = raw_frames[0].get("depth_map")
|
||||
|
||||
if depth_map is not None:
|
||||
camera_data_for_refine[serial] = {
|
||||
"depth": depth_map,
|
||||
"K": data["intrinsics"],
|
||||
}
|
||||
|
||||
if not camera_data_for_refine:
|
||||
raise click.UsageError("No depth maps found in input depth file.")
|
||||
|
||||
# 3. Configure Refinement
|
||||
config = GroundPlaneConfig(
|
||||
enabled=True,
|
||||
target_y=0.0,
|
||||
stride=stride,
|
||||
depth_min=height_range[0],
|
||||
depth_max=height_range[1],
|
||||
ransac_dist_thresh=ransac_threshold,
|
||||
max_rotation_deg=max_rotation_deg,
|
||||
max_translation_m=max_translation_m,
|
||||
min_inlier_ratio=min_inlier_ratio,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
# 4. Run Refinement
|
||||
logger.info("Running ground plane refinement...")
|
||||
new_extrinsics, metrics = refine_ground_from_depth(
|
||||
camera_data_for_refine, extrinsics, config
|
||||
)
|
||||
|
||||
logger.info(f"Refinement result: {metrics.message}")
|
||||
if metrics.success:
|
||||
logger.info(f"Max rotation: {metrics.rotation_deg:.2f} deg")
|
||||
logger.info(f"Max translation: {metrics.translation_m:.3f} m")
|
||||
|
||||
# 5. Save Output Extrinsics
|
||||
output_data = extrinsics_data.copy()
|
||||
|
||||
for serial, T_new in new_extrinsics.items():
|
||||
if serial in output_data:
|
||||
pose_str = " ".join(f"{x:.6f}" for x in T_new.flatten())
|
||||
output_data[serial]["pose"] = pose_str
|
||||
|
||||
if serial in metrics.camera_corrections:
|
||||
T_corr = metrics.camera_corrections[serial]
|
||||
trace = np.trace(T_corr[:3, :3])
|
||||
cos_angle = np.clip((trace - 1) / 2, -1.0, 1.0)
|
||||
rot_deg = float(np.rad2deg(np.arccos(cos_angle)))
|
||||
trans_m = float(np.linalg.norm(T_corr[:3, 3]))
|
||||
|
||||
output_data[serial]["ground_refine"] = {
|
||||
"corrected": True,
|
||||
"delta_rot_deg": rot_deg,
|
||||
"delta_trans_m": trans_m,
|
||||
}
|
||||
else:
|
||||
output_data[serial]["ground_refine"] = {
|
||||
"corrected": False,
|
||||
"reason": "skipped_or_failed",
|
||||
}
|
||||
|
||||
if "_meta" not in output_data:
|
||||
output_data["_meta"] = {}
|
||||
|
||||
output_data["_meta"]["ground_refined"] = {
|
||||
"timestamp": str(np.datetime64("now")),
|
||||
"config": {
|
||||
"max_rotation_deg": max_rotation_deg,
|
||||
"max_translation_m": max_translation_m,
|
||||
"ransac_threshold": ransac_threshold,
|
||||
"height_range": height_range,
|
||||
},
|
||||
"metrics": {
|
||||
"success": metrics.success,
|
||||
"num_cameras_total": metrics.num_cameras_total,
|
||||
"num_cameras_valid": metrics.num_cameras_valid,
|
||||
"correction_applied": metrics.correction_applied,
|
||||
"max_rotation_deg": metrics.rotation_deg,
|
||||
"max_translation_m": metrics.translation_m,
|
||||
},
|
||||
}
|
||||
|
||||
logger.info(f"Saving refined extrinsics to {output_extrinsics}")
|
||||
with open(output_extrinsics, "w") as f:
|
||||
json.dump(output_data, f, indent=4, sort_keys=True)
|
||||
|
||||
# 6. Save Metrics JSON (Optional)
|
||||
if metrics_json:
|
||||
metrics_data = {
|
||||
"success": metrics.success,
|
||||
"message": metrics.message,
|
||||
"num_cameras_total": metrics.num_cameras_total,
|
||||
"num_cameras_valid": metrics.num_cameras_valid,
|
||||
"skipped_cameras": metrics.skipped_cameras,
|
||||
"max_rotation_deg": metrics.rotation_deg,
|
||||
"max_translation_m": metrics.translation_m,
|
||||
"camera_corrections": {
|
||||
s: " ".join(f"{x:.6f}" for x in T.flatten())
|
||||
for s, T in metrics.camera_corrections.items()
|
||||
},
|
||||
}
|
||||
logger.info(f"Saving metrics to {metrics_json}")
|
||||
with open(metrics_json, "w") as f:
|
||||
json.dump(metrics_data, f, indent=4)
|
||||
|
||||
# 7. Generate Plot (Optional)
|
||||
if plot:
|
||||
if not plot_output:
|
||||
plot_output = str(Path(output_extrinsics).with_suffix(".html"))
|
||||
|
||||
logger.info(f"Generating diagnostic plot to {plot_output}")
|
||||
fig = create_ground_diagnostic_plot(
|
||||
metrics,
|
||||
camera_data_for_refine,
|
||||
extrinsics, # before
|
||||
new_extrinsics, # after
|
||||
)
|
||||
save_diagnostic_plot(fig, plot_output)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error: {e}")
|
||||
if debug:
|
||||
raise
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user