feat: implement depth bias estimation and correction in ICP pipeline
This commit is contained in:
@@ -45,6 +45,7 @@ class ICPConfig:
|
||||
robust_kernel: str = "none" # "none" or "tukey"
|
||||
robust_kernel_k: float = 0.1
|
||||
global_init: bool = False # Enable FPFH+RANSAC global pre-alignment
|
||||
depth_bias: bool = True # Enable per-camera depth bias pre-correction
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -67,6 +68,7 @@ class ICPMetrics:
|
||||
num_pairs_converged: int = 0
|
||||
num_cameras_optimized: int = 0
|
||||
num_disconnected: int = 0
|
||||
depth_biases: Dict[str, float] = field(default_factory=dict)
|
||||
per_pair_results: dict[tuple[str, str], ICPResult] = field(default_factory=dict)
|
||||
reference_camera: str = ""
|
||||
message: str = ""
|
||||
@@ -515,6 +517,272 @@ def optimize_pose_graph(
|
||||
return pose_graph
|
||||
|
||||
|
||||
def estimate_depth_biases(
|
||||
camera_data: Dict[str, Dict[str, Any]],
|
||||
extrinsics: Dict[str, Mat44],
|
||||
floor_planes: Dict[str, Any], # Dict[str, FloorPlane]
|
||||
config: ICPConfig,
|
||||
reference_serial: Optional[str] = None,
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Estimate per-camera depth offsets (beta, in meters) from overlap correspondences.
|
||||
|
||||
The model enforces pairwise constraints:
|
||||
beta_j - beta_i ~= b_ij
|
||||
where b_ij is the robust median of source-ray residuals over pair correspondences.
|
||||
"""
|
||||
from .ground_plane import unproject_depth_to_points
|
||||
|
||||
all_serials = sorted(list(camera_data.keys()))
|
||||
if not all_serials:
|
||||
return {}
|
||||
|
||||
if reference_serial is None or reference_serial not in all_serials:
|
||||
reference = all_serials[0]
|
||||
else:
|
||||
reference = reference_serial
|
||||
|
||||
# Default/fallback output: 0.0 for all cameras
|
||||
betas: Dict[str, float] = {serial: 0.0 for serial in all_serials}
|
||||
|
||||
# Auto-set overlap mode based on region, matching refine_with_icp behavior.
|
||||
overlap_mode = config.overlap_mode
|
||||
if config.region == "floor":
|
||||
overlap_mode = "xz"
|
||||
elif config.region in ["hybrid", "full"]:
|
||||
overlap_mode = "3d"
|
||||
|
||||
camera_points: Dict[str, PointsNC] = {}
|
||||
camera_rays_world: Dict[str, PointsNC] = {}
|
||||
camera_pcds: Dict[str, o3d.geometry.PointCloud] = {}
|
||||
|
||||
# 1) Build per-camera world points/rays in the configured scene region.
|
||||
for serial in all_serials:
|
||||
if serial not in extrinsics:
|
||||
continue
|
||||
|
||||
data = camera_data[serial]
|
||||
depth_map = data.get("depth")
|
||||
K = data.get("K")
|
||||
if depth_map is None or K is None:
|
||||
continue
|
||||
|
||||
points_cam = unproject_depth_to_points(depth_map, K, stride=4)
|
||||
if len(points_cam) < 100:
|
||||
continue
|
||||
|
||||
T_wc = extrinsics[serial]
|
||||
points_world = (points_cam @ T_wc[:3, :3].T) + T_wc[:3, 3]
|
||||
|
||||
scene_points = points_world
|
||||
if config.region != "full":
|
||||
plane = floor_planes.get(serial)
|
||||
if plane is None:
|
||||
continue
|
||||
|
||||
floor_y = -plane.d
|
||||
scene_points = extract_scene_points(
|
||||
points_world,
|
||||
floor_y,
|
||||
plane.normal,
|
||||
mode=config.region,
|
||||
band_height=config.band_height,
|
||||
)
|
||||
|
||||
if len(scene_points) < 100:
|
||||
continue
|
||||
|
||||
pcd = o3d.geometry.PointCloud()
|
||||
pcd.points = o3d.utility.Vector3dVector(scene_points)
|
||||
pcd = preprocess_point_cloud(pcd, config.voxel_size)
|
||||
if len(pcd.points) < 100:
|
||||
continue
|
||||
|
||||
pts = np.asarray(pcd.points)
|
||||
cam_pos = T_wc[:3, 3]
|
||||
rays_world = pts - cam_pos[None, :]
|
||||
ray_norm = np.linalg.norm(rays_world, axis=1)
|
||||
valid_mask = ray_norm > 1e-9
|
||||
if np.count_nonzero(valid_mask) < 100:
|
||||
continue
|
||||
|
||||
pts = pts[valid_mask]
|
||||
rays_world = rays_world[valid_mask] / ray_norm[valid_mask, None]
|
||||
|
||||
camera_points[serial] = pts
|
||||
camera_rays_world[serial] = rays_world
|
||||
|
||||
pcd_valid = o3d.geometry.PointCloud()
|
||||
pcd_valid.points = o3d.utility.Vector3dVector(pts)
|
||||
camera_pcds[serial] = pcd_valid
|
||||
|
||||
valid_serials = sorted(list(camera_points.keys()))
|
||||
if reference not in valid_serials:
|
||||
logger.warning(
|
||||
f"Reference camera {reference} has no valid data for bias estimation; all biases set to 0.0"
|
||||
)
|
||||
betas[reference] = 0.0
|
||||
logger.info(f"Estimated depth biases (m): {betas}")
|
||||
return betas
|
||||
|
||||
# 2) Build robust pairwise median constraints b_ij.
|
||||
max_corr_dist = config.voxel_size * config.max_correspondence_distance_factor
|
||||
max_corr_dist_sq = max_corr_dist * max_corr_dist
|
||||
|
||||
pair_constraints: List[Tuple[str, str, float, int]] = []
|
||||
|
||||
for i, s1 in enumerate(valid_serials):
|
||||
for j in range(i + 1, len(valid_serials)):
|
||||
s2 = valid_serials[j]
|
||||
|
||||
if overlap_mode == "3d":
|
||||
overlap = compute_overlap_3d(
|
||||
camera_points[s1], camera_points[s2], config.overlap_margin
|
||||
)
|
||||
unit = "m^3"
|
||||
else:
|
||||
overlap = compute_overlap_xz(
|
||||
camera_points[s1], camera_points[s2], config.overlap_margin
|
||||
)
|
||||
unit = "m^2"
|
||||
|
||||
if overlap < config.min_overlap_area:
|
||||
logger.debug(
|
||||
f"Bias pair ({s1}, {s2}) skipped: overlap {overlap:.2f} {unit} < {config.min_overlap_area:.2f} {unit}"
|
||||
)
|
||||
continue
|
||||
|
||||
src_pts = camera_points[s1]
|
||||
src_rays = camera_rays_world[s1]
|
||||
tgt_pts = camera_points[s2]
|
||||
|
||||
kdtree = o3d.geometry.KDTreeFlann(camera_pcds[s2])
|
||||
residuals: List[float] = []
|
||||
|
||||
for idx in range(len(src_pts)):
|
||||
query = src_pts[idx]
|
||||
k, nn_indices, nn_dist_sq = kdtree.search_knn_vector_3d(query, 1)
|
||||
if k < 1:
|
||||
continue
|
||||
if nn_dist_sq[0] > max_corr_dist_sq:
|
||||
continue
|
||||
|
||||
target_point = tgt_pts[nn_indices[0]]
|
||||
source_point = query
|
||||
source_ray_world = src_rays[idx]
|
||||
|
||||
residual = float(np.dot(target_point - source_point, source_ray_world))
|
||||
if np.isfinite(residual):
|
||||
residuals.append(residual)
|
||||
|
||||
if len(residuals) < 100:
|
||||
logger.debug(
|
||||
f"Bias pair ({s1}, {s2}) skipped: insufficient correspondences {len(residuals)} < 100"
|
||||
)
|
||||
continue
|
||||
|
||||
b_ij = float(np.median(np.asarray(residuals, dtype=np.float64)))
|
||||
pair_constraints.append((s1, s2, b_ij, len(residuals)))
|
||||
logger.debug(
|
||||
f"Bias pair ({s1}, {s2}): corr={len(residuals)}, median={b_ij:.4f} m"
|
||||
)
|
||||
|
||||
if not pair_constraints:
|
||||
logger.warning(
|
||||
"No valid pair constraints for depth bias estimation; returning zeros"
|
||||
)
|
||||
betas[reference] = 0.0
|
||||
logger.info(f"Estimated depth biases (m): {betas}")
|
||||
return betas
|
||||
|
||||
# 3) Keep only cameras connected to reference; disconnected cameras remain 0.0.
|
||||
adjacency: Dict[str, set[str]] = {s: set() for s in valid_serials}
|
||||
for s1, s2, _, _ in pair_constraints:
|
||||
adjacency[s1].add(s2)
|
||||
adjacency[s2].add(s1)
|
||||
|
||||
connected = {reference}
|
||||
queue = [reference]
|
||||
while queue:
|
||||
curr = queue.pop(0)
|
||||
for nbr in adjacency.get(curr, set()):
|
||||
if nbr not in connected:
|
||||
connected.add(nbr)
|
||||
queue.append(nbr)
|
||||
|
||||
# Unknowns are all connected nodes except the fixed-gauge reference.
|
||||
unknown_serials = sorted(list(connected - {reference}))
|
||||
if not unknown_serials:
|
||||
betas[reference] = 0.0
|
||||
logger.info(f"Estimated depth biases (m): {betas}")
|
||||
return betas
|
||||
|
||||
serial_to_col = {s: i for i, s in enumerate(unknown_serials)}
|
||||
rows: List[np.ndarray] = []
|
||||
rhs: List[float] = []
|
||||
weights: List[float] = []
|
||||
|
||||
for s1, s2, b_ij, n_corr in pair_constraints:
|
||||
if s1 not in connected or s2 not in connected:
|
||||
continue
|
||||
|
||||
row = np.zeros(len(unknown_serials), dtype=np.float64)
|
||||
if s1 != reference:
|
||||
row[serial_to_col[s1]] = -1.0
|
||||
if s2 != reference:
|
||||
row[serial_to_col[s2]] = 1.0
|
||||
|
||||
if not np.any(row):
|
||||
continue
|
||||
|
||||
rows.append(row)
|
||||
rhs.append(b_ij)
|
||||
# Mild confidence weighting by correspondence count.
|
||||
weights.append(float(np.sqrt(max(n_corr, 1))))
|
||||
|
||||
if not rows:
|
||||
betas[reference] = 0.0
|
||||
logger.info(f"Estimated depth biases (m): {betas}")
|
||||
return betas
|
||||
|
||||
A = np.vstack(rows)
|
||||
y = np.asarray(rhs, dtype=np.float64)
|
||||
w = np.asarray(weights, dtype=np.float64)
|
||||
|
||||
Aw = A * w[:, None]
|
||||
yw = y * w
|
||||
|
||||
x, *_ = np.linalg.lstsq(Aw, yw, rcond=None)
|
||||
|
||||
betas[reference] = 0.0
|
||||
for serial, col in serial_to_col.items():
|
||||
betas[serial] = float(x[col])
|
||||
|
||||
# 4) Cap implausible values.
|
||||
max_abs_bias = float(getattr(config, "max_abs_bias", 0.3))
|
||||
for serial in betas:
|
||||
if serial == reference:
|
||||
continue
|
||||
if abs(betas[serial]) > max_abs_bias:
|
||||
logger.warning(
|
||||
f"Depth bias for camera {serial} clipped from {betas[serial]:.4f} m "
|
||||
f"to {np.clip(betas[serial], -max_abs_bias, max_abs_bias):.4f} m"
|
||||
)
|
||||
betas[serial] = float(np.clip(betas[serial], -max_abs_bias, max_abs_bias))
|
||||
|
||||
# Gauge is fixed exactly.
|
||||
betas[reference] = 0.0
|
||||
|
||||
if len(connected) < len(valid_serials):
|
||||
disconnected_valid = sorted(list(set(valid_serials) - connected))
|
||||
logger.warning(
|
||||
f"Depth-bias estimation disconnected cameras (set to 0.0): {disconnected_valid}"
|
||||
)
|
||||
|
||||
logger.info(f"Estimated depth biases (m): {betas}")
|
||||
return betas
|
||||
|
||||
|
||||
def refine_with_icp(
|
||||
camera_data: Dict[str, Dict[str, Any]],
|
||||
extrinsics: Dict[str, Mat44],
|
||||
@@ -543,6 +811,18 @@ def refine_with_icp(
|
||||
elif config.region in ["hybrid", "full"]:
|
||||
config.overlap_mode = "3d"
|
||||
|
||||
if config.depth_bias:
|
||||
biases = estimate_depth_biases(
|
||||
camera_data,
|
||||
extrinsics,
|
||||
floor_planes,
|
||||
config,
|
||||
reference_serial=metrics.reference_camera,
|
||||
)
|
||||
else:
|
||||
biases = {}
|
||||
metrics.depth_biases = biases
|
||||
|
||||
for serial in serials:
|
||||
if serial not in floor_planes or serial not in extrinsics:
|
||||
continue
|
||||
@@ -550,7 +830,11 @@ def refine_with_icp(
|
||||
data = camera_data[serial]
|
||||
plane = floor_planes[serial]
|
||||
|
||||
points_cam = unproject_depth_to_points(data["depth"], data["K"], stride=4)
|
||||
depth_corrected = data["depth"].copy()
|
||||
depth_corrected += biases.get(serial, 0.0)
|
||||
depth_corrected[depth_corrected <= 0] = np.nan
|
||||
|
||||
points_cam = unproject_depth_to_points(depth_corrected, data["K"], stride=4)
|
||||
T_wc = extrinsics[serial]
|
||||
points_world = (points_cam @ T_wc[:3, :3].T) + T_wc[:3, 3]
|
||||
|
||||
@@ -787,9 +1071,12 @@ def refine_with_icp(
|
||||
continue
|
||||
|
||||
new_extrinsics[serial] = T_optimized
|
||||
metrics.num_cameras_optimized += 1
|
||||
if serial != metrics.reference_camera:
|
||||
metrics.num_cameras_optimized += 1
|
||||
|
||||
metrics.success = metrics.num_cameras_optimized > 0
|
||||
metrics.message = f"Optimized {metrics.num_cameras_optimized} cameras"
|
||||
metrics.message = (
|
||||
f"Optimized {metrics.num_cameras_optimized} cameras (excluding reference)"
|
||||
)
|
||||
|
||||
return new_extrinsics, metrics
|
||||
|
||||
Reference in New Issue
Block a user