feat: implement depth bias estimation and correction in ICP pipeline

This commit is contained in:
2026-02-11 14:11:40 +00:00
parent 29eec81ea0
commit 8c6087683f
11 changed files with 1506 additions and 30 deletions
+290 -3
View File
@@ -45,6 +45,7 @@ class ICPConfig:
robust_kernel: str = "none" # "none" or "tukey"
robust_kernel_k: float = 0.1
global_init: bool = False # Enable FPFH+RANSAC global pre-alignment
depth_bias: bool = True # Enable per-camera depth bias pre-correction
@dataclass
@@ -67,6 +68,7 @@ class ICPMetrics:
num_pairs_converged: int = 0
num_cameras_optimized: int = 0
num_disconnected: int = 0
depth_biases: Dict[str, float] = field(default_factory=dict)
per_pair_results: dict[tuple[str, str], ICPResult] = field(default_factory=dict)
reference_camera: str = ""
message: str = ""
@@ -515,6 +517,272 @@ def optimize_pose_graph(
return pose_graph
def estimate_depth_biases(
camera_data: Dict[str, Dict[str, Any]],
extrinsics: Dict[str, Mat44],
floor_planes: Dict[str, Any], # Dict[str, FloorPlane]
config: ICPConfig,
reference_serial: Optional[str] = None,
) -> Dict[str, float]:
"""
Estimate per-camera depth offsets (beta, in meters) from overlap correspondences.
The model enforces pairwise constraints:
beta_j - beta_i ~= b_ij
where b_ij is the robust median of source-ray residuals over pair correspondences.
"""
from .ground_plane import unproject_depth_to_points
all_serials = sorted(list(camera_data.keys()))
if not all_serials:
return {}
if reference_serial is None or reference_serial not in all_serials:
reference = all_serials[0]
else:
reference = reference_serial
# Default/fallback output: 0.0 for all cameras
betas: Dict[str, float] = {serial: 0.0 for serial in all_serials}
# Auto-set overlap mode based on region, matching refine_with_icp behavior.
overlap_mode = config.overlap_mode
if config.region == "floor":
overlap_mode = "xz"
elif config.region in ["hybrid", "full"]:
overlap_mode = "3d"
camera_points: Dict[str, PointsNC] = {}
camera_rays_world: Dict[str, PointsNC] = {}
camera_pcds: Dict[str, o3d.geometry.PointCloud] = {}
# 1) Build per-camera world points/rays in the configured scene region.
for serial in all_serials:
if serial not in extrinsics:
continue
data = camera_data[serial]
depth_map = data.get("depth")
K = data.get("K")
if depth_map is None or K is None:
continue
points_cam = unproject_depth_to_points(depth_map, K, stride=4)
if len(points_cam) < 100:
continue
T_wc = extrinsics[serial]
points_world = (points_cam @ T_wc[:3, :3].T) + T_wc[:3, 3]
scene_points = points_world
if config.region != "full":
plane = floor_planes.get(serial)
if plane is None:
continue
floor_y = -plane.d
scene_points = extract_scene_points(
points_world,
floor_y,
plane.normal,
mode=config.region,
band_height=config.band_height,
)
if len(scene_points) < 100:
continue
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(scene_points)
pcd = preprocess_point_cloud(pcd, config.voxel_size)
if len(pcd.points) < 100:
continue
pts = np.asarray(pcd.points)
cam_pos = T_wc[:3, 3]
rays_world = pts - cam_pos[None, :]
ray_norm = np.linalg.norm(rays_world, axis=1)
valid_mask = ray_norm > 1e-9
if np.count_nonzero(valid_mask) < 100:
continue
pts = pts[valid_mask]
rays_world = rays_world[valid_mask] / ray_norm[valid_mask, None]
camera_points[serial] = pts
camera_rays_world[serial] = rays_world
pcd_valid = o3d.geometry.PointCloud()
pcd_valid.points = o3d.utility.Vector3dVector(pts)
camera_pcds[serial] = pcd_valid
valid_serials = sorted(list(camera_points.keys()))
if reference not in valid_serials:
logger.warning(
f"Reference camera {reference} has no valid data for bias estimation; all biases set to 0.0"
)
betas[reference] = 0.0
logger.info(f"Estimated depth biases (m): {betas}")
return betas
# 2) Build robust pairwise median constraints b_ij.
max_corr_dist = config.voxel_size * config.max_correspondence_distance_factor
max_corr_dist_sq = max_corr_dist * max_corr_dist
pair_constraints: List[Tuple[str, str, float, int]] = []
for i, s1 in enumerate(valid_serials):
for j in range(i + 1, len(valid_serials)):
s2 = valid_serials[j]
if overlap_mode == "3d":
overlap = compute_overlap_3d(
camera_points[s1], camera_points[s2], config.overlap_margin
)
unit = "m^3"
else:
overlap = compute_overlap_xz(
camera_points[s1], camera_points[s2], config.overlap_margin
)
unit = "m^2"
if overlap < config.min_overlap_area:
logger.debug(
f"Bias pair ({s1}, {s2}) skipped: overlap {overlap:.2f} {unit} < {config.min_overlap_area:.2f} {unit}"
)
continue
src_pts = camera_points[s1]
src_rays = camera_rays_world[s1]
tgt_pts = camera_points[s2]
kdtree = o3d.geometry.KDTreeFlann(camera_pcds[s2])
residuals: List[float] = []
for idx in range(len(src_pts)):
query = src_pts[idx]
k, nn_indices, nn_dist_sq = kdtree.search_knn_vector_3d(query, 1)
if k < 1:
continue
if nn_dist_sq[0] > max_corr_dist_sq:
continue
target_point = tgt_pts[nn_indices[0]]
source_point = query
source_ray_world = src_rays[idx]
residual = float(np.dot(target_point - source_point, source_ray_world))
if np.isfinite(residual):
residuals.append(residual)
if len(residuals) < 100:
logger.debug(
f"Bias pair ({s1}, {s2}) skipped: insufficient correspondences {len(residuals)} < 100"
)
continue
b_ij = float(np.median(np.asarray(residuals, dtype=np.float64)))
pair_constraints.append((s1, s2, b_ij, len(residuals)))
logger.debug(
f"Bias pair ({s1}, {s2}): corr={len(residuals)}, median={b_ij:.4f} m"
)
if not pair_constraints:
logger.warning(
"No valid pair constraints for depth bias estimation; returning zeros"
)
betas[reference] = 0.0
logger.info(f"Estimated depth biases (m): {betas}")
return betas
# 3) Keep only cameras connected to reference; disconnected cameras remain 0.0.
adjacency: Dict[str, set[str]] = {s: set() for s in valid_serials}
for s1, s2, _, _ in pair_constraints:
adjacency[s1].add(s2)
adjacency[s2].add(s1)
connected = {reference}
queue = [reference]
while queue:
curr = queue.pop(0)
for nbr in adjacency.get(curr, set()):
if nbr not in connected:
connected.add(nbr)
queue.append(nbr)
# Unknowns are all connected nodes except the fixed-gauge reference.
unknown_serials = sorted(list(connected - {reference}))
if not unknown_serials:
betas[reference] = 0.0
logger.info(f"Estimated depth biases (m): {betas}")
return betas
serial_to_col = {s: i for i, s in enumerate(unknown_serials)}
rows: List[np.ndarray] = []
rhs: List[float] = []
weights: List[float] = []
for s1, s2, b_ij, n_corr in pair_constraints:
if s1 not in connected or s2 not in connected:
continue
row = np.zeros(len(unknown_serials), dtype=np.float64)
if s1 != reference:
row[serial_to_col[s1]] = -1.0
if s2 != reference:
row[serial_to_col[s2]] = 1.0
if not np.any(row):
continue
rows.append(row)
rhs.append(b_ij)
# Mild confidence weighting by correspondence count.
weights.append(float(np.sqrt(max(n_corr, 1))))
if not rows:
betas[reference] = 0.0
logger.info(f"Estimated depth biases (m): {betas}")
return betas
A = np.vstack(rows)
y = np.asarray(rhs, dtype=np.float64)
w = np.asarray(weights, dtype=np.float64)
Aw = A * w[:, None]
yw = y * w
x, *_ = np.linalg.lstsq(Aw, yw, rcond=None)
betas[reference] = 0.0
for serial, col in serial_to_col.items():
betas[serial] = float(x[col])
# 4) Cap implausible values.
max_abs_bias = float(getattr(config, "max_abs_bias", 0.3))
for serial in betas:
if serial == reference:
continue
if abs(betas[serial]) > max_abs_bias:
logger.warning(
f"Depth bias for camera {serial} clipped from {betas[serial]:.4f} m "
f"to {np.clip(betas[serial], -max_abs_bias, max_abs_bias):.4f} m"
)
betas[serial] = float(np.clip(betas[serial], -max_abs_bias, max_abs_bias))
# Gauge is fixed exactly.
betas[reference] = 0.0
if len(connected) < len(valid_serials):
disconnected_valid = sorted(list(set(valid_serials) - connected))
logger.warning(
f"Depth-bias estimation disconnected cameras (set to 0.0): {disconnected_valid}"
)
logger.info(f"Estimated depth biases (m): {betas}")
return betas
def refine_with_icp(
camera_data: Dict[str, Dict[str, Any]],
extrinsics: Dict[str, Mat44],
@@ -543,6 +811,18 @@ def refine_with_icp(
elif config.region in ["hybrid", "full"]:
config.overlap_mode = "3d"
if config.depth_bias:
biases = estimate_depth_biases(
camera_data,
extrinsics,
floor_planes,
config,
reference_serial=metrics.reference_camera,
)
else:
biases = {}
metrics.depth_biases = biases
for serial in serials:
if serial not in floor_planes or serial not in extrinsics:
continue
@@ -550,7 +830,11 @@ def refine_with_icp(
data = camera_data[serial]
plane = floor_planes[serial]
points_cam = unproject_depth_to_points(data["depth"], data["K"], stride=4)
depth_corrected = data["depth"].copy()
depth_corrected += biases.get(serial, 0.0)
depth_corrected[depth_corrected <= 0] = np.nan
points_cam = unproject_depth_to_points(depth_corrected, data["K"], stride=4)
T_wc = extrinsics[serial]
points_world = (points_cam @ T_wc[:3, :3].T) + T_wc[:3, 3]
@@ -787,9 +1071,12 @@ def refine_with_icp(
continue
new_extrinsics[serial] = T_optimized
metrics.num_cameras_optimized += 1
if serial != metrics.reference_camera:
metrics.num_cameras_optimized += 1
metrics.success = metrics.num_cameras_optimized > 0
metrics.message = f"Optimized {metrics.num_cameras_optimized} cameras"
metrics.message = (
f"Optimized {metrics.num_cameras_optimized} cameras (excluding reference)"
)
return new_extrinsics, metrics