feat: add explicit 4x4 transformation matrix validation to compare_pose_sets.py
This commit is contained in:
@@ -7,17 +7,85 @@ Assumes both pose sets are in world_from_cam convention.
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Final
|
||||
|
||||
import click
|
||||
import numpy as np
|
||||
import plotly.graph_objects as go
|
||||
|
||||
|
||||
def parse_pose(pose_str: str) -> np.ndarray:
|
||||
def parse_pose(pose_str: str, context: str = "") -> np.ndarray:
|
||||
vals = [float(x) for x in pose_str.split()]
|
||||
if len(vals) != 16:
|
||||
raise ValueError(f"Expected 16 values for pose, got {len(vals)}")
|
||||
return np.array(vals).reshape((4, 4))
|
||||
raise ValueError(f"[{context}] Expected 16 values for pose, got {len(vals)}")
|
||||
pose = np.array(vals).reshape((4, 4))
|
||||
|
||||
# Validate transformation matrix properties
|
||||
# 1. Last row check [0, 0, 0, 1]
|
||||
last_row = pose[3, :]
|
||||
expected_last_row = np.array([0, 0, 0, 1], dtype=float)
|
||||
if not np.allclose(last_row, expected_last_row, atol=1e-5):
|
||||
raise ValueError(
|
||||
f"[{context}] Invalid last row in transformation matrix: {last_row}. "
|
||||
f"Expected [0, 0, 0, 1]"
|
||||
)
|
||||
|
||||
# 2. Rotation block orthonormality
|
||||
R = pose[:3, :3]
|
||||
# R @ R.T approx I
|
||||
identity_check = R @ R.T
|
||||
if not np.allclose(identity_check, np.eye(3), atol=1e-3):
|
||||
raise ValueError(
|
||||
f"[{context}] Rotation block is not orthonormal (R @ R.T != I)."
|
||||
)
|
||||
|
||||
# 3. Determinant check det(R) approx 1
|
||||
det = np.linalg.det(R)
|
||||
if not np.allclose(det, 1.0, atol=1e-3):
|
||||
raise ValueError(
|
||||
f"[{context}] Rotation block determinant is {det:.6f}, expected 1.0 (improper rotation or scaling)."
|
||||
)
|
||||
|
||||
return pose
|
||||
|
||||
|
||||
def load_poses_from_json(path: str) -> dict[str, np.ndarray]:
|
||||
"""
|
||||
Heuristically load poses from a JSON file.
|
||||
Supports:
|
||||
1) flat: {"serial": {"pose": "..."}}
|
||||
2) nested Fusion: {"serial": {"FusionConfiguration": {"pose": "..."}}}
|
||||
"""
|
||||
with open(path, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
poses: dict[str, np.ndarray] = {}
|
||||
for serial, entry in data.items():
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
|
||||
context = f"File: {path}, Serial: {serial}"
|
||||
|
||||
# Check nested FusionConfiguration first
|
||||
if "FusionConfiguration" in entry and isinstance(
|
||||
entry["FusionConfiguration"], dict
|
||||
):
|
||||
if "pose" in entry["FusionConfiguration"]:
|
||||
poses[str(serial)] = parse_pose(
|
||||
entry["FusionConfiguration"]["pose"], context=context
|
||||
)
|
||||
# Then check flat
|
||||
elif "pose" in entry:
|
||||
poses[str(serial)] = parse_pose(entry["pose"], context=context)
|
||||
|
||||
if not poses:
|
||||
raise click.UsageError(
|
||||
f"No parsable poses found in {path}.\n"
|
||||
"Expected formats:\n"
|
||||
' 1) Flat: {"serial": {"pose": "..."}}\n'
|
||||
' 2) Nested: {"serial": {"FusionConfiguration": {"pose": "..."}}}'
|
||||
)
|
||||
return poses
|
||||
|
||||
|
||||
def serialize_pose(pose: np.ndarray) -> str:
|
||||
@@ -183,13 +251,13 @@ def add_camera_trace(
|
||||
"--pose-a-json",
|
||||
type=click.Path(exists=True),
|
||||
required=True,
|
||||
help="Pose set A (serial -> {pose: '...'})",
|
||||
help="Pose set A. Supports flat {'serial': {'pose': '...'}} or nested FusionConfiguration format.",
|
||||
)
|
||||
@click.option(
|
||||
"--pose-b-json",
|
||||
type=click.Path(exists=True),
|
||||
required=True,
|
||||
help="Pose set B (serial -> {pose: '...'} or inside_network format)",
|
||||
help="Pose set B. Supports flat {'serial': {'pose': '...'}} or nested FusionConfiguration format.",
|
||||
)
|
||||
@click.option(
|
||||
"--report-json",
|
||||
@@ -210,6 +278,7 @@ def add_camera_trace(
|
||||
@click.option(
|
||||
"--show-plot",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Show the plot interactively",
|
||||
)
|
||||
@click.option(
|
||||
@@ -237,25 +306,13 @@ def main(
|
||||
"""
|
||||
Compare two camera pose sets from different world frames using rigid alignment.
|
||||
Both are treated as T_world_from_cam.
|
||||
|
||||
Supports symmetric, heuristic input parsing for both A and B:
|
||||
1) flat: {"serial": {"pose": "..."}}
|
||||
2) nested Fusion: {"serial": {"FusionConfiguration": {"pose": "..."}}}
|
||||
"""
|
||||
with open(pose_a_json, "r") as f:
|
||||
data_a = json.load(f)
|
||||
|
||||
with open(pose_b_json, "r") as f:
|
||||
data_b = json.load(f)
|
||||
|
||||
poses_a: dict[str, np.ndarray] = {}
|
||||
for serial, data in data_a.items():
|
||||
if "pose" in data:
|
||||
poses_a[str(serial)] = parse_pose(data["pose"])
|
||||
|
||||
poses_b: dict[str, np.ndarray] = {}
|
||||
for serial, data in data_b.items():
|
||||
# Support both standard and inside_network.json nested format
|
||||
if "FusionConfiguration" in data and "pose" in data["FusionConfiguration"]:
|
||||
poses_b[str(serial)] = parse_pose(data["FusionConfiguration"]["pose"])
|
||||
elif "pose" in data:
|
||||
poses_b[str(serial)] = parse_pose(data["pose"])
|
||||
poses_a = load_poses_from_json(pose_a_json)
|
||||
poses_b = load_poses_from_json(pose_b_json)
|
||||
|
||||
shared_serials = sorted(list(set(poses_a.keys()) & set(poses_b.keys())))
|
||||
if len(shared_serials) < 3:
|
||||
@@ -347,23 +404,25 @@ def main(
|
||||
if plot_output or show_plot:
|
||||
fig = go.Figure()
|
||||
|
||||
for axis, color in zip(
|
||||
[np.eye(3)[:, 0], np.eye(3)[:, 1], np.eye(3)[:, 2]],
|
||||
["red", "green", "blue"],
|
||||
):
|
||||
fig.add_trace(
|
||||
go.Scatter3d(
|
||||
x=[0, axis[0] * axis_scale * 2],
|
||||
y=[0, axis[1] * axis_scale * 2],
|
||||
z=[0, axis[2] * axis_scale * 2],
|
||||
mode="lines",
|
||||
line=dict(color=color, width=4),
|
||||
name=f"World {'XYZ'[np.argmax(axis)]}",
|
||||
showlegend=True,
|
||||
show_axis: Final[bool] = True
|
||||
if show_axis:
|
||||
for axis, color in zip(
|
||||
[np.eye(3)[:, 0], np.eye(3)[:, 1], np.eye(3)[:, 2]],
|
||||
["red", "green", "blue"],
|
||||
):
|
||||
fig.add_trace(
|
||||
go.Scatter3d(
|
||||
x=[0, axis[0] * axis_scale],
|
||||
y=[0, axis[1] * axis_scale],
|
||||
z=[0, axis[2] * axis_scale],
|
||||
mode="lines",
|
||||
line=dict(color=color, width=4),
|
||||
name=f"World {'XYZ'[np.argmax(axis)]}",
|
||||
showlegend=True,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
show_ground = False
|
||||
show_ground: Final[bool] = False
|
||||
if show_ground:
|
||||
ground_size = 5.0
|
||||
half_size = ground_size / 2.0
|
||||
@@ -440,4 +499,4 @@ def main(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main() # pylint: disable=no-value-for-parameter
|
||||
|
||||
Reference in New Issue
Block a user