fix: complete ground_plane.py implementation and tests

This commit is contained in:
2026-02-09 07:16:14 +00:00
parent 1d3266ec60
commit 43a441f2d4
3 changed files with 297 additions and 19 deletions
+48 -19
View File
@@ -3,6 +3,7 @@ from typing import Optional, Tuple, List
from jaxtyping import Float
from typing import TYPE_CHECKING
import open3d as o3d
from dataclasses import dataclass
if TYPE_CHECKING:
Vec3 = Float[np.ndarray, "3"]
@@ -14,6 +15,20 @@ else:
PointsNC = np.ndarray
@dataclass
class FloorPlane:
normal: Vec3
d: float
num_inliers: int = 0
@dataclass
class FloorCorrection:
transform: Mat44
valid: bool
reason: str = ""
def unproject_depth_to_points(
depth_map: np.ndarray,
K: np.ndarray,
@@ -63,13 +78,13 @@ def detect_floor_plane(
ransac_n: int = 3,
num_iterations: int = 1000,
seed: Optional[int] = None,
) -> Tuple[Optional[Vec3], float, int]:
) -> Optional[FloorPlane]:
"""
Detect the floor plane from a point cloud using RANSAC.
Returns (normal, d, num_inliers) where plane is normal.dot(p) + d = 0.
Returns FloorPlane or None if detection fails.
"""
if points.shape[0] < ransac_n:
return None, 0.0, 0
return None
# Convert to Open3D PointCloud
pcd = o3d.geometry.PointCloud()
@@ -86,8 +101,9 @@ def detect_floor_plane(
num_iterations=num_iterations,
)
if not plane_model:
return None, 0.0, 0
# Check if we found enough inliers
if len(inliers) < ransac_n:
return None
# plane_model is [a, b, c, d]
a, b, c, d = plane_model
@@ -99,13 +115,13 @@ def detect_floor_plane(
normal /= norm
d /= norm
return normal, d, len(inliers)
return FloorPlane(normal=normal, d=d, num_inliers=len(inliers))
def compute_consensus_plane(
planes: List[Tuple[Vec3, float]],
planes: List[FloorPlane],
weights: Optional[List[float]] = None,
) -> Tuple[Vec3, float]:
) -> FloorPlane:
"""
Compute a consensus plane from multiple plane detections.
"""
@@ -122,14 +138,16 @@ def compute_consensus_plane(
)
# Use the first plane as reference for orientation
ref_normal = planes[0][0]
ref_normal = planes[0].normal
accum_normal = np.zeros(3, dtype=np.float64)
accum_d = 0.0
total_weight = 0.0
for i, (normal, d) in enumerate(planes):
for i, plane in enumerate(planes):
w = weights[i]
normal = plane.normal
d = plane.d
# Check orientation against reference
if np.dot(normal, ref_normal) < 0:
@@ -158,23 +176,24 @@ def compute_consensus_plane(
avg_normal = np.array([0.0, 1.0, 0.0])
avg_d = 0.0
return avg_normal, float(avg_d)
return FloorPlane(normal=avg_normal, d=float(avg_d))
from .alignment import rotation_align_vectors
def compute_floor_correction(
current_floor_plane: Tuple[Vec3, float],
current_floor_plane: FloorPlane,
target_floor_y: float = 0.0,
max_rotation_deg: float = 5.0,
max_translation_m: float = 0.1,
) -> Optional[Mat44]:
) -> FloorCorrection:
"""
Compute the correction transform to align the current floor plane to the target floor height.
Constrains correction to pitch/roll and vertical translation only.
"""
current_normal, current_d = current_floor_plane
current_normal = current_floor_plane.normal
current_d = current_floor_plane.d
# Target normal is always [0, 1, 0] (Y-up)
target_normal = np.array([0.0, 1.0, 0.0])
@@ -182,8 +201,10 @@ def compute_floor_correction(
# 1. Compute rotation to align normals
try:
R_align = rotation_align_vectors(current_normal, target_normal)
except ValueError:
return None
except ValueError as e:
return FloorCorrection(
transform=np.eye(4), valid=False, reason=f"Rotation alignment failed: {e}"
)
# Check rotation magnitude
# Angle of rotation is acos((trace(R) - 1) / 2)
@@ -194,7 +215,11 @@ def compute_floor_correction(
angle_deg = np.rad2deg(angle_rad)
if angle_deg > max_rotation_deg:
return None
return FloorCorrection(
transform=np.eye(4),
valid=False,
reason=f"Rotation {angle_deg:.1f} deg exceeds limit {max_rotation_deg:.1f} deg",
)
# 2. Compute translation
# We want to move points such that the floor is at y = target_floor_y
@@ -207,7 +232,11 @@ def compute_floor_correction(
# Check translation magnitude
if abs(t_y) > max_translation_m:
return None
return FloorCorrection(
transform=np.eye(4),
valid=False,
reason=f"Translation {t_y:.3f} m exceeds limit {max_translation_m:.3f} m",
)
# Construct T
T = np.eye(4)
@@ -215,4 +244,4 @@ def compute_floor_correction(
# Translation is applied in the rotated frame (aligned to target normal)
T[:3, 3] = target_normal * t_y
return T.astype(np.float64)
return FloorCorrection(transform=T.astype(np.float64), valid=True)