feat: add explicit 4x4 transformation matrix validation to compare_pose_sets.py

This commit is contained in:
2026-02-09 03:24:36 +00:00
parent d6c7829b1e
commit c497af7783
3 changed files with 165 additions and 39 deletions
+98 -39
View File
@@ -7,17 +7,85 @@ Assumes both pose sets are in world_from_cam convention.
import json
import sys
from pathlib import Path
from typing import Final
import click
import numpy as np
import plotly.graph_objects as go
def parse_pose(pose_str: str) -> np.ndarray:
def parse_pose(pose_str: str, context: str = "") -> np.ndarray:
vals = [float(x) for x in pose_str.split()]
if len(vals) != 16:
raise ValueError(f"Expected 16 values for pose, got {len(vals)}")
return np.array(vals).reshape((4, 4))
raise ValueError(f"[{context}] Expected 16 values for pose, got {len(vals)}")
pose = np.array(vals).reshape((4, 4))
# Validate transformation matrix properties
# 1. Last row check [0, 0, 0, 1]
last_row = pose[3, :]
expected_last_row = np.array([0, 0, 0, 1], dtype=float)
if not np.allclose(last_row, expected_last_row, atol=1e-5):
raise ValueError(
f"[{context}] Invalid last row in transformation matrix: {last_row}. "
f"Expected [0, 0, 0, 1]"
)
# 2. Rotation block orthonormality
R = pose[:3, :3]
# R @ R.T approx I
identity_check = R @ R.T
if not np.allclose(identity_check, np.eye(3), atol=1e-3):
raise ValueError(
f"[{context}] Rotation block is not orthonormal (R @ R.T != I)."
)
# 3. Determinant check det(R) approx 1
det = np.linalg.det(R)
if not np.allclose(det, 1.0, atol=1e-3):
raise ValueError(
f"[{context}] Rotation block determinant is {det:.6f}, expected 1.0 (improper rotation or scaling)."
)
return pose
def load_poses_from_json(path: str) -> dict[str, np.ndarray]:
"""
Heuristically load poses from a JSON file.
Supports:
1) flat: {"serial": {"pose": "..."}}
2) nested Fusion: {"serial": {"FusionConfiguration": {"pose": "..."}}}
"""
with open(path, "r") as f:
data = json.load(f)
poses: dict[str, np.ndarray] = {}
for serial, entry in data.items():
if not isinstance(entry, dict):
continue
context = f"File: {path}, Serial: {serial}"
# Check nested FusionConfiguration first
if "FusionConfiguration" in entry and isinstance(
entry["FusionConfiguration"], dict
):
if "pose" in entry["FusionConfiguration"]:
poses[str(serial)] = parse_pose(
entry["FusionConfiguration"]["pose"], context=context
)
# Then check flat
elif "pose" in entry:
poses[str(serial)] = parse_pose(entry["pose"], context=context)
if not poses:
raise click.UsageError(
f"No parsable poses found in {path}.\n"
"Expected formats:\n"
' 1) Flat: {"serial": {"pose": "..."}}\n'
' 2) Nested: {"serial": {"FusionConfiguration": {"pose": "..."}}}'
)
return poses
def serialize_pose(pose: np.ndarray) -> str:
@@ -183,13 +251,13 @@ def add_camera_trace(
"--pose-a-json",
type=click.Path(exists=True),
required=True,
help="Pose set A (serial -> {pose: '...'})",
help="Pose set A. Supports flat {'serial': {'pose': '...'}} or nested FusionConfiguration format.",
)
@click.option(
"--pose-b-json",
type=click.Path(exists=True),
required=True,
help="Pose set B (serial -> {pose: '...'} or inside_network format)",
help="Pose set B. Supports flat {'serial': {'pose': '...'}} or nested FusionConfiguration format.",
)
@click.option(
"--report-json",
@@ -210,6 +278,7 @@ def add_camera_trace(
@click.option(
"--show-plot",
is_flag=True,
default=False,
help="Show the plot interactively",
)
@click.option(
@@ -237,25 +306,13 @@ def main(
"""
Compare two camera pose sets from different world frames using rigid alignment.
Both are treated as T_world_from_cam.
Supports symmetric, heuristic input parsing for both A and B:
1) flat: {"serial": {"pose": "..."}}
2) nested Fusion: {"serial": {"FusionConfiguration": {"pose": "..."}}}
"""
with open(pose_a_json, "r") as f:
data_a = json.load(f)
with open(pose_b_json, "r") as f:
data_b = json.load(f)
poses_a: dict[str, np.ndarray] = {}
for serial, data in data_a.items():
if "pose" in data:
poses_a[str(serial)] = parse_pose(data["pose"])
poses_b: dict[str, np.ndarray] = {}
for serial, data in data_b.items():
# Support both standard and inside_network.json nested format
if "FusionConfiguration" in data and "pose" in data["FusionConfiguration"]:
poses_b[str(serial)] = parse_pose(data["FusionConfiguration"]["pose"])
elif "pose" in data:
poses_b[str(serial)] = parse_pose(data["pose"])
poses_a = load_poses_from_json(pose_a_json)
poses_b = load_poses_from_json(pose_b_json)
shared_serials = sorted(list(set(poses_a.keys()) & set(poses_b.keys())))
if len(shared_serials) < 3:
@@ -347,23 +404,25 @@ def main(
if plot_output or show_plot:
fig = go.Figure()
for axis, color in zip(
[np.eye(3)[:, 0], np.eye(3)[:, 1], np.eye(3)[:, 2]],
["red", "green", "blue"],
):
fig.add_trace(
go.Scatter3d(
x=[0, axis[0] * axis_scale * 2],
y=[0, axis[1] * axis_scale * 2],
z=[0, axis[2] * axis_scale * 2],
mode="lines",
line=dict(color=color, width=4),
name=f"World {'XYZ'[np.argmax(axis)]}",
showlegend=True,
show_axis: Final[bool] = True
if show_axis:
for axis, color in zip(
[np.eye(3)[:, 0], np.eye(3)[:, 1], np.eye(3)[:, 2]],
["red", "green", "blue"],
):
fig.add_trace(
go.Scatter3d(
x=[0, axis[0] * axis_scale],
y=[0, axis[1] * axis_scale],
z=[0, axis[2] * axis_scale],
mode="lines",
line=dict(color=color, width=4),
name=f"World {'XYZ'[np.argmax(axis)]}",
showlegend=True,
)
)
)
show_ground = False
show_ground: Final[bool] = False
if show_ground:
ground_size = 5.0
half_size = ground_size / 2.0
@@ -440,4 +499,4 @@ def main(
if __name__ == "__main__":
main()
main() # pylint: disable=no-value-for-parameter