Files
OpenGait/scripts/build_scoliosis_fixedpool_partition.py
T

122 lines
3.9 KiB
Python

from __future__ import annotations
import json
import random
from collections import Counter
from pathlib import Path
from typing import TypedDict, cast
import click
class Partition(TypedDict):
TRAIN_SET: list[str]
TEST_SET: list[str]
def infer_pid_label(dataset_root: Path, pid: str) -> str:
pid_root = dataset_root / pid
if not pid_root.exists():
raise FileNotFoundError(f"PID root not found under dataset root: {pid_root}")
label_dirs = sorted([entry.name.lower() for entry in pid_root.iterdir() if entry.is_dir()])
if len(label_dirs) != 1:
raise ValueError(f"Expected exactly one class dir for pid {pid}, got {label_dirs}")
label = label_dirs[0]
if label not in {"positive", "neutral", "negative"}:
raise ValueError(f"Unexpected label directory for pid {pid}: {label}")
return label
@click.command()
@click.option(
"--base-partition",
type=click.Path(path_type=Path, exists=True, dir_okay=False),
required=True,
help="Path to the source partition JSON, e.g. datasets/Scoliosis1K/Scoliosis1K_118.json",
)
@click.option(
"--dataset-root",
type=click.Path(path_type=Path, exists=True, file_okay=False),
required=True,
help="Dataset root used to infer each pid class label, e.g. /mnt/public/data/Scoliosis1K/Scoliosis1K-sil-pkl",
)
@click.option(
"--negative-multiplier",
type=int,
required=True,
help="Target negative count as a multiple of the positive/neutral count, e.g. 2 for 1:1:2",
)
@click.option(
"--output-path",
type=click.Path(path_type=Path, dir_okay=False),
required=True,
help="Path to write the derived partition JSON.",
)
@click.option(
"--seed",
type=int,
default=118,
show_default=True,
help="Random seed used when downsampling negatives.",
)
def main(
base_partition: Path,
dataset_root: Path,
negative_multiplier: int,
output_path: Path,
seed: int,
) -> None:
with base_partition.open("r", encoding="utf-8") as handle:
partition = cast(Partition, json.load(handle))
train_ids = list(partition["TRAIN_SET"])
test_ids = list(partition["TEST_SET"])
train_by_label: dict[str, list[str]] = {"positive": [], "neutral": [], "negative": []}
for pid in train_ids:
label = infer_pid_label(dataset_root, pid)
train_by_label[label].append(pid)
pos_count = len(train_by_label["positive"])
neu_count = len(train_by_label["neutral"])
neg_count = len(train_by_label["negative"])
if pos_count != neu_count:
raise ValueError(
"This helper assumes equal positive/neutral train counts so that only "
+ f"negative downsampling changes the ratio. Got positive={pos_count}, neutral={neu_count}."
)
target_negative_count = negative_multiplier * pos_count
if target_negative_count > neg_count:
raise ValueError(
f"Requested {target_negative_count} negatives but only {neg_count} are available "
+ f"in base partition {base_partition}."
)
rng = random.Random(seed)
sampled_negatives = sorted(rng.sample(train_by_label["negative"], target_negative_count))
derived_train = (
sorted(train_by_label["positive"])
+ sorted(train_by_label["neutral"])
+ sampled_negatives
)
derived_partition = {
"TRAIN_SET": derived_train,
"TEST_SET": test_ids,
}
output_path.parent.mkdir(parents=True, exist_ok=True)
with output_path.open("w", encoding="utf-8") as handle:
json.dump(derived_partition, handle, indent=2)
_ = handle.write("\n")
train_counts = Counter(infer_pid_label(dataset_root, pid) for pid in derived_train)
test_counts = Counter(infer_pid_label(dataset_root, pid) for pid in test_ids)
click.echo(f"wrote {output_path}")
click.echo(f"train_counts={dict(train_counts)}")
click.echo(f"test_counts={dict(test_counts)}")
if __name__ == "__main__":
main()