Files
cvmmap-streamer/scripts/zed_batch_segment_sources.py
T

256 lines
8.0 KiB
Python

from __future__ import annotations
import csv
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Generic, Protocol, TypeVar
import click
from click.core import ParameterSource
class SegmentScanLike(Protocol):
segment_dir: Path
matched_files: int
is_valid: bool
ScanT = TypeVar("ScanT", bound=SegmentScanLike)
@dataclass(slots=True, frozen=True)
class SourceResolution(Generic[ScanT]):
mode: str
segment_dirs: tuple[Path, ...]
ignored_partial_dirs: tuple[ScanT, ...]
def dedupe_paths(paths: list[Path]) -> list[Path]:
ordered: list[Path] = []
seen: set[Path] = set()
for path in paths:
resolved = path.expanduser().resolve()
if resolved in seen:
continue
seen.add(resolved)
ordered.append(resolved)
return ordered
def parse_segments_csv(csv_path: Path, csv_root: Path | None) -> tuple[Path, ...]:
csv_path = csv_path.expanduser().resolve()
if not csv_path.is_file():
raise click.ClickException(f"CSV not found: {csv_path}")
if csv_root is not None:
base_dir = csv_root.expanduser().resolve()
if not base_dir.is_dir():
raise click.ClickException(f"CSV root is not a directory: {base_dir}")
else:
base_dir = csv_path.parent
segment_dirs: list[Path] = []
seen: set[Path] = set()
with csv_path.open(newline="") as stream:
reader = csv.DictReader(stream)
if reader.fieldnames is None or "segment_dir" not in reader.fieldnames:
raise click.ClickException(f"{csv_path} must contain a 'segment_dir' header")
for row_number, row in enumerate(reader, start=2):
raw_segment_dir = (row.get("segment_dir") or "").strip()
if not raw_segment_dir:
raise click.ClickException(f"{csv_path}:{row_number} has an empty segment_dir value")
segment_dir = Path(raw_segment_dir)
resolved = segment_dir if segment_dir.is_absolute() else base_dir / segment_dir
resolved = resolved.expanduser().resolve()
if resolved in seen:
continue
seen.add(resolved)
segment_dirs.append(resolved)
if not segment_dirs:
raise click.ClickException(f"{csv_path} did not contain any segment_dir rows")
return tuple(segment_dirs)
def discover_segment_dirs(
root: Path,
recursive: bool,
*,
scan_segment_dir: Callable[[Path], ScanT],
no_matches_message: Callable[[Path], str],
) -> SourceResolution[ScanT]:
resolved_root = root.expanduser().resolve()
if not resolved_root.is_dir():
raise click.ClickException(f"dataset root does not exist: {resolved_root}")
candidate_dirs = {resolved_root}
iterator = resolved_root.rglob("*") if recursive else resolved_root.iterdir()
for path in iterator:
if path.is_dir():
candidate_dirs.add(path.resolve())
valid_dirs: list[Path] = []
ignored_partial_dirs: list[ScanT] = []
for segment_dir in sorted(candidate_dirs):
scan = scan_segment_dir(segment_dir)
if scan.is_valid:
valid_dirs.append(segment_dir)
elif scan.matched_files > 0:
ignored_partial_dirs.append(scan)
if not valid_dirs:
raise click.ClickException(no_matches_message(resolved_root))
return SourceResolution(
mode="dataset-root",
segment_dirs=tuple(valid_dirs),
ignored_partial_dirs=tuple(ignored_partial_dirs),
)
def raise_if_recursive_flag_is_incompatible(
ctx: click.Context,
dataset_root: Path | None,
*,
dataset_root_flag: str = "--dataset-root",
) -> None:
if ctx.get_parameter_source("recursive") is ParameterSource.DEFAULT:
return
if dataset_root is None:
raise click.ClickException(f"--recursive/--no-recursive can only be used with {dataset_root_flag}")
def raise_for_legacy_source_args(
legacy_input_dir: Path | None,
legacy_segment_dirs: tuple[Path, ...],
*,
dataset_root_flag: str = "--dataset-root",
segment_flag: str = "--segment",
) -> None:
if legacy_input_dir is not None:
resolved = legacy_input_dir.expanduser().resolve()
raise click.ClickException(
f"positional dataset paths are no longer supported; use {dataset_root_flag} {resolved}"
)
if legacy_segment_dirs:
resolved = legacy_segment_dirs[0].expanduser().resolve()
raise click.ClickException(
f"--segment-dir is no longer supported in this batch wrapper; use {segment_flag} {resolved} "
f"for an explicit segment directory, or {dataset_root_flag} <DATASET_ROOT> --recursive for discovery"
)
def raise_for_legacy_extra_args(
extra_args: list[str],
*,
dataset_root_flag: str = "--dataset-root",
) -> None:
if not extra_args:
return
first = extra_args[0]
if first.startswith("-"):
extras_text = " ".join(extra_args)
raise click.ClickException(f"unexpected extra arguments: {extras_text}")
resolved = Path(first).expanduser().resolve()
raise click.ClickException(
f"positional dataset paths are no longer supported; use {dataset_root_flag} {resolved}"
)
def raise_if_segment_path_looks_like_dataset_root(
segment_dir: Path,
*,
scan_segment_dir: Callable[[Path], ScanT],
dataset_root_flag: str = "--dataset-root",
segment_flag: str = "--segment",
) -> None:
resolved = segment_dir.expanduser().resolve()
if not resolved.is_dir():
return
scan = scan_segment_dir(resolved)
if scan.is_valid or scan.matched_files > 0:
return
nested_segments = _find_nested_valid_segment_dirs(resolved, scan_segment_dir=scan_segment_dir)
if not nested_segments:
return
example = nested_segments[0]
raise click.ClickException(
f"{resolved} looks like a dataset root, not a segment directory. "
f"{segment_flag} expects a directory that directly contains *_zedN.svo or *_zedN.svo2 files. "
f"Use {dataset_root_flag} {resolved} to discover nested segments such as {example}"
)
def resolve_sources(
dataset_root: Path | None,
segment_dirs: tuple[Path, ...],
segments_csv: Path | None,
csv_root: Path | None,
recursive: bool,
*,
scan_segment_dir: Callable[[Path], ScanT],
no_matches_message: Callable[[Path], str],
) -> SourceResolution[ScanT]:
source_count = sum(
(
1 if dataset_root is not None else 0,
1 if segment_dirs else 0,
1 if segments_csv is not None else 0,
)
)
if source_count != 1:
raise click.ClickException(
"provide exactly one source mode: --dataset-root, --segment, or --segments-csv"
)
if dataset_root is not None:
return discover_segment_dirs(
dataset_root,
recursive,
scan_segment_dir=scan_segment_dir,
no_matches_message=no_matches_message,
)
if segment_dirs:
ordered_dirs = dedupe_paths(list(segment_dirs))
for segment_dir in ordered_dirs:
raise_if_segment_path_looks_like_dataset_root(
segment_dir,
scan_segment_dir=scan_segment_dir,
)
return SourceResolution(mode="segments", segment_dirs=tuple(ordered_dirs), ignored_partial_dirs=())
return SourceResolution(
mode="segments-csv",
segment_dirs=parse_segments_csv(segments_csv, csv_root),
ignored_partial_dirs=(),
)
def _find_nested_valid_segment_dirs(
root: Path,
*,
scan_segment_dir: Callable[[Path], ScanT],
limit: int = 3,
) -> tuple[Path, ...]:
matches: list[Path] = []
for path in sorted(root.rglob("*")):
if not path.is_dir():
continue
resolved = path.resolve()
if resolved == root:
continue
scan = scan_segment_dir(resolved)
if scan.is_valid:
matches.append(resolved)
if len(matches) >= limit:
break
return tuple(matches)