feat: add Plotly diagnostic visualization for ground plane refinement
This commit is contained in:
@@ -4,6 +4,7 @@ from jaxtyping import Float
|
||||
from typing import TYPE_CHECKING
|
||||
import open3d as o3d
|
||||
from dataclasses import dataclass, field
|
||||
import plotly.graph_objects as go
|
||||
|
||||
if TYPE_CHECKING:
|
||||
Vec3 = Float[np.ndarray, "3"]
|
||||
@@ -417,3 +418,153 @@ def refine_ground_from_depth(
|
||||
)
|
||||
|
||||
return new_extrinsics, metrics
|
||||
|
||||
|
||||
def create_ground_diagnostic_plot(
|
||||
metrics: GroundPlaneMetrics,
|
||||
camera_data: Dict[str, Dict[str, Any]],
|
||||
extrinsics_before: Dict[str, Mat44],
|
||||
extrinsics_after: Dict[str, Mat44],
|
||||
) -> go.Figure:
|
||||
"""
|
||||
Create a Plotly diagnostic visualization for ground plane refinement.
|
||||
"""
|
||||
fig = go.Figure()
|
||||
|
||||
# 1. Add World Origin Axes
|
||||
axis_scale = 0.5
|
||||
for axis, color, name in zip(
|
||||
[np.array([1, 0, 0]), np.array([0, 1, 0]), np.array([0, 0, 1])],
|
||||
["red", "green", "blue"],
|
||||
["X", "Y", "Z"],
|
||||
):
|
||||
fig.add_trace(
|
||||
go.Scatter3d(
|
||||
x=[0, axis[0] * axis_scale],
|
||||
y=[0, axis[1] * axis_scale],
|
||||
z=[0, axis[2] * axis_scale],
|
||||
mode="lines",
|
||||
line=dict(color=color, width=4),
|
||||
name=f"World {name}",
|
||||
showlegend=True,
|
||||
)
|
||||
)
|
||||
|
||||
# 2. Add Consensus Plane (if available)
|
||||
if metrics.consensus_plane:
|
||||
plane = metrics.consensus_plane
|
||||
# Create a surface for the plane
|
||||
size = 5.0
|
||||
x = np.linspace(-size, size, 2)
|
||||
z = np.linspace(-size, size, 2)
|
||||
xx, zz = np.meshgrid(x, z)
|
||||
# n.p + d = 0 => n0*x + n1*y + n2*z + d = 0 => y = -(n0*x + n2*z + d) / n1
|
||||
if abs(plane.normal[1]) > 1e-6:
|
||||
yy = (
|
||||
-(plane.normal[0] * xx + plane.normal[2] * zz + plane.d)
|
||||
/ plane.normal[1]
|
||||
)
|
||||
fig.add_trace(
|
||||
go.Surface(
|
||||
x=xx,
|
||||
y=yy,
|
||||
z=zz,
|
||||
showscale=False,
|
||||
opacity=0.3,
|
||||
colorscale=[[0, "lightgray"], [1, "lightgray"]],
|
||||
name="Consensus Plane",
|
||||
)
|
||||
)
|
||||
|
||||
# 3. Add Floor Points per camera
|
||||
for serial, data in camera_data.items():
|
||||
if serial not in extrinsics_before:
|
||||
continue
|
||||
|
||||
depth_map = data.get("depth")
|
||||
K = data.get("K")
|
||||
if depth_map is None or K is None:
|
||||
continue
|
||||
|
||||
# Use a larger stride for visualization to keep it responsive
|
||||
viz_stride = 8
|
||||
points_cam = unproject_depth_to_points(depth_map, K, stride=viz_stride)
|
||||
|
||||
if len(points_cam) == 0:
|
||||
continue
|
||||
|
||||
# Transform to world frame (before)
|
||||
T_before = extrinsics_before[serial]
|
||||
R_b = T_before[:3, :3]
|
||||
t_b = T_before[:3, 3]
|
||||
points_world = (points_cam @ R_b.T) + t_b
|
||||
|
||||
fig.add_trace(
|
||||
go.Scatter3d(
|
||||
x=points_world[:, 0],
|
||||
y=points_world[:, 1],
|
||||
z=points_world[:, 2],
|
||||
mode="markers",
|
||||
marker=dict(size=2, opacity=0.5),
|
||||
name=f"Points {serial}",
|
||||
)
|
||||
)
|
||||
|
||||
# 4. Add Camera Positions Before/After
|
||||
for serial in extrinsics_before:
|
||||
T_b = extrinsics_before[serial]
|
||||
pos_b = T_b[:3, 3]
|
||||
|
||||
fig.add_trace(
|
||||
go.Scatter3d(
|
||||
x=[pos_b[0]],
|
||||
y=[pos_b[1]],
|
||||
z=[pos_b[2]],
|
||||
mode="markers+text",
|
||||
marker=dict(size=5, color="red"),
|
||||
text=[f"{serial} (before)"],
|
||||
name=f"Cam {serial} (before)",
|
||||
)
|
||||
)
|
||||
|
||||
if serial in extrinsics_after:
|
||||
T_a = extrinsics_after[serial]
|
||||
pos_a = T_a[:3, 3]
|
||||
fig.add_trace(
|
||||
go.Scatter3d(
|
||||
x=[pos_a[0]],
|
||||
y=[pos_a[1]],
|
||||
z=[pos_a[2]],
|
||||
mode="markers+text",
|
||||
marker=dict(size=5, color="green"),
|
||||
text=[f"{serial} (after)"],
|
||||
name=f"Cam {serial} (after)",
|
||||
)
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
title="Ground Plane Refinement Diagnostics",
|
||||
scene=dict(
|
||||
xaxis_title="X",
|
||||
yaxis_title="Y",
|
||||
zaxis_title="Z",
|
||||
aspectmode="data",
|
||||
camera=dict(
|
||||
up=dict(x=0, y=-1, z=0), # Y-down convention for visualization
|
||||
eye=dict(x=1.5, y=-1.5, z=1.5),
|
||||
),
|
||||
),
|
||||
margin=dict(l=0, r=0, b=0, t=40),
|
||||
)
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def save_diagnostic_plot(fig: go.Figure, path: str) -> None:
|
||||
"""
|
||||
Save the diagnostic plot to an HTML file.
|
||||
"""
|
||||
import os
|
||||
|
||||
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
|
||||
fig.write_html(path)
|
||||
|
||||
@@ -6,6 +6,8 @@ from aruco.ground_plane import (
|
||||
compute_consensus_plane,
|
||||
compute_floor_correction,
|
||||
refine_ground_from_depth,
|
||||
create_ground_diagnostic_plot,
|
||||
save_diagnostic_plot,
|
||||
FloorPlane,
|
||||
FloorCorrection,
|
||||
GroundPlaneConfig,
|
||||
@@ -543,3 +545,46 @@ def test_refine_ground_from_depth_partial_success():
|
||||
|
||||
# Cam 1 extrinsics should be changed
|
||||
assert not np.array_equal(new_extrinsics["cam1"], extrinsics["cam1"])
|
||||
|
||||
|
||||
def test_create_ground_diagnostic_plot_smoke():
|
||||
# Create minimal metrics and data
|
||||
metrics = GroundPlaneMetrics(
|
||||
success=True,
|
||||
consensus_plane=FloorPlane(normal=np.array([0, 1, 0]), d=1.0),
|
||||
)
|
||||
camera_data = {
|
||||
"cam1": {
|
||||
"depth": np.full((10, 10), 2.0, dtype=np.float32),
|
||||
"K": np.eye(3),
|
||||
}
|
||||
}
|
||||
extrinsics_before = {"cam1": np.eye(4)}
|
||||
extrinsics_after = {"cam1": np.eye(4)}
|
||||
extrinsics_after["cam1"][1, 3] = 1.0
|
||||
|
||||
import plotly.graph_objects as go
|
||||
|
||||
fig = create_ground_diagnostic_plot(
|
||||
metrics, camera_data, extrinsics_before, extrinsics_after
|
||||
)
|
||||
|
||||
assert isinstance(fig, go.Figure)
|
||||
# Check for some expected traces
|
||||
trace_names = [t.name for t in fig.data]
|
||||
assert any("World X" in name for name in trace_names if name)
|
||||
assert any("Consensus Plane" in name for name in trace_names if name)
|
||||
assert any("Points cam1" in name for name in trace_names if name)
|
||||
assert any("Cam cam1 (before)" in name for name in trace_names if name)
|
||||
assert any("Cam cam1 (after)" in name for name in trace_names if name)
|
||||
|
||||
|
||||
def test_save_diagnostic_plot_smoke(tmp_path):
|
||||
import plotly.graph_objects as go
|
||||
|
||||
fig = go.Figure()
|
||||
plot_path = tmp_path / "diag.html"
|
||||
save_diagnostic_plot(fig, str(plot_path))
|
||||
|
||||
assert plot_path.exists()
|
||||
assert plot_path.stat().st_size > 0
|
||||
|
||||
Reference in New Issue
Block a user