Files
zed-playground/py_workspace/aruco/ground_plane.py
T

691 lines
21 KiB
Python

import numpy as np
from typing import Optional, Tuple, List, Dict, Any
from jaxtyping import Float
from typing import TYPE_CHECKING
import open3d as o3d
from dataclasses import dataclass, field
import plotly.graph_objects as go
if TYPE_CHECKING:
Vec3 = Float[np.ndarray, "3"]
Mat44 = Float[np.ndarray, "4 4"]
PointsNC = Float[np.ndarray, "N 3"]
else:
Vec3 = np.ndarray
Mat44 = np.ndarray
PointsNC = np.ndarray
@dataclass
class FloorPlane:
normal: Vec3
d: float
num_inliers: int = 0
@dataclass
class FloorCorrection:
transform: Mat44
valid: bool
reason: str = ""
@dataclass
class GroundPlaneConfig:
enabled: bool = True
target_y: float = 0.0
stride: int = 8
depth_min: float = 0.2
depth_max: float = 5.0
ransac_dist_thresh: float = 0.02
ransac_n: int = 3
ransac_iters: int = 1000
max_rotation_deg: float = 5.0
max_translation_m: float = 0.1
min_inliers: int = 500
min_inlier_ratio: float = 0.15
min_valid_cameras: int = 2
normal_vertical_thresh: float = 0.9
max_consensus_deviation_deg: float = 10.0
max_consensus_deviation_m: float = 0.5
seed: Optional[int] = None
@dataclass
class GroundPlaneMetrics:
success: bool = False
correction_applied: bool = False
num_cameras_total: int = 0
num_cameras_valid: int = 0
# Per-camera corrections
camera_corrections: Dict[str, Mat44] = field(default_factory=dict)
skipped_cameras: List[str] = field(default_factory=list)
# Summary stats (optional, maybe average of corrections?)
rotation_deg: float = 0.0
translation_m: float = 0.0
camera_planes: Dict[str, FloorPlane] = field(default_factory=dict)
consensus_plane: Optional[FloorPlane] = None
message: str = ""
def unproject_depth_to_points(
depth_map: np.ndarray,
K: np.ndarray,
stride: int = 1,
depth_min: float = 0.1,
depth_max: float = 10.0,
) -> PointsNC:
"""
Unproject a depth map to a point cloud.
"""
h, w = depth_map.shape
fx = K[0, 0]
fy = K[1, 1]
cx = K[0, 2]
cy = K[1, 2]
# Create meshgrid of pixel coordinates
# Use stride to reduce number of points
u_coords = np.arange(0, w, stride)
v_coords = np.arange(0, h, stride)
u, v = np.meshgrid(u_coords, v_coords)
# Sample depth map
z = depth_map[0:h:stride, 0:w:stride]
# Filter by depth bounds
valid_mask = (z > depth_min) & (z < depth_max) & np.isfinite(z)
# Apply mask
z_valid = z[valid_mask]
u_valid = u[valid_mask]
v_valid = v[valid_mask]
# Unproject
x_valid = (u_valid - cx) * z_valid / fx
y_valid = (v_valid - cy) * z_valid / fy
# Stack into (N, 3) array
points = np.stack((x_valid, y_valid, z_valid), axis=-1)
return points.astype(np.float64)
def detect_floor_plane(
points: PointsNC,
distance_threshold: float = 0.02,
ransac_n: int = 3,
num_iterations: int = 1000,
seed: Optional[int] = None,
) -> Optional[FloorPlane]:
"""
Detect the floor plane from a point cloud using RANSAC.
Returns FloorPlane or None if detection fails.
"""
if points.shape[0] < ransac_n:
return None
# Convert to Open3D PointCloud
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points)
# Set seed for determinism if provided
if seed is not None:
o3d.utility.random.seed(seed)
# Segment plane
plane_model, inliers = pcd.segment_plane(
distance_threshold=distance_threshold,
ransac_n=ransac_n,
num_iterations=num_iterations,
)
# Check if we found enough inliers
if len(inliers) < ransac_n:
return None
# plane_model is [a, b, c, d]
a, b, c, d = plane_model
normal = np.array([a, b, c], dtype=np.float64)
# Normalize normal (Open3D usually returns normalized, but be safe)
norm = np.linalg.norm(normal)
if norm > 1e-6:
normal /= norm
d /= norm
return FloorPlane(normal=normal, d=d, num_inliers=len(inliers))
def compute_consensus_plane(
planes: List[FloorPlane],
weights: Optional[List[float]] = None,
) -> FloorPlane:
"""
Compute a consensus plane from multiple plane detections.
Uses a robust median-like approach to reject outliers.
"""
if not planes:
raise ValueError("No planes provided for consensus.")
n_planes = len(planes)
if weights is None:
weights = [1.0] * n_planes
if len(weights) != n_planes:
raise ValueError(
f"Weights length {len(weights)} must match planes length {n_planes}"
)
# 1. Align all normals to be in the upper hemisphere (y > 0)
# This simplifies averaging
aligned_planes = []
for p in planes:
normal = p.normal.copy()
d = p.d
if normal[1] < 0:
normal = -normal
d = -d
aligned_planes.append(FloorPlane(normal=normal, d=d, num_inliers=p.num_inliers))
# 2. Compute median normal and d to be robust against outliers
normals = np.array([p.normal for p in aligned_planes])
ds = np.array([p.d for p in aligned_planes])
# Median of each component for normal (approximate robust mean)
median_normal = np.median(normals, axis=0)
norm = np.linalg.norm(median_normal)
if norm > 1e-6:
median_normal /= norm
else:
median_normal = np.array([0.0, 1.0, 0.0])
median_d = float(np.median(ds))
# 3. Filter outliers based on deviation from median
# Angle deviation
valid_indices = []
for i, p in enumerate(aligned_planes):
# Angle between normal and median normal
dot = np.clip(np.dot(p.normal, median_normal), -1.0, 1.0)
angle_deg = np.rad2deg(np.arccos(dot))
# Distance deviation
dist_diff = abs(p.d - median_d)
# Thresholds for outlier rejection (hardcoded for now, could be config)
if angle_deg < 15.0 and dist_diff < 0.5:
valid_indices.append(i)
if not valid_indices:
# Fallback to all if everything is rejected (should be rare)
valid_indices = list(range(n_planes))
# 4. Weighted average of valid planes
accum_normal = np.zeros(3, dtype=np.float64)
accum_d = 0.0
total_weight = 0.0
for i in valid_indices:
w = weights[i]
p = aligned_planes[i]
accum_normal += p.normal * w
accum_d += p.d * w
total_weight += w
if total_weight <= 0:
# Should not happen given checks above
return FloorPlane(normal=median_normal, d=median_d)
avg_normal = accum_normal / total_weight
avg_d = accum_d / total_weight
# Re-normalize normal
norm = np.linalg.norm(avg_normal)
if norm > 1e-6:
avg_normal /= norm
avg_d /= norm
else:
avg_normal = np.array([0.0, 1.0, 0.0])
avg_d = 0.0
return FloorPlane(normal=avg_normal, d=float(avg_d))
from .alignment import rotation_align_vectors
def compute_floor_correction(
current_floor_plane: FloorPlane,
target_floor_y: float = 0.0,
max_rotation_deg: float = 5.0,
max_translation_m: float = 0.1,
target_plane: Optional[FloorPlane] = None,
) -> FloorCorrection:
"""
Compute the correction transform to align the current floor plane to the target floor height.
Constrains correction to pitch/roll and vertical translation only.
If target_plane is provided, aligns current plane to target_plane (relative correction).
Otherwise, aligns to absolute Y=target_floor_y (absolute correction).
"""
current_normal = current_floor_plane.normal
current_d = current_floor_plane.d
# Target normal is always [0, 1, 0] (Y-up)
target_normal = np.array([0.0, 1.0, 0.0])
if target_plane is not None:
# Use target_plane.normal as the target normal
align_target_normal = target_plane.normal
# Ensure it points roughly up
if align_target_normal[1] < 0:
align_target_normal = -align_target_normal
else:
align_target_normal = target_normal
# 1. Compute rotation to align normals
try:
R_align = rotation_align_vectors(current_normal, align_target_normal)
except ValueError as e:
return FloorCorrection(
transform=np.eye(4), valid=False, reason=f"Rotation alignment failed: {e}"
)
# Check rotation magnitude
# Angle of rotation is acos((trace(R) - 1) / 2)
trace = np.trace(R_align)
# Clip to avoid numerical errors outside [-1, 1]
cos_angle = np.clip((trace - 1) / 2, -1.0, 1.0)
angle_rad = np.arccos(cos_angle)
angle_deg = np.rad2deg(angle_rad)
if angle_deg > max_rotation_deg:
return FloorCorrection(
transform=np.eye(4),
valid=False,
reason=f"Rotation {angle_deg:.1f} deg exceeds limit {max_rotation_deg:.1f} deg",
)
# 2. Compute translation
if target_plane is not None:
# Relative correction: align d to target_plane.d
# Shift = current_d - target_plane.d (assuming normals aligned)
# We use absolute values of d to handle potential sign flips in plane detection
# But wait, d sign matters for plane side.
# If normals are aligned (which we ensured with R_align and align_target_normal),
# then d should be comparable directly.
# However, target_plane.d might be negative if normal was flipped.
# Let's use the d corresponding to align_target_normal.
target_d = target_plane.d
if np.dot(target_plane.normal, align_target_normal) < 0:
target_d = -target_d
# Current d needs to be relative to current normal?
# No, current_d is relative to current_normal.
# After rotation R_align, current_normal becomes align_target_normal.
# So current_d is preserved (distance to origin doesn't change with rotation around origin).
# So we just compare d values.
t_mag = current_d - target_d
trans_dir = align_target_normal
else:
# Absolute correction to target_y
# We want new y = target_floor_y
# So shift = target_floor_y + current_d
t_mag = target_floor_y + current_d
trans_dir = target_normal
# Check translation magnitude
if abs(t_mag) > max_translation_m:
return FloorCorrection(
transform=np.eye(4),
valid=False,
reason=f"Translation {t_mag:.3f} m exceeds limit {max_translation_m:.3f} m",
)
# Construct T
T = np.eye(4)
T[:3, :3] = R_align
# Translation is applied in the rotated frame (aligned to target normal)
T[:3, 3] = trans_dir * t_mag
return FloorCorrection(transform=T.astype(np.float64), valid=True)
def refine_ground_from_depth(
camera_data: Dict[str, Dict[str, Any]],
extrinsics: Dict[str, Mat44],
config: GroundPlaneConfig = GroundPlaneConfig(),
) -> Tuple[Dict[str, Mat44], GroundPlaneMetrics]:
"""
Orchestrate ground plane refinement across multiple cameras.
Args:
camera_data: Dict mapping serial -> {'depth': np.ndarray, 'K': np.ndarray}
extrinsics: Dict mapping serial -> world_from_cam matrix (4x4)
config: Configuration parameters
Returns:
Tuple of (new_extrinsics, metrics)
"""
metrics = GroundPlaneMetrics()
metrics.num_cameras_total = len(camera_data)
if not config.enabled:
metrics.message = "Ground plane refinement disabled in config"
return extrinsics, metrics
valid_planes: List[FloorPlane] = []
valid_serials: List[str] = []
# 1. Detect planes in each camera
for serial, data in camera_data.items():
if serial not in extrinsics:
continue
depth_map = data.get("depth")
K = data.get("K")
if depth_map is None or K is None:
continue
# Unproject to camera frame
points_cam = unproject_depth_to_points(
depth_map,
K,
stride=config.stride,
depth_min=config.depth_min,
depth_max=config.depth_max,
)
if len(points_cam) < config.min_inliers:
continue
# Transform to world frame
T_world_cam = extrinsics[serial]
# points_cam is (N, 3)
# Apply rotation and translation
R = T_world_cam[:3, :3]
t = T_world_cam[:3, 3]
points_world = (points_cam @ R.T) + t
# Detect plane
plane = detect_floor_plane(
points_world,
distance_threshold=config.ransac_dist_thresh,
ransac_n=config.ransac_n,
num_iterations=config.ransac_iters,
seed=config.seed,
)
if plane is not None:
# Check inlier count
if plane.num_inliers < config.min_inliers:
continue
# Check inlier ratio if configured
if config.min_inlier_ratio > 0:
ratio = plane.num_inliers / len(points_world)
if ratio < config.min_inlier_ratio:
continue
# Check normal orientation (must be roughly vertical)
# We expect floor normal to be roughly [0, 1, 0] or [0, -1, 0]
if abs(plane.normal[1]) < config.normal_vertical_thresh:
continue
metrics.camera_planes[serial] = plane
valid_planes.append(plane)
valid_serials.append(serial)
metrics.num_cameras_valid = len(valid_planes)
# 2. Check minimum requirements
if len(valid_planes) < config.min_valid_cameras:
metrics.message = f"Found {len(valid_planes)} valid planes, required {config.min_valid_cameras}"
return extrinsics, metrics
# 3. Compute consensus (for reporting/metrics only)
try:
consensus = compute_consensus_plane(valid_planes)
metrics.consensus_plane = consensus
except ValueError as e:
metrics.message = f"Consensus computation failed: {e}"
return extrinsics, metrics
# 4. Compute and apply per-camera correction
new_extrinsics = extrinsics.copy()
# Track max rotation/translation for summary metrics
max_rot = 0.0
max_trans = 0.0
corrections_count = 0
for serial, T_old in extrinsics.items():
# If we didn't find a plane for this camera, skip it
if serial not in metrics.camera_planes:
metrics.skipped_cameras.append(serial)
continue
plane = metrics.camera_planes[serial]
correction = compute_floor_correction(
plane,
target_floor_y=config.target_y,
max_rotation_deg=config.max_rotation_deg,
max_translation_m=config.max_translation_m,
target_plane=metrics.consensus_plane,
)
if not correction.valid:
metrics.skipped_cameras.append(serial)
continue
# Validate against consensus if available
if metrics.consensus_plane:
# Check if this camera's plane is too far from consensus
# This prevents a single bad camera from getting a huge correction
# even if it passed individual checks (e.g. it found a wall instead of floor)
# Angle check
dot = np.clip(
np.dot(plane.normal, metrics.consensus_plane.normal), -1.0, 1.0
)
# Handle flipped normals
if dot < 0:
dot = -dot
angle_deg = np.rad2deg(np.arccos(dot))
if angle_deg > config.max_consensus_deviation_deg:
metrics.skipped_cameras.append(serial)
continue
# Distance check (project consensus origin onto this plane)
# Consensus plane: n_c . p + d_c = 0
# This plane: n . p + d = 0
# Compare d values (assuming normals aligned)
d_diff = abs(abs(plane.d) - abs(metrics.consensus_plane.d))
if d_diff > config.max_consensus_deviation_m:
metrics.skipped_cameras.append(serial)
continue
T_corr = correction.transform
metrics.camera_corrections[serial] = T_corr
# Apply correction: T_new = T_corr @ T_old
new_extrinsics[serial] = T_corr @ T_old
# Update summary metrics
trace = np.trace(T_corr[:3, :3])
cos_angle = np.clip((trace - 1) / 2, -1.0, 1.0)
rot_deg = float(np.rad2deg(np.arccos(cos_angle)))
trans_m = float(np.linalg.norm(T_corr[:3, 3]))
max_rot = max(max_rot, rot_deg)
max_trans = max(max_trans, trans_m)
corrections_count += 1
metrics.rotation_deg = max_rot
metrics.translation_m = max_trans
metrics.success = True
metrics.correction_applied = corrections_count > 0
metrics.message = (
f"Corrected {corrections_count} cameras, skipped {len(metrics.skipped_cameras)}"
)
return new_extrinsics, metrics
def create_ground_diagnostic_plot(
metrics: GroundPlaneMetrics,
camera_data: Dict[str, Dict[str, Any]],
extrinsics_before: Dict[str, Mat44],
extrinsics_after: Dict[str, Mat44],
) -> go.Figure:
"""
Create a Plotly diagnostic visualization for ground plane refinement.
"""
fig = go.Figure()
# 1. Add World Origin Axes
axis_scale = 0.5
for axis, color, name in zip(
[np.array([1, 0, 0]), np.array([0, 1, 0]), np.array([0, 0, 1])],
["red", "green", "blue"],
["X", "Y", "Z"],
):
fig.add_trace(
go.Scatter3d(
x=[0, axis[0] * axis_scale],
y=[0, axis[1] * axis_scale],
z=[0, axis[2] * axis_scale],
mode="lines",
line=dict(color=color, width=4),
name=f"World {name}",
showlegend=True,
)
)
# 2. Add Consensus Plane (if available)
if metrics.consensus_plane:
plane = metrics.consensus_plane
# Create a surface for the plane
size = 5.0
x = np.linspace(-size, size, 2)
z = np.linspace(-size, size, 2)
xx, zz = np.meshgrid(x, z)
# n.p + d = 0 => n0*x + n1*y + n2*z + d = 0 => y = -(n0*x + n2*z + d) / n1
if abs(plane.normal[1]) > 1e-6:
yy = (
-(plane.normal[0] * xx + plane.normal[2] * zz + plane.d)
/ plane.normal[1]
)
fig.add_trace(
go.Surface(
x=xx,
y=yy,
z=zz,
showscale=False,
opacity=0.3,
colorscale=[[0, "lightgray"], [1, "lightgray"]],
name="Consensus Plane",
)
)
# 3. Add Floor Points per camera
for serial, data in camera_data.items():
if serial not in extrinsics_before:
continue
depth_map = data.get("depth")
K = data.get("K")
if depth_map is None or K is None:
continue
# Use a larger stride for visualization to keep it responsive
viz_stride = 8
points_cam = unproject_depth_to_points(depth_map, K, stride=viz_stride)
if len(points_cam) == 0:
continue
# Transform to world frame (before)
T_before = extrinsics_before[serial]
R_b = T_before[:3, :3]
t_b = T_before[:3, 3]
points_world = (points_cam @ R_b.T) + t_b
fig.add_trace(
go.Scatter3d(
x=points_world[:, 0],
y=points_world[:, 1],
z=points_world[:, 2],
mode="markers",
marker=dict(size=2, opacity=0.5),
name=f"Points {serial}",
)
)
# 4. Add Camera Positions Before/After
for serial in extrinsics_before:
T_b = extrinsics_before[serial]
pos_b = T_b[:3, 3]
fig.add_trace(
go.Scatter3d(
x=[pos_b[0]],
y=[pos_b[1]],
z=[pos_b[2]],
mode="markers+text",
marker=dict(size=5, color="red"),
text=[f"{serial} (before)"],
name=f"Cam {serial} (before)",
)
)
if serial in extrinsics_after:
T_a = extrinsics_after[serial]
pos_a = T_a[:3, 3]
fig.add_trace(
go.Scatter3d(
x=[pos_a[0]],
y=[pos_a[1]],
z=[pos_a[2]],
mode="markers+text",
marker=dict(size=5, color="green"),
text=[f"{serial} (after)"],
name=f"Cam {serial} (after)",
)
)
fig.update_layout(
title="Ground Plane Refinement Diagnostics",
scene=dict(
xaxis_title="X",
yaxis_title="Y",
zaxis_title="Z",
aspectmode="data",
camera=dict(
up=dict(x=0, y=-1, z=0), # Y-down convention for visualization
eye=dict(x=1.5, y=-1.5, z=1.5),
),
),
margin=dict(l=0, r=0, b=0, t=40),
)
return fig
def save_diagnostic_plot(fig: go.Figure, path: str) -> None:
"""
Save the diagnostic plot to an HTML file.
"""
import os
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
fig.write_html(path)