Files
zed-playground/py_workspace/aruco/icp_registration.py
T

479 lines
15 KiB
Python

import numpy as np
import open3d as o3d
from typing import Dict, List, Optional, Tuple, Any, TYPE_CHECKING
from dataclasses import dataclass, field
from jaxtyping import Float
from scipy.spatial.transform import Rotation
from loguru import logger
from .pose_math import invert_transform
if TYPE_CHECKING:
Vec3 = Float[np.ndarray, "3"]
Mat44 = Float[np.ndarray, "4 4"]
PointsNC = Float[np.ndarray, "N 3"]
else:
Vec3 = np.ndarray
Mat44 = np.ndarray
PointsNC = np.ndarray
@dataclass
class ICPConfig:
"""Configuration for ICP registration."""
voxel_size: float = 0.02 # Base voxel size in meters
max_iterations: list[int] = field(default_factory=lambda: [50, 30, 14])
method: str = "point_to_plane" # "point_to_plane" or "gicp"
band_height: float = 0.3 # Near-floor band height in meters
min_fitness: float = 0.3 # Min ICP fitness to accept pair
min_overlap_area: float = 1.0 # Min XZ overlap area in m^2
overlap_margin: float = 0.5 # Inflate bboxes by this margin (m)
gravity_penalty_weight: float = 10.0 # Soft constraint on pitch/roll
max_correspondence_distance_factor: float = 1.4
max_rotation_deg: float = 5.0 # Safety bound on ICP delta
max_translation_m: float = 0.1 # Safety bound on ICP delta
@dataclass
class ICPResult:
"""Result of a pairwise ICP registration."""
transformation: Mat44 # 4x4
fitness: float
inlier_rmse: float
information_matrix: np.ndarray # 6x6
converged: bool
@dataclass
class ICPMetrics:
"""Metrics for the global ICP refinement process."""
success: bool = False
num_pairs_attempted: int = 0
num_pairs_converged: int = 0
num_cameras_optimized: int = 0
num_disconnected: int = 0
per_pair_results: dict[tuple[str, str], ICPResult] = field(default_factory=dict)
reference_camera: str = ""
message: str = ""
def extract_near_floor_band(
points_world: PointsNC,
floor_y: float,
band_height: float,
floor_normal: Vec3,
) -> PointsNC:
"""
Extract points within a vertical band relative to the floor.
Points are in world frame.
"""
if len(points_world) == 0:
return points_world
# Project points onto floor normal
# Distance from origin along normal: p . n
# We want points where floor_y <= p.n <= floor_y + band_height
projections = points_world @ floor_normal
mask = (projections >= floor_y) & (projections <= floor_y + band_height)
return points_world[mask]
def compute_overlap_xz(
points_a: PointsNC,
points_b: PointsNC,
margin: float = 0.0,
) -> float:
"""
Compute intersection area of XZ bounding boxes.
"""
if len(points_a) == 0 or len(points_b) == 0:
return 0.0
min_a = np.min(points_a[:, [0, 2]], axis=0) - margin
max_a = np.max(points_a[:, [0, 2]], axis=0) + margin
min_b = np.min(points_b[:, [0, 2]], axis=0) - margin
max_b = np.max(points_b[:, [0, 2]], axis=0) + margin
inter_min = np.maximum(min_a, min_b)
inter_max = np.minimum(max_a, max_b)
dims = np.maximum(0, inter_max - inter_min)
return float(dims[0] * dims[1])
def apply_gravity_constraint(
T_icp: Mat44,
T_original: Mat44,
penalty_weight: float = 10.0,
) -> Mat44:
"""
Preserve RANSAC gravity alignment while allowing yaw + XZ + height refinement.
"""
R_icp = T_icp[:3, :3]
R_orig = T_original[:3, :3]
rot_icp = Rotation.from_matrix(R_icp)
rot_orig = Rotation.from_matrix(R_orig)
euler_icp = rot_icp.as_euler("xyz")
euler_orig = rot_orig.as_euler("xyz")
# Blend pitch (x) and roll (z)
# blended = original + (icp - original) / (1 + penalty_weight)
# Handle angular wrap-around for robustness
diff = euler_icp - euler_orig
diff = (diff + np.pi) % (2 * np.pi) - np.pi
blended_euler = euler_orig + diff / (1 + penalty_weight)
# Keep ICP yaw (y)
blended_euler[1] = euler_icp[1]
R_constrained = Rotation.from_euler("xyz", blended_euler).as_matrix()
T_constrained = T_icp.copy()
T_constrained[:3, :3] = R_constrained
return T_constrained
def pairwise_icp(
source_pcd: o3d.geometry.PointCloud,
target_pcd: o3d.geometry.PointCloud,
config: ICPConfig,
init_transform: Mat44,
) -> ICPResult:
"""
Multi-scale ICP registration.
"""
current_transform = init_transform
voxel_scales = [4, 2, 1]
# Initialize reg_result to handle empty scales or other issues
# but we expect at least one scale.
reg_result = None
for i, scale in enumerate(voxel_scales):
voxel_size = config.voxel_size * scale
max_iter = config.max_iterations[i]
source_down = source_pcd.voxel_down_sample(voxel_size)
target_down = target_pcd.voxel_down_sample(voxel_size)
source_down.estimate_normals(
o3d.geometry.KDTreeSearchParamHybrid(radius=voxel_size * 2, max_nn=30)
)
target_down.estimate_normals(
o3d.geometry.KDTreeSearchParamHybrid(radius=voxel_size * 2, max_nn=30)
)
dist_thresh = voxel_size * config.max_correspondence_distance_factor
criteria = o3d.pipelines.registration.ICPConvergenceCriteria(
max_iteration=max_iter
)
if config.method == "point_to_plane":
estimation = (
o3d.pipelines.registration.TransformationEstimationPointToPlane()
)
reg_result = o3d.pipelines.registration.registration_icp(
source_down,
target_down,
dist_thresh,
current_transform,
estimation,
criteria,
)
elif config.method == "gicp":
estimation = (
o3d.pipelines.registration.TransformationEstimationForGeneralizedICP()
)
reg_result = o3d.pipelines.registration.registration_generalized_icp(
source_down,
target_down,
dist_thresh,
current_transform,
estimation,
criteria,
)
else:
# Fallback
estimation = (
o3d.pipelines.registration.TransformationEstimationPointToPoint()
)
reg_result = o3d.pipelines.registration.registration_icp(
source_down,
target_down,
dist_thresh,
current_transform,
estimation,
criteria,
)
current_transform = reg_result.transformation
if reg_result is None:
return ICPResult(
transformation=init_transform,
fitness=0.0,
inlier_rmse=0.0,
information_matrix=np.eye(6),
converged=False,
)
# Final information matrix
info_matrix = o3d.pipelines.registration.get_information_matrix_from_point_clouds(
source_pcd,
target_pcd,
config.voxel_size * config.max_correspondence_distance_factor,
current_transform,
)
return ICPResult(
transformation=current_transform,
fitness=reg_result.fitness,
inlier_rmse=reg_result.inlier_rmse,
information_matrix=info_matrix,
converged=reg_result.fitness > config.min_fitness,
)
def build_pose_graph(
serials: List[str],
extrinsics: Dict[str, Mat44],
pair_results: Dict[Tuple[str, str], ICPResult],
reference_serial: str,
) -> o3d.pipelines.registration.PoseGraph:
"""
Build a PoseGraph from pairwise results.
Only includes cameras reachable from the reference camera.
"""
# 1. Detect connected component from reference
connected = {reference_serial}
queue = [reference_serial]
while queue:
curr = queue.pop(0)
for (s1, s2), result in pair_results.items():
if not result.converged:
continue
if s1 == curr and s2 not in connected:
connected.add(s2)
queue.append(s2)
elif s2 == curr and s1 not in connected:
connected.add(s1)
queue.append(s1)
# 2. Filter serials to only include connected ones
# Keep reference_serial at index 0
optimized_serials = [reference_serial] + sorted(
list(connected - {reference_serial})
)
serial_to_idx = {s: i for i, s in enumerate(optimized_serials)}
# Log disconnected cameras
disconnected = set(serials) - connected
if disconnected:
logger.warning(
f"Cameras disconnected from reference {reference_serial}: {disconnected}"
)
pose_graph = o3d.pipelines.registration.PoseGraph()
for serial in optimized_serials:
T_wc = extrinsics[serial]
pose_graph.nodes.append(o3d.pipelines.registration.PoseGraphNode(T_wc))
for (s1, s2), result in pair_results.items():
if not result.converged:
continue
if s1 not in serial_to_idx or s2 not in serial_to_idx:
continue
idx1 = serial_to_idx[s1]
idx2 = serial_to_idx[s2]
# Edge from idx2 to idx1 (transformation maps 1 to 2)
# Open3D PoseGraphEdge(source, target, T) means P_source = T * P_target
# Here P_2 = T_21 * P_1, so source=2, target=1
edge = o3d.pipelines.registration.PoseGraphEdge(
idx2, idx1, result.transformation, result.information_matrix, uncertain=True
)
pose_graph.edges.append(edge)
return pose_graph
def optimize_pose_graph(
pose_graph: o3d.pipelines.registration.PoseGraph,
) -> o3d.pipelines.registration.PoseGraph:
"""
Run global optimization.
"""
option = o3d.pipelines.registration.GlobalOptimizationOption(
max_correspondence_distance=0.1,
edge_prune_threshold=0.25,
reference_node=0,
)
o3d.pipelines.registration.global_optimization(
pose_graph,
o3d.pipelines.registration.GlobalOptimizationLevenbergMarquardt(),
o3d.pipelines.registration.GlobalOptimizationConvergenceCriteria(),
option,
)
return pose_graph
def refine_with_icp(
camera_data: Dict[str, Dict[str, Any]],
extrinsics: Dict[str, Mat44],
floor_planes: Dict[str, Any], # Dict[str, FloorPlane]
config: ICPConfig,
) -> Tuple[Dict[str, Mat44], ICPMetrics]:
"""
Main orchestrator for ICP refinement.
"""
from .ground_plane import unproject_depth_to_points
metrics = ICPMetrics()
serials = sorted(list(camera_data.keys()))
if not serials:
return extrinsics, metrics
metrics.reference_camera = serials[0]
# 1. Extract near-floor bands
camera_pcds: Dict[str, o3d.geometry.PointCloud] = {}
camera_points: Dict[str, PointsNC] = {}
for serial in serials:
if serial not in floor_planes or serial not in extrinsics:
continue
data = camera_data[serial]
plane = floor_planes[serial]
points_cam = unproject_depth_to_points(data["depth"], data["K"], stride=4)
T_wc = extrinsics[serial]
points_world = (points_cam @ T_wc[:3, :3].T) + T_wc[:3, 3]
# floor_y = -plane.d (distance to origin along normal)
floor_y = -plane.d
band_points = extract_near_floor_band(
points_world, floor_y, config.band_height, plane.normal
)
if len(band_points) < 100:
continue
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(band_points)
camera_pcds[serial] = pcd
camera_points[serial] = band_points
# 2. Pairwise ICP
valid_serials = sorted(list(camera_pcds.keys()))
pair_results: Dict[Tuple[str, str], ICPResult] = {}
for i, s1 in enumerate(valid_serials):
for j in range(i + 1, len(valid_serials)):
s2 = valid_serials[j]
area = compute_overlap_xz(
camera_points[s1], camera_points[s2], config.overlap_margin
)
if area < config.min_overlap_area:
continue
metrics.num_pairs_attempted += 1
# Initial relative transform from current extrinsics
# T_21 = T_w2^-1 * T_w1
T_w1 = extrinsics[s1]
T_w2 = extrinsics[s2]
init_T = invert_transform(T_w2) @ T_w1
# pairwise_icp aligns source_pcd to target_pcd.
# We pass camera-frame points to pairwise_icp to use init_T meaningfully.
pcd1_cam = o3d.geometry.PointCloud()
pcd1_cam.points = o3d.utility.Vector3dVector(
(np.asarray(camera_pcds[s1].points) - T_w1[:3, 3]) @ T_w1[:3, :3]
)
pcd2_cam = o3d.geometry.PointCloud()
pcd2_cam.points = o3d.utility.Vector3dVector(
(np.asarray(camera_pcds[s2].points) - T_w2[:3, 3]) @ T_w2[:3, :3]
)
result = pairwise_icp(pcd1_cam, pcd2_cam, config, init_T)
# Apply gravity constraint to the result relative to original transform
result.transformation = apply_gravity_constraint(
result.transformation, init_T, config.gravity_penalty_weight
)
if result.converged:
pair_results[(s1, s2)] = result
metrics.num_pairs_converged += 1
metrics.per_pair_results[(s1, s2)] = result
if not pair_results:
metrics.message = "No converged ICP pairs"
return extrinsics, metrics
# 3. Pose Graph
pose_graph = build_pose_graph(
valid_serials, extrinsics, pair_results, metrics.reference_camera
)
# 4. Optimize
optimize_pose_graph(pose_graph)
# 5. Extract and Validate
new_extrinsics = extrinsics.copy()
# Re-derive optimized_serials to match build_pose_graph logic for node-to-serial mapping
connected = {metrics.reference_camera}
queue = [metrics.reference_camera]
while queue:
curr = queue.pop(0)
for (s1, s2), result in pair_results.items():
if not result.converged:
continue
if s1 == curr and s2 not in connected:
connected.add(s2)
queue.append(s2)
elif s2 == curr and s1 not in connected:
connected.add(s1)
queue.append(s1)
optimized_serials = [metrics.reference_camera] + sorted(
list(connected - {metrics.reference_camera})
)
metrics.num_disconnected = len(valid_serials) - len(optimized_serials)
metrics.num_cameras_optimized = 0
for i, serial in enumerate(optimized_serials):
T_optimized = pose_graph.nodes[i].pose
T_old = extrinsics[serial]
# Validate delta
T_delta = T_optimized @ invert_transform(T_old)
rot_delta = Rotation.from_matrix(T_delta[:3, :3]).as_euler("xyz", degrees=True)
rot_mag = np.linalg.norm(rot_delta)
trans_mag = np.linalg.norm(T_delta[:3, 3])
if rot_mag > config.max_rotation_deg or trans_mag > config.max_translation_m:
logger.warning(
f"Camera {serial} ICP correction exceeds bounds: rot={rot_mag:.2f} deg, trans={trans_mag:.3f} m. Rejecting."
)
continue
new_extrinsics[serial] = T_optimized
metrics.num_cameras_optimized += 1
metrics.success = metrics.num_cameras_optimized > 1
metrics.message = f"Optimized {metrics.num_cameras_optimized} cameras"
return new_extrinsics, metrics