""" Utility script to visualize camera extrinsics from a JSON file. """ import json import argparse import numpy as np import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D # type: ignore from typing import Any def parse_pose(pose_str: str) -> np.ndarray: """Parses a 16-float pose string into a 4x4 matrix.""" try: vals = [float(x) for x in pose_str.split()] if len(vals) != 16: raise ValueError(f"Expected 16 values, got {len(vals)}") return np.array(vals).reshape((4, 4)) except Exception as e: raise ValueError(f"Failed to parse pose string: {e}") def plot_camera( ax: Any, pose: np.ndarray, label: str, scale: float = 0.2, birdseye: bool = False, convention: str = "world_from_cam", ): """ Plots a camera center and its orientation axes. X=red, Y=green, Z=blue (right-handed convention) World convention: Y-up (vertical), X-Z (ground plane) """ R = pose[:3, :3] t = pose[:3, 3] if convention == "cam_from_world": # Camera center in world coordinates: C = -R^T * t center = -R.T @ t # Camera orientation in world coordinates: R_world_from_cam = R^T # The columns of R_world_from_cam are the axes axes = R.T else: # world_from_cam center = t axes = R x_axis = axes[:, 0] y_axis = axes[:, 1] z_axis = axes[:, 2] if birdseye: # Bird-eye view: X-Z plane (looking down +Y) ax.scatter(center[0], center[2], color="black", s=20) ax.text(center[0], center[2], label, fontsize=9) # Plot projected axes ax.quiver( center[0], center[2], x_axis[0], x_axis[2], color="red", scale=1 / scale, scale_units="xy", angles="xy", ) ax.quiver( center[0], center[2], y_axis[0], y_axis[2], color="green", scale=1 / scale, scale_units="xy", angles="xy", ) ax.quiver( center[0], center[2], z_axis[0], z_axis[2], color="blue", scale=1 / scale, scale_units="xy", angles="xy", ) else: ax.scatter(center[0], center[1], center[2], color="black", s=20) ax.text(center[0], center[1], center[2], label, fontsize=9) ax.quiver( center[0], center[1], center[2], x_axis[0], x_axis[1], x_axis[2], length=scale, color="red", ) ax.quiver( center[0], center[1], center[2], y_axis[0], y_axis[1], y_axis[2], length=scale, color="green", ) ax.quiver( center[0], center[1], center[2], z_axis[0], z_axis[1], z_axis[2], length=scale, color="blue", ) def main(): parser = argparse.ArgumentParser( description="Visualize camera extrinsics from JSON." ) parser.add_argument("--input", "-i", required=True, help="Path to input JSON file.") parser.add_argument( "--output", "-o", help="Path to save the output visualization (PNG)." ) parser.add_argument( "--show", action="store_true", help="Show the plot interactively." ) parser.add_argument( "--scale", type=float, default=0.2, help="Scale of the camera axes." ) parser.add_argument( "--birdseye", action="store_true", help="Show a top-down bird-eye view (X-Z plane in Y-up convention).", ) parser.add_argument( "--pose-convention", choices=["auto", "world_from_cam", "cam_from_world"], default="auto", help="Interpretation of the pose matrix in JSON. 'auto' selects based on plausible spread.", ) args = parser.parse_args() try: with open(str(args.input), "r") as f: data = json.load(f) except Exception as e: print(f"Error reading input file: {e}") return fig = plt.figure(figsize=(10, 8)) if args.birdseye: ax = fig.add_subplot(111) else: ax = fig.add_subplot(111, projection="3d") # First pass: parse all poses poses = {} for serial, cam_data in data.items(): if not isinstance(cam_data, dict) or "pose" not in cam_data: continue try: poses[serial] = parse_pose(str(cam_data["pose"])) except ValueError as e: print(f"Warning: Skipping camera {serial} due to error: {e}") if not poses: print("No valid camera poses found in the input file.") return # Determine convention convention = args.pose_convention if convention == "auto": # Try both and see which one gives a larger X-Z spread def get_spread(conv): centers = [] for p in poses.values(): R = p[:3, :3] t = p[:3, 3] if conv == "cam_from_world": c = -R.T @ t else: c = t centers.append(c) centers = np.array(centers) dx = centers[:, 0].max() - centers[:, 0].min() dz = centers[:, 2].max() - centers[:, 2].min() return dx * dz s1 = get_spread("world_from_cam") s2 = get_spread("cam_from_world") convention = "world_from_cam" if s1 >= s2 else "cam_from_world" print( f"Auto-selected pose convention: {convention} (spreads: {s1:.2f} vs {s2:.2f})" ) camera_centers: list[np.ndarray] = [] for serial, pose in poses.items(): plot_camera( ax, pose, str(serial), scale=float(args.scale), birdseye=bool(args.birdseye), convention=convention, ) R = pose[:3, :3] t = pose[:3, 3] if convention == "cam_from_world": center = -R.T @ t else: center = t camera_centers.append(center) found_cameras = len(camera_centers) centers = np.array(camera_centers) max_range = float( np.array( [ centers[:, 0].max() - centers[:, 0].min(), centers[:, 1].max() - centers[:, 1].min(), centers[:, 2].max() - centers[:, 2].min(), ] ).max() / 2.0 ) mid_x = float((centers[:, 0].max() + centers[:, 0].min()) * 0.5) mid_y = float((centers[:, 1].max() + centers[:, 1].min()) * 0.5) mid_z = float((centers[:, 2].max() + centers[:, 2].min()) * 0.5) if args.birdseye: ax.set_xlim(mid_x - max_range - 0.5, mid_x + max_range + 0.5) ax.set_ylim(mid_z - max_range - 0.5, mid_z + max_range + 0.5) ax.set_xlabel("X (m)") ax.set_ylabel("Z (m)") ax.set_aspect("equal") ax.set_title(f"Camera Extrinsics (Bird-eye, {convention}): {args.input}") ax.grid(True) else: # We know ax is a 3D axis here ax_3d: Any = ax ax_3d.set_xlim(mid_x - max_range - 0.5, mid_x + max_range + 0.5) ax_3d.set_ylim(mid_y - max_range - 0.5, mid_y + max_range + 0.5) ax_3d.set_zlim(mid_z - max_range - 0.5, mid_z + max_range + 0.5) ax_3d.set_xlabel("X (m)") ax_3d.set_ylabel("Y (Up) (m)") ax_3d.set_zlabel("Z (m)") ax_3d.set_title(f"Camera Extrinsics ({convention}): {args.input}") from matplotlib.lines import Line2D legend_elements = [ Line2D([0], [0], color="red", lw=2, label="X"), Line2D([0], [0], color="green", lw=2, label="Y"), Line2D([0], [0], color="blue", lw=2, label="Z"), ] ax.legend(handles=legend_elements, loc="upper right") if args.output: plt.savefig(str(args.output)) print(f"Visualization saved to {args.output}") if args.show: plt.show() elif not args.output: print( "No output path specified and --show not passed. Plot not saved or shown." ) if __name__ == "__main__": main()