Files
zed-playground/py_workspace/aruco/ground_plane.py
T

584 lines
17 KiB
Python

import numpy as np
from typing import Optional, Tuple, List, Dict, Any
from jaxtyping import Float
from typing import TYPE_CHECKING
import open3d as o3d
from dataclasses import dataclass, field
import plotly.graph_objects as go
if TYPE_CHECKING:
Vec3 = Float[np.ndarray, "3"]
Mat44 = Float[np.ndarray, "4 4"]
PointsNC = Float[np.ndarray, "N 3"]
else:
Vec3 = np.ndarray
Mat44 = np.ndarray
PointsNC = np.ndarray
@dataclass
class FloorPlane:
normal: Vec3
d: float
num_inliers: int = 0
@dataclass
class FloorCorrection:
transform: Mat44
valid: bool
reason: str = ""
@dataclass
class GroundPlaneConfig:
enabled: bool = True
target_y: float = 0.0
stride: int = 8
depth_min: float = 0.2
depth_max: float = 5.0
ransac_dist_thresh: float = 0.02
ransac_n: int = 3
ransac_iters: int = 1000
max_rotation_deg: float = 5.0
max_translation_m: float = 0.1
min_inliers: int = 500
min_inlier_ratio: float = 0.0
min_valid_cameras: int = 2
seed: Optional[int] = None
@dataclass
class GroundPlaneMetrics:
success: bool = False
correction_applied: bool = False
num_cameras_total: int = 0
num_cameras_valid: int = 0
# Per-camera corrections
camera_corrections: Dict[str, Mat44] = field(default_factory=dict)
skipped_cameras: List[str] = field(default_factory=list)
# Summary stats (optional, maybe average of corrections?)
rotation_deg: float = 0.0
translation_m: float = 0.0
camera_planes: Dict[str, FloorPlane] = field(default_factory=dict)
consensus_plane: Optional[FloorPlane] = None
message: str = ""
def unproject_depth_to_points(
depth_map: np.ndarray,
K: np.ndarray,
stride: int = 1,
depth_min: float = 0.1,
depth_max: float = 10.0,
) -> PointsNC:
"""
Unproject a depth map to a point cloud.
"""
h, w = depth_map.shape
fx = K[0, 0]
fy = K[1, 1]
cx = K[0, 2]
cy = K[1, 2]
# Create meshgrid of pixel coordinates
# Use stride to reduce number of points
u_coords = np.arange(0, w, stride)
v_coords = np.arange(0, h, stride)
u, v = np.meshgrid(u_coords, v_coords)
# Sample depth map
z = depth_map[0:h:stride, 0:w:stride]
# Filter by depth bounds
valid_mask = (z > depth_min) & (z < depth_max) & np.isfinite(z)
# Apply mask
z_valid = z[valid_mask]
u_valid = u[valid_mask]
v_valid = v[valid_mask]
# Unproject
x_valid = (u_valid - cx) * z_valid / fx
y_valid = (v_valid - cy) * z_valid / fy
# Stack into (N, 3) array
points = np.stack((x_valid, y_valid, z_valid), axis=-1)
return points.astype(np.float64)
def detect_floor_plane(
points: PointsNC,
distance_threshold: float = 0.02,
ransac_n: int = 3,
num_iterations: int = 1000,
seed: Optional[int] = None,
) -> Optional[FloorPlane]:
"""
Detect the floor plane from a point cloud using RANSAC.
Returns FloorPlane or None if detection fails.
"""
if points.shape[0] < ransac_n:
return None
# Convert to Open3D PointCloud
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points)
# Set seed for determinism if provided
if seed is not None:
o3d.utility.random.seed(seed)
# Segment plane
plane_model, inliers = pcd.segment_plane(
distance_threshold=distance_threshold,
ransac_n=ransac_n,
num_iterations=num_iterations,
)
# Check if we found enough inliers
if len(inliers) < ransac_n:
return None
# plane_model is [a, b, c, d]
a, b, c, d = plane_model
normal = np.array([a, b, c], dtype=np.float64)
# Normalize normal (Open3D usually returns normalized, but be safe)
norm = np.linalg.norm(normal)
if norm > 1e-6:
normal /= norm
d /= norm
return FloorPlane(normal=normal, d=d, num_inliers=len(inliers))
def compute_consensus_plane(
planes: List[FloorPlane],
weights: Optional[List[float]] = None,
) -> FloorPlane:
"""
Compute a consensus plane from multiple plane detections.
"""
if not planes:
raise ValueError("No planes provided for consensus.")
n_planes = len(planes)
if weights is None:
weights = [1.0] * n_planes
if len(weights) != n_planes:
raise ValueError(
f"Weights length {len(weights)} must match planes length {n_planes}"
)
# Use the first plane as reference for orientation
ref_normal = planes[0].normal
accum_normal = np.zeros(3, dtype=np.float64)
accum_d = 0.0
total_weight = 0.0
for i, plane in enumerate(planes):
w = weights[i]
normal = plane.normal
d = plane.d
# Check orientation against reference
if np.dot(normal, ref_normal) < 0:
# Flip normal and d to align with reference
normal = -normal
d = -d
accum_normal += normal * w
accum_d += d * w
total_weight += w
if total_weight <= 0:
raise ValueError("Total weight must be positive.")
avg_normal = accum_normal / total_weight
avg_d = accum_d / total_weight
# Re-normalize normal
norm = np.linalg.norm(avg_normal)
if norm > 1e-6:
avg_normal /= norm
# Scale d by 1/norm to maintain plane equation consistency
avg_d /= norm
else:
# Fallback (should be rare if inputs are valid)
avg_normal = np.array([0.0, 1.0, 0.0])
avg_d = 0.0
return FloorPlane(normal=avg_normal, d=float(avg_d))
from .alignment import rotation_align_vectors
def compute_floor_correction(
current_floor_plane: FloorPlane,
target_floor_y: float = 0.0,
max_rotation_deg: float = 5.0,
max_translation_m: float = 0.1,
) -> FloorCorrection:
"""
Compute the correction transform to align the current floor plane to the target floor height.
Constrains correction to pitch/roll and vertical translation only.
"""
current_normal = current_floor_plane.normal
current_d = current_floor_plane.d
# Target normal is always [0, 1, 0] (Y-up)
target_normal = np.array([0.0, 1.0, 0.0])
# 1. Compute rotation to align normals
try:
R_align = rotation_align_vectors(current_normal, target_normal)
except ValueError as e:
return FloorCorrection(
transform=np.eye(4), valid=False, reason=f"Rotation alignment failed: {e}"
)
# Check rotation magnitude
# Angle of rotation is acos((trace(R) - 1) / 2)
trace = np.trace(R_align)
# Clip to avoid numerical errors outside [-1, 1]
cos_angle = np.clip((trace - 1) / 2, -1.0, 1.0)
angle_rad = np.arccos(cos_angle)
angle_deg = np.rad2deg(angle_rad)
if angle_deg > max_rotation_deg:
return FloorCorrection(
transform=np.eye(4),
valid=False,
reason=f"Rotation {angle_deg:.1f} deg exceeds limit {max_rotation_deg:.1f} deg",
)
# 2. Compute translation
# We want to move points such that the floor is at y = target_floor_y
# Plane equation: n . p + d = 0
# Current floor at y = -current_d (if n=[0,1,0])
# We want new y = target_floor_y
# So shift = target_floor_y - (-current_d) = target_floor_y + current_d
t_y = target_floor_y + current_d
# Check translation magnitude
if abs(t_y) > max_translation_m:
return FloorCorrection(
transform=np.eye(4),
valid=False,
reason=f"Translation {t_y:.3f} m exceeds limit {max_translation_m:.3f} m",
)
# Construct T
T = np.eye(4)
T[:3, :3] = R_align
# Translation is applied in the rotated frame (aligned to target normal)
T[:3, 3] = target_normal * t_y
return FloorCorrection(transform=T.astype(np.float64), valid=True)
def refine_ground_from_depth(
camera_data: Dict[str, Dict[str, Any]],
extrinsics: Dict[str, Mat44],
config: GroundPlaneConfig = GroundPlaneConfig(),
) -> Tuple[Dict[str, Mat44], GroundPlaneMetrics]:
"""
Orchestrate ground plane refinement across multiple cameras.
Args:
camera_data: Dict mapping serial -> {'depth': np.ndarray, 'K': np.ndarray}
extrinsics: Dict mapping serial -> world_from_cam matrix (4x4)
config: Configuration parameters
Returns:
Tuple of (new_extrinsics, metrics)
"""
metrics = GroundPlaneMetrics()
metrics.num_cameras_total = len(camera_data)
if not config.enabled:
metrics.message = "Ground plane refinement disabled in config"
return extrinsics, metrics
valid_planes: List[FloorPlane] = []
valid_serials: List[str] = []
# 1. Detect planes in each camera
for serial, data in camera_data.items():
if serial not in extrinsics:
continue
depth_map = data.get("depth")
K = data.get("K")
if depth_map is None or K is None:
continue
# Unproject to camera frame
points_cam = unproject_depth_to_points(
depth_map,
K,
stride=config.stride,
depth_min=config.depth_min,
depth_max=config.depth_max,
)
if len(points_cam) < config.min_inliers:
continue
# Transform to world frame
T_world_cam = extrinsics[serial]
# points_cam is (N, 3)
# Apply rotation and translation
R = T_world_cam[:3, :3]
t = T_world_cam[:3, 3]
points_world = (points_cam @ R.T) + t
# Detect plane
plane = detect_floor_plane(
points_world,
distance_threshold=config.ransac_dist_thresh,
ransac_n=config.ransac_n,
num_iterations=config.ransac_iters,
seed=config.seed,
)
if plane is not None:
# Check inlier count
if plane.num_inliers < config.min_inliers:
continue
# Check inlier ratio if configured
if config.min_inlier_ratio > 0:
ratio = plane.num_inliers / len(points_world)
if ratio < config.min_inlier_ratio:
continue
metrics.camera_planes[serial] = plane
valid_planes.append(plane)
valid_serials.append(serial)
metrics.num_cameras_valid = len(valid_planes)
# 2. Check minimum requirements
if len(valid_planes) < config.min_valid_cameras:
metrics.message = f"Found {len(valid_planes)} valid planes, required {config.min_valid_cameras}"
return extrinsics, metrics
# 3. Compute consensus (for reporting/metrics only)
try:
consensus = compute_consensus_plane(valid_planes)
metrics.consensus_plane = consensus
except ValueError as e:
metrics.message = f"Consensus computation failed: {e}"
return extrinsics, metrics
# 4. Compute and apply per-camera correction
new_extrinsics = extrinsics.copy()
# Track max rotation/translation for summary metrics
max_rot = 0.0
max_trans = 0.0
corrections_count = 0
for serial, T_old in extrinsics.items():
# If we didn't find a plane for this camera, skip it
if serial not in metrics.camera_planes:
metrics.skipped_cameras.append(serial)
continue
plane = metrics.camera_planes[serial]
correction = compute_floor_correction(
plane,
target_floor_y=config.target_y,
max_rotation_deg=config.max_rotation_deg,
max_translation_m=config.max_translation_m,
)
if not correction.valid:
metrics.skipped_cameras.append(serial)
continue
T_corr = correction.transform
metrics.camera_corrections[serial] = T_corr
# Apply correction: T_new = T_corr @ T_old
new_extrinsics[serial] = T_corr @ T_old
# Update summary metrics
trace = np.trace(T_corr[:3, :3])
cos_angle = np.clip((trace - 1) / 2, -1.0, 1.0)
rot_deg = float(np.rad2deg(np.arccos(cos_angle)))
trans_m = float(np.linalg.norm(T_corr[:3, 3]))
max_rot = max(max_rot, rot_deg)
max_trans = max(max_trans, trans_m)
corrections_count += 1
metrics.rotation_deg = max_rot
metrics.translation_m = max_trans
metrics.success = True
metrics.correction_applied = corrections_count > 0
metrics.message = (
f"Corrected {corrections_count} cameras, skipped {len(metrics.skipped_cameras)}"
)
return new_extrinsics, metrics
def create_ground_diagnostic_plot(
metrics: GroundPlaneMetrics,
camera_data: Dict[str, Dict[str, Any]],
extrinsics_before: Dict[str, Mat44],
extrinsics_after: Dict[str, Mat44],
) -> go.Figure:
"""
Create a Plotly diagnostic visualization for ground plane refinement.
"""
fig = go.Figure()
# 1. Add World Origin Axes
axis_scale = 0.5
for axis, color, name in zip(
[np.array([1, 0, 0]), np.array([0, 1, 0]), np.array([0, 0, 1])],
["red", "green", "blue"],
["X", "Y", "Z"],
):
fig.add_trace(
go.Scatter3d(
x=[0, axis[0] * axis_scale],
y=[0, axis[1] * axis_scale],
z=[0, axis[2] * axis_scale],
mode="lines",
line=dict(color=color, width=4),
name=f"World {name}",
showlegend=True,
)
)
# 2. Add Consensus Plane (if available)
if metrics.consensus_plane:
plane = metrics.consensus_plane
# Create a surface for the plane
size = 5.0
x = np.linspace(-size, size, 2)
z = np.linspace(-size, size, 2)
xx, zz = np.meshgrid(x, z)
# n.p + d = 0 => n0*x + n1*y + n2*z + d = 0 => y = -(n0*x + n2*z + d) / n1
if abs(plane.normal[1]) > 1e-6:
yy = (
-(plane.normal[0] * xx + plane.normal[2] * zz + plane.d)
/ plane.normal[1]
)
fig.add_trace(
go.Surface(
x=xx,
y=yy,
z=zz,
showscale=False,
opacity=0.3,
colorscale=[[0, "lightgray"], [1, "lightgray"]],
name="Consensus Plane",
)
)
# 3. Add Floor Points per camera
for serial, data in camera_data.items():
if serial not in extrinsics_before:
continue
depth_map = data.get("depth")
K = data.get("K")
if depth_map is None or K is None:
continue
# Use a larger stride for visualization to keep it responsive
viz_stride = 8
points_cam = unproject_depth_to_points(depth_map, K, stride=viz_stride)
if len(points_cam) == 0:
continue
# Transform to world frame (before)
T_before = extrinsics_before[serial]
R_b = T_before[:3, :3]
t_b = T_before[:3, 3]
points_world = (points_cam @ R_b.T) + t_b
fig.add_trace(
go.Scatter3d(
x=points_world[:, 0],
y=points_world[:, 1],
z=points_world[:, 2],
mode="markers",
marker=dict(size=2, opacity=0.5),
name=f"Points {serial}",
)
)
# 4. Add Camera Positions Before/After
for serial in extrinsics_before:
T_b = extrinsics_before[serial]
pos_b = T_b[:3, 3]
fig.add_trace(
go.Scatter3d(
x=[pos_b[0]],
y=[pos_b[1]],
z=[pos_b[2]],
mode="markers+text",
marker=dict(size=5, color="red"),
text=[f"{serial} (before)"],
name=f"Cam {serial} (before)",
)
)
if serial in extrinsics_after:
T_a = extrinsics_after[serial]
pos_a = T_a[:3, 3]
fig.add_trace(
go.Scatter3d(
x=[pos_a[0]],
y=[pos_a[1]],
z=[pos_a[2]],
mode="markers+text",
marker=dict(size=5, color="green"),
text=[f"{serial} (after)"],
name=f"Cam {serial} (after)",
)
)
fig.update_layout(
title="Ground Plane Refinement Diagnostics",
scene=dict(
xaxis_title="X",
yaxis_title="Y",
zaxis_title="Z",
aspectmode="data",
camera=dict(
up=dict(x=0, y=-1, z=0), # Y-down convention for visualization
eye=dict(x=1.5, y=-1.5, z=1.5),
),
),
margin=dict(l=0, r=0, b=0, t=40),
)
return fig
def save_diagnostic_plot(fig: go.Figure, path: str) -> None:
"""
Save the diagnostic plot to an HTML file.
"""
import os
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
fig.write_html(path)