from __future__ import annotations import dataclasses import tempfile import unittest from pathlib import Path import click from click.testing import CliRunner from scripts import zed_batch_segment_sources as segment_sources from scripts.zed_batch_svo_grid_to_mp4 import main as grid_main from scripts.zed_batch_svo_to_mcap import main as mcap_main @dataclasses.dataclass(slots=True, frozen=True) class FakeScan: segment_dir: Path matched_files: int is_valid: bool reason: str | None = None def fake_scan(segment_dir: Path) -> FakeScan: if not segment_dir.is_dir(): return FakeScan(segment_dir=segment_dir, matched_files=0, is_valid=False, reason="missing directory") if (segment_dir / "valid.segment").is_file(): return FakeScan(segment_dir=segment_dir, matched_files=2, is_valid=True) if (segment_dir / "partial.segment").is_file(): return FakeScan(segment_dir=segment_dir, matched_files=1, is_valid=False, reason="partial segment") return FakeScan(segment_dir=segment_dir, matched_files=0, is_valid=False, reason="no camera files") def create_multicamera_segment(parent: Path, segment_name: str) -> Path: segment_dir = parent / segment_name segment_dir.mkdir(parents=True) for camera_index in range(1, 5): (segment_dir / f"{segment_name}_zed{camera_index}.svo2").write_bytes(b"") return segment_dir class SharedSourceResolutionTests(unittest.TestCase): def test_dataset_root_recursive_discovers_nested_segments(self) -> None: with tempfile.TemporaryDirectory() as tmp: dataset_root = Path(tmp) / "dataset" segment_dir = dataset_root / "run" / "2026-04-08T11-50-32" segment_dir.mkdir(parents=True) (segment_dir / "valid.segment").write_text("", encoding="utf-8") sources = segment_sources.resolve_sources( dataset_root, (), None, None, True, scan_segment_dir=fake_scan, no_matches_message=lambda root: f"no segments under {root}", ) self.assertEqual(sources.mode, "dataset-root") self.assertEqual(sources.segment_dirs, (segment_dir.resolve(),)) def test_dataset_root_without_recursive_does_not_descend(self) -> None: with tempfile.TemporaryDirectory() as tmp: dataset_root = Path(tmp) / "dataset" segment_dir = dataset_root / "run" / "2026-04-08T11-50-32" segment_dir.mkdir(parents=True) (segment_dir / "valid.segment").write_text("", encoding="utf-8") with self.assertRaises(click.ClickException) as error: segment_sources.resolve_sources( dataset_root, (), None, None, False, scan_segment_dir=fake_scan, no_matches_message=lambda root: f"no segments under {root}", ) self.assertIn("no segments under", str(error.exception)) def test_explicit_segments_are_deduped(self) -> None: with tempfile.TemporaryDirectory() as tmp: segment_dir = Path(tmp) / "2026-04-08T11-50-32" segment_dir.mkdir() (segment_dir / "valid.segment").write_text("", encoding="utf-8") sources = segment_sources.resolve_sources( None, (segment_dir, segment_dir), None, None, True, scan_segment_dir=fake_scan, no_matches_message=lambda root: f"no segments under {root}", ) self.assertEqual(sources.mode, "segments") self.assertEqual(sources.segment_dirs, (segment_dir.resolve(),)) def test_segments_csv_uses_segment_dir_column(self) -> None: with tempfile.TemporaryDirectory() as tmp: tmp_path = Path(tmp) segment_dir = tmp_path / "segments" / "2026-04-08T11-50-32" segment_dir.mkdir(parents=True) (segment_dir / "valid.segment").write_text("", encoding="utf-8") csv_path = tmp_path / "segments.csv" csv_path.write_text("segment_dir\nsegments/2026-04-08T11-50-32\n", encoding="utf-8") sources = segment_sources.resolve_sources( None, (), csv_path, None, True, scan_segment_dir=fake_scan, no_matches_message=lambda root: f"no segments under {root}", ) self.assertEqual(sources.mode, "segments-csv") self.assertEqual(sources.segment_dirs, (segment_dir.resolve(),)) def test_segment_path_like_dataset_root_has_hint(self) -> None: with tempfile.TemporaryDirectory() as tmp: dataset_root = Path(tmp) / "dataset" segment_dir = dataset_root / "run" / "2026-04-08T11-50-32" segment_dir.mkdir(parents=True) (segment_dir / "valid.segment").write_text("", encoding="utf-8") with self.assertRaises(click.ClickException) as error: segment_sources.resolve_sources( None, (dataset_root,), None, None, True, scan_segment_dir=fake_scan, no_matches_message=lambda root: f"no segments under {root}", ) message = str(error.exception) self.assertIn("looks like a dataset root", message) self.assertIn("--dataset-root", message) class BatchCliSmokeTests(unittest.TestCase): def setUp(self) -> None: self.runner = CliRunner() def test_mcap_dataset_root_flag_discovers_segments(self) -> None: with tempfile.TemporaryDirectory() as tmp: dataset_root = Path(tmp) / "dataset" create_multicamera_segment(dataset_root / "run", "2026-04-08T11-50-32") result = self.runner.invoke( mcap_main, [ "--dataset-root", str(dataset_root), "--recursive", "--dry-run", "--zed-bin", "/bin/true", ], ) self.assertEqual(result.exit_code, 0, result.output) self.assertIn("source=dataset-root matched=1 pending=1", result.output) def test_mcap_segment_flag_rejects_dataset_root_with_hint(self) -> None: with tempfile.TemporaryDirectory() as tmp: dataset_root = Path(tmp) / "dataset" create_multicamera_segment(dataset_root / "run", "2026-04-08T11-50-32") result = self.runner.invoke( mcap_main, [ "--segment", str(dataset_root), "--dry-run", "--zed-bin", "/bin/true", ], ) self.assertNotEqual(result.exit_code, 0) self.assertIn("looks like a dataset root", result.output) self.assertIn("--dataset-root", result.output) def test_mcap_rejects_legacy_positional_dataset_root(self) -> None: with tempfile.TemporaryDirectory() as tmp: dataset_root = Path(tmp) / "dataset" create_multicamera_segment(dataset_root / "run", "2026-04-08T11-50-32") result = self.runner.invoke( mcap_main, [ str(dataset_root), "--dry-run", "--zed-bin", "/bin/true", ], ) self.assertNotEqual(result.exit_code, 0) self.assertIn("positional dataset paths are no longer supported", result.output) self.assertIn("--dataset-root", result.output) def test_mcap_rejects_recursive_without_dataset_root(self) -> None: with tempfile.TemporaryDirectory() as tmp: segment_dir = create_multicamera_segment(Path(tmp), "2026-04-08T11-50-32") result = self.runner.invoke( mcap_main, [ "--segment", str(segment_dir), "--no-recursive", "--dry-run", "--zed-bin", "/bin/true", ], ) self.assertNotEqual(result.exit_code, 0) self.assertIn("--recursive/--no-recursive can only be used with --dataset-root", result.output) def test_grid_segment_flag_discovers_one_segment(self) -> None: with tempfile.TemporaryDirectory() as tmp: segment_dir = create_multicamera_segment(Path(tmp), "2026-04-08T11-50-32") result = self.runner.invoke( grid_main, [ "--segment", str(segment_dir), "--dry-run", "--zed-bin", "/bin/true", ], ) self.assertEqual(result.exit_code, 0, result.output) self.assertIn("source=segments matched=1 pending=1", result.output) def test_grid_rejects_legacy_segment_dir_flag(self) -> None: with tempfile.TemporaryDirectory() as tmp: segment_dir = create_multicamera_segment(Path(tmp), "2026-04-08T11-50-32") result = self.runner.invoke( grid_main, [ "--segment-dir", str(segment_dir), "--dry-run", "--zed-bin", "/bin/true", ], ) self.assertNotEqual(result.exit_code, 0) self.assertIn("--segment-dir is no longer supported", result.output) self.assertIn("--segment", result.output) if __name__ == "__main__": unittest.main()