refactor(cli): extract depth postprocess and add tests

- Extract apply_depth_verify_refine_postprocess() from main() for testability
- Add test_depth_cli_postprocess.py using mocks to validate JSON and CSV behavior
- Keeps CLI behavior unchanged
This commit is contained in:
2026-02-05 05:07:48 +00:00
parent 15337d2d15
commit 553cc457f0
2 changed files with 321 additions and 113 deletions
+136 -112
View File
@@ -1,6 +1,7 @@
import click import click
import cv2 import cv2
import json import json
import csv
import numpy as np import numpy as np
import pyzed.sl as sl import pyzed.sl as sl
from pathlib import Path from pathlib import Path
@@ -21,6 +22,132 @@ from aruco.depth_verify import verify_extrinsics_with_depth
from aruco.depth_refine import refine_extrinsics_with_depth from aruco.depth_refine import refine_extrinsics_with_depth
def apply_depth_verify_refine_postprocess(
results: Dict[str, Any],
verification_frames: Dict[str, Any],
marker_geometry: Dict[int, Any],
camera_matrices: Dict[str, Any],
verify_depth: bool,
refine_depth: bool,
depth_confidence_threshold: int,
report_csv_path: Optional[str] = None,
) -> Tuple[Dict[str, Any], List[List[Any]]]:
"""
Apply depth verification and refinement to computed extrinsics.
Returns updated results and list of CSV rows.
"""
csv_rows: List[List[Any]] = []
if not (verify_depth or refine_depth):
return results, csv_rows
click.echo("\nRunning depth verification/refinement on computed extrinsics...")
for serial, vf in verification_frames.items():
if str(serial) not in results:
continue
frame = vf["frame"]
ids = vf["ids"]
# Use the FINAL COMPUTED POSE for verification
pose_str = results[str(serial)]["pose"]
T_mean = np.fromstring(pose_str, sep=" ").reshape(4, 4)
cam_matrix = camera_matrices[serial]
marker_corners_world = {
int(mid): marker_geometry[int(mid)]
for mid in ids.flatten()
if int(mid) in marker_geometry
}
if marker_corners_world and frame.depth_map is not None:
verify_res = verify_extrinsics_with_depth(
T_mean,
marker_corners_world,
frame.depth_map,
cam_matrix,
confidence_map=frame.confidence_map,
confidence_thresh=depth_confidence_threshold,
)
results[str(serial)]["depth_verify"] = {
"rmse": verify_res.rmse,
"mean_abs": verify_res.mean_abs,
"median": verify_res.median,
"depth_normalized_rmse": verify_res.depth_normalized_rmse,
"n_valid": verify_res.n_valid,
"n_total": verify_res.n_total,
}
click.echo(
f"Camera {serial} verification: RMSE={verify_res.rmse:.3f}m, "
f"Valid={verify_res.n_valid}/{verify_res.n_total}"
)
if refine_depth:
if verify_res.n_valid < 4:
click.echo(
f"Camera {serial}: Not enough valid depth points for refinement ({verify_res.n_valid}). Skipping."
)
else:
click.echo(f"Camera {serial}: Refining extrinsics with depth...")
T_refined, refine_stats = refine_extrinsics_with_depth(
T_mean,
marker_corners_world,
frame.depth_map,
cam_matrix,
)
verify_res_post = verify_extrinsics_with_depth(
T_refined,
marker_corners_world,
frame.depth_map,
cam_matrix,
confidence_map=frame.confidence_map,
confidence_thresh=depth_confidence_threshold,
)
pose_str_refined = " ".join(f"{x:.6f}" for x in T_refined.flatten())
results[str(serial)]["pose"] = pose_str_refined
results[str(serial)]["refine_depth"] = refine_stats
results[str(serial)]["depth_verify_post"] = {
"rmse": verify_res_post.rmse,
"mean_abs": verify_res_post.mean_abs,
"median": verify_res_post.median,
"depth_normalized_rmse": verify_res_post.depth_normalized_rmse,
"n_valid": verify_res_post.n_valid,
"n_total": verify_res_post.n_total,
}
improvement = verify_res.rmse - verify_res_post.rmse
results[str(serial)]["refine_depth"]["improvement_rmse"] = (
improvement
)
click.echo(
f"Camera {serial} refined: RMSE={verify_res_post.rmse:.3f}m "
f"(Improved by {improvement:.3f}m). "
f"Delta Rot={refine_stats['delta_rotation_deg']:.2f}deg, "
f"Trans={refine_stats['delta_translation_norm_m']:.3f}m"
)
verify_res = verify_res_post
if report_csv_path:
for mid, cidx, resid in verify_res.residuals:
csv_rows.append([serial, mid, cidx, resid])
if report_csv_path and csv_rows:
with open(report_csv_path, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["serial", "marker_id", "corner_idx", "residual"])
writer.writerows(csv_rows)
click.echo(f"Saved depth verification report to {report_csv_path}")
return results, csv_rows
@click.command() @click.command()
@click.option("--svo", "-s", multiple=True, required=False, help="Path to SVO files.") @click.option("--svo", "-s", multiple=True, required=False, help="Path to SVO files.")
@click.option("--markers", "-m", required=True, help="Path to markers parquet file.") @click.option("--markers", "-m", required=True, help="Path to markers parquet file.")
@@ -260,120 +387,17 @@ def main(
return return
# 4. Run Depth Verification if requested # 4. Run Depth Verification if requested
csv_rows: List[List[Any]] = [] apply_depth_verify_refine_postprocess(
if verify_depth or refine_depth: results,
click.echo("\nRunning depth verification/refinement on computed extrinsics...") verification_frames,
for serial, acc in accumulators.items(): marker_geometry,
if serial not in verification_frames or str(serial) not in results: camera_matrices,
continue verify_depth,
refine_depth,
# Retrieve stored frame data depth_confidence_threshold,
vf = verification_frames[serial] report_csv,
frame = vf["frame"]
ids = vf["ids"]
# Use the FINAL COMPUTED POSE for verification
pose_str = results[str(serial)]["pose"]
T_mean = np.fromstring(pose_str, sep=" ").reshape(4, 4)
cam_matrix = camera_matrices[serial]
marker_corners_world = {
int(mid): marker_geometry[int(mid)]
for mid in ids.flatten()
if int(mid) in marker_geometry
}
if marker_corners_world and frame.depth_map is not None:
verify_res = verify_extrinsics_with_depth(
T_mean,
marker_corners_world,
frame.depth_map,
cam_matrix,
confidence_map=frame.confidence_map,
confidence_thresh=depth_confidence_threshold,
) )
results[str(serial)]["depth_verify"] = {
"rmse": verify_res.rmse,
"mean_abs": verify_res.mean_abs,
"median": verify_res.median,
"depth_normalized_rmse": verify_res.depth_normalized_rmse,
"n_valid": verify_res.n_valid,
"n_total": verify_res.n_total,
}
click.echo(
f"Camera {serial} verification: RMSE={verify_res.rmse:.3f}m, "
f"Valid={verify_res.n_valid}/{verify_res.n_total}"
)
if refine_depth:
if verify_res.n_valid < 4:
click.echo(
f"Camera {serial}: Not enough valid depth points for refinement ({verify_res.n_valid}). Skipping."
)
else:
click.echo(
f"Camera {serial}: Refining extrinsics with depth..."
)
T_refined, refine_stats = refine_extrinsics_with_depth(
T_mean,
marker_corners_world,
frame.depth_map,
cam_matrix,
)
verify_res_post = verify_extrinsics_with_depth(
T_refined,
marker_corners_world,
frame.depth_map,
cam_matrix,
confidence_map=frame.confidence_map,
confidence_thresh=depth_confidence_threshold,
)
pose_str_refined = " ".join(
f"{x:.6f}" for x in T_refined.flatten()
)
results[str(serial)]["pose"] = pose_str_refined
results[str(serial)]["refine_depth"] = refine_stats
results[str(serial)]["depth_verify_post"] = {
"rmse": verify_res_post.rmse,
"mean_abs": verify_res_post.mean_abs,
"median": verify_res_post.median,
"depth_normalized_rmse": verify_res_post.depth_normalized_rmse,
"n_valid": verify_res_post.n_valid,
"n_total": verify_res_post.n_total,
}
improvement = verify_res.rmse - verify_res_post.rmse
results[str(serial)]["refine_depth"]["improvement_rmse"] = (
improvement
)
click.echo(
f"Camera {serial} refined: RMSE={verify_res_post.rmse:.3f}m "
f"(Improved by {improvement:.3f}m). "
f"Delta Rot={refine_stats['delta_rotation_deg']:.2f}deg, "
f"Trans={refine_stats['delta_translation_norm_m']:.3f}m"
)
verify_res = verify_res_post
if report_csv:
for mid, cidx, resid in verify_res.residuals:
csv_rows.append([serial, mid, cidx, resid])
# 5. Save CSV Report
if report_csv and csv_rows:
import csv
with open(report_csv, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["serial", "marker_id", "corner_idx", "residual"])
writer.writerows(csv_rows)
click.echo(f"Saved depth verification report to {report_csv}")
# 6. Save to JSON # 6. Save to JSON
with open(output, "w") as f: with open(output, "w") as f:
json.dump(results, f, indent=4, sort_keys=True) json.dump(results, f, indent=4, sort_keys=True)
@@ -0,0 +1,184 @@
import pytest
import numpy as np
from unittest.mock import MagicMock, patch
import sys
from pathlib import Path
# Add py_workspace to path so we can import calibrate_extrinsics
sys.path.append(str(Path(__file__).parent.parent))
# We will import the function after we create it, or we can import the module and patch it
# For now, let's assume we will add the function to calibrate_extrinsics.py
# Since the file exists but the function doesn't, we can't import it yet.
# But for TDD, I will write the test assuming the function exists in the module.
# I'll use a dynamic import or just import the module and access the function dynamically if needed,
# but standard import is better. I'll write the test file, but I won't run it until I refactor the code.
from calibrate_extrinsics import apply_depth_verify_refine_postprocess
@pytest.fixture
def mock_dependencies():
with (
patch("calibrate_extrinsics.verify_extrinsics_with_depth") as mock_verify,
patch("calibrate_extrinsics.refine_extrinsics_with_depth") as mock_refine,
patch("calibrate_extrinsics.click.echo") as mock_echo,
):
# Setup mock return values
mock_verify_res = MagicMock()
mock_verify_res.rmse = 0.05
mock_verify_res.mean_abs = 0.04
mock_verify_res.median = 0.03
mock_verify_res.depth_normalized_rmse = 0.02
mock_verify_res.n_valid = 100
mock_verify_res.n_total = 120
mock_verify_res.residuals = [(1, 0, 0.01), (1, 1, 0.02)]
mock_verify.return_value = mock_verify_res
mock_refine_res_stats = {
"delta_rotation_deg": 1.0,
"delta_translation_norm_m": 0.1,
}
# refine returns (new_pose_matrix, stats)
mock_refine.return_value = (np.eye(4), mock_refine_res_stats)
yield mock_verify, mock_refine, mock_echo
def test_verify_only(mock_dependencies, tmp_path):
mock_verify, mock_refine, _ = mock_dependencies
# Setup inputs
serial = "123456"
results = {
serial: {
"pose": "1 0 0 0 0 1 0 0 0 0 1 0 0 0 0 1", # Identity matrix flattened
"stats": {},
}
}
verification_frames = {
serial: {
"frame": MagicMock(
depth_map=np.zeros((10, 10)), confidence_map=np.zeros((10, 10))
),
"ids": np.array([[1]]),
"corners": np.zeros((1, 4, 2)),
}
}
marker_geometry = {1: np.zeros((4, 3))}
camera_matrices = {serial: np.eye(3)}
updated_results, csv_rows = apply_depth_verify_refine_postprocess(
results=results,
verification_frames=verification_frames,
marker_geometry=marker_geometry,
camera_matrices=camera_matrices,
verify_depth=True,
refine_depth=False,
depth_confidence_threshold=50,
report_csv_path=None,
)
assert "depth_verify" in updated_results[serial]
assert updated_results[serial]["depth_verify"]["rmse"] == 0.05
assert "refine_depth" not in updated_results[serial]
assert (
len(csv_rows) == 0
) # No CSV path provided, so no rows returned for writing (or empty list)
mock_verify.assert_called_once()
mock_refine.assert_not_called()
def test_refine_depth(mock_dependencies):
mock_verify, mock_refine, _ = mock_dependencies
# Setup inputs
serial = "123456"
results = {serial: {"pose": "1 0 0 0 0 1 0 0 0 0 1 0 0 0 0 1", "stats": {}}}
verification_frames = {
serial: {
"frame": MagicMock(
depth_map=np.zeros((10, 10)), confidence_map=np.zeros((10, 10))
),
"ids": np.array([[1]]),
"corners": np.zeros((1, 4, 2)),
}
}
marker_geometry = {1: np.zeros((4, 3))}
camera_matrices = {serial: np.eye(3)}
# Mock verify to return different values for pre and post
# First call (pre-refine)
res_pre = MagicMock()
res_pre.rmse = 0.1
res_pre.n_valid = 100
res_pre.residuals = []
# Second call (post-refine)
res_post = MagicMock()
res_post.rmse = 0.05
res_post.n_valid = 100
res_post.residuals = []
mock_verify.side_effect = [res_pre, res_post]
updated_results, _ = apply_depth_verify_refine_postprocess(
results=results,
verification_frames=verification_frames,
marker_geometry=marker_geometry,
camera_matrices=camera_matrices,
verify_depth=False, # refine implies verify usually, but let's check logic
refine_depth=True,
depth_confidence_threshold=50,
)
assert "refine_depth" in updated_results[serial]
assert "depth_verify_post" in updated_results[serial]
assert (
updated_results[serial]["refine_depth"]["improvement_rmse"] == 0.05
) # 0.1 - 0.05
assert mock_verify.call_count == 2
mock_refine.assert_called_once()
def test_csv_output(mock_dependencies, tmp_path):
mock_verify, _, _ = mock_dependencies
csv_path = tmp_path / "report.csv"
serial = "123456"
results = {serial: {"pose": "1 0 0 0 0 1 0 0 0 0 1 0 0 0 0 1", "stats": {}}}
verification_frames = {
serial: {
"frame": MagicMock(
depth_map=np.zeros((10, 10)), confidence_map=np.zeros((10, 10))
),
"ids": np.array([[1]]),
"corners": np.zeros((1, 4, 2)),
}
}
marker_geometry = {1: np.zeros((4, 3))}
camera_matrices = {serial: np.eye(3)}
updated_results, csv_rows = apply_depth_verify_refine_postprocess(
results=results,
verification_frames=verification_frames,
marker_geometry=marker_geometry,
camera_matrices=camera_matrices,
verify_depth=True,
refine_depth=False,
depth_confidence_threshold=50,
report_csv_path=str(csv_path),
)
assert len(csv_rows) == 2 # From mock_verify_res.residuals
assert csv_rows[0] == [serial, 1, 0, 0.01]
# Verify file content
assert csv_path.exists()
content = csv_path.read_text().splitlines()
assert len(content) == 3 # Header + 2 rows
assert content[0] == "serial,marker_id,corner_idx,residual"
assert content[1] == f"{serial},1,0,0.01"