Compare commits
3 Commits
5f98844aff
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| 7b98e066e4 | |||
| 4a12bd64b9 | |||
| 4d916e71c1 |
@@ -0,0 +1,25 @@
|
|||||||
|
coco18tococo17_args:
|
||||||
|
transfer_to_coco17: False
|
||||||
|
|
||||||
|
padkeypoints_args:
|
||||||
|
pad_method: knn
|
||||||
|
use_conf: True
|
||||||
|
|
||||||
|
norm_args:
|
||||||
|
pose_format: coco
|
||||||
|
use_conf: ${padkeypoints_args.use_conf}
|
||||||
|
heatmap_image_height: 128
|
||||||
|
|
||||||
|
heatmap_generator_args:
|
||||||
|
sigma: 8.0
|
||||||
|
use_score: ${padkeypoints_args.use_conf}
|
||||||
|
img_h: ${norm_args.heatmap_image_height}
|
||||||
|
img_w: ${norm_args.heatmap_image_height}
|
||||||
|
with_limb: null
|
||||||
|
with_kp: null
|
||||||
|
|
||||||
|
align_args:
|
||||||
|
align: True
|
||||||
|
final_img_size: 64
|
||||||
|
offset: 0
|
||||||
|
heatmap_image_size: ${norm_args.heatmap_image_height}
|
||||||
@@ -69,6 +69,33 @@ python -m torch.distributed.launch --nproc_per_node=4 \
|
|||||||
opengait/main.py --cfgs configs/sconet/sconet_scoliosis1k.yaml --phase test --log_to_file
|
opengait/main.py --cfgs configs/sconet/sconet_scoliosis1k.yaml --phase test --log_to_file
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Fixed-pool ratio comparison
|
||||||
|
|
||||||
|
If you want to compare `1:1:2` against `1:1:8` without changing the evaluation
|
||||||
|
pool, do not compare `Scoliosis1K_112.json` against `Scoliosis1K_118.json`
|
||||||
|
directly. Those two files differ substantially in train/test membership.
|
||||||
|
|
||||||
|
For a cleaner same-pool comparison, use:
|
||||||
|
|
||||||
|
* `datasets/Scoliosis1K/Scoliosis1K_118.json`
|
||||||
|
* original `1:1:8` split
|
||||||
|
* `datasets/Scoliosis1K/Scoliosis1K_118_fixedpool_train112.json`
|
||||||
|
* same `TEST_SET` as `118`
|
||||||
|
* same positive/neutral `TRAIN_SET` ids as `118`
|
||||||
|
* downsampled `TRAIN_SET` negatives to `148`, giving train counts
|
||||||
|
`74 positive / 74 neutral / 148 negative`
|
||||||
|
|
||||||
|
The helper used to generate that derived partition is:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run python scripts/build_scoliosis_fixedpool_partition.py \
|
||||||
|
--base-partition datasets/Scoliosis1K/Scoliosis1K_118.json \
|
||||||
|
--dataset-root /mnt/public/data/Scoliosis1K/Scoliosis1K-sil-pkl \
|
||||||
|
--negative-multiplier 2 \
|
||||||
|
--output-path datasets/Scoliosis1K/Scoliosis1K_118_fixedpool_train112.json \
|
||||||
|
--seed 118
|
||||||
|
```
|
||||||
|
|
||||||
### Modality sanity check
|
### Modality sanity check
|
||||||
|
|
||||||
The silhouette and skeleton-map pipelines are different experiments and should not be mixed when you interpret results.
|
The silhouette and skeleton-map pipelines are different experiments and should not be mixed when you interpret results.
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -96,6 +96,79 @@ Result:
|
|||||||
|
|
||||||
This is the strongest recovered path so far.
|
This is the strongest recovered path so far.
|
||||||
|
|
||||||
|
### Verified provenance of `Scoliosis1K-drf-pkl-118-aligned`
|
||||||
|
|
||||||
|
The `118-aligned` root is no longer just an informed guess. It was verified
|
||||||
|
directly against the raw pose source:
|
||||||
|
- `/mnt/public/data/Scoliosis1K/Scoliosis1K-pose-pkl`
|
||||||
|
|
||||||
|
The matching preprocessing path is:
|
||||||
|
- `datasets/pretreatment_scoliosis_drf.py`
|
||||||
|
- default heatmap config:
|
||||||
|
- `configs/drf/pretreatment_heatmap_drf.yaml`
|
||||||
|
- archived equivalent config:
|
||||||
|
- `configs/drf/pretreatment_heatmap_drf_118_aligned.yaml`
|
||||||
|
|
||||||
|
That means the aligned root was produced with:
|
||||||
|
- shared `sigma: 8.0`
|
||||||
|
- `align: True`
|
||||||
|
- `final_img_size: 64`
|
||||||
|
- default `heatmap_reduction=upstream`
|
||||||
|
- no `--stats_partition`, i.e. dataset-level PAV min-max stats
|
||||||
|
|
||||||
|
Equivalent command:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run python datasets/pretreatment_scoliosis_drf.py \
|
||||||
|
--pose_data_path /mnt/public/data/Scoliosis1K/Scoliosis1K-pose-pkl \
|
||||||
|
--output_path /mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-aligned
|
||||||
|
```
|
||||||
|
|
||||||
|
Verification evidence:
|
||||||
|
- a regenerated `0_heatmap.pkl` sample from the raw pose input matched the stored
|
||||||
|
`Scoliosis1K-drf-pkl-118-aligned` sample exactly (`array_equal == True`)
|
||||||
|
- a full recomputation of `pav_stats.pkl` from the raw pose input matched the
|
||||||
|
stored `pav_min`, `pav_max`, and `stats_partition=None` exactly
|
||||||
|
|
||||||
|
So `118-aligned` is the old default OpenGait-style DRF export, not the later:
|
||||||
|
- `118-paper` paper-literal summed-heatmap export
|
||||||
|
- `118` train-only-stats splitroot export
|
||||||
|
- `sigma15` / `sigma15_joint8` exports
|
||||||
|
|
||||||
|
### Targeted preprocessing ablations around the recovered path
|
||||||
|
|
||||||
|
After verifying the aligned root provenance, a few focused runtime/data ablations
|
||||||
|
were tested against the author checkpoint to see which part of the contract still
|
||||||
|
mattered most.
|
||||||
|
|
||||||
|
Baseline:
|
||||||
|
- `118-aligned`
|
||||||
|
- `BaseSilCuttingTransform`
|
||||||
|
- result:
|
||||||
|
- `80.24 Acc / 76.73 Prec / 76.40 Rec / 76.56 F1`
|
||||||
|
|
||||||
|
Hybrid 1:
|
||||||
|
- aligned heatmap + splitroot PAV
|
||||||
|
- result:
|
||||||
|
- `77.30 Acc / 73.70 Prec / 73.04 Rec / 73.28 F1`
|
||||||
|
|
||||||
|
Hybrid 2:
|
||||||
|
- splitroot heatmap + aligned PAV
|
||||||
|
- result:
|
||||||
|
- `80.37 Acc / 77.16 Prec / 76.48 Rec / 76.80 F1`
|
||||||
|
|
||||||
|
Runtime ablation:
|
||||||
|
- `118-aligned` + `BaseSilTransform` (`no-cut`)
|
||||||
|
- result:
|
||||||
|
- `49.93 Acc / 50.49 Prec / 51.58 Rec / 47.75 F1`
|
||||||
|
|
||||||
|
What these ablations suggest:
|
||||||
|
- `BaseSilCuttingTransform` is necessary; `no-cut` breaks the checkpoint badly
|
||||||
|
- dataset-level PAV stats (`stats_partition=None`) matter more than the exact
|
||||||
|
aligned-vs-splitroot heatmap writer
|
||||||
|
- the heatmap export is still part of the contract, but it is no longer the
|
||||||
|
dominant remaining mismatch
|
||||||
|
|
||||||
### Other tested paths
|
### Other tested paths
|
||||||
|
|
||||||
`configs/drf/drf_author_eval_118_splitroot_1gpu.yaml`
|
`configs/drf/drf_author_eval_118_splitroot_1gpu.yaml`
|
||||||
@@ -123,6 +196,8 @@ What these results mean:
|
|||||||
- the original “very bad” local eval was mostly a compatibility failure
|
- the original “very bad” local eval was mostly a compatibility failure
|
||||||
- the largest single hidden bug was the class-order mismatch
|
- the largest single hidden bug was the class-order mismatch
|
||||||
- the author checkpoint is also sensitive to which local DRF dataset root is used
|
- the author checkpoint is also sensitive to which local DRF dataset root is used
|
||||||
|
- the recovered runtime is now good enough to make the checkpoint believable, but
|
||||||
|
preprocessing alone did not recover the paper DRF headline row
|
||||||
|
|
||||||
What they do **not** mean:
|
What they do **not** mean:
|
||||||
|
|
||||||
@@ -130,6 +205,20 @@ What they do **not** mean:
|
|||||||
- the provided YAML is trustworthy as-is
|
- the provided YAML is trustworthy as-is
|
||||||
- the paper’s full DRF claim is fully reproduced here
|
- the paper’s full DRF claim is fully reproduced here
|
||||||
|
|
||||||
|
One practical caveat on `1:1:2` vs `1:1:8` comparisons in this repo:
|
||||||
|
- local `Scoliosis1K_112.json` and `Scoliosis1K_118.json` are not the same train/test
|
||||||
|
split with only a different class ratio
|
||||||
|
- they differ substantially in membership
|
||||||
|
- so local `112` vs `118` results should not be overinterpreted as a pure
|
||||||
|
class-balance ablation unless the train/test pool is explicitly held fixed
|
||||||
|
|
||||||
|
To support a clean same-pool comparison, the repo now also includes:
|
||||||
|
- `datasets/Scoliosis1K/Scoliosis1K_118_fixedpool_train112.json`
|
||||||
|
|
||||||
|
That partition keeps the full `118` `TEST_SET` unchanged and keeps the same
|
||||||
|
positive/neutral `TRAIN_SET` ids as `118`, but downsamples `TRAIN_SET` negatives
|
||||||
|
to `148` so the train ratio becomes `74 / 74 / 148` (`1:1:2`).
|
||||||
|
|
||||||
The strongest recovered result:
|
The strongest recovered result:
|
||||||
- `80.24 / 76.73 / 76.40 / 76.56`
|
- `80.24 / 76.73 / 76.40 / 76.56`
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,12 @@
|
|||||||
# ScoNet and DRF: Status, Architecture, and Reproduction Notes
|
# ScoNet and DRF: Status, Architecture, and Reproduction Notes
|
||||||
|
|
||||||
This note records the current Scoliosis1K implementation status in this repo and the main conclusions from the recent reproduction/debugging work.
|
This note is the high-level status page for Scoliosis1K work in this repo.
|
||||||
|
It records what is implemented, what currently works best in practice, and
|
||||||
|
how to interpret the local DRF/ScoNet results.
|
||||||
|
|
||||||
For a stricter paper-vs-local reproducibility breakdown, see [scoliosis_reproducibility_audit.md](scoliosis_reproducibility_audit.md).
|
For the stricter paper-vs-local breakdown, see [scoliosis_reproducibility_audit.md](scoliosis_reproducibility_audit.md).
|
||||||
|
For the concrete experiment queue, see [scoliosis_next_experiments.md](scoliosis_next_experiments.md).
|
||||||
|
For the author-checkpoint compatibility recovery, see [drf_author_checkpoint_compat.md](drf_author_checkpoint_compat.md).
|
||||||
For the recommended long-running local launch workflow, see [systemd-run-training.md](systemd-run-training.md).
|
For the recommended long-running local launch workflow, see [systemd-run-training.md](systemd-run-training.md).
|
||||||
|
|
||||||
## Current status
|
## Current status
|
||||||
@@ -12,6 +16,22 @@ For the recommended long-running local launch workflow, see [systemd-run-trainin
|
|||||||
- `opengait/modeling/models/drf.py` is now implemented as a standalone DRF model in this repo.
|
- `opengait/modeling/models/drf.py` is now implemented as a standalone DRF model in this repo.
|
||||||
- Logging supports TensorBoard and optional Weights & Biases through `opengait/utils/msg_manager.py`.
|
- Logging supports TensorBoard and optional Weights & Biases through `opengait/utils/msg_manager.py`.
|
||||||
|
|
||||||
|
## Current bottom line
|
||||||
|
|
||||||
|
- The current practical winner is the skeleton-map ScoNet path, not DRF.
|
||||||
|
- The best verified local checkpoint is:
|
||||||
|
- `ScoNet_skeleton_112_sigma15_joint8_bodyonly_plaince_adamw_cosine_finetune_1gpu_80k`
|
||||||
|
- retained best checkpoint at `27000`
|
||||||
|
- verified full-test result: `92.38 Acc / 90.30 Prec / 87.39 Rec / 88.70 F1`
|
||||||
|
- The strongest practical recipe behind that checkpoint is:
|
||||||
|
- split: `1:1:2`
|
||||||
|
- representation: `body-only`
|
||||||
|
- losses: plain CE + triplet
|
||||||
|
- baseline training: `SGD`
|
||||||
|
- later finetune: `AdamW` + cosine decay
|
||||||
|
- A local DRF run trained from scratch on the same practical recipe did not improve over the plain skeleton baseline.
|
||||||
|
- The author-provided DRF checkpoint is now usable in-tree after compatibility fixes, but only under the recovered `118-aligned` runtime contract.
|
||||||
|
|
||||||
## Naming clarification
|
## Naming clarification
|
||||||
|
|
||||||
The name `ScoNet` is overloaded across the paper, config files, and checkpoints. Use the mapping below when reading this repo:
|
The name `ScoNet` is overloaded across the paper, config files, and checkpoints. Use the mapping below when reading this repo:
|
||||||
@@ -73,20 +93,47 @@ The main findings so far are:
|
|||||||
- a later full-test rerun confirmed the `body-only + plain CE` `7000` result exactly
|
- a later full-test rerun confirmed the `body-only + plain CE` `7000` result exactly
|
||||||
- an `AdamW` cosine finetune from that same plain-CE checkpoint improved the practical best further; the retained `27000` checkpoint reproduced at `92.38%` accuracy and `88.70%` macro-F1 on the full test set
|
- an `AdamW` cosine finetune from that same plain-CE checkpoint improved the practical best further; the retained `27000` checkpoint reproduced at `92.38%` accuracy and `88.70%` macro-F1 on the full test set
|
||||||
- a `head-lite + plain CE` variant looked promising on the fixed proxy subset but underperformed on the full test set at `7000` (`78.07%` accuracy, `62.08%` macro-F1)
|
- a `head-lite + plain CE` variant looked promising on the fixed proxy subset but underperformed on the full test set at `7000` (`78.07%` accuracy, `62.08%` macro-F1)
|
||||||
|
- The first practical DRF bridge on that same winning `1:1:2` recipe did not improve on the plain skeleton baseline:
|
||||||
|
- best retained DRF checkpoint (`2000`) on the full test set: `80.21 Acc / 58.92 Prec / 59.23 Rec / 57.84 F1`
|
||||||
|
- practical plain skeleton checkpoint (`7000`) on the full test set: `83.16 Acc / 68.24 Prec / 80.02 Rec / 68.47 F1`
|
||||||
|
- The author-provided DRF checkpoint initially looked unusable in this fork, but that turned out to be a compatibility problem, not a pure weight problem.
|
||||||
|
- after recovering the legacy runtime contract, the best compatible path was `Scoliosis1K-drf-pkl-118-aligned`
|
||||||
|
- recovered author-checkpoint result: `80.24 Acc / 76.73 Prec / 76.40 Rec / 76.56 F1`
|
||||||
|
|
||||||
The current working conclusion is:
|
The current working conclusion is:
|
||||||
|
|
||||||
- the core ScoNet trainer is not the problem
|
- the core ScoNet trainer is not the problem
|
||||||
- the strong silhouette checkpoint is not evidence that the skeleton-map path works
|
- the strong silhouette checkpoint is not evidence that the skeleton-map path works
|
||||||
- the main remaining suspect is the skeleton-map representation and preprocessing path
|
- the biggest historical problem was the skeleton-map/runtime contract, not just the optimizer
|
||||||
- for practical model development, `1:1:2` is currently the better working split than `1:1:8`
|
- for practical model development, `1:1:2` is currently the better working split than `1:1:8`
|
||||||
- for practical model development, the current best skeleton recipe is `body-only + plain CE`, and the current best retained checkpoint comes from a later `AdamW` cosine finetune on `1:1:2`
|
- for practical model development, the current best skeleton recipe is `body-only + plain CE`, and the current best retained checkpoint comes from a later `AdamW` cosine finetune on `1:1:2`
|
||||||
- the first practical DRF bridge on that same winning `1:1:2` recipe did not improve on the plain skeleton baseline:
|
- for practical use, DRF is still behind the local ScoNet skeleton winner
|
||||||
- best retained DRF checkpoint (`2000`) on the full test set: `80.21 Acc / 58.92 Prec / 59.23 Rec / 57.84 F1`
|
- for paper-compatibility analysis, the author checkpoint demonstrates that our earlier DRF failure was partly caused by contract mismatch
|
||||||
- current best plain skeleton checkpoint (`7000`) on the full test set: `83.16 Acc / 68.24 Prec / 80.02 Rec / 68.47 F1`
|
|
||||||
|
|
||||||
For readability in this repo's docs, `ScoNet-MT-ske` refers to the skeleton-map variant that the DRF paper writes as `ScoNet-MT^{ske}`.
|
For readability in this repo's docs, `ScoNet-MT-ske` refers to the skeleton-map variant that the DRF paper writes as `ScoNet-MT^{ske}`.
|
||||||
|
|
||||||
|
## DRF compatibility note
|
||||||
|
|
||||||
|
There are now two different DRF stories in this repo:
|
||||||
|
|
||||||
|
1. The local-from-scratch DRF branch.
|
||||||
|
- This is the branch trained directly in our fork on the current practical recipe.
|
||||||
|
- It did not beat the plain skeleton baseline.
|
||||||
|
|
||||||
|
2. The author-checkpoint compatibility branch.
|
||||||
|
- This uses the author-supplied checkpoint plus in-tree compatibility fixes.
|
||||||
|
- The main recovered issues were:
|
||||||
|
- legacy module naming drift: `attention_layer.*` vs `PGA.*`
|
||||||
|
- class-order mismatch between the author stub and our evaluator assumptions
|
||||||
|
- stale/internally inconsistent author YAML
|
||||||
|
- preprocessing/runtime mismatch, where `118-aligned` matched much better than the paper-literal export
|
||||||
|
|
||||||
|
That distinction matters. It means:
|
||||||
|
|
||||||
|
- "our DRF training branch underperformed" is true
|
||||||
|
- "the author DRF checkpoint is unusable" is false
|
||||||
|
- "the author result was drop-in reproducible from the handed-over YAML" is also false
|
||||||
|
|
||||||
## Architecture mapping
|
## Architecture mapping
|
||||||
|
|
||||||
`ScoNet` in this repo maps to the paper as follows:
|
`ScoNet` in this repo maps to the paper as follows:
|
||||||
@@ -115,12 +162,6 @@ The standard Scoliosis1K ScoNet recipe is:
|
|||||||
|
|
||||||
The skeleton-map control used the same recipe, except for the modality-specific changes listed above.
|
The skeleton-map control used the same recipe, except for the modality-specific changes listed above.
|
||||||
|
|
||||||
## Recommended next checks
|
|
||||||
|
|
||||||
1. Train a pure silhouette `1:1:8` baseline from the upstream ScoNet config as a clean sanity control.
|
|
||||||
2. Treat skeleton-map preprocessing as the primary debugging target until a `ScoNet-MT-ske`-style run gets close to the paper.
|
|
||||||
3. Only after the skeleton baseline is credible should DRF/PAV-specific conclusions be treated as decisive.
|
|
||||||
|
|
||||||
## Practical conclusion
|
## Practical conclusion
|
||||||
|
|
||||||
For practical use in this repo, the current winning path is:
|
For practical use in this repo, the current winning path is:
|
||||||
@@ -143,12 +184,15 @@ So the current local recommendation is:
|
|||||||
- keep `1:1:2` as the main practical split
|
- keep `1:1:2` as the main practical split
|
||||||
- treat DRF as an optional research branch, not the mainline model
|
- treat DRF as an optional research branch, not the mainline model
|
||||||
|
|
||||||
|
If the goal is practical deployment/use, use the retained best skeleton checkpoint family first.
|
||||||
|
If the goal is paper audit or author-checkpoint verification, use the dedicated DRF compatibility configs instead.
|
||||||
|
|
||||||
## Remaining useful experiments
|
## Remaining useful experiments
|
||||||
|
|
||||||
At this point, there are only a few experiments that still look high-value:
|
At this point, there are only a few experiments that still look high-value:
|
||||||
|
|
||||||
1. one clean `full-body` finetune under the same successful `1:1:2` recipe, just to confirm that `body-only` is really the best practical representation
|
1. one clean `full-body` finetune under the same successful `1:1:2` recipe, just to confirm that `body-only` is really the best practical representation
|
||||||
2. one DRF rerun on top of the now-stronger practical baseline recipe, only if the goal is to test whether DRF can add value once the skeleton branch is already strong
|
2. one DRF warm-start rerun on top of the now-stronger practical baseline recipe, only if the goal is to test whether DRF can add value once the skeleton branch is already strong
|
||||||
3. a final packaging/evaluation pass around the retained best checkpoints, rather than more broad preprocessing churn
|
3. a final packaging/evaluation pass around the retained best checkpoints, rather than more broad preprocessing churn
|
||||||
|
|
||||||
Everything else looks lower value than simply using the retained best `27000` checkpoint.
|
Everything else looks lower value than simply using the retained best `27000` checkpoint.
|
||||||
|
|||||||
@@ -0,0 +1,121 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
from collections import Counter
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TypedDict, cast
|
||||||
|
|
||||||
|
import click
|
||||||
|
|
||||||
|
|
||||||
|
class Partition(TypedDict):
|
||||||
|
TRAIN_SET: list[str]
|
||||||
|
TEST_SET: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
def infer_pid_label(dataset_root: Path, pid: str) -> str:
|
||||||
|
pid_root = dataset_root / pid
|
||||||
|
if not pid_root.exists():
|
||||||
|
raise FileNotFoundError(f"PID root not found under dataset root: {pid_root}")
|
||||||
|
label_dirs = sorted([entry.name.lower() for entry in pid_root.iterdir() if entry.is_dir()])
|
||||||
|
if len(label_dirs) != 1:
|
||||||
|
raise ValueError(f"Expected exactly one class dir for pid {pid}, got {label_dirs}")
|
||||||
|
label = label_dirs[0]
|
||||||
|
if label not in {"positive", "neutral", "negative"}:
|
||||||
|
raise ValueError(f"Unexpected label directory for pid {pid}: {label}")
|
||||||
|
return label
|
||||||
|
|
||||||
|
|
||||||
|
@click.command()
|
||||||
|
@click.option(
|
||||||
|
"--base-partition",
|
||||||
|
type=click.Path(path_type=Path, exists=True, dir_okay=False),
|
||||||
|
required=True,
|
||||||
|
help="Path to the source partition JSON, e.g. datasets/Scoliosis1K/Scoliosis1K_118.json",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--dataset-root",
|
||||||
|
type=click.Path(path_type=Path, exists=True, file_okay=False),
|
||||||
|
required=True,
|
||||||
|
help="Dataset root used to infer each pid class label, e.g. /mnt/public/data/Scoliosis1K/Scoliosis1K-sil-pkl",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--negative-multiplier",
|
||||||
|
type=int,
|
||||||
|
required=True,
|
||||||
|
help="Target negative count as a multiple of the positive/neutral count, e.g. 2 for 1:1:2",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--output-path",
|
||||||
|
type=click.Path(path_type=Path, dir_okay=False),
|
||||||
|
required=True,
|
||||||
|
help="Path to write the derived partition JSON.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=118,
|
||||||
|
show_default=True,
|
||||||
|
help="Random seed used when downsampling negatives.",
|
||||||
|
)
|
||||||
|
def main(
|
||||||
|
base_partition: Path,
|
||||||
|
dataset_root: Path,
|
||||||
|
negative_multiplier: int,
|
||||||
|
output_path: Path,
|
||||||
|
seed: int,
|
||||||
|
) -> None:
|
||||||
|
with base_partition.open("r", encoding="utf-8") as handle:
|
||||||
|
partition = cast(Partition, json.load(handle))
|
||||||
|
|
||||||
|
train_ids = list(partition["TRAIN_SET"])
|
||||||
|
test_ids = list(partition["TEST_SET"])
|
||||||
|
|
||||||
|
train_by_label: dict[str, list[str]] = {"positive": [], "neutral": [], "negative": []}
|
||||||
|
for pid in train_ids:
|
||||||
|
label = infer_pid_label(dataset_root, pid)
|
||||||
|
train_by_label[label].append(pid)
|
||||||
|
|
||||||
|
pos_count = len(train_by_label["positive"])
|
||||||
|
neu_count = len(train_by_label["neutral"])
|
||||||
|
neg_count = len(train_by_label["negative"])
|
||||||
|
if pos_count != neu_count:
|
||||||
|
raise ValueError(
|
||||||
|
"This helper assumes equal positive/neutral train counts so that only "
|
||||||
|
+ f"negative downsampling changes the ratio. Got positive={pos_count}, neutral={neu_count}."
|
||||||
|
)
|
||||||
|
|
||||||
|
target_negative_count = negative_multiplier * pos_count
|
||||||
|
if target_negative_count > neg_count:
|
||||||
|
raise ValueError(
|
||||||
|
f"Requested {target_negative_count} negatives but only {neg_count} are available "
|
||||||
|
+ f"in base partition {base_partition}."
|
||||||
|
)
|
||||||
|
|
||||||
|
rng = random.Random(seed)
|
||||||
|
sampled_negatives = sorted(rng.sample(train_by_label["negative"], target_negative_count))
|
||||||
|
derived_train = (
|
||||||
|
sorted(train_by_label["positive"])
|
||||||
|
+ sorted(train_by_label["neutral"])
|
||||||
|
+ sampled_negatives
|
||||||
|
)
|
||||||
|
|
||||||
|
derived_partition = {
|
||||||
|
"TRAIN_SET": derived_train,
|
||||||
|
"TEST_SET": test_ids,
|
||||||
|
}
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with output_path.open("w", encoding="utf-8") as handle:
|
||||||
|
json.dump(derived_partition, handle, indent=2)
|
||||||
|
_ = handle.write("\n")
|
||||||
|
|
||||||
|
train_counts = Counter(infer_pid_label(dataset_root, pid) for pid in derived_train)
|
||||||
|
test_counts = Counter(infer_pid_label(dataset_root, pid) for pid in test_ids)
|
||||||
|
click.echo(f"wrote {output_path}")
|
||||||
|
click.echo(f"train_counts={dict(train_counts)}")
|
||||||
|
click.echo(f"test_counts={dict(test_counts)}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user