Files
zed-playground/py_workspace/tests/test_depth_save.py
T
crosstyan 83a74d293b feat: add HDF5 depth map persistence module
- Implement aruco/depth_save.py with save_depth_data/load_depth_data
- Add tests/test_depth_save.py covering roundtrip and edge cases
- Ensure type safety with basedpyright
- Support compression and metadata handling
2026-02-09 07:18:00 +00:00

133 lines
4.6 KiB
Python

import numpy as np
import pytest
import h5py
import json
from pathlib import Path
from aruco.depth_save import save_depth_data, load_depth_data
@pytest.fixture
def sample_camera_data():
"""Create sample camera data for testing."""
return {
"12345678": {
"intrinsics": np.array([[1000, 0, 640], [0, 1000, 360], [0, 0, 1]]),
"resolution": (1280, 720),
"pooled_depth": np.random.rand(720, 1280).astype(np.float32),
"pooled_confidence": np.random.randint(0, 100, (720, 1280)).astype(
np.uint8
),
"pool_metadata": {
"pool_size_requested": 5,
"pool_size_actual": 3,
"pooled": True,
"pooled_rmse": 0.05,
},
"raw_frames": [
{
"frame_index": 10,
"score": 95.5,
"depth_map": np.random.rand(720, 1280).astype(np.float32),
"confidence_map": np.random.randint(0, 100, (720, 1280)).astype(
np.uint8
),
}
],
},
"87654321": {
"intrinsics": np.array([[1000, 0, 640], [0, 1000, 360], [0, 0, 1]]),
"resolution": (1280, 720),
"pooled_depth": np.random.rand(720, 1280).astype(np.float32),
# No confidence map for this camera
"pool_metadata": None,
"raw_frames": [],
},
}
def test_save_depth_data_creates_file(tmp_path, sample_camera_data):
"""Test that save_depth_data creates a valid HDF5 file."""
output_path = tmp_path / "test_depth.h5"
save_depth_data(output_path, sample_camera_data)
assert output_path.exists()
assert h5py.is_hdf5(output_path)
def test_save_depth_data_metadata(tmp_path, sample_camera_data):
"""Test that global metadata is saved correctly."""
output_path = tmp_path / "test_depth.h5"
save_depth_data(output_path, sample_camera_data)
with h5py.File(output_path, "r") as f:
assert "meta" in f
assert f["meta"].attrs["schema_version"] == 1
assert f["meta"].attrs["units"] == "meters"
assert f["meta"].attrs["coordinate_frame"] == "world_from_cam"
assert "created_at" in f["meta"].attrs
def test_save_load_roundtrip(tmp_path, sample_camera_data):
"""Test that data can be saved and loaded back accurately."""
output_path = tmp_path / "test_depth.h5"
save_depth_data(output_path, sample_camera_data)
loaded_data = load_depth_data(output_path)
assert set(loaded_data.keys()) == set(sample_camera_data.keys())
for serial in sample_camera_data:
original = sample_camera_data[serial]
loaded = loaded_data[serial]
np.testing.assert_array_equal(loaded["intrinsics"], original["intrinsics"])
assert tuple(loaded["resolution"]) == tuple(original["resolution"])
np.testing.assert_allclose(loaded["pooled_depth"], original["pooled_depth"])
if "pooled_confidence" in original:
np.testing.assert_array_equal(
loaded["pooled_confidence"], original["pooled_confidence"]
)
else:
assert "pooled_confidence" not in loaded
if original["pool_metadata"]:
assert loaded["pool_metadata"] == original["pool_metadata"]
else:
assert loaded["pool_metadata"] is None
def test_save_raw_frames(tmp_path, sample_camera_data):
"""Test that raw frames are saved and loaded correctly."""
output_path = tmp_path / "test_depth.h5"
save_depth_data(output_path, sample_camera_data)
loaded_data = load_depth_data(output_path)
# Check camera with raw frames
serial = "12345678"
original_frames = sample_camera_data[serial]["raw_frames"]
loaded_frames = loaded_data[serial]["raw_frames"]
assert len(loaded_frames) == len(original_frames)
for orig, load in zip(original_frames, loaded_frames):
assert load["frame_index"] == orig["frame_index"]
assert load["score"] == orig["score"]
np.testing.assert_allclose(load["depth_map"], orig["depth_map"])
np.testing.assert_array_equal(load["confidence_map"], orig["confidence_map"])
def test_invalid_path_handling():
"""Test handling of invalid paths."""
with pytest.raises(Exception):
save_depth_data("/nonexistent/directory/file.h5", {})
def test_empty_data_handling(tmp_path):
"""Test saving empty data dictionary."""
output_path = tmp_path / "empty.h5"
save_depth_data(output_path, {})
loaded = load_depth_data(output_path)
assert len(loaded) == 0