Align DRF skeleton preprocessing with upstream heatmap path

This commit is contained in:
2026-03-08 14:50:35 +08:00
parent bbb41e8dd9
commit 295d951206
10 changed files with 174 additions and 21 deletions
+3 -3
View File
@@ -1,6 +1,6 @@
data_cfg:
dataset_name: Scoliosis1K
dataset_root: /mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-paper
dataset_root: /mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-aligned
dataset_partition: ./datasets/Scoliosis1K/Scoliosis1K_118.json
num_workers: 1
remove_no_gallery: false
@@ -19,7 +19,7 @@ evaluator_cfg:
frames_all_limit: 720
metric: euc
transform:
- type: BaseSilTransform
- type: BaseSilCuttingTransform
- type: NoOperation
loss_cfg:
@@ -102,5 +102,5 @@ trainer_cfg:
sample_type: fixed_unordered
type: TripletSampler
transform:
- type: BaseSilTransform
- type: BaseSilCuttingTransform
- type: NoOperation
+3 -3
View File
@@ -1,6 +1,6 @@
data_cfg:
dataset_name: Scoliosis1K
dataset_root: /mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-paper
dataset_root: /mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-aligned
dataset_partition: ./datasets/Scoliosis1K/Scoliosis1K_118.json
num_workers: 1
remove_no_gallery: false
@@ -19,7 +19,7 @@ evaluator_cfg:
frames_all_limit: 720
metric: euc
transform:
- type: BaseSilTransform
- type: BaseSilCuttingTransform
- type: NoOperation
loss_cfg:
@@ -102,5 +102,5 @@ trainer_cfg:
sample_type: fixed_unordered
type: TripletSampler
transform:
- type: BaseSilTransform
- type: BaseSilCuttingTransform
- type: NoOperation
+3 -3
View File
@@ -1,6 +1,6 @@
data_cfg:
dataset_name: Scoliosis1K
dataset_root: /mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-paper
dataset_root: /mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-aligned
dataset_partition: ./datasets/Scoliosis1K/Scoliosis1K_118.json
num_workers: 1
remove_no_gallery: false
@@ -23,7 +23,7 @@ evaluator_cfg:
frames_all_limit: 720
metric: euc
transform:
- type: BaseSilTransform
- type: BaseSilCuttingTransform
- type: NoOperation
loss_cfg:
@@ -109,5 +109,5 @@ trainer_cfg:
sample_type: fixed_unordered
type: TripletSampler
transform:
- type: BaseSilTransform
- type: BaseSilCuttingTransform
- type: NoOperation
@@ -9,7 +9,6 @@ norm_args:
pose_format: coco
use_conf: ${padkeypoints_args.use_conf}
heatmap_image_height: 128
target_body_height: ${norm_args.heatmap_image_height}
heatmap_generator_args:
sigma: 8.0
@@ -0,0 +1,105 @@
data_cfg:
dataset_name: Scoliosis1K
dataset_root: /mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-aligned
dataset_partition: ./datasets/Scoliosis1K/Scoliosis1K_118.json
data_in_use:
- true
- false
num_workers: 1
remove_no_gallery: false
test_dataset_name: Scoliosis1K
evaluator_cfg:
enable_float16: true
restore_ckpt_strict: true
restore_hint: 20000
save_name: ScoNet_skeleton_118
eval_func: evaluate_scoliosis
sampler:
batch_shuffle: false
batch_size: 2
sample_type: all_ordered
frames_all_limit: 720
metric: euc
transform:
- type: BaseSilCuttingTransform
loss_cfg:
- loss_term_weight: 1.0
margin: 0.2
type: TripletLoss
log_prefix: triplet
- loss_term_weight: 1.0
scale: 16
type: CrossEntropyLoss
log_prefix: softmax
log_accuracy: true
model_cfg:
model: ScoNet
backbone_cfg:
type: ResNet9
block: BasicBlock
in_channel: 2
channels:
- 64
- 128
- 256
- 512
layers:
- 1
- 1
- 1
- 1
strides:
- 1
- 2
- 2
- 1
maxpool: false
SeparateFCs:
in_channels: 512
out_channels: 256
parts_num: 16
SeparateBNNecks:
class_num: 3
in_channels: 256
parts_num: 16
bin_num:
- 16
optimizer_cfg:
lr: 0.1
momentum: 0.9
solver: SGD
weight_decay: 0.0005
scheduler_cfg:
gamma: 0.1
milestones:
- 10000
- 14000
- 18000
scheduler: MultiStepLR
trainer_cfg:
enable_float16: true
fix_BN: false
with_test: false
log_iter: 100
restore_ckpt_strict: true
restore_hint: 0
save_iter: 20000
save_name: ScoNet_skeleton_118
sync_BN: true
total_iter: 20000
sampler:
batch_shuffle: true
batch_size:
- 8
- 8
frames_num_fixed: 30
sample_type: fixed_unordered
type: TripletSampler
transform:
- type: BaseSilCuttingTransform
+11 -3
View File
@@ -109,9 +109,17 @@ uv run python datasets/pretreatment_scoliosis_drf.py \
--output_path=<path_to_drf_pkl>
```
To reproduce the paper defaults more closely, the script now uses
`configs/drf/pretreatment_heatmap_drf.yaml` by default, which enables
summed two-channel skeleton maps and a literal 128-pixel height normalization.
The script uses `configs/drf/pretreatment_heatmap_drf.yaml` by default.
That keeps the upstream OpenGait/SkeletonGait heatmap behavior from
commit `f754f6f3831e9f83bb28f4e2f63dd43d8bcf9dc4` for the skeleton-map
branch while still building the DRF-specific two-channel output.
If you explicitly want the more paper-literal summed heatmap ablation, add:
```bash
--heatmap_reduction=sum
```
If you explicitly want train-only PAV min-max statistics, add:
```bash
+15 -4
View File
@@ -8,6 +8,7 @@ import pickle
import argparse
import numpy as np
from glob import glob
from typing import Literal
from tqdm import tqdm
import matplotlib.cm as cm
import torch.distributed as dist
@@ -328,7 +329,7 @@ class HeatmapToImage:
class HeatmapReducer:
"""Reduce stacked joint/limb heatmaps to a single grayscale channel."""
def __init__(self, reduction: str = "max") -> None:
def __init__(self, reduction: Literal["max", "sum"] = "max") -> None:
if reduction not in {"max", "sum"}:
raise ValueError(f"Unsupported heatmap reduction: {reduction}")
self.reduction = reduction
@@ -574,7 +575,7 @@ def GenerateHeatmapTransform(
norm_args,
heatmap_generator_args,
align_args,
reduction="max",
reduction: Literal["upstream", "max", "sum"] = "upstream",
):
base_transform = T.Compose([
@@ -585,17 +586,27 @@ def GenerateHeatmapTransform(
heatmap_generator_args["with_limb"] = True
heatmap_generator_args["with_kp"] = False
bone_image_transform = (
HeatmapToImage()
if reduction == "upstream"
else HeatmapReducer(reduction=reduction)
)
transform_bone = T.Compose([
GeneratePoseTarget(**heatmap_generator_args),
HeatmapReducer(reduction=reduction),
bone_image_transform,
HeatmapAlignment(**align_args)
])
heatmap_generator_args["with_limb"] = False
heatmap_generator_args["with_kp"] = True
joint_image_transform = (
HeatmapToImage()
if reduction == "upstream"
else HeatmapReducer(reduction=reduction)
)
transform_joint = T.Compose([
GeneratePoseTarget(**heatmap_generator_args),
HeatmapReducer(reduction=reduction),
joint_image_transform,
HeatmapAlignment(**align_args)
])
+13 -2
View File
@@ -7,7 +7,7 @@ import pickle
import sys
from glob import glob
from pathlib import Path
from typing import Any, TypedDict, cast
from typing import Any, Literal, TypedDict, cast
import numpy as np
import yaml
@@ -34,6 +34,7 @@ JOINT_PAIRS = (
)
EPS = 1e-6
FloatArray = NDArray[np.float32]
HeatmapReduction = Literal["upstream", "max", "sum"]
class SequenceRecord(TypedDict):
@@ -66,6 +67,16 @@ def get_args() -> argparse.Namespace:
default="configs/drf/pretreatment_heatmap_drf.yaml",
help="Heatmap preprocessing config used to build the skeleton map branch.",
)
_ = parser.add_argument(
"--heatmap_reduction",
type=str,
choices=["upstream", "max", "sum"],
default="upstream",
help=(
"How to collapse joint/limb heatmaps into one channel. "
"'upstream' matches OpenGait at f754f6f..., while 'sum' keeps the paper-literal ablation."
),
)
_ = parser.add_argument(
"--stats_partition",
type=str,
@@ -180,7 +191,7 @@ def main() -> None:
norm_args=heatmap_cfg["norm_args"],
heatmap_generator_args=heatmap_cfg["heatmap_generator_args"],
align_args=heatmap_cfg["align_args"],
reduction="sum",
reduction=cast(HeatmapReduction, args.heatmap_reduction),
)
pose_paths = iter_pose_paths(args.pose_data_path)
+7 -1
View File
@@ -47,7 +47,9 @@ class BaseModelBody(BaseModel):
labs = list2var(labs_batch).long()
seqL = np2var(seqL_batch).int() if seqL_batch is not None else None
body_features = aggregate_body_features(body_seq, seqL)
# Preserve a singleton modality axis so DRF can mirror the author stub's
# `squeeze(1)` behavior while still accepting the same sequence-level prior.
body_features = aggregate_body_features(body_seq, seqL).unsqueeze(1)
if seqL is not None:
seqL_sum = int(seqL.sum().data.cpu().numpy())
@@ -80,3 +82,7 @@ def aggregate_body_features(
aggregated.append(flattened[start:end].mean(dim=0))
start = end
return torch.stack(aggregated, dim=0)
# Match the symbol name used by the author-provided DRF stub.
BaseModel = BaseModelBody
+14 -1
View File
@@ -43,7 +43,7 @@ class DRF(BaseModelBody):
list[str],
list[str],
Int[torch.Tensor, "1 batch"] | None,
Float[torch.Tensor, "batch pairs metrics"],
Float[torch.Tensor, "batch _ pairs metrics"] | Float[torch.Tensor, "batch pairs metrics"],
],
) -> dict[str, dict[str, Any]]:
ipts, pids, labels, _, seqL, key_features = inputs
@@ -64,6 +64,7 @@ class DRF(BaseModelBody):
feat = self.HPP(outs)
embed_1 = self.FCs(feat)
key_features = canonicalize_pav(key_features)
embed_1 = self.PGA(embed_1, key_features)
embed_2, logits = self.BNNecks(embed_1)
@@ -123,3 +124,15 @@ LABEL_MAP: dict[str, int] = {
"neutral": 1,
"positive": 2,
}
def canonicalize_pav(
pav: Float[torch.Tensor, "batch _ pairs metrics"] | Float[torch.Tensor, "batch pairs metrics"],
) -> Float[torch.Tensor, "batch pairs metrics"]:
if pav.ndim == 4:
if pav.shape[1] != 1:
raise ValueError(f"Expected singleton PAV axis, got shape {tuple(pav.shape)}")
return pav.squeeze(1)
if pav.ndim != 3:
raise ValueError(f"Expected PAV with 3 or 4 dims, got shape {tuple(pav.shape)}")
return pav