Align DRF skeleton preprocessing with upstream heatmap path

This commit is contained in:
2026-03-08 14:50:35 +08:00
parent bbb41e8dd9
commit 295d951206
10 changed files with 174 additions and 21 deletions
+3 -3
View File
@@ -1,6 +1,6 @@
data_cfg: data_cfg:
dataset_name: Scoliosis1K dataset_name: Scoliosis1K
dataset_root: /mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-paper dataset_root: /mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-aligned
dataset_partition: ./datasets/Scoliosis1K/Scoliosis1K_118.json dataset_partition: ./datasets/Scoliosis1K/Scoliosis1K_118.json
num_workers: 1 num_workers: 1
remove_no_gallery: false remove_no_gallery: false
@@ -19,7 +19,7 @@ evaluator_cfg:
frames_all_limit: 720 frames_all_limit: 720
metric: euc metric: euc
transform: transform:
- type: BaseSilTransform - type: BaseSilCuttingTransform
- type: NoOperation - type: NoOperation
loss_cfg: loss_cfg:
@@ -102,5 +102,5 @@ trainer_cfg:
sample_type: fixed_unordered sample_type: fixed_unordered
type: TripletSampler type: TripletSampler
transform: transform:
- type: BaseSilTransform - type: BaseSilCuttingTransform
- type: NoOperation - type: NoOperation
+3 -3
View File
@@ -1,6 +1,6 @@
data_cfg: data_cfg:
dataset_name: Scoliosis1K dataset_name: Scoliosis1K
dataset_root: /mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-paper dataset_root: /mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-aligned
dataset_partition: ./datasets/Scoliosis1K/Scoliosis1K_118.json dataset_partition: ./datasets/Scoliosis1K/Scoliosis1K_118.json
num_workers: 1 num_workers: 1
remove_no_gallery: false remove_no_gallery: false
@@ -19,7 +19,7 @@ evaluator_cfg:
frames_all_limit: 720 frames_all_limit: 720
metric: euc metric: euc
transform: transform:
- type: BaseSilTransform - type: BaseSilCuttingTransform
- type: NoOperation - type: NoOperation
loss_cfg: loss_cfg:
@@ -102,5 +102,5 @@ trainer_cfg:
sample_type: fixed_unordered sample_type: fixed_unordered
type: TripletSampler type: TripletSampler
transform: transform:
- type: BaseSilTransform - type: BaseSilCuttingTransform
- type: NoOperation - type: NoOperation
+3 -3
View File
@@ -1,6 +1,6 @@
data_cfg: data_cfg:
dataset_name: Scoliosis1K dataset_name: Scoliosis1K
dataset_root: /mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-paper dataset_root: /mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-aligned
dataset_partition: ./datasets/Scoliosis1K/Scoliosis1K_118.json dataset_partition: ./datasets/Scoliosis1K/Scoliosis1K_118.json
num_workers: 1 num_workers: 1
remove_no_gallery: false remove_no_gallery: false
@@ -23,7 +23,7 @@ evaluator_cfg:
frames_all_limit: 720 frames_all_limit: 720
metric: euc metric: euc
transform: transform:
- type: BaseSilTransform - type: BaseSilCuttingTransform
- type: NoOperation - type: NoOperation
loss_cfg: loss_cfg:
@@ -109,5 +109,5 @@ trainer_cfg:
sample_type: fixed_unordered sample_type: fixed_unordered
type: TripletSampler type: TripletSampler
transform: transform:
- type: BaseSilTransform - type: BaseSilCuttingTransform
- type: NoOperation - type: NoOperation
@@ -9,7 +9,6 @@ norm_args:
pose_format: coco pose_format: coco
use_conf: ${padkeypoints_args.use_conf} use_conf: ${padkeypoints_args.use_conf}
heatmap_image_height: 128 heatmap_image_height: 128
target_body_height: ${norm_args.heatmap_image_height}
heatmap_generator_args: heatmap_generator_args:
sigma: 8.0 sigma: 8.0
@@ -0,0 +1,105 @@
data_cfg:
dataset_name: Scoliosis1K
dataset_root: /mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-aligned
dataset_partition: ./datasets/Scoliosis1K/Scoliosis1K_118.json
data_in_use:
- true
- false
num_workers: 1
remove_no_gallery: false
test_dataset_name: Scoliosis1K
evaluator_cfg:
enable_float16: true
restore_ckpt_strict: true
restore_hint: 20000
save_name: ScoNet_skeleton_118
eval_func: evaluate_scoliosis
sampler:
batch_shuffle: false
batch_size: 2
sample_type: all_ordered
frames_all_limit: 720
metric: euc
transform:
- type: BaseSilCuttingTransform
loss_cfg:
- loss_term_weight: 1.0
margin: 0.2
type: TripletLoss
log_prefix: triplet
- loss_term_weight: 1.0
scale: 16
type: CrossEntropyLoss
log_prefix: softmax
log_accuracy: true
model_cfg:
model: ScoNet
backbone_cfg:
type: ResNet9
block: BasicBlock
in_channel: 2
channels:
- 64
- 128
- 256
- 512
layers:
- 1
- 1
- 1
- 1
strides:
- 1
- 2
- 2
- 1
maxpool: false
SeparateFCs:
in_channels: 512
out_channels: 256
parts_num: 16
SeparateBNNecks:
class_num: 3
in_channels: 256
parts_num: 16
bin_num:
- 16
optimizer_cfg:
lr: 0.1
momentum: 0.9
solver: SGD
weight_decay: 0.0005
scheduler_cfg:
gamma: 0.1
milestones:
- 10000
- 14000
- 18000
scheduler: MultiStepLR
trainer_cfg:
enable_float16: true
fix_BN: false
with_test: false
log_iter: 100
restore_ckpt_strict: true
restore_hint: 0
save_iter: 20000
save_name: ScoNet_skeleton_118
sync_BN: true
total_iter: 20000
sampler:
batch_shuffle: true
batch_size:
- 8
- 8
frames_num_fixed: 30
sample_type: fixed_unordered
type: TripletSampler
transform:
- type: BaseSilCuttingTransform
+11 -3
View File
@@ -109,9 +109,17 @@ uv run python datasets/pretreatment_scoliosis_drf.py \
--output_path=<path_to_drf_pkl> --output_path=<path_to_drf_pkl>
``` ```
To reproduce the paper defaults more closely, the script now uses The script uses `configs/drf/pretreatment_heatmap_drf.yaml` by default.
`configs/drf/pretreatment_heatmap_drf.yaml` by default, which enables That keeps the upstream OpenGait/SkeletonGait heatmap behavior from
summed two-channel skeleton maps and a literal 128-pixel height normalization. commit `f754f6f3831e9f83bb28f4e2f63dd43d8bcf9dc4` for the skeleton-map
branch while still building the DRF-specific two-channel output.
If you explicitly want the more paper-literal summed heatmap ablation, add:
```bash
--heatmap_reduction=sum
```
If you explicitly want train-only PAV min-max statistics, add: If you explicitly want train-only PAV min-max statistics, add:
```bash ```bash
+15 -4
View File
@@ -8,6 +8,7 @@ import pickle
import argparse import argparse
import numpy as np import numpy as np
from glob import glob from glob import glob
from typing import Literal
from tqdm import tqdm from tqdm import tqdm
import matplotlib.cm as cm import matplotlib.cm as cm
import torch.distributed as dist import torch.distributed as dist
@@ -328,7 +329,7 @@ class HeatmapToImage:
class HeatmapReducer: class HeatmapReducer:
"""Reduce stacked joint/limb heatmaps to a single grayscale channel.""" """Reduce stacked joint/limb heatmaps to a single grayscale channel."""
def __init__(self, reduction: str = "max") -> None: def __init__(self, reduction: Literal["max", "sum"] = "max") -> None:
if reduction not in {"max", "sum"}: if reduction not in {"max", "sum"}:
raise ValueError(f"Unsupported heatmap reduction: {reduction}") raise ValueError(f"Unsupported heatmap reduction: {reduction}")
self.reduction = reduction self.reduction = reduction
@@ -574,7 +575,7 @@ def GenerateHeatmapTransform(
norm_args, norm_args,
heatmap_generator_args, heatmap_generator_args,
align_args, align_args,
reduction="max", reduction: Literal["upstream", "max", "sum"] = "upstream",
): ):
base_transform = T.Compose([ base_transform = T.Compose([
@@ -585,17 +586,27 @@ def GenerateHeatmapTransform(
heatmap_generator_args["with_limb"] = True heatmap_generator_args["with_limb"] = True
heatmap_generator_args["with_kp"] = False heatmap_generator_args["with_kp"] = False
bone_image_transform = (
HeatmapToImage()
if reduction == "upstream"
else HeatmapReducer(reduction=reduction)
)
transform_bone = T.Compose([ transform_bone = T.Compose([
GeneratePoseTarget(**heatmap_generator_args), GeneratePoseTarget(**heatmap_generator_args),
HeatmapReducer(reduction=reduction), bone_image_transform,
HeatmapAlignment(**align_args) HeatmapAlignment(**align_args)
]) ])
heatmap_generator_args["with_limb"] = False heatmap_generator_args["with_limb"] = False
heatmap_generator_args["with_kp"] = True heatmap_generator_args["with_kp"] = True
joint_image_transform = (
HeatmapToImage()
if reduction == "upstream"
else HeatmapReducer(reduction=reduction)
)
transform_joint = T.Compose([ transform_joint = T.Compose([
GeneratePoseTarget(**heatmap_generator_args), GeneratePoseTarget(**heatmap_generator_args),
HeatmapReducer(reduction=reduction), joint_image_transform,
HeatmapAlignment(**align_args) HeatmapAlignment(**align_args)
]) ])
+13 -2
View File
@@ -7,7 +7,7 @@ import pickle
import sys import sys
from glob import glob from glob import glob
from pathlib import Path from pathlib import Path
from typing import Any, TypedDict, cast from typing import Any, Literal, TypedDict, cast
import numpy as np import numpy as np
import yaml import yaml
@@ -34,6 +34,7 @@ JOINT_PAIRS = (
) )
EPS = 1e-6 EPS = 1e-6
FloatArray = NDArray[np.float32] FloatArray = NDArray[np.float32]
HeatmapReduction = Literal["upstream", "max", "sum"]
class SequenceRecord(TypedDict): class SequenceRecord(TypedDict):
@@ -66,6 +67,16 @@ def get_args() -> argparse.Namespace:
default="configs/drf/pretreatment_heatmap_drf.yaml", default="configs/drf/pretreatment_heatmap_drf.yaml",
help="Heatmap preprocessing config used to build the skeleton map branch.", help="Heatmap preprocessing config used to build the skeleton map branch.",
) )
_ = parser.add_argument(
"--heatmap_reduction",
type=str,
choices=["upstream", "max", "sum"],
default="upstream",
help=(
"How to collapse joint/limb heatmaps into one channel. "
"'upstream' matches OpenGait at f754f6f..., while 'sum' keeps the paper-literal ablation."
),
)
_ = parser.add_argument( _ = parser.add_argument(
"--stats_partition", "--stats_partition",
type=str, type=str,
@@ -180,7 +191,7 @@ def main() -> None:
norm_args=heatmap_cfg["norm_args"], norm_args=heatmap_cfg["norm_args"],
heatmap_generator_args=heatmap_cfg["heatmap_generator_args"], heatmap_generator_args=heatmap_cfg["heatmap_generator_args"],
align_args=heatmap_cfg["align_args"], align_args=heatmap_cfg["align_args"],
reduction="sum", reduction=cast(HeatmapReduction, args.heatmap_reduction),
) )
pose_paths = iter_pose_paths(args.pose_data_path) pose_paths = iter_pose_paths(args.pose_data_path)
+7 -1
View File
@@ -47,7 +47,9 @@ class BaseModelBody(BaseModel):
labs = list2var(labs_batch).long() labs = list2var(labs_batch).long()
seqL = np2var(seqL_batch).int() if seqL_batch is not None else None seqL = np2var(seqL_batch).int() if seqL_batch is not None else None
body_features = aggregate_body_features(body_seq, seqL) # Preserve a singleton modality axis so DRF can mirror the author stub's
# `squeeze(1)` behavior while still accepting the same sequence-level prior.
body_features = aggregate_body_features(body_seq, seqL).unsqueeze(1)
if seqL is not None: if seqL is not None:
seqL_sum = int(seqL.sum().data.cpu().numpy()) seqL_sum = int(seqL.sum().data.cpu().numpy())
@@ -80,3 +82,7 @@ def aggregate_body_features(
aggregated.append(flattened[start:end].mean(dim=0)) aggregated.append(flattened[start:end].mean(dim=0))
start = end start = end
return torch.stack(aggregated, dim=0) return torch.stack(aggregated, dim=0)
# Match the symbol name used by the author-provided DRF stub.
BaseModel = BaseModelBody
+14 -1
View File
@@ -43,7 +43,7 @@ class DRF(BaseModelBody):
list[str], list[str],
list[str], list[str],
Int[torch.Tensor, "1 batch"] | None, Int[torch.Tensor, "1 batch"] | None,
Float[torch.Tensor, "batch pairs metrics"], Float[torch.Tensor, "batch _ pairs metrics"] | Float[torch.Tensor, "batch pairs metrics"],
], ],
) -> dict[str, dict[str, Any]]: ) -> dict[str, dict[str, Any]]:
ipts, pids, labels, _, seqL, key_features = inputs ipts, pids, labels, _, seqL, key_features = inputs
@@ -64,6 +64,7 @@ class DRF(BaseModelBody):
feat = self.HPP(outs) feat = self.HPP(outs)
embed_1 = self.FCs(feat) embed_1 = self.FCs(feat)
key_features = canonicalize_pav(key_features)
embed_1 = self.PGA(embed_1, key_features) embed_1 = self.PGA(embed_1, key_features)
embed_2, logits = self.BNNecks(embed_1) embed_2, logits = self.BNNecks(embed_1)
@@ -123,3 +124,15 @@ LABEL_MAP: dict[str, int] = {
"neutral": 1, "neutral": 1,
"positive": 2, "positive": 2,
} }
def canonicalize_pav(
pav: Float[torch.Tensor, "batch _ pairs metrics"] | Float[torch.Tensor, "batch pairs metrics"],
) -> Float[torch.Tensor, "batch pairs metrics"]:
if pav.ndim == 4:
if pav.shape[1] != 1:
raise ValueError(f"Expected singleton PAV axis, got shape {tuple(pav.shape)}")
return pav.squeeze(1)
if pav.ndim != 3:
raise ValueError(f"Expected PAV with 3 or 4 dims, got shape {tuple(pav.shape)}")
return pav