feat: add Plotly diagnostic visualization for ground plane refinement

This commit is contained in:
2026-02-09 07:41:44 +00:00
parent 248510f5bb
commit 0f7d7a9a63
2 changed files with 196 additions and 0 deletions
+151
View File
@@ -4,6 +4,7 @@ from jaxtyping import Float
from typing import TYPE_CHECKING
import open3d as o3d
from dataclasses import dataclass, field
import plotly.graph_objects as go
if TYPE_CHECKING:
Vec3 = Float[np.ndarray, "3"]
@@ -417,3 +418,153 @@ def refine_ground_from_depth(
)
return new_extrinsics, metrics
def create_ground_diagnostic_plot(
metrics: GroundPlaneMetrics,
camera_data: Dict[str, Dict[str, Any]],
extrinsics_before: Dict[str, Mat44],
extrinsics_after: Dict[str, Mat44],
) -> go.Figure:
"""
Create a Plotly diagnostic visualization for ground plane refinement.
"""
fig = go.Figure()
# 1. Add World Origin Axes
axis_scale = 0.5
for axis, color, name in zip(
[np.array([1, 0, 0]), np.array([0, 1, 0]), np.array([0, 0, 1])],
["red", "green", "blue"],
["X", "Y", "Z"],
):
fig.add_trace(
go.Scatter3d(
x=[0, axis[0] * axis_scale],
y=[0, axis[1] * axis_scale],
z=[0, axis[2] * axis_scale],
mode="lines",
line=dict(color=color, width=4),
name=f"World {name}",
showlegend=True,
)
)
# 2. Add Consensus Plane (if available)
if metrics.consensus_plane:
plane = metrics.consensus_plane
# Create a surface for the plane
size = 5.0
x = np.linspace(-size, size, 2)
z = np.linspace(-size, size, 2)
xx, zz = np.meshgrid(x, z)
# n.p + d = 0 => n0*x + n1*y + n2*z + d = 0 => y = -(n0*x + n2*z + d) / n1
if abs(plane.normal[1]) > 1e-6:
yy = (
-(plane.normal[0] * xx + plane.normal[2] * zz + plane.d)
/ plane.normal[1]
)
fig.add_trace(
go.Surface(
x=xx,
y=yy,
z=zz,
showscale=False,
opacity=0.3,
colorscale=[[0, "lightgray"], [1, "lightgray"]],
name="Consensus Plane",
)
)
# 3. Add Floor Points per camera
for serial, data in camera_data.items():
if serial not in extrinsics_before:
continue
depth_map = data.get("depth")
K = data.get("K")
if depth_map is None or K is None:
continue
# Use a larger stride for visualization to keep it responsive
viz_stride = 8
points_cam = unproject_depth_to_points(depth_map, K, stride=viz_stride)
if len(points_cam) == 0:
continue
# Transform to world frame (before)
T_before = extrinsics_before[serial]
R_b = T_before[:3, :3]
t_b = T_before[:3, 3]
points_world = (points_cam @ R_b.T) + t_b
fig.add_trace(
go.Scatter3d(
x=points_world[:, 0],
y=points_world[:, 1],
z=points_world[:, 2],
mode="markers",
marker=dict(size=2, opacity=0.5),
name=f"Points {serial}",
)
)
# 4. Add Camera Positions Before/After
for serial in extrinsics_before:
T_b = extrinsics_before[serial]
pos_b = T_b[:3, 3]
fig.add_trace(
go.Scatter3d(
x=[pos_b[0]],
y=[pos_b[1]],
z=[pos_b[2]],
mode="markers+text",
marker=dict(size=5, color="red"),
text=[f"{serial} (before)"],
name=f"Cam {serial} (before)",
)
)
if serial in extrinsics_after:
T_a = extrinsics_after[serial]
pos_a = T_a[:3, 3]
fig.add_trace(
go.Scatter3d(
x=[pos_a[0]],
y=[pos_a[1]],
z=[pos_a[2]],
mode="markers+text",
marker=dict(size=5, color="green"),
text=[f"{serial} (after)"],
name=f"Cam {serial} (after)",
)
)
fig.update_layout(
title="Ground Plane Refinement Diagnostics",
scene=dict(
xaxis_title="X",
yaxis_title="Y",
zaxis_title="Z",
aspectmode="data",
camera=dict(
up=dict(x=0, y=-1, z=0), # Y-down convention for visualization
eye=dict(x=1.5, y=-1.5, z=1.5),
),
),
margin=dict(l=0, r=0, b=0, t=40),
)
return fig
def save_diagnostic_plot(fig: go.Figure, path: str) -> None:
"""
Save the diagnostic plot to an HTML file.
"""
import os
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
fig.write_html(path)
+45
View File
@@ -6,6 +6,8 @@ from aruco.ground_plane import (
compute_consensus_plane,
compute_floor_correction,
refine_ground_from_depth,
create_ground_diagnostic_plot,
save_diagnostic_plot,
FloorPlane,
FloorCorrection,
GroundPlaneConfig,
@@ -543,3 +545,46 @@ def test_refine_ground_from_depth_partial_success():
# Cam 1 extrinsics should be changed
assert not np.array_equal(new_extrinsics["cam1"], extrinsics["cam1"])
def test_create_ground_diagnostic_plot_smoke():
# Create minimal metrics and data
metrics = GroundPlaneMetrics(
success=True,
consensus_plane=FloorPlane(normal=np.array([0, 1, 0]), d=1.0),
)
camera_data = {
"cam1": {
"depth": np.full((10, 10), 2.0, dtype=np.float32),
"K": np.eye(3),
}
}
extrinsics_before = {"cam1": np.eye(4)}
extrinsics_after = {"cam1": np.eye(4)}
extrinsics_after["cam1"][1, 3] = 1.0
import plotly.graph_objects as go
fig = create_ground_diagnostic_plot(
metrics, camera_data, extrinsics_before, extrinsics_after
)
assert isinstance(fig, go.Figure)
# Check for some expected traces
trace_names = [t.name for t in fig.data]
assert any("World X" in name for name in trace_names if name)
assert any("Consensus Plane" in name for name in trace_names if name)
assert any("Points cam1" in name for name in trace_names if name)
assert any("Cam cam1 (before)" in name for name in trace_names if name)
assert any("Cam cam1 (after)" in name for name in trace_names if name)
def test_save_diagnostic_plot_smoke(tmp_path):
import plotly.graph_objects as go
fig = go.Figure()
plot_path = tmp_path / "diag.html"
save_diagnostic_plot(fig, str(plot_path))
assert plot_path.exists()
assert plot_path.stat().st_size > 0