feat: implement ground plane orchestration

This commit is contained in:
2026-02-09 07:27:36 +00:00
parent 6f34cd48fe
commit 94d9a27724
2 changed files with 318 additions and 2 deletions
+158 -2
View File
@@ -1,9 +1,9 @@
import numpy as np
from typing import Optional, Tuple, List
from typing import Optional, Tuple, List, Dict, Any
from jaxtyping import Float
from typing import TYPE_CHECKING
import open3d as o3d
from dataclasses import dataclass
from dataclasses import dataclass, field
if TYPE_CHECKING:
Vec3 = Float[np.ndarray, "3"]
@@ -29,6 +29,36 @@ class FloorCorrection:
reason: str = ""
@dataclass
class GroundPlaneConfig:
enabled: bool = True
target_y: float = 0.0
stride: int = 8
depth_min: float = 0.2
depth_max: float = 5.0
ransac_dist_thresh: float = 0.02
ransac_n: int = 3
ransac_iters: int = 1000
max_rotation_deg: float = 5.0
max_translation_m: float = 0.1
min_inliers: int = 500
min_valid_cameras: int = 2
@dataclass
class GroundPlaneMetrics:
success: bool = False
correction_applied: bool = False
num_cameras_total: int = 0
num_cameras_valid: int = 0
correction_transform: Mat44 = field(default_factory=lambda: np.eye(4))
rotation_deg: float = 0.0
translation_m: float = 0.0
camera_planes: Dict[str, FloorPlane] = field(default_factory=dict)
consensus_plane: Optional[FloorPlane] = None
message: str = ""
def unproject_depth_to_points(
depth_map: np.ndarray,
K: np.ndarray,
@@ -245,3 +275,129 @@ def compute_floor_correction(
T[:3, 3] = target_normal * t_y
return FloorCorrection(transform=T.astype(np.float64), valid=True)
def refine_ground_from_depth(
camera_data: Dict[str, Dict[str, Any]],
extrinsics: Dict[str, Mat44],
config: GroundPlaneConfig = GroundPlaneConfig(),
) -> Tuple[Dict[str, Mat44], GroundPlaneMetrics]:
"""
Orchestrate ground plane refinement across multiple cameras.
Args:
camera_data: Dict mapping serial -> {'depth': np.ndarray, 'K': np.ndarray}
extrinsics: Dict mapping serial -> world_from_cam matrix (4x4)
config: Configuration parameters
Returns:
Tuple of (new_extrinsics, metrics)
"""
metrics = GroundPlaneMetrics()
metrics.num_cameras_total = len(camera_data)
if not config.enabled:
metrics.message = "Ground plane refinement disabled in config"
return extrinsics, metrics
valid_planes: List[FloorPlane] = []
valid_serials: List[str] = []
# 1. Detect planes in each camera
for serial, data in camera_data.items():
if serial not in extrinsics:
continue
depth_map = data.get("depth")
K = data.get("K")
if depth_map is None or K is None:
continue
# Unproject to camera frame
points_cam = unproject_depth_to_points(
depth_map,
K,
stride=config.stride,
depth_min=config.depth_min,
depth_max=config.depth_max,
)
if len(points_cam) < config.min_inliers:
continue
# Transform to world frame
T_world_cam = extrinsics[serial]
# points_cam is (N, 3)
# Apply rotation and translation
R = T_world_cam[:3, :3]
t = T_world_cam[:3, 3]
points_world = (points_cam @ R.T) + t
# Detect plane
plane = detect_floor_plane(
points_world,
distance_threshold=config.ransac_dist_thresh,
ransac_n=config.ransac_n,
num_iterations=config.ransac_iters,
)
if plane is not None and plane.num_inliers >= config.min_inliers:
metrics.camera_planes[serial] = plane
valid_planes.append(plane)
valid_serials.append(serial)
metrics.num_cameras_valid = len(valid_planes)
# 2. Check minimum requirements
if len(valid_planes) < config.min_valid_cameras:
metrics.message = f"Found {len(valid_planes)} valid planes, required {config.min_valid_cameras}"
return extrinsics, metrics
# 3. Compute consensus
try:
consensus = compute_consensus_plane(valid_planes)
metrics.consensus_plane = consensus
except ValueError as e:
metrics.message = f"Consensus computation failed: {e}"
return extrinsics, metrics
# 4. Compute correction
correction = compute_floor_correction(
consensus,
target_floor_y=config.target_y,
max_rotation_deg=config.max_rotation_deg,
max_translation_m=config.max_translation_m,
)
metrics.correction_transform = correction.transform
if not correction.valid:
metrics.message = f"Correction invalid: {correction.reason}"
return extrinsics, metrics
# 5. Apply correction
# T_corr is the transform that moves the world frame.
# New world points P' = T_corr * P
# We want new extrinsics T'_world_cam such that P' = T'_world_cam * P_cam
# T'_world_cam * P_cam = T_corr * (T_world_cam * P_cam)
# So T'_world_cam = T_corr * T_world_cam
new_extrinsics = {}
T_corr = correction.transform
for serial, T_old in extrinsics.items():
new_extrinsics[serial] = T_corr @ T_old
# Calculate metrics
# Rotation angle of T_corr
trace = np.trace(T_corr[:3, :3])
cos_angle = np.clip((trace - 1) / 2, -1.0, 1.0)
metrics.rotation_deg = float(np.rad2deg(np.arccos(cos_angle)))
metrics.translation_m = float(np.linalg.norm(T_corr[:3, 3]))
metrics.success = True
metrics.correction_applied = True
metrics.message = "Success"
return new_extrinsics, metrics