256 lines
8.0 KiB
Python
256 lines
8.0 KiB
Python
from __future__ import annotations
|
|
|
|
import csv
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Callable, Generic, Protocol, TypeVar
|
|
|
|
import click
|
|
from click.core import ParameterSource
|
|
|
|
|
|
class SegmentScanLike(Protocol):
|
|
segment_dir: Path
|
|
matched_files: int
|
|
is_valid: bool
|
|
|
|
|
|
ScanT = TypeVar("ScanT", bound=SegmentScanLike)
|
|
|
|
|
|
@dataclass(slots=True, frozen=True)
|
|
class SourceResolution(Generic[ScanT]):
|
|
mode: str
|
|
segment_dirs: tuple[Path, ...]
|
|
ignored_partial_dirs: tuple[ScanT, ...]
|
|
|
|
|
|
def dedupe_paths(paths: list[Path]) -> list[Path]:
|
|
ordered: list[Path] = []
|
|
seen: set[Path] = set()
|
|
for path in paths:
|
|
resolved = path.expanduser().resolve()
|
|
if resolved in seen:
|
|
continue
|
|
seen.add(resolved)
|
|
ordered.append(resolved)
|
|
return ordered
|
|
|
|
|
|
def parse_segments_csv(csv_path: Path, csv_root: Path | None) -> tuple[Path, ...]:
|
|
csv_path = csv_path.expanduser().resolve()
|
|
if not csv_path.is_file():
|
|
raise click.ClickException(f"CSV not found: {csv_path}")
|
|
|
|
if csv_root is not None:
|
|
base_dir = csv_root.expanduser().resolve()
|
|
if not base_dir.is_dir():
|
|
raise click.ClickException(f"CSV root is not a directory: {base_dir}")
|
|
else:
|
|
base_dir = csv_path.parent
|
|
|
|
segment_dirs: list[Path] = []
|
|
seen: set[Path] = set()
|
|
with csv_path.open(newline="") as stream:
|
|
reader = csv.DictReader(stream)
|
|
if reader.fieldnames is None or "segment_dir" not in reader.fieldnames:
|
|
raise click.ClickException(f"{csv_path} must contain a 'segment_dir' header")
|
|
|
|
for row_number, row in enumerate(reader, start=2):
|
|
raw_segment_dir = (row.get("segment_dir") or "").strip()
|
|
if not raw_segment_dir:
|
|
raise click.ClickException(f"{csv_path}:{row_number} has an empty segment_dir value")
|
|
segment_dir = Path(raw_segment_dir)
|
|
resolved = segment_dir if segment_dir.is_absolute() else base_dir / segment_dir
|
|
resolved = resolved.expanduser().resolve()
|
|
if resolved in seen:
|
|
continue
|
|
seen.add(resolved)
|
|
segment_dirs.append(resolved)
|
|
|
|
if not segment_dirs:
|
|
raise click.ClickException(f"{csv_path} did not contain any segment_dir rows")
|
|
return tuple(segment_dirs)
|
|
|
|
|
|
def discover_segment_dirs(
|
|
root: Path,
|
|
recursive: bool,
|
|
*,
|
|
scan_segment_dir: Callable[[Path], ScanT],
|
|
no_matches_message: Callable[[Path], str],
|
|
) -> SourceResolution[ScanT]:
|
|
resolved_root = root.expanduser().resolve()
|
|
if not resolved_root.is_dir():
|
|
raise click.ClickException(f"dataset root does not exist: {resolved_root}")
|
|
|
|
candidate_dirs = {resolved_root}
|
|
iterator = resolved_root.rglob("*") if recursive else resolved_root.iterdir()
|
|
for path in iterator:
|
|
if path.is_dir():
|
|
candidate_dirs.add(path.resolve())
|
|
|
|
valid_dirs: list[Path] = []
|
|
ignored_partial_dirs: list[ScanT] = []
|
|
for segment_dir in sorted(candidate_dirs):
|
|
scan = scan_segment_dir(segment_dir)
|
|
if scan.is_valid:
|
|
valid_dirs.append(segment_dir)
|
|
elif scan.matched_files > 0:
|
|
ignored_partial_dirs.append(scan)
|
|
|
|
if not valid_dirs:
|
|
raise click.ClickException(no_matches_message(resolved_root))
|
|
|
|
return SourceResolution(
|
|
mode="dataset-root",
|
|
segment_dirs=tuple(valid_dirs),
|
|
ignored_partial_dirs=tuple(ignored_partial_dirs),
|
|
)
|
|
|
|
|
|
def raise_if_recursive_flag_is_incompatible(
|
|
ctx: click.Context,
|
|
dataset_root: Path | None,
|
|
*,
|
|
dataset_root_flag: str = "--dataset-root",
|
|
) -> None:
|
|
if ctx.get_parameter_source("recursive") is ParameterSource.DEFAULT:
|
|
return
|
|
if dataset_root is None:
|
|
raise click.ClickException(f"--recursive/--no-recursive can only be used with {dataset_root_flag}")
|
|
|
|
|
|
def raise_for_legacy_source_args(
|
|
legacy_input_dir: Path | None,
|
|
legacy_segment_dirs: tuple[Path, ...],
|
|
*,
|
|
dataset_root_flag: str = "--dataset-root",
|
|
segment_flag: str = "--segment",
|
|
) -> None:
|
|
if legacy_input_dir is not None:
|
|
resolved = legacy_input_dir.expanduser().resolve()
|
|
raise click.ClickException(
|
|
f"positional dataset paths are no longer supported; use {dataset_root_flag} {resolved}"
|
|
)
|
|
|
|
if legacy_segment_dirs:
|
|
resolved = legacy_segment_dirs[0].expanduser().resolve()
|
|
raise click.ClickException(
|
|
f"--segment-dir is no longer supported in this batch wrapper; use {segment_flag} {resolved} "
|
|
f"for an explicit segment directory, or {dataset_root_flag} <DATASET_ROOT> --recursive for discovery"
|
|
)
|
|
|
|
|
|
def raise_for_legacy_extra_args(
|
|
extra_args: list[str],
|
|
*,
|
|
dataset_root_flag: str = "--dataset-root",
|
|
) -> None:
|
|
if not extra_args:
|
|
return
|
|
|
|
first = extra_args[0]
|
|
if first.startswith("-"):
|
|
extras_text = " ".join(extra_args)
|
|
raise click.ClickException(f"unexpected extra arguments: {extras_text}")
|
|
|
|
resolved = Path(first).expanduser().resolve()
|
|
raise click.ClickException(
|
|
f"positional dataset paths are no longer supported; use {dataset_root_flag} {resolved}"
|
|
)
|
|
|
|
|
|
def raise_if_segment_path_looks_like_dataset_root(
|
|
segment_dir: Path,
|
|
*,
|
|
scan_segment_dir: Callable[[Path], ScanT],
|
|
dataset_root_flag: str = "--dataset-root",
|
|
segment_flag: str = "--segment",
|
|
) -> None:
|
|
resolved = segment_dir.expanduser().resolve()
|
|
if not resolved.is_dir():
|
|
return
|
|
|
|
scan = scan_segment_dir(resolved)
|
|
if scan.is_valid or scan.matched_files > 0:
|
|
return
|
|
|
|
nested_segments = _find_nested_valid_segment_dirs(resolved, scan_segment_dir=scan_segment_dir)
|
|
if not nested_segments:
|
|
return
|
|
|
|
example = nested_segments[0]
|
|
raise click.ClickException(
|
|
f"{resolved} looks like a dataset root, not a segment directory. "
|
|
f"{segment_flag} expects a directory that directly contains *_zedN.svo or *_zedN.svo2 files. "
|
|
f"Use {dataset_root_flag} {resolved} to discover nested segments such as {example}"
|
|
)
|
|
|
|
|
|
def resolve_sources(
|
|
dataset_root: Path | None,
|
|
segment_dirs: tuple[Path, ...],
|
|
segments_csv: Path | None,
|
|
csv_root: Path | None,
|
|
recursive: bool,
|
|
*,
|
|
scan_segment_dir: Callable[[Path], ScanT],
|
|
no_matches_message: Callable[[Path], str],
|
|
) -> SourceResolution[ScanT]:
|
|
source_count = sum(
|
|
(
|
|
1 if dataset_root is not None else 0,
|
|
1 if segment_dirs else 0,
|
|
1 if segments_csv is not None else 0,
|
|
)
|
|
)
|
|
if source_count != 1:
|
|
raise click.ClickException(
|
|
"provide exactly one source mode: --dataset-root, --segment, or --segments-csv"
|
|
)
|
|
|
|
if dataset_root is not None:
|
|
return discover_segment_dirs(
|
|
dataset_root,
|
|
recursive,
|
|
scan_segment_dir=scan_segment_dir,
|
|
no_matches_message=no_matches_message,
|
|
)
|
|
|
|
if segment_dirs:
|
|
ordered_dirs = dedupe_paths(list(segment_dirs))
|
|
for segment_dir in ordered_dirs:
|
|
raise_if_segment_path_looks_like_dataset_root(
|
|
segment_dir,
|
|
scan_segment_dir=scan_segment_dir,
|
|
)
|
|
return SourceResolution(mode="segments", segment_dirs=tuple(ordered_dirs), ignored_partial_dirs=())
|
|
|
|
return SourceResolution(
|
|
mode="segments-csv",
|
|
segment_dirs=parse_segments_csv(segments_csv, csv_root),
|
|
ignored_partial_dirs=(),
|
|
)
|
|
|
|
|
|
def _find_nested_valid_segment_dirs(
|
|
root: Path,
|
|
*,
|
|
scan_segment_dir: Callable[[Path], ScanT],
|
|
limit: int = 3,
|
|
) -> tuple[Path, ...]:
|
|
matches: list[Path] = []
|
|
for path in sorted(root.rglob("*")):
|
|
if not path.is_dir():
|
|
continue
|
|
resolved = path.resolve()
|
|
if resolved == root:
|
|
continue
|
|
scan = scan_segment_dir(resolved)
|
|
if scan.is_valid:
|
|
matches.append(resolved)
|
|
if len(matches) >= limit:
|
|
break
|
|
return tuple(matches)
|