from __future__ import annotations import csv from dataclasses import dataclass from pathlib import Path from typing import Callable, Generic, Protocol, TypeVar import click from click.core import ParameterSource class SegmentScanLike(Protocol): segment_dir: Path matched_files: int is_valid: bool ScanT = TypeVar("ScanT", bound=SegmentScanLike) @dataclass(slots=True, frozen=True) class SourceResolution(Generic[ScanT]): mode: str segment_dirs: tuple[Path, ...] ignored_partial_dirs: tuple[ScanT, ...] def dedupe_paths(paths: list[Path]) -> list[Path]: ordered: list[Path] = [] seen: set[Path] = set() for path in paths: resolved = path.expanduser().resolve() if resolved in seen: continue seen.add(resolved) ordered.append(resolved) return ordered def parse_segments_csv(csv_path: Path, csv_root: Path | None) -> tuple[Path, ...]: csv_path = csv_path.expanduser().resolve() if not csv_path.is_file(): raise click.ClickException(f"CSV not found: {csv_path}") if csv_root is not None: base_dir = csv_root.expanduser().resolve() if not base_dir.is_dir(): raise click.ClickException(f"CSV root is not a directory: {base_dir}") else: base_dir = csv_path.parent segment_dirs: list[Path] = [] seen: set[Path] = set() with csv_path.open(newline="") as stream: reader = csv.DictReader(stream) if reader.fieldnames is None or "segment_dir" not in reader.fieldnames: raise click.ClickException(f"{csv_path} must contain a 'segment_dir' header") for row_number, row in enumerate(reader, start=2): raw_segment_dir = (row.get("segment_dir") or "").strip() if not raw_segment_dir: raise click.ClickException(f"{csv_path}:{row_number} has an empty segment_dir value") segment_dir = Path(raw_segment_dir) resolved = segment_dir if segment_dir.is_absolute() else base_dir / segment_dir resolved = resolved.expanduser().resolve() if resolved in seen: continue seen.add(resolved) segment_dirs.append(resolved) if not segment_dirs: raise click.ClickException(f"{csv_path} did not contain any segment_dir rows") return tuple(segment_dirs) def discover_segment_dirs( root: Path, recursive: bool, *, scan_segment_dir: Callable[[Path], ScanT], no_matches_message: Callable[[Path], str], ) -> SourceResolution[ScanT]: resolved_root = root.expanduser().resolve() if not resolved_root.is_dir(): raise click.ClickException(f"dataset root does not exist: {resolved_root}") candidate_dirs = {resolved_root} iterator = resolved_root.rglob("*") if recursive else resolved_root.iterdir() for path in iterator: if path.is_dir(): candidate_dirs.add(path.resolve()) valid_dirs: list[Path] = [] ignored_partial_dirs: list[ScanT] = [] for segment_dir in sorted(candidate_dirs): scan = scan_segment_dir(segment_dir) if scan.is_valid: valid_dirs.append(segment_dir) elif scan.matched_files > 0: ignored_partial_dirs.append(scan) if not valid_dirs: raise click.ClickException(no_matches_message(resolved_root)) return SourceResolution( mode="dataset-root", segment_dirs=tuple(valid_dirs), ignored_partial_dirs=tuple(ignored_partial_dirs), ) def raise_if_recursive_flag_is_incompatible( ctx: click.Context, dataset_root: Path | None, *, dataset_root_flag: str = "--dataset-root", ) -> None: if ctx.get_parameter_source("recursive") is ParameterSource.DEFAULT: return if dataset_root is None: raise click.ClickException(f"--recursive/--no-recursive can only be used with {dataset_root_flag}") def raise_for_legacy_source_args( legacy_input_dir: Path | None, legacy_segment_dirs: tuple[Path, ...], *, dataset_root_flag: str = "--dataset-root", segment_flag: str = "--segment", ) -> None: if legacy_input_dir is not None: resolved = legacy_input_dir.expanduser().resolve() raise click.ClickException( f"positional dataset paths are no longer supported; use {dataset_root_flag} {resolved}" ) if legacy_segment_dirs: resolved = legacy_segment_dirs[0].expanduser().resolve() raise click.ClickException( f"--segment-dir is no longer supported in this batch wrapper; use {segment_flag} {resolved} " f"for an explicit segment directory, or {dataset_root_flag} --recursive for discovery" ) def raise_for_legacy_extra_args( extra_args: list[str], *, dataset_root_flag: str = "--dataset-root", ) -> None: if not extra_args: return first = extra_args[0] if first.startswith("-"): extras_text = " ".join(extra_args) raise click.ClickException(f"unexpected extra arguments: {extras_text}") resolved = Path(first).expanduser().resolve() raise click.ClickException( f"positional dataset paths are no longer supported; use {dataset_root_flag} {resolved}" ) def raise_if_segment_path_looks_like_dataset_root( segment_dir: Path, *, scan_segment_dir: Callable[[Path], ScanT], dataset_root_flag: str = "--dataset-root", segment_flag: str = "--segment", ) -> None: resolved = segment_dir.expanduser().resolve() if not resolved.is_dir(): return scan = scan_segment_dir(resolved) if scan.is_valid or scan.matched_files > 0: return nested_segments = _find_nested_valid_segment_dirs(resolved, scan_segment_dir=scan_segment_dir) if not nested_segments: return example = nested_segments[0] raise click.ClickException( f"{resolved} looks like a dataset root, not a segment directory. " f"{segment_flag} expects a directory that directly contains *_zedN.svo or *_zedN.svo2 files. " f"Use {dataset_root_flag} {resolved} to discover nested segments such as {example}" ) def resolve_sources( dataset_root: Path | None, segment_dirs: tuple[Path, ...], segments_csv: Path | None, csv_root: Path | None, recursive: bool, *, scan_segment_dir: Callable[[Path], ScanT], no_matches_message: Callable[[Path], str], ) -> SourceResolution[ScanT]: source_count = sum( ( 1 if dataset_root is not None else 0, 1 if segment_dirs else 0, 1 if segments_csv is not None else 0, ) ) if source_count != 1: raise click.ClickException( "provide exactly one source mode: --dataset-root, --segment, or --segments-csv" ) if dataset_root is not None: return discover_segment_dirs( dataset_root, recursive, scan_segment_dir=scan_segment_dir, no_matches_message=no_matches_message, ) if segment_dirs: ordered_dirs = dedupe_paths(list(segment_dirs)) for segment_dir in ordered_dirs: raise_if_segment_path_looks_like_dataset_root( segment_dir, scan_segment_dir=scan_segment_dir, ) return SourceResolution(mode="segments", segment_dirs=tuple(ordered_dirs), ignored_partial_dirs=()) return SourceResolution( mode="segments-csv", segment_dirs=parse_segments_csv(segments_csv, csv_root), ignored_partial_dirs=(), ) def _find_nested_valid_segment_dirs( root: Path, *, scan_segment_dir: Callable[[Path], ScanT], limit: int = 3, ) -> tuple[Path, ...]: matches: list[Path] = [] for path in sorted(root.rglob("*")): if not path.is_dir(): continue resolved = path.resolve() if resolved == root: continue scan = scan_segment_dir(resolved) if scan.is_valid: matches.append(resolved) if len(matches) >= limit: break return tuple(matches)