Align DRF skeleton preprocessing with upstream heatmap path
This commit is contained in:
@@ -109,9 +109,17 @@ uv run python datasets/pretreatment_scoliosis_drf.py \
|
||||
--output_path=<path_to_drf_pkl>
|
||||
```
|
||||
|
||||
To reproduce the paper defaults more closely, the script now uses
|
||||
`configs/drf/pretreatment_heatmap_drf.yaml` by default, which enables
|
||||
summed two-channel skeleton maps and a literal 128-pixel height normalization.
|
||||
The script uses `configs/drf/pretreatment_heatmap_drf.yaml` by default.
|
||||
That keeps the upstream OpenGait/SkeletonGait heatmap behavior from
|
||||
commit `f754f6f3831e9f83bb28f4e2f63dd43d8bcf9dc4` for the skeleton-map
|
||||
branch while still building the DRF-specific two-channel output.
|
||||
|
||||
If you explicitly want the more paper-literal summed heatmap ablation, add:
|
||||
|
||||
```bash
|
||||
--heatmap_reduction=sum
|
||||
```
|
||||
|
||||
If you explicitly want train-only PAV min-max statistics, add:
|
||||
|
||||
```bash
|
||||
|
||||
@@ -8,6 +8,7 @@ import pickle
|
||||
import argparse
|
||||
import numpy as np
|
||||
from glob import glob
|
||||
from typing import Literal
|
||||
from tqdm import tqdm
|
||||
import matplotlib.cm as cm
|
||||
import torch.distributed as dist
|
||||
@@ -328,7 +329,7 @@ class HeatmapToImage:
|
||||
class HeatmapReducer:
|
||||
"""Reduce stacked joint/limb heatmaps to a single grayscale channel."""
|
||||
|
||||
def __init__(self, reduction: str = "max") -> None:
|
||||
def __init__(self, reduction: Literal["max", "sum"] = "max") -> None:
|
||||
if reduction not in {"max", "sum"}:
|
||||
raise ValueError(f"Unsupported heatmap reduction: {reduction}")
|
||||
self.reduction = reduction
|
||||
@@ -574,7 +575,7 @@ def GenerateHeatmapTransform(
|
||||
norm_args,
|
||||
heatmap_generator_args,
|
||||
align_args,
|
||||
reduction="max",
|
||||
reduction: Literal["upstream", "max", "sum"] = "upstream",
|
||||
):
|
||||
|
||||
base_transform = T.Compose([
|
||||
@@ -585,17 +586,27 @@ def GenerateHeatmapTransform(
|
||||
|
||||
heatmap_generator_args["with_limb"] = True
|
||||
heatmap_generator_args["with_kp"] = False
|
||||
bone_image_transform = (
|
||||
HeatmapToImage()
|
||||
if reduction == "upstream"
|
||||
else HeatmapReducer(reduction=reduction)
|
||||
)
|
||||
transform_bone = T.Compose([
|
||||
GeneratePoseTarget(**heatmap_generator_args),
|
||||
HeatmapReducer(reduction=reduction),
|
||||
bone_image_transform,
|
||||
HeatmapAlignment(**align_args)
|
||||
])
|
||||
|
||||
heatmap_generator_args["with_limb"] = False
|
||||
heatmap_generator_args["with_kp"] = True
|
||||
joint_image_transform = (
|
||||
HeatmapToImage()
|
||||
if reduction == "upstream"
|
||||
else HeatmapReducer(reduction=reduction)
|
||||
)
|
||||
transform_joint = T.Compose([
|
||||
GeneratePoseTarget(**heatmap_generator_args),
|
||||
HeatmapReducer(reduction=reduction),
|
||||
joint_image_transform,
|
||||
HeatmapAlignment(**align_args)
|
||||
])
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import pickle
|
||||
import sys
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
from typing import Any, TypedDict, cast
|
||||
from typing import Any, Literal, TypedDict, cast
|
||||
|
||||
import numpy as np
|
||||
import yaml
|
||||
@@ -34,6 +34,7 @@ JOINT_PAIRS = (
|
||||
)
|
||||
EPS = 1e-6
|
||||
FloatArray = NDArray[np.float32]
|
||||
HeatmapReduction = Literal["upstream", "max", "sum"]
|
||||
|
||||
|
||||
class SequenceRecord(TypedDict):
|
||||
@@ -66,6 +67,16 @@ def get_args() -> argparse.Namespace:
|
||||
default="configs/drf/pretreatment_heatmap_drf.yaml",
|
||||
help="Heatmap preprocessing config used to build the skeleton map branch.",
|
||||
)
|
||||
_ = parser.add_argument(
|
||||
"--heatmap_reduction",
|
||||
type=str,
|
||||
choices=["upstream", "max", "sum"],
|
||||
default="upstream",
|
||||
help=(
|
||||
"How to collapse joint/limb heatmaps into one channel. "
|
||||
"'upstream' matches OpenGait at f754f6f..., while 'sum' keeps the paper-literal ablation."
|
||||
),
|
||||
)
|
||||
_ = parser.add_argument(
|
||||
"--stats_partition",
|
||||
type=str,
|
||||
@@ -180,7 +191,7 @@ def main() -> None:
|
||||
norm_args=heatmap_cfg["norm_args"],
|
||||
heatmap_generator_args=heatmap_cfg["heatmap_generator_args"],
|
||||
align_args=heatmap_cfg["align_args"],
|
||||
reduction="sum",
|
||||
reduction=cast(HeatmapReduction, args.heatmap_reduction),
|
||||
)
|
||||
|
||||
pose_paths = iter_pose_paths(args.pose_data_path)
|
||||
|
||||
Reference in New Issue
Block a user