404 lines
11 KiB
Python
404 lines
11 KiB
Python
import numpy as np
|
|
from typing import Optional, Tuple, List, Dict, Any
|
|
from jaxtyping import Float
|
|
from typing import TYPE_CHECKING
|
|
import open3d as o3d
|
|
from dataclasses import dataclass, field
|
|
|
|
if TYPE_CHECKING:
|
|
Vec3 = Float[np.ndarray, "3"]
|
|
Mat44 = Float[np.ndarray, "4 4"]
|
|
PointsNC = Float[np.ndarray, "N 3"]
|
|
else:
|
|
Vec3 = np.ndarray
|
|
Mat44 = np.ndarray
|
|
PointsNC = np.ndarray
|
|
|
|
|
|
@dataclass
|
|
class FloorPlane:
|
|
normal: Vec3
|
|
d: float
|
|
num_inliers: int = 0
|
|
|
|
|
|
@dataclass
|
|
class FloorCorrection:
|
|
transform: Mat44
|
|
valid: bool
|
|
reason: str = ""
|
|
|
|
|
|
@dataclass
|
|
class GroundPlaneConfig:
|
|
enabled: bool = True
|
|
target_y: float = 0.0
|
|
stride: int = 8
|
|
depth_min: float = 0.2
|
|
depth_max: float = 5.0
|
|
ransac_dist_thresh: float = 0.02
|
|
ransac_n: int = 3
|
|
ransac_iters: int = 1000
|
|
max_rotation_deg: float = 5.0
|
|
max_translation_m: float = 0.1
|
|
min_inliers: int = 500
|
|
min_valid_cameras: int = 2
|
|
|
|
|
|
@dataclass
|
|
class GroundPlaneMetrics:
|
|
success: bool = False
|
|
correction_applied: bool = False
|
|
num_cameras_total: int = 0
|
|
num_cameras_valid: int = 0
|
|
correction_transform: Mat44 = field(default_factory=lambda: np.eye(4))
|
|
rotation_deg: float = 0.0
|
|
translation_m: float = 0.0
|
|
camera_planes: Dict[str, FloorPlane] = field(default_factory=dict)
|
|
consensus_plane: Optional[FloorPlane] = None
|
|
message: str = ""
|
|
|
|
|
|
def unproject_depth_to_points(
|
|
depth_map: np.ndarray,
|
|
K: np.ndarray,
|
|
stride: int = 1,
|
|
depth_min: float = 0.1,
|
|
depth_max: float = 10.0,
|
|
) -> PointsNC:
|
|
"""
|
|
Unproject a depth map to a point cloud.
|
|
"""
|
|
h, w = depth_map.shape
|
|
fx = K[0, 0]
|
|
fy = K[1, 1]
|
|
cx = K[0, 2]
|
|
cy = K[1, 2]
|
|
|
|
# Create meshgrid of pixel coordinates
|
|
# Use stride to reduce number of points
|
|
u_coords = np.arange(0, w, stride)
|
|
v_coords = np.arange(0, h, stride)
|
|
u, v = np.meshgrid(u_coords, v_coords)
|
|
|
|
# Sample depth map
|
|
z = depth_map[0:h:stride, 0:w:stride]
|
|
|
|
# Filter by depth bounds
|
|
valid_mask = (z > depth_min) & (z < depth_max) & np.isfinite(z)
|
|
|
|
# Apply mask
|
|
z_valid = z[valid_mask]
|
|
u_valid = u[valid_mask]
|
|
v_valid = v[valid_mask]
|
|
|
|
# Unproject
|
|
x_valid = (u_valid - cx) * z_valid / fx
|
|
y_valid = (v_valid - cy) * z_valid / fy
|
|
|
|
# Stack into (N, 3) array
|
|
points = np.stack((x_valid, y_valid, z_valid), axis=-1)
|
|
|
|
return points.astype(np.float64)
|
|
|
|
|
|
def detect_floor_plane(
|
|
points: PointsNC,
|
|
distance_threshold: float = 0.02,
|
|
ransac_n: int = 3,
|
|
num_iterations: int = 1000,
|
|
seed: Optional[int] = None,
|
|
) -> Optional[FloorPlane]:
|
|
"""
|
|
Detect the floor plane from a point cloud using RANSAC.
|
|
Returns FloorPlane or None if detection fails.
|
|
"""
|
|
if points.shape[0] < ransac_n:
|
|
return None
|
|
|
|
# Convert to Open3D PointCloud
|
|
pcd = o3d.geometry.PointCloud()
|
|
pcd.points = o3d.utility.Vector3dVector(points)
|
|
|
|
# Set seed for determinism if provided
|
|
if seed is not None:
|
|
o3d.utility.random.seed(seed)
|
|
|
|
# Segment plane
|
|
plane_model, inliers = pcd.segment_plane(
|
|
distance_threshold=distance_threshold,
|
|
ransac_n=ransac_n,
|
|
num_iterations=num_iterations,
|
|
)
|
|
|
|
# Check if we found enough inliers
|
|
if len(inliers) < ransac_n:
|
|
return None
|
|
|
|
# plane_model is [a, b, c, d]
|
|
a, b, c, d = plane_model
|
|
normal = np.array([a, b, c], dtype=np.float64)
|
|
|
|
# Normalize normal (Open3D usually returns normalized, but be safe)
|
|
norm = np.linalg.norm(normal)
|
|
if norm > 1e-6:
|
|
normal /= norm
|
|
d /= norm
|
|
|
|
return FloorPlane(normal=normal, d=d, num_inliers=len(inliers))
|
|
|
|
|
|
def compute_consensus_plane(
|
|
planes: List[FloorPlane],
|
|
weights: Optional[List[float]] = None,
|
|
) -> FloorPlane:
|
|
"""
|
|
Compute a consensus plane from multiple plane detections.
|
|
"""
|
|
if not planes:
|
|
raise ValueError("No planes provided for consensus.")
|
|
|
|
n_planes = len(planes)
|
|
if weights is None:
|
|
weights = [1.0] * n_planes
|
|
|
|
if len(weights) != n_planes:
|
|
raise ValueError(
|
|
f"Weights length {len(weights)} must match planes length {n_planes}"
|
|
)
|
|
|
|
# Use the first plane as reference for orientation
|
|
ref_normal = planes[0].normal
|
|
|
|
accum_normal = np.zeros(3, dtype=np.float64)
|
|
accum_d = 0.0
|
|
total_weight = 0.0
|
|
|
|
for i, plane in enumerate(planes):
|
|
w = weights[i]
|
|
normal = plane.normal
|
|
d = plane.d
|
|
|
|
# Check orientation against reference
|
|
if np.dot(normal, ref_normal) < 0:
|
|
# Flip normal and d to align with reference
|
|
normal = -normal
|
|
d = -d
|
|
|
|
accum_normal += normal * w
|
|
accum_d += d * w
|
|
total_weight += w
|
|
|
|
if total_weight <= 0:
|
|
raise ValueError("Total weight must be positive.")
|
|
|
|
avg_normal = accum_normal / total_weight
|
|
avg_d = accum_d / total_weight
|
|
|
|
# Re-normalize normal
|
|
norm = np.linalg.norm(avg_normal)
|
|
if norm > 1e-6:
|
|
avg_normal /= norm
|
|
# Scale d by 1/norm to maintain plane equation consistency
|
|
avg_d /= norm
|
|
else:
|
|
# Fallback (should be rare if inputs are valid)
|
|
avg_normal = np.array([0.0, 1.0, 0.0])
|
|
avg_d = 0.0
|
|
|
|
return FloorPlane(normal=avg_normal, d=float(avg_d))
|
|
|
|
|
|
from .alignment import rotation_align_vectors
|
|
|
|
|
|
def compute_floor_correction(
|
|
current_floor_plane: FloorPlane,
|
|
target_floor_y: float = 0.0,
|
|
max_rotation_deg: float = 5.0,
|
|
max_translation_m: float = 0.1,
|
|
) -> FloorCorrection:
|
|
"""
|
|
Compute the correction transform to align the current floor plane to the target floor height.
|
|
Constrains correction to pitch/roll and vertical translation only.
|
|
"""
|
|
current_normal = current_floor_plane.normal
|
|
current_d = current_floor_plane.d
|
|
|
|
# Target normal is always [0, 1, 0] (Y-up)
|
|
target_normal = np.array([0.0, 1.0, 0.0])
|
|
|
|
# 1. Compute rotation to align normals
|
|
try:
|
|
R_align = rotation_align_vectors(current_normal, target_normal)
|
|
except ValueError as e:
|
|
return FloorCorrection(
|
|
transform=np.eye(4), valid=False, reason=f"Rotation alignment failed: {e}"
|
|
)
|
|
|
|
# Check rotation magnitude
|
|
# Angle of rotation is acos((trace(R) - 1) / 2)
|
|
trace = np.trace(R_align)
|
|
# Clip to avoid numerical errors outside [-1, 1]
|
|
cos_angle = np.clip((trace - 1) / 2, -1.0, 1.0)
|
|
angle_rad = np.arccos(cos_angle)
|
|
angle_deg = np.rad2deg(angle_rad)
|
|
|
|
if angle_deg > max_rotation_deg:
|
|
return FloorCorrection(
|
|
transform=np.eye(4),
|
|
valid=False,
|
|
reason=f"Rotation {angle_deg:.1f} deg exceeds limit {max_rotation_deg:.1f} deg",
|
|
)
|
|
|
|
# 2. Compute translation
|
|
# We want to move points such that the floor is at y = target_floor_y
|
|
# Plane equation: n . p + d = 0
|
|
# Current floor at y = -current_d (if n=[0,1,0])
|
|
# We want new y = target_floor_y
|
|
# So shift = target_floor_y - (-current_d) = target_floor_y + current_d
|
|
|
|
t_y = target_floor_y + current_d
|
|
|
|
# Check translation magnitude
|
|
if abs(t_y) > max_translation_m:
|
|
return FloorCorrection(
|
|
transform=np.eye(4),
|
|
valid=False,
|
|
reason=f"Translation {t_y:.3f} m exceeds limit {max_translation_m:.3f} m",
|
|
)
|
|
|
|
# Construct T
|
|
T = np.eye(4)
|
|
T[:3, :3] = R_align
|
|
# Translation is applied in the rotated frame (aligned to target normal)
|
|
T[:3, 3] = target_normal * t_y
|
|
|
|
return FloorCorrection(transform=T.astype(np.float64), valid=True)
|
|
|
|
|
|
def refine_ground_from_depth(
|
|
camera_data: Dict[str, Dict[str, Any]],
|
|
extrinsics: Dict[str, Mat44],
|
|
config: GroundPlaneConfig = GroundPlaneConfig(),
|
|
) -> Tuple[Dict[str, Mat44], GroundPlaneMetrics]:
|
|
"""
|
|
Orchestrate ground plane refinement across multiple cameras.
|
|
|
|
Args:
|
|
camera_data: Dict mapping serial -> {'depth': np.ndarray, 'K': np.ndarray}
|
|
extrinsics: Dict mapping serial -> world_from_cam matrix (4x4)
|
|
config: Configuration parameters
|
|
|
|
Returns:
|
|
Tuple of (new_extrinsics, metrics)
|
|
"""
|
|
metrics = GroundPlaneMetrics()
|
|
metrics.num_cameras_total = len(camera_data)
|
|
|
|
if not config.enabled:
|
|
metrics.message = "Ground plane refinement disabled in config"
|
|
return extrinsics, metrics
|
|
|
|
valid_planes: List[FloorPlane] = []
|
|
valid_serials: List[str] = []
|
|
|
|
# 1. Detect planes in each camera
|
|
for serial, data in camera_data.items():
|
|
if serial not in extrinsics:
|
|
continue
|
|
|
|
depth_map = data.get("depth")
|
|
K = data.get("K")
|
|
|
|
if depth_map is None or K is None:
|
|
continue
|
|
|
|
# Unproject to camera frame
|
|
points_cam = unproject_depth_to_points(
|
|
depth_map,
|
|
K,
|
|
stride=config.stride,
|
|
depth_min=config.depth_min,
|
|
depth_max=config.depth_max,
|
|
)
|
|
|
|
if len(points_cam) < config.min_inliers:
|
|
continue
|
|
|
|
# Transform to world frame
|
|
T_world_cam = extrinsics[serial]
|
|
# points_cam is (N, 3)
|
|
# Apply rotation and translation
|
|
R = T_world_cam[:3, :3]
|
|
t = T_world_cam[:3, 3]
|
|
points_world = (points_cam @ R.T) + t
|
|
|
|
# Detect plane
|
|
plane = detect_floor_plane(
|
|
points_world,
|
|
distance_threshold=config.ransac_dist_thresh,
|
|
ransac_n=config.ransac_n,
|
|
num_iterations=config.ransac_iters,
|
|
)
|
|
|
|
if plane is not None and plane.num_inliers >= config.min_inliers:
|
|
metrics.camera_planes[serial] = plane
|
|
valid_planes.append(plane)
|
|
valid_serials.append(serial)
|
|
|
|
metrics.num_cameras_valid = len(valid_planes)
|
|
|
|
# 2. Check minimum requirements
|
|
if len(valid_planes) < config.min_valid_cameras:
|
|
metrics.message = f"Found {len(valid_planes)} valid planes, required {config.min_valid_cameras}"
|
|
return extrinsics, metrics
|
|
|
|
# 3. Compute consensus
|
|
try:
|
|
consensus = compute_consensus_plane(valid_planes)
|
|
metrics.consensus_plane = consensus
|
|
except ValueError as e:
|
|
metrics.message = f"Consensus computation failed: {e}"
|
|
return extrinsics, metrics
|
|
|
|
# 4. Compute correction
|
|
correction = compute_floor_correction(
|
|
consensus,
|
|
target_floor_y=config.target_y,
|
|
max_rotation_deg=config.max_rotation_deg,
|
|
max_translation_m=config.max_translation_m,
|
|
)
|
|
|
|
metrics.correction_transform = correction.transform
|
|
|
|
if not correction.valid:
|
|
metrics.message = f"Correction invalid: {correction.reason}"
|
|
return extrinsics, metrics
|
|
|
|
# 5. Apply correction
|
|
# T_corr is the transform that moves the world frame.
|
|
# New world points P' = T_corr * P
|
|
# We want new extrinsics T'_world_cam such that P' = T'_world_cam * P_cam
|
|
# T'_world_cam * P_cam = T_corr * (T_world_cam * P_cam)
|
|
# So T'_world_cam = T_corr * T_world_cam
|
|
|
|
new_extrinsics = {}
|
|
T_corr = correction.transform
|
|
|
|
for serial, T_old in extrinsics.items():
|
|
new_extrinsics[serial] = T_corr @ T_old
|
|
|
|
# Calculate metrics
|
|
# Rotation angle of T_corr
|
|
trace = np.trace(T_corr[:3, :3])
|
|
cos_angle = np.clip((trace - 1) / 2, -1.0, 1.0)
|
|
metrics.rotation_deg = float(np.rad2deg(np.arccos(cos_angle)))
|
|
metrics.translation_m = float(np.linalg.norm(T_corr[:3, 3]))
|
|
|
|
metrics.success = True
|
|
metrics.correction_applied = True
|
|
metrics.message = "Success"
|
|
|
|
return new_extrinsics, metrics
|