diff --git a/py_workspace/aruco/ground_plane.py b/py_workspace/aruco/ground_plane.py index 7dfdddb..38df630 100644 --- a/py_workspace/aruco/ground_plane.py +++ b/py_workspace/aruco/ground_plane.py @@ -4,6 +4,7 @@ from jaxtyping import Float from typing import TYPE_CHECKING import open3d as o3d from dataclasses import dataclass, field +import plotly.graph_objects as go if TYPE_CHECKING: Vec3 = Float[np.ndarray, "3"] @@ -417,3 +418,153 @@ def refine_ground_from_depth( ) return new_extrinsics, metrics + + +def create_ground_diagnostic_plot( + metrics: GroundPlaneMetrics, + camera_data: Dict[str, Dict[str, Any]], + extrinsics_before: Dict[str, Mat44], + extrinsics_after: Dict[str, Mat44], +) -> go.Figure: + """ + Create a Plotly diagnostic visualization for ground plane refinement. + """ + fig = go.Figure() + + # 1. Add World Origin Axes + axis_scale = 0.5 + for axis, color, name in zip( + [np.array([1, 0, 0]), np.array([0, 1, 0]), np.array([0, 0, 1])], + ["red", "green", "blue"], + ["X", "Y", "Z"], + ): + fig.add_trace( + go.Scatter3d( + x=[0, axis[0] * axis_scale], + y=[0, axis[1] * axis_scale], + z=[0, axis[2] * axis_scale], + mode="lines", + line=dict(color=color, width=4), + name=f"World {name}", + showlegend=True, + ) + ) + + # 2. Add Consensus Plane (if available) + if metrics.consensus_plane: + plane = metrics.consensus_plane + # Create a surface for the plane + size = 5.0 + x = np.linspace(-size, size, 2) + z = np.linspace(-size, size, 2) + xx, zz = np.meshgrid(x, z) + # n.p + d = 0 => n0*x + n1*y + n2*z + d = 0 => y = -(n0*x + n2*z + d) / n1 + if abs(plane.normal[1]) > 1e-6: + yy = ( + -(plane.normal[0] * xx + plane.normal[2] * zz + plane.d) + / plane.normal[1] + ) + fig.add_trace( + go.Surface( + x=xx, + y=yy, + z=zz, + showscale=False, + opacity=0.3, + colorscale=[[0, "lightgray"], [1, "lightgray"]], + name="Consensus Plane", + ) + ) + + # 3. Add Floor Points per camera + for serial, data in camera_data.items(): + if serial not in extrinsics_before: + continue + + depth_map = data.get("depth") + K = data.get("K") + if depth_map is None or K is None: + continue + + # Use a larger stride for visualization to keep it responsive + viz_stride = 8 + points_cam = unproject_depth_to_points(depth_map, K, stride=viz_stride) + + if len(points_cam) == 0: + continue + + # Transform to world frame (before) + T_before = extrinsics_before[serial] + R_b = T_before[:3, :3] + t_b = T_before[:3, 3] + points_world = (points_cam @ R_b.T) + t_b + + fig.add_trace( + go.Scatter3d( + x=points_world[:, 0], + y=points_world[:, 1], + z=points_world[:, 2], + mode="markers", + marker=dict(size=2, opacity=0.5), + name=f"Points {serial}", + ) + ) + + # 4. Add Camera Positions Before/After + for serial in extrinsics_before: + T_b = extrinsics_before[serial] + pos_b = T_b[:3, 3] + + fig.add_trace( + go.Scatter3d( + x=[pos_b[0]], + y=[pos_b[1]], + z=[pos_b[2]], + mode="markers+text", + marker=dict(size=5, color="red"), + text=[f"{serial} (before)"], + name=f"Cam {serial} (before)", + ) + ) + + if serial in extrinsics_after: + T_a = extrinsics_after[serial] + pos_a = T_a[:3, 3] + fig.add_trace( + go.Scatter3d( + x=[pos_a[0]], + y=[pos_a[1]], + z=[pos_a[2]], + mode="markers+text", + marker=dict(size=5, color="green"), + text=[f"{serial} (after)"], + name=f"Cam {serial} (after)", + ) + ) + + fig.update_layout( + title="Ground Plane Refinement Diagnostics", + scene=dict( + xaxis_title="X", + yaxis_title="Y", + zaxis_title="Z", + aspectmode="data", + camera=dict( + up=dict(x=0, y=-1, z=0), # Y-down convention for visualization + eye=dict(x=1.5, y=-1.5, z=1.5), + ), + ), + margin=dict(l=0, r=0, b=0, t=40), + ) + + return fig + + +def save_diagnostic_plot(fig: go.Figure, path: str) -> None: + """ + Save the diagnostic plot to an HTML file. + """ + import os + + os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) + fig.write_html(path) diff --git a/py_workspace/tests/test_ground_plane.py b/py_workspace/tests/test_ground_plane.py index caaf9a5..1e96eea 100644 --- a/py_workspace/tests/test_ground_plane.py +++ b/py_workspace/tests/test_ground_plane.py @@ -6,6 +6,8 @@ from aruco.ground_plane import ( compute_consensus_plane, compute_floor_correction, refine_ground_from_depth, + create_ground_diagnostic_plot, + save_diagnostic_plot, FloorPlane, FloorCorrection, GroundPlaneConfig, @@ -543,3 +545,46 @@ def test_refine_ground_from_depth_partial_success(): # Cam 1 extrinsics should be changed assert not np.array_equal(new_extrinsics["cam1"], extrinsics["cam1"]) + + +def test_create_ground_diagnostic_plot_smoke(): + # Create minimal metrics and data + metrics = GroundPlaneMetrics( + success=True, + consensus_plane=FloorPlane(normal=np.array([0, 1, 0]), d=1.0), + ) + camera_data = { + "cam1": { + "depth": np.full((10, 10), 2.0, dtype=np.float32), + "K": np.eye(3), + } + } + extrinsics_before = {"cam1": np.eye(4)} + extrinsics_after = {"cam1": np.eye(4)} + extrinsics_after["cam1"][1, 3] = 1.0 + + import plotly.graph_objects as go + + fig = create_ground_diagnostic_plot( + metrics, camera_data, extrinsics_before, extrinsics_after + ) + + assert isinstance(fig, go.Figure) + # Check for some expected traces + trace_names = [t.name for t in fig.data] + assert any("World X" in name for name in trace_names if name) + assert any("Consensus Plane" in name for name in trace_names if name) + assert any("Points cam1" in name for name in trace_names if name) + assert any("Cam cam1 (before)" in name for name in trace_names if name) + assert any("Cam cam1 (after)" in name for name in trace_names if name) + + +def test_save_diagnostic_plot_smoke(tmp_path): + import plotly.graph_objects as go + + fig = go.Figure() + plot_path = tmp_path / "diag.html" + save_diagnostic_plot(fig, str(plot_path)) + + assert plot_path.exists() + assert plot_path.stat().st_size > 0