refactor: things
This commit is contained in:
@@ -0,0 +1,502 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Compare two camera pose sets from different world frames using rigid alignment.
|
||||
Assumes both pose sets are in world_from_cam convention.
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Final
|
||||
|
||||
import click
|
||||
import numpy as np
|
||||
import plotly.graph_objects as go
|
||||
|
||||
|
||||
def parse_pose(pose_str: str, context: str = "") -> np.ndarray:
|
||||
vals = [float(x) for x in pose_str.split()]
|
||||
if len(vals) != 16:
|
||||
raise ValueError(f"[{context}] Expected 16 values for pose, got {len(vals)}")
|
||||
pose = np.array(vals).reshape((4, 4))
|
||||
|
||||
# Validate transformation matrix properties
|
||||
# 1. Last row check [0, 0, 0, 1]
|
||||
last_row = pose[3, :]
|
||||
expected_last_row = np.array([0, 0, 0, 1], dtype=float)
|
||||
if not np.allclose(last_row, expected_last_row, atol=1e-5):
|
||||
raise ValueError(
|
||||
f"[{context}] Invalid last row in transformation matrix: {last_row}. "
|
||||
f"Expected [0, 0, 0, 1]"
|
||||
)
|
||||
|
||||
# 2. Rotation block orthonormality
|
||||
R = pose[:3, :3]
|
||||
# R @ R.T approx I
|
||||
identity_check = R @ R.T
|
||||
if not np.allclose(identity_check, np.eye(3), atol=1e-3):
|
||||
raise ValueError(
|
||||
f"[{context}] Rotation block is not orthonormal (R @ R.T != I)."
|
||||
)
|
||||
|
||||
# 3. Determinant check det(R) approx 1
|
||||
det = np.linalg.det(R)
|
||||
if not np.allclose(det, 1.0, atol=1e-3):
|
||||
raise ValueError(
|
||||
f"[{context}] Rotation block determinant is {det:.6f}, expected 1.0 (improper rotation or scaling)."
|
||||
)
|
||||
|
||||
return pose
|
||||
|
||||
|
||||
def load_poses_from_json(path: str) -> dict[str, np.ndarray]:
|
||||
"""
|
||||
Heuristically load poses from a JSON file.
|
||||
Supports:
|
||||
1) flat: {"serial": {"pose": "..."}}
|
||||
2) nested Fusion: {"serial": {"FusionConfiguration": {"pose": "..."}}}
|
||||
"""
|
||||
with open(path, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
poses: dict[str, np.ndarray] = {}
|
||||
for serial, entry in data.items():
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
|
||||
context = f"File: {path}, Serial: {serial}"
|
||||
|
||||
# Check nested FusionConfiguration first
|
||||
if "FusionConfiguration" in entry and isinstance(
|
||||
entry["FusionConfiguration"], dict
|
||||
):
|
||||
if "pose" in entry["FusionConfiguration"]:
|
||||
poses[str(serial)] = parse_pose(
|
||||
entry["FusionConfiguration"]["pose"], context=context
|
||||
)
|
||||
# Then check flat
|
||||
elif "pose" in entry:
|
||||
poses[str(serial)] = parse_pose(entry["pose"], context=context)
|
||||
|
||||
if not poses:
|
||||
raise click.UsageError(
|
||||
f"No parsable poses found in {path}.\n"
|
||||
"Expected formats:\n"
|
||||
' 1) Flat: {"serial": {"pose": "..."}}\n'
|
||||
' 2) Nested: {"serial": {"FusionConfiguration": {"pose": "..."}}}'
|
||||
)
|
||||
return poses
|
||||
|
||||
|
||||
def serialize_pose(pose: np.ndarray) -> str:
|
||||
return " ".join(f"{x:.6f}" for x in pose.flatten())
|
||||
|
||||
|
||||
def rigid_transform_3d(A: np.ndarray, B: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Find rigid alignment (R, t) such that R*A + t approx B.
|
||||
A, B are (N, 3) arrays of points.
|
||||
Uses Kabsch algorithm.
|
||||
"""
|
||||
assert A.shape == B.shape
|
||||
centroid_A = np.mean(A, axis=0)
|
||||
centroid_B = np.mean(B, axis=0)
|
||||
|
||||
AA = A - centroid_A
|
||||
BB = B - centroid_B
|
||||
|
||||
H = AA.T @ BB
|
||||
U, S, Vt = np.linalg.svd(H)
|
||||
R_mat = Vt.T @ U.T
|
||||
|
||||
if np.linalg.det(R_mat) < 0:
|
||||
Vt[2, :] *= -1
|
||||
R_mat = Vt.T @ U.T
|
||||
|
||||
t = centroid_B - R_mat @ centroid_A
|
||||
return R_mat, t
|
||||
|
||||
|
||||
def get_camera_center(pose: np.ndarray) -> np.ndarray:
|
||||
return pose[:3, 3]
|
||||
|
||||
|
||||
def get_camera_up(pose: np.ndarray) -> np.ndarray:
|
||||
# In CV convention, Y is down, so -Y is up.
|
||||
# R is [x_axis, y_axis, z_axis]
|
||||
return -pose[:3, 1]
|
||||
|
||||
|
||||
def rotation_error_deg(R1: np.ndarray, R2: np.ndarray) -> float:
|
||||
R_rel = R1.T @ R2
|
||||
cos_theta = (np.trace(R_rel) - 1.0) / 2.0
|
||||
cos_theta = np.clip(cos_theta, -1.0, 1.0)
|
||||
return np.degrees(np.arccos(cos_theta))
|
||||
|
||||
|
||||
def angle_between_vectors_deg(v1: np.ndarray, v2: np.ndarray) -> float:
|
||||
v1_u = v1 / np.linalg.norm(v1)
|
||||
v2_u = v2 / np.linalg.norm(v2)
|
||||
cos_theta = np.dot(v1_u, v2_u)
|
||||
cos_theta = np.clip(cos_theta, -1.0, 1.0)
|
||||
return np.degrees(np.arccos(cos_theta))
|
||||
|
||||
|
||||
def add_camera_trace(
|
||||
fig: go.Figure,
|
||||
pose: np.ndarray,
|
||||
label: str,
|
||||
scale: float = 0.2,
|
||||
frustum_scale: float = 0.5,
|
||||
fov_deg: float = 60.0,
|
||||
color: str = "blue",
|
||||
):
|
||||
"""
|
||||
Adds a camera frustum and axes to the Plotly figure.
|
||||
"""
|
||||
R = pose[:3, :3]
|
||||
center = pose[:3, 3]
|
||||
|
||||
# OpenCV convention: X right, Y down, Z forward
|
||||
x_axis_local = np.array([1, 0, 0])
|
||||
y_axis_local = np.array([0, 1, 0])
|
||||
z_axis_local = np.array([0, 0, 1])
|
||||
|
||||
# Transform local axes to world
|
||||
x_axis_world = R @ x_axis_local
|
||||
y_axis_world = R @ y_axis_local
|
||||
z_axis_world = R @ z_axis_local
|
||||
|
||||
# Frustum points in local coordinates
|
||||
fov_rad = np.radians(fov_deg)
|
||||
w = frustum_scale * np.tan(fov_rad / 2.0)
|
||||
h = w * 0.75 # 4:3 aspect ratio assumption
|
||||
|
||||
pts_local = np.array(
|
||||
[
|
||||
[0, 0, 0], # Center
|
||||
[-w, -h, frustum_scale], # Top-Left
|
||||
[w, -h, frustum_scale], # Top-Right
|
||||
[w, h, frustum_scale], # Bottom-Right
|
||||
[-w, h, frustum_scale], # Bottom-Left
|
||||
]
|
||||
)
|
||||
|
||||
# Transform frustum to world
|
||||
pts_world = (R @ pts_local.T).T + center
|
||||
|
||||
# Create lines for frustum
|
||||
x_lines, y_lines, z_lines = [], [], []
|
||||
|
||||
def add_line(i, j):
|
||||
x_lines.extend([pts_world[i, 0], pts_world[j, 0], None])
|
||||
y_lines.extend([pts_world[i, 1], pts_world[j, 1], None])
|
||||
z_lines.extend([pts_world[i, 2], pts_world[j, 2], None])
|
||||
|
||||
for i in range(1, 5):
|
||||
add_line(0, i)
|
||||
add_line(1, 2)
|
||||
add_line(2, 3)
|
||||
add_line(3, 4)
|
||||
add_line(4, 1)
|
||||
|
||||
fig.add_trace(
|
||||
go.Scatter3d(
|
||||
x=x_lines,
|
||||
y=y_lines,
|
||||
z=z_lines,
|
||||
mode="lines",
|
||||
line=dict(color=color, width=2),
|
||||
name=f"{label} Frustum",
|
||||
showlegend=False,
|
||||
hoverinfo="skip",
|
||||
)
|
||||
)
|
||||
|
||||
# Add center point with label
|
||||
fig.add_trace(
|
||||
go.Scatter3d(
|
||||
x=[center[0]],
|
||||
y=[center[1]],
|
||||
z=[center[2]],
|
||||
mode="markers+text",
|
||||
marker=dict(size=4, color="black"),
|
||||
text=[label],
|
||||
textposition="top center",
|
||||
name=label,
|
||||
showlegend=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Add axes (RGB = XYZ)
|
||||
for axis_world, axis_color in zip(
|
||||
[x_axis_world, y_axis_world, z_axis_world], ["red", "green", "blue"]
|
||||
):
|
||||
end = center + axis_world * scale
|
||||
fig.add_trace(
|
||||
go.Scatter3d(
|
||||
x=[center[0], end[0]],
|
||||
y=[center[1], end[1]],
|
||||
z=[center[2], end[2]],
|
||||
mode="lines",
|
||||
line=dict(color=axis_color, width=3),
|
||||
showlegend=False,
|
||||
hoverinfo="skip",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option(
|
||||
"--pose-a-json",
|
||||
type=click.Path(exists=True),
|
||||
required=True,
|
||||
help="Pose set A. Supports flat {'serial': {'pose': '...'}} or nested FusionConfiguration format.",
|
||||
)
|
||||
@click.option(
|
||||
"--pose-b-json",
|
||||
type=click.Path(exists=True),
|
||||
required=True,
|
||||
help="Pose set B. Supports flat {'serial': {'pose': '...'}} or nested FusionConfiguration format.",
|
||||
)
|
||||
@click.option(
|
||||
"--report-json",
|
||||
type=click.Path(),
|
||||
required=True,
|
||||
help="Output path for comparison report",
|
||||
)
|
||||
@click.option(
|
||||
"--aligned-pose-b-json",
|
||||
type=click.Path(),
|
||||
help="Output path for aligned pose B set",
|
||||
)
|
||||
@click.option(
|
||||
"--plot-output",
|
||||
type=click.Path(),
|
||||
help="Output path for visualization (HTML or PNG)",
|
||||
)
|
||||
@click.option(
|
||||
"--show-plot",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Show the plot interactively",
|
||||
)
|
||||
@click.option(
|
||||
"--frustum-scale",
|
||||
type=float,
|
||||
default=0.3,
|
||||
help="Scale of the camera frustum",
|
||||
)
|
||||
@click.option(
|
||||
"--axis-scale",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="Scale of the camera axes",
|
||||
)
|
||||
def main(
|
||||
pose_a_json: str,
|
||||
pose_b_json: str,
|
||||
report_json: str,
|
||||
aligned_pose_b_json: str | None,
|
||||
plot_output: str | None,
|
||||
show_plot: bool,
|
||||
frustum_scale: float,
|
||||
axis_scale: float,
|
||||
):
|
||||
"""
|
||||
Compare two camera pose sets from different world frames using rigid alignment.
|
||||
Both are treated as T_world_from_cam.
|
||||
|
||||
Supports symmetric, heuristic input parsing for both A and B:
|
||||
1) flat: {"serial": {"pose": "..."}}
|
||||
2) nested Fusion: {"serial": {"FusionConfiguration": {"pose": "..."}}}
|
||||
"""
|
||||
poses_a = load_poses_from_json(pose_a_json)
|
||||
poses_b = load_poses_from_json(pose_b_json)
|
||||
|
||||
shared_serials = sorted(list(set(poses_a.keys()) & set(poses_b.keys())))
|
||||
if len(shared_serials) < 3:
|
||||
click.echo(
|
||||
f"Error: Found only {len(shared_serials)} shared serials ({shared_serials}). Need at least 3.",
|
||||
err=True,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
pts_b = np.array([get_camera_center(poses_b[s]) for s in shared_serials])
|
||||
pts_a = np.array([get_camera_center(poses_a[s]) for s in shared_serials])
|
||||
|
||||
# Align B to A: R_align * pts_b + t_align approx pts_a
|
||||
R_align, t_align = rigid_transform_3d(pts_b, pts_a)
|
||||
|
||||
T_align = np.eye(4)
|
||||
T_align[:3, :3] = R_align
|
||||
T_align[:3, 3] = t_align
|
||||
|
||||
per_cam_results = []
|
||||
pos_errors = []
|
||||
rot_errors = []
|
||||
up_errors = []
|
||||
|
||||
for s in shared_serials:
|
||||
T_b = poses_b[s]
|
||||
T_a = poses_a[s]
|
||||
|
||||
# T_world_a_from_cam = T_world_a_from_world_b * T_world_b_from_cam
|
||||
T_b_aligned = T_align @ T_b
|
||||
|
||||
pos_err = np.linalg.norm(
|
||||
get_camera_center(T_b_aligned) - get_camera_center(T_a)
|
||||
)
|
||||
|
||||
rot_err = rotation_error_deg(T_b_aligned[:3, :3], T_a[:3, :3])
|
||||
|
||||
up_b = get_camera_up(T_b_aligned)
|
||||
up_a = get_camera_up(T_a)
|
||||
up_err = angle_between_vectors_deg(up_b, up_a)
|
||||
|
||||
per_cam_results.append(
|
||||
{
|
||||
"serial": s,
|
||||
"position_error_m": float(pos_err),
|
||||
"rotation_error_deg": float(rot_err),
|
||||
"up_consistency_error_deg": float(up_err),
|
||||
}
|
||||
)
|
||||
|
||||
pos_errors.append(pos_err)
|
||||
rot_errors.append(rot_err)
|
||||
up_errors.append(up_err)
|
||||
|
||||
report = {
|
||||
"shared_serials": shared_serials,
|
||||
"alignment": {
|
||||
"R_align": R_align.tolist(),
|
||||
"t_align": t_align.tolist(),
|
||||
"T_align": T_align.tolist(),
|
||||
},
|
||||
"per_camera": per_cam_results,
|
||||
"summary": {
|
||||
"mean_position_error_m": float(np.mean(pos_errors)),
|
||||
"max_position_error_m": float(np.max(pos_errors)),
|
||||
"mean_rotation_error_deg": float(np.mean(rot_errors)),
|
||||
"max_rotation_error_deg": float(np.max(rot_errors)),
|
||||
"mean_up_consistency_error_deg": float(np.mean(up_errors)),
|
||||
"max_up_consistency_error_deg": float(np.max(up_errors)),
|
||||
},
|
||||
}
|
||||
|
||||
Path(report_json).parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(report_json, "w") as f:
|
||||
json.dump(report, f, indent=4)
|
||||
click.echo(f"Report written to {report_json}")
|
||||
|
||||
if aligned_pose_b_json:
|
||||
aligned_data = {}
|
||||
for s, T_b in poses_b.items():
|
||||
T_b_aligned = T_align @ T_b
|
||||
aligned_data[s] = {"pose": serialize_pose(T_b_aligned)}
|
||||
|
||||
Path(aligned_pose_b_json).parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(aligned_pose_b_json, "w") as f:
|
||||
json.dump(aligned_data, f, indent=4)
|
||||
click.echo(f"Aligned pose B set written to {aligned_pose_b_json}")
|
||||
|
||||
if plot_output or show_plot:
|
||||
fig = go.Figure()
|
||||
|
||||
show_axis: Final[bool] = True
|
||||
if show_axis:
|
||||
for axis, color in zip(
|
||||
[np.eye(3)[:, 0], np.eye(3)[:, 1], np.eye(3)[:, 2]],
|
||||
["red", "green", "blue"],
|
||||
):
|
||||
fig.add_trace(
|
||||
go.Scatter3d(
|
||||
x=[0, axis[0] * axis_scale],
|
||||
y=[0, axis[1] * axis_scale],
|
||||
z=[0, axis[2] * axis_scale],
|
||||
mode="lines",
|
||||
line=dict(color=color, width=4),
|
||||
name=f"World {'XYZ'[np.argmax(axis)]}",
|
||||
showlegend=True,
|
||||
)
|
||||
)
|
||||
|
||||
show_ground: Final[bool] = False
|
||||
if show_ground:
|
||||
ground_size = 5.0
|
||||
half_size = ground_size / 2.0
|
||||
x_grid = np.linspace(-half_size, half_size, 2)
|
||||
z_grid = np.linspace(-half_size, half_size, 2)
|
||||
x_mesh, z_mesh = np.meshgrid(x_grid, z_grid)
|
||||
y_mesh = np.zeros_like(x_mesh)
|
||||
fig.add_trace(
|
||||
go.Surface(
|
||||
x=x_mesh,
|
||||
y=y_mesh,
|
||||
z=z_mesh,
|
||||
showscale=False,
|
||||
opacity=0.1,
|
||||
colorscale=[[0, "gray"], [1, "gray"]],
|
||||
name="Ground Plane",
|
||||
hoverinfo="skip",
|
||||
)
|
||||
)
|
||||
|
||||
for s in sorted(poses_a.keys()):
|
||||
add_camera_trace(
|
||||
fig,
|
||||
poses_a[s],
|
||||
f"a_{s}",
|
||||
scale=axis_scale,
|
||||
frustum_scale=frustum_scale,
|
||||
color="blue",
|
||||
)
|
||||
|
||||
for s in sorted(poses_b.keys()):
|
||||
T_b_aligned = T_align @ poses_b[s]
|
||||
add_camera_trace(
|
||||
fig,
|
||||
T_b_aligned,
|
||||
f"b_{s}",
|
||||
scale=axis_scale,
|
||||
frustum_scale=frustum_scale,
|
||||
color="orange",
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
title="Pose A vs Aligned Pose B",
|
||||
scene=dict(
|
||||
xaxis_title="X (Right)",
|
||||
yaxis_title="Y (Down)",
|
||||
zaxis_title="Z (Forward)",
|
||||
aspectmode="data",
|
||||
camera=dict(
|
||||
up=dict(x=0, y=-1, z=0),
|
||||
eye=dict(x=1.5, y=-1.5, z=1.5),
|
||||
),
|
||||
),
|
||||
margin=dict(l=0, r=0, b=0, t=40),
|
||||
)
|
||||
|
||||
if plot_output:
|
||||
if plot_output.endswith(".html"):
|
||||
fig.write_html(plot_output)
|
||||
click.echo(f"Plot saved to {plot_output}")
|
||||
else:
|
||||
try:
|
||||
fig.write_image(plot_output)
|
||||
click.echo(f"Plot saved to {plot_output}")
|
||||
except Exception as e:
|
||||
click.echo(f"Error saving image (ensure kaleido is installed): {e}")
|
||||
if not plot_output.endswith(".html"):
|
||||
html_out = str(Path(plot_output).with_suffix(".html"))
|
||||
fig.write_html(html_out)
|
||||
click.echo(f"Fallback: Plot saved to {html_out}")
|
||||
|
||||
if show_plot:
|
||||
fig.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main() # pylint: disable=no-value-for-parameter
|
||||
Reference in New Issue
Block a user