refactor(cli): extract depth postprocess and add tests
- Extract apply_depth_verify_refine_postprocess() from main() for testability - Add test_depth_cli_postprocess.py using mocks to validate JSON and CSV behavior - Keeps CLI behavior unchanged
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
import click
|
import click
|
||||||
import cv2
|
import cv2
|
||||||
import json
|
import json
|
||||||
|
import csv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pyzed.sl as sl
|
import pyzed.sl as sl
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -21,6 +22,132 @@ from aruco.depth_verify import verify_extrinsics_with_depth
|
|||||||
from aruco.depth_refine import refine_extrinsics_with_depth
|
from aruco.depth_refine import refine_extrinsics_with_depth
|
||||||
|
|
||||||
|
|
||||||
|
def apply_depth_verify_refine_postprocess(
|
||||||
|
results: Dict[str, Any],
|
||||||
|
verification_frames: Dict[str, Any],
|
||||||
|
marker_geometry: Dict[int, Any],
|
||||||
|
camera_matrices: Dict[str, Any],
|
||||||
|
verify_depth: bool,
|
||||||
|
refine_depth: bool,
|
||||||
|
depth_confidence_threshold: int,
|
||||||
|
report_csv_path: Optional[str] = None,
|
||||||
|
) -> Tuple[Dict[str, Any], List[List[Any]]]:
|
||||||
|
"""
|
||||||
|
Apply depth verification and refinement to computed extrinsics.
|
||||||
|
Returns updated results and list of CSV rows.
|
||||||
|
"""
|
||||||
|
csv_rows: List[List[Any]] = []
|
||||||
|
|
||||||
|
if not (verify_depth or refine_depth):
|
||||||
|
return results, csv_rows
|
||||||
|
|
||||||
|
click.echo("\nRunning depth verification/refinement on computed extrinsics...")
|
||||||
|
|
||||||
|
for serial, vf in verification_frames.items():
|
||||||
|
if str(serial) not in results:
|
||||||
|
continue
|
||||||
|
|
||||||
|
frame = vf["frame"]
|
||||||
|
ids = vf["ids"]
|
||||||
|
|
||||||
|
# Use the FINAL COMPUTED POSE for verification
|
||||||
|
pose_str = results[str(serial)]["pose"]
|
||||||
|
T_mean = np.fromstring(pose_str, sep=" ").reshape(4, 4)
|
||||||
|
cam_matrix = camera_matrices[serial]
|
||||||
|
|
||||||
|
marker_corners_world = {
|
||||||
|
int(mid): marker_geometry[int(mid)]
|
||||||
|
for mid in ids.flatten()
|
||||||
|
if int(mid) in marker_geometry
|
||||||
|
}
|
||||||
|
|
||||||
|
if marker_corners_world and frame.depth_map is not None:
|
||||||
|
verify_res = verify_extrinsics_with_depth(
|
||||||
|
T_mean,
|
||||||
|
marker_corners_world,
|
||||||
|
frame.depth_map,
|
||||||
|
cam_matrix,
|
||||||
|
confidence_map=frame.confidence_map,
|
||||||
|
confidence_thresh=depth_confidence_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
results[str(serial)]["depth_verify"] = {
|
||||||
|
"rmse": verify_res.rmse,
|
||||||
|
"mean_abs": verify_res.mean_abs,
|
||||||
|
"median": verify_res.median,
|
||||||
|
"depth_normalized_rmse": verify_res.depth_normalized_rmse,
|
||||||
|
"n_valid": verify_res.n_valid,
|
||||||
|
"n_total": verify_res.n_total,
|
||||||
|
}
|
||||||
|
|
||||||
|
click.echo(
|
||||||
|
f"Camera {serial} verification: RMSE={verify_res.rmse:.3f}m, "
|
||||||
|
f"Valid={verify_res.n_valid}/{verify_res.n_total}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if refine_depth:
|
||||||
|
if verify_res.n_valid < 4:
|
||||||
|
click.echo(
|
||||||
|
f"Camera {serial}: Not enough valid depth points for refinement ({verify_res.n_valid}). Skipping."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
click.echo(f"Camera {serial}: Refining extrinsics with depth...")
|
||||||
|
T_refined, refine_stats = refine_extrinsics_with_depth(
|
||||||
|
T_mean,
|
||||||
|
marker_corners_world,
|
||||||
|
frame.depth_map,
|
||||||
|
cam_matrix,
|
||||||
|
)
|
||||||
|
|
||||||
|
verify_res_post = verify_extrinsics_with_depth(
|
||||||
|
T_refined,
|
||||||
|
marker_corners_world,
|
||||||
|
frame.depth_map,
|
||||||
|
cam_matrix,
|
||||||
|
confidence_map=frame.confidence_map,
|
||||||
|
confidence_thresh=depth_confidence_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
pose_str_refined = " ".join(f"{x:.6f}" for x in T_refined.flatten())
|
||||||
|
results[str(serial)]["pose"] = pose_str_refined
|
||||||
|
results[str(serial)]["refine_depth"] = refine_stats
|
||||||
|
results[str(serial)]["depth_verify_post"] = {
|
||||||
|
"rmse": verify_res_post.rmse,
|
||||||
|
"mean_abs": verify_res_post.mean_abs,
|
||||||
|
"median": verify_res_post.median,
|
||||||
|
"depth_normalized_rmse": verify_res_post.depth_normalized_rmse,
|
||||||
|
"n_valid": verify_res_post.n_valid,
|
||||||
|
"n_total": verify_res_post.n_total,
|
||||||
|
}
|
||||||
|
|
||||||
|
improvement = verify_res.rmse - verify_res_post.rmse
|
||||||
|
results[str(serial)]["refine_depth"]["improvement_rmse"] = (
|
||||||
|
improvement
|
||||||
|
)
|
||||||
|
|
||||||
|
click.echo(
|
||||||
|
f"Camera {serial} refined: RMSE={verify_res_post.rmse:.3f}m "
|
||||||
|
f"(Improved by {improvement:.3f}m). "
|
||||||
|
f"Delta Rot={refine_stats['delta_rotation_deg']:.2f}deg, "
|
||||||
|
f"Trans={refine_stats['delta_translation_norm_m']:.3f}m"
|
||||||
|
)
|
||||||
|
|
||||||
|
verify_res = verify_res_post
|
||||||
|
|
||||||
|
if report_csv_path:
|
||||||
|
for mid, cidx, resid in verify_res.residuals:
|
||||||
|
csv_rows.append([serial, mid, cidx, resid])
|
||||||
|
|
||||||
|
if report_csv_path and csv_rows:
|
||||||
|
with open(report_csv_path, "w", newline="") as f:
|
||||||
|
writer = csv.writer(f)
|
||||||
|
writer.writerow(["serial", "marker_id", "corner_idx", "residual"])
|
||||||
|
writer.writerows(csv_rows)
|
||||||
|
click.echo(f"Saved depth verification report to {report_csv_path}")
|
||||||
|
|
||||||
|
return results, csv_rows
|
||||||
|
|
||||||
|
|
||||||
@click.command()
|
@click.command()
|
||||||
@click.option("--svo", "-s", multiple=True, required=False, help="Path to SVO files.")
|
@click.option("--svo", "-s", multiple=True, required=False, help="Path to SVO files.")
|
||||||
@click.option("--markers", "-m", required=True, help="Path to markers parquet file.")
|
@click.option("--markers", "-m", required=True, help="Path to markers parquet file.")
|
||||||
@@ -260,119 +387,16 @@ def main(
|
|||||||
return
|
return
|
||||||
|
|
||||||
# 4. Run Depth Verification if requested
|
# 4. Run Depth Verification if requested
|
||||||
csv_rows: List[List[Any]] = []
|
apply_depth_verify_refine_postprocess(
|
||||||
if verify_depth or refine_depth:
|
results,
|
||||||
click.echo("\nRunning depth verification/refinement on computed extrinsics...")
|
verification_frames,
|
||||||
for serial, acc in accumulators.items():
|
marker_geometry,
|
||||||
if serial not in verification_frames or str(serial) not in results:
|
camera_matrices,
|
||||||
continue
|
verify_depth,
|
||||||
|
refine_depth,
|
||||||
# Retrieve stored frame data
|
depth_confidence_threshold,
|
||||||
vf = verification_frames[serial]
|
report_csv,
|
||||||
frame = vf["frame"]
|
)
|
||||||
ids = vf["ids"]
|
|
||||||
|
|
||||||
# Use the FINAL COMPUTED POSE for verification
|
|
||||||
pose_str = results[str(serial)]["pose"]
|
|
||||||
T_mean = np.fromstring(pose_str, sep=" ").reshape(4, 4)
|
|
||||||
cam_matrix = camera_matrices[serial]
|
|
||||||
|
|
||||||
marker_corners_world = {
|
|
||||||
int(mid): marker_geometry[int(mid)]
|
|
||||||
for mid in ids.flatten()
|
|
||||||
if int(mid) in marker_geometry
|
|
||||||
}
|
|
||||||
|
|
||||||
if marker_corners_world and frame.depth_map is not None:
|
|
||||||
verify_res = verify_extrinsics_with_depth(
|
|
||||||
T_mean,
|
|
||||||
marker_corners_world,
|
|
||||||
frame.depth_map,
|
|
||||||
cam_matrix,
|
|
||||||
confidence_map=frame.confidence_map,
|
|
||||||
confidence_thresh=depth_confidence_threshold,
|
|
||||||
)
|
|
||||||
|
|
||||||
results[str(serial)]["depth_verify"] = {
|
|
||||||
"rmse": verify_res.rmse,
|
|
||||||
"mean_abs": verify_res.mean_abs,
|
|
||||||
"median": verify_res.median,
|
|
||||||
"depth_normalized_rmse": verify_res.depth_normalized_rmse,
|
|
||||||
"n_valid": verify_res.n_valid,
|
|
||||||
"n_total": verify_res.n_total,
|
|
||||||
}
|
|
||||||
|
|
||||||
click.echo(
|
|
||||||
f"Camera {serial} verification: RMSE={verify_res.rmse:.3f}m, "
|
|
||||||
f"Valid={verify_res.n_valid}/{verify_res.n_total}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if refine_depth:
|
|
||||||
if verify_res.n_valid < 4:
|
|
||||||
click.echo(
|
|
||||||
f"Camera {serial}: Not enough valid depth points for refinement ({verify_res.n_valid}). Skipping."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
click.echo(
|
|
||||||
f"Camera {serial}: Refining extrinsics with depth..."
|
|
||||||
)
|
|
||||||
T_refined, refine_stats = refine_extrinsics_with_depth(
|
|
||||||
T_mean,
|
|
||||||
marker_corners_world,
|
|
||||||
frame.depth_map,
|
|
||||||
cam_matrix,
|
|
||||||
)
|
|
||||||
|
|
||||||
verify_res_post = verify_extrinsics_with_depth(
|
|
||||||
T_refined,
|
|
||||||
marker_corners_world,
|
|
||||||
frame.depth_map,
|
|
||||||
cam_matrix,
|
|
||||||
confidence_map=frame.confidence_map,
|
|
||||||
confidence_thresh=depth_confidence_threshold,
|
|
||||||
)
|
|
||||||
|
|
||||||
pose_str_refined = " ".join(
|
|
||||||
f"{x:.6f}" for x in T_refined.flatten()
|
|
||||||
)
|
|
||||||
results[str(serial)]["pose"] = pose_str_refined
|
|
||||||
results[str(serial)]["refine_depth"] = refine_stats
|
|
||||||
results[str(serial)]["depth_verify_post"] = {
|
|
||||||
"rmse": verify_res_post.rmse,
|
|
||||||
"mean_abs": verify_res_post.mean_abs,
|
|
||||||
"median": verify_res_post.median,
|
|
||||||
"depth_normalized_rmse": verify_res_post.depth_normalized_rmse,
|
|
||||||
"n_valid": verify_res_post.n_valid,
|
|
||||||
"n_total": verify_res_post.n_total,
|
|
||||||
}
|
|
||||||
|
|
||||||
improvement = verify_res.rmse - verify_res_post.rmse
|
|
||||||
results[str(serial)]["refine_depth"]["improvement_rmse"] = (
|
|
||||||
improvement
|
|
||||||
)
|
|
||||||
|
|
||||||
click.echo(
|
|
||||||
f"Camera {serial} refined: RMSE={verify_res_post.rmse:.3f}m "
|
|
||||||
f"(Improved by {improvement:.3f}m). "
|
|
||||||
f"Delta Rot={refine_stats['delta_rotation_deg']:.2f}deg, "
|
|
||||||
f"Trans={refine_stats['delta_translation_norm_m']:.3f}m"
|
|
||||||
)
|
|
||||||
|
|
||||||
verify_res = verify_res_post
|
|
||||||
|
|
||||||
if report_csv:
|
|
||||||
for mid, cidx, resid in verify_res.residuals:
|
|
||||||
csv_rows.append([serial, mid, cidx, resid])
|
|
||||||
|
|
||||||
# 5. Save CSV Report
|
|
||||||
if report_csv and csv_rows:
|
|
||||||
import csv
|
|
||||||
|
|
||||||
with open(report_csv, "w", newline="") as f:
|
|
||||||
writer = csv.writer(f)
|
|
||||||
writer.writerow(["serial", "marker_id", "corner_idx", "residual"])
|
|
||||||
writer.writerows(csv_rows)
|
|
||||||
click.echo(f"Saved depth verification report to {report_csv}")
|
|
||||||
|
|
||||||
# 6. Save to JSON
|
# 6. Save to JSON
|
||||||
with open(output, "w") as f:
|
with open(output, "w") as f:
|
||||||
|
|||||||
@@ -0,0 +1,184 @@
|
|||||||
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add py_workspace to path so we can import calibrate_extrinsics
|
||||||
|
sys.path.append(str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
|
# We will import the function after we create it, or we can import the module and patch it
|
||||||
|
# For now, let's assume we will add the function to calibrate_extrinsics.py
|
||||||
|
# Since the file exists but the function doesn't, we can't import it yet.
|
||||||
|
# But for TDD, I will write the test assuming the function exists in the module.
|
||||||
|
# I'll use a dynamic import or just import the module and access the function dynamically if needed,
|
||||||
|
# but standard import is better. I'll write the test file, but I won't run it until I refactor the code.
|
||||||
|
|
||||||
|
from calibrate_extrinsics import apply_depth_verify_refine_postprocess
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_dependencies():
|
||||||
|
with (
|
||||||
|
patch("calibrate_extrinsics.verify_extrinsics_with_depth") as mock_verify,
|
||||||
|
patch("calibrate_extrinsics.refine_extrinsics_with_depth") as mock_refine,
|
||||||
|
patch("calibrate_extrinsics.click.echo") as mock_echo,
|
||||||
|
):
|
||||||
|
# Setup mock return values
|
||||||
|
mock_verify_res = MagicMock()
|
||||||
|
mock_verify_res.rmse = 0.05
|
||||||
|
mock_verify_res.mean_abs = 0.04
|
||||||
|
mock_verify_res.median = 0.03
|
||||||
|
mock_verify_res.depth_normalized_rmse = 0.02
|
||||||
|
mock_verify_res.n_valid = 100
|
||||||
|
mock_verify_res.n_total = 120
|
||||||
|
mock_verify_res.residuals = [(1, 0, 0.01), (1, 1, 0.02)]
|
||||||
|
mock_verify.return_value = mock_verify_res
|
||||||
|
|
||||||
|
mock_refine_res_stats = {
|
||||||
|
"delta_rotation_deg": 1.0,
|
||||||
|
"delta_translation_norm_m": 0.1,
|
||||||
|
}
|
||||||
|
# refine returns (new_pose_matrix, stats)
|
||||||
|
mock_refine.return_value = (np.eye(4), mock_refine_res_stats)
|
||||||
|
|
||||||
|
yield mock_verify, mock_refine, mock_echo
|
||||||
|
|
||||||
|
|
||||||
|
def test_verify_only(mock_dependencies, tmp_path):
|
||||||
|
mock_verify, mock_refine, _ = mock_dependencies
|
||||||
|
|
||||||
|
# Setup inputs
|
||||||
|
serial = "123456"
|
||||||
|
results = {
|
||||||
|
serial: {
|
||||||
|
"pose": "1 0 0 0 0 1 0 0 0 0 1 0 0 0 0 1", # Identity matrix flattened
|
||||||
|
"stats": {},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
verification_frames = {
|
||||||
|
serial: {
|
||||||
|
"frame": MagicMock(
|
||||||
|
depth_map=np.zeros((10, 10)), confidence_map=np.zeros((10, 10))
|
||||||
|
),
|
||||||
|
"ids": np.array([[1]]),
|
||||||
|
"corners": np.zeros((1, 4, 2)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
marker_geometry = {1: np.zeros((4, 3))}
|
||||||
|
camera_matrices = {serial: np.eye(3)}
|
||||||
|
|
||||||
|
updated_results, csv_rows = apply_depth_verify_refine_postprocess(
|
||||||
|
results=results,
|
||||||
|
verification_frames=verification_frames,
|
||||||
|
marker_geometry=marker_geometry,
|
||||||
|
camera_matrices=camera_matrices,
|
||||||
|
verify_depth=True,
|
||||||
|
refine_depth=False,
|
||||||
|
depth_confidence_threshold=50,
|
||||||
|
report_csv_path=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "depth_verify" in updated_results[serial]
|
||||||
|
assert updated_results[serial]["depth_verify"]["rmse"] == 0.05
|
||||||
|
assert "refine_depth" not in updated_results[serial]
|
||||||
|
assert (
|
||||||
|
len(csv_rows) == 0
|
||||||
|
) # No CSV path provided, so no rows returned for writing (or empty list)
|
||||||
|
|
||||||
|
mock_verify.assert_called_once()
|
||||||
|
mock_refine.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_refine_depth(mock_dependencies):
|
||||||
|
mock_verify, mock_refine, _ = mock_dependencies
|
||||||
|
|
||||||
|
# Setup inputs
|
||||||
|
serial = "123456"
|
||||||
|
results = {serial: {"pose": "1 0 0 0 0 1 0 0 0 0 1 0 0 0 0 1", "stats": {}}}
|
||||||
|
verification_frames = {
|
||||||
|
serial: {
|
||||||
|
"frame": MagicMock(
|
||||||
|
depth_map=np.zeros((10, 10)), confidence_map=np.zeros((10, 10))
|
||||||
|
),
|
||||||
|
"ids": np.array([[1]]),
|
||||||
|
"corners": np.zeros((1, 4, 2)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
marker_geometry = {1: np.zeros((4, 3))}
|
||||||
|
camera_matrices = {serial: np.eye(3)}
|
||||||
|
|
||||||
|
# Mock verify to return different values for pre and post
|
||||||
|
# First call (pre-refine)
|
||||||
|
res_pre = MagicMock()
|
||||||
|
res_pre.rmse = 0.1
|
||||||
|
res_pre.n_valid = 100
|
||||||
|
res_pre.residuals = []
|
||||||
|
|
||||||
|
# Second call (post-refine)
|
||||||
|
res_post = MagicMock()
|
||||||
|
res_post.rmse = 0.05
|
||||||
|
res_post.n_valid = 100
|
||||||
|
res_post.residuals = []
|
||||||
|
|
||||||
|
mock_verify.side_effect = [res_pre, res_post]
|
||||||
|
|
||||||
|
updated_results, _ = apply_depth_verify_refine_postprocess(
|
||||||
|
results=results,
|
||||||
|
verification_frames=verification_frames,
|
||||||
|
marker_geometry=marker_geometry,
|
||||||
|
camera_matrices=camera_matrices,
|
||||||
|
verify_depth=False, # refine implies verify usually, but let's check logic
|
||||||
|
refine_depth=True,
|
||||||
|
depth_confidence_threshold=50,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "refine_depth" in updated_results[serial]
|
||||||
|
assert "depth_verify_post" in updated_results[serial]
|
||||||
|
assert (
|
||||||
|
updated_results[serial]["refine_depth"]["improvement_rmse"] == 0.05
|
||||||
|
) # 0.1 - 0.05
|
||||||
|
|
||||||
|
assert mock_verify.call_count == 2
|
||||||
|
mock_refine.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_csv_output(mock_dependencies, tmp_path):
|
||||||
|
mock_verify, _, _ = mock_dependencies
|
||||||
|
|
||||||
|
csv_path = tmp_path / "report.csv"
|
||||||
|
|
||||||
|
serial = "123456"
|
||||||
|
results = {serial: {"pose": "1 0 0 0 0 1 0 0 0 0 1 0 0 0 0 1", "stats": {}}}
|
||||||
|
verification_frames = {
|
||||||
|
serial: {
|
||||||
|
"frame": MagicMock(
|
||||||
|
depth_map=np.zeros((10, 10)), confidence_map=np.zeros((10, 10))
|
||||||
|
),
|
||||||
|
"ids": np.array([[1]]),
|
||||||
|
"corners": np.zeros((1, 4, 2)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
marker_geometry = {1: np.zeros((4, 3))}
|
||||||
|
camera_matrices = {serial: np.eye(3)}
|
||||||
|
|
||||||
|
updated_results, csv_rows = apply_depth_verify_refine_postprocess(
|
||||||
|
results=results,
|
||||||
|
verification_frames=verification_frames,
|
||||||
|
marker_geometry=marker_geometry,
|
||||||
|
camera_matrices=camera_matrices,
|
||||||
|
verify_depth=True,
|
||||||
|
refine_depth=False,
|
||||||
|
depth_confidence_threshold=50,
|
||||||
|
report_csv_path=str(csv_path),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(csv_rows) == 2 # From mock_verify_res.residuals
|
||||||
|
assert csv_rows[0] == [serial, 1, 0, 0.01]
|
||||||
|
|
||||||
|
# Verify file content
|
||||||
|
assert csv_path.exists()
|
||||||
|
content = csv_path.read_text().splitlines()
|
||||||
|
assert len(content) == 3 # Header + 2 rows
|
||||||
|
assert content[0] == "serial,marker_id,corner_idx,residual"
|
||||||
|
assert content[1] == f"{serial},1,0,0.01"
|
||||||
Reference in New Issue
Block a user