from __future__ import annotations import json import random from collections import Counter from pathlib import Path from typing import TypedDict, cast import click class Partition(TypedDict): TRAIN_SET: list[str] TEST_SET: list[str] def infer_pid_label(dataset_root: Path, pid: str) -> str: pid_root = dataset_root / pid if not pid_root.exists(): raise FileNotFoundError(f"PID root not found under dataset root: {pid_root}") label_dirs = sorted([entry.name.lower() for entry in pid_root.iterdir() if entry.is_dir()]) if len(label_dirs) != 1: raise ValueError(f"Expected exactly one class dir for pid {pid}, got {label_dirs}") label = label_dirs[0] if label not in {"positive", "neutral", "negative"}: raise ValueError(f"Unexpected label directory for pid {pid}: {label}") return label @click.command() @click.option( "--base-partition", type=click.Path(path_type=Path, exists=True, dir_okay=False), required=True, help="Path to the source partition JSON, e.g. datasets/Scoliosis1K/Scoliosis1K_118.json", ) @click.option( "--dataset-root", type=click.Path(path_type=Path, exists=True, file_okay=False), required=True, help="Dataset root used to infer each pid class label, e.g. /mnt/public/data/Scoliosis1K/Scoliosis1K-sil-pkl", ) @click.option( "--negative-multiplier", type=int, required=True, help="Target negative count as a multiple of the positive/neutral count, e.g. 2 for 1:1:2", ) @click.option( "--output-path", type=click.Path(path_type=Path, dir_okay=False), required=True, help="Path to write the derived partition JSON.", ) @click.option( "--seed", type=int, default=118, show_default=True, help="Random seed used when downsampling negatives.", ) def main( base_partition: Path, dataset_root: Path, negative_multiplier: int, output_path: Path, seed: int, ) -> None: with base_partition.open("r", encoding="utf-8") as handle: partition = cast(Partition, json.load(handle)) train_ids = list(partition["TRAIN_SET"]) test_ids = list(partition["TEST_SET"]) train_by_label: dict[str, list[str]] = {"positive": [], "neutral": [], "negative": []} for pid in train_ids: label = infer_pid_label(dataset_root, pid) train_by_label[label].append(pid) pos_count = len(train_by_label["positive"]) neu_count = len(train_by_label["neutral"]) neg_count = len(train_by_label["negative"]) if pos_count != neu_count: raise ValueError( "This helper assumes equal positive/neutral train counts so that only " + f"negative downsampling changes the ratio. Got positive={pos_count}, neutral={neu_count}." ) target_negative_count = negative_multiplier * pos_count if target_negative_count > neg_count: raise ValueError( f"Requested {target_negative_count} negatives but only {neg_count} are available " + f"in base partition {base_partition}." ) rng = random.Random(seed) sampled_negatives = sorted(rng.sample(train_by_label["negative"], target_negative_count)) derived_train = ( sorted(train_by_label["positive"]) + sorted(train_by_label["neutral"]) + sampled_negatives ) derived_partition = { "TRAIN_SET": derived_train, "TEST_SET": test_ids, } output_path.parent.mkdir(parents=True, exist_ok=True) with output_path.open("w", encoding="utf-8") as handle: json.dump(derived_partition, handle, indent=2) _ = handle.write("\n") train_counts = Counter(infer_pid_label(dataset_root, pid) for pid in derived_train) test_counts = Counter(infer_pid_label(dataset_root, pid) for pid in test_ids) click.echo(f"wrote {output_path}") click.echo(f"train_counts={dict(train_counts)}") click.echo(f"test_counts={dict(test_counts)}") if __name__ == "__main__": main()