Files
zed-playground/py_workspace/refine_ground_plane.py
T

290 lines
9.0 KiB
Python

import click
import json
import numpy as np
from pathlib import Path
from typing import Dict, Any, Optional
from loguru import logger
import sys
from aruco.ground_plane import (
GroundPlaneConfig,
refine_ground_from_depth,
create_ground_diagnostic_plot,
save_diagnostic_plot,
GroundPlaneMetrics,
Mat44,
)
from aruco.depth_save import load_depth_data
@click.command()
@click.option(
"--input-extrinsics",
"-i",
required=True,
type=click.Path(exists=True, dir_okay=False),
help="Input extrinsics JSON file.",
)
@click.option(
"--input-depth",
"-d",
required=True,
type=click.Path(exists=True, dir_okay=False),
help="Input depth HDF5 file.",
)
@click.option(
"--output-extrinsics",
"-o",
required=True,
type=click.Path(dir_okay=False),
help="Output extrinsics JSON file.",
)
@click.option(
"--metrics-json",
type=click.Path(dir_okay=False),
help="Optional path to save metrics JSON.",
)
@click.option(
"--plot/--no-plot",
default=True,
help="Generate diagnostic plot.",
)
@click.option(
"--plot-output",
type=click.Path(dir_okay=False),
help="Path for diagnostic plot HTML (defaults to output_extrinsics_base.html).",
)
@click.option(
"--max-rotation-deg",
default=5.0,
help="Maximum allowed rotation correction in degrees.",
)
@click.option(
"--max-translation-m",
default=0.1,
help="Maximum allowed translation correction in meters.",
)
@click.option(
"--ransac-threshold",
default=0.02,
help="RANSAC distance threshold in meters.",
)
@click.option(
"--min-inlier-ratio",
default=0.0,
help="Minimum ratio of inliers to total points (0.0-1.0).",
)
@click.option(
"--height-range",
nargs=2,
type=float,
default=(0.2, 5.0),
help="Min and max height (depth) range in meters.",
)
@click.option(
"--stride",
default=8,
help="Pixel stride for depth sampling.",
)
@click.option(
"--seed",
type=int,
help="Random seed for RANSAC determinism.",
)
@click.option(
"--debug/--no-debug",
default=False,
help="Enable debug logging.",
)
def main(
input_extrinsics: str,
input_depth: str,
output_extrinsics: str,
metrics_json: Optional[str],
plot: bool,
plot_output: Optional[str],
max_rotation_deg: float,
max_translation_m: float,
ransac_threshold: float,
min_inlier_ratio: float,
height_range: tuple[float, float],
stride: int,
seed: Optional[int],
debug: bool,
):
"""
Refine camera extrinsics by aligning the ground plane detected in depth maps.
Loads existing extrinsics and depth data, detects the floor plane in each camera,
and computes a correction transform to align the floor to Y=0.
"""
# Configure logging
logger.remove()
logger.add(
sys.stderr,
level="DEBUG" if debug else "INFO",
format="<green>{time:HH:mm:ss}</green> | <level>{message}</level>",
)
try:
# 1. Load Extrinsics
logger.info(f"Loading extrinsics from {input_extrinsics}")
with open(input_extrinsics, "r") as f:
extrinsics_data = json.load(f)
# Parse extrinsics into Dict[str, Mat44]
extrinsics: Dict[str, Mat44] = {}
for serial, data in extrinsics_data.items():
if serial == "_meta":
continue
if "pose" in data:
pose_str = data["pose"]
T = np.fromstring(pose_str, sep=" ").reshape(4, 4)
extrinsics[serial] = T
if not extrinsics:
raise click.UsageError("No valid camera poses found in input extrinsics.")
# 2. Load Depth Data
logger.info(f"Loading depth data from {input_depth}")
depth_data = load_depth_data(input_depth)
# Prepare camera data for refinement
camera_data_for_refine: Dict[str, Dict[str, Any]] = {}
for serial, data in depth_data.items():
# Use pooled depth if available, otherwise check raw frames
depth_map = data.get("pooled_depth")
if depth_map is None:
# Fallback to first raw frame if available
raw_frames = data.get("raw_frames", [])
if raw_frames:
depth_map = raw_frames[0].get("depth_map")
if depth_map is not None:
camera_data_for_refine[serial] = {
"depth": depth_map,
"K": data["intrinsics"],
}
if not camera_data_for_refine:
raise click.UsageError("No depth maps found in input depth file.")
# 3. Configure Refinement
config = GroundPlaneConfig(
enabled=True,
target_y=0.0,
stride=stride,
depth_min=height_range[0],
depth_max=height_range[1],
ransac_dist_thresh=ransac_threshold,
max_rotation_deg=max_rotation_deg,
max_translation_m=max_translation_m,
min_inlier_ratio=min_inlier_ratio,
seed=seed,
)
# 4. Run Refinement
logger.info("Running ground plane refinement...")
new_extrinsics, metrics = refine_ground_from_depth(
camera_data_for_refine, extrinsics, config
)
logger.info(f"Refinement result: {metrics.message}")
if metrics.success:
logger.info(f"Max rotation: {metrics.rotation_deg:.2f} deg")
logger.info(f"Max translation: {metrics.translation_m:.3f} m")
# 5. Save Output Extrinsics
output_data = extrinsics_data.copy()
for serial, T_new in new_extrinsics.items():
if serial in output_data:
pose_str = " ".join(f"{x:.6f}" for x in T_new.flatten())
output_data[serial]["pose"] = pose_str
if serial in metrics.camera_corrections:
T_corr = metrics.camera_corrections[serial]
trace = np.trace(T_corr[:3, :3])
cos_angle = np.clip((trace - 1) / 2, -1.0, 1.0)
rot_deg = float(np.rad2deg(np.arccos(cos_angle)))
trans_m = float(np.linalg.norm(T_corr[:3, 3]))
output_data[serial]["ground_refine"] = {
"corrected": True,
"delta_rot_deg": rot_deg,
"delta_trans_m": trans_m,
}
else:
output_data[serial]["ground_refine"] = {
"corrected": False,
"reason": "skipped_or_failed",
}
if "_meta" not in output_data:
output_data["_meta"] = {}
output_data["_meta"]["ground_refined"] = {
"timestamp": str(np.datetime64("now")),
"config": {
"max_rotation_deg": max_rotation_deg,
"max_translation_m": max_translation_m,
"ransac_threshold": ransac_threshold,
"height_range": height_range,
},
"metrics": {
"success": metrics.success,
"num_cameras_total": metrics.num_cameras_total,
"num_cameras_valid": metrics.num_cameras_valid,
"correction_applied": metrics.correction_applied,
"max_rotation_deg": metrics.rotation_deg,
"max_translation_m": metrics.translation_m,
},
}
logger.info(f"Saving refined extrinsics to {output_extrinsics}")
with open(output_extrinsics, "w") as f:
json.dump(output_data, f, indent=4, sort_keys=True)
# 6. Save Metrics JSON (Optional)
if metrics_json:
metrics_data = {
"success": metrics.success,
"message": metrics.message,
"num_cameras_total": metrics.num_cameras_total,
"num_cameras_valid": metrics.num_cameras_valid,
"skipped_cameras": metrics.skipped_cameras,
"max_rotation_deg": metrics.rotation_deg,
"max_translation_m": metrics.translation_m,
"camera_corrections": {
s: " ".join(f"{x:.6f}" for x in T.flatten())
for s, T in metrics.camera_corrections.items()
},
}
logger.info(f"Saving metrics to {metrics_json}")
with open(metrics_json, "w") as f:
json.dump(metrics_data, f, indent=4)
# 7. Generate Plot (Optional)
if plot:
if not plot_output:
plot_output = str(Path(output_extrinsics).with_suffix(".html"))
logger.info(f"Generating diagnostic plot to {plot_output}")
fig = create_ground_diagnostic_plot(
metrics,
camera_data_for_refine,
extrinsics, # before
new_extrinsics, # after
)
save_diagnostic_plot(fig, plot_output)
except Exception as e:
logger.error(f"Error: {e}")
if debug:
raise
sys.exit(1)
if __name__ == "__main__":
main()