Refine DRF preprocessing and body-prior pipeline

This commit is contained in:
2026-03-08 04:04:15 +08:00
parent fddbf6eeda
commit bbb41e8dd9
10 changed files with 448 additions and 53 deletions
+10 -4
View File
@@ -7,7 +7,7 @@ import pickle
import sys
from glob import glob
from pathlib import Path
from typing import Any, TypedDict
from typing import Any, TypedDict, cast
import numpy as np
import yaml
@@ -63,14 +63,17 @@ def get_args() -> argparse.Namespace:
_ = parser.add_argument(
"--heatmap_cfg_path",
type=str,
default="configs/skeletongait/pretreatment_heatmap.yaml",
default="configs/drf/pretreatment_heatmap_drf.yaml",
help="Heatmap preprocessing config used to build the skeleton map branch.",
)
_ = parser.add_argument(
"--stats_partition",
type=str,
default=None,
help="Optional dataset partition JSON. When set, PAV min/max stats use TRAIN_SET ids only.",
help=(
"Optional dataset partition JSON. When set, PAV min/max stats use TRAIN_SET ids only. "
"Omit it to match the paper's dataset-level min-max normalization."
),
)
return parser.parse_args()
@@ -79,7 +82,9 @@ def load_heatmap_cfg(cfg_path: str) -> dict[str, Any]:
with open(cfg_path, "r", encoding="utf-8") as stream:
cfg = yaml.safe_load(stream)
replaced = heatmap_prep.replace_variables(cfg, cfg)
return dict(replaced)
if not isinstance(replaced, dict):
raise TypeError(f"Expected heatmap config dict from {cfg_path}, got {type(replaced).__name__}")
return cast(dict[str, Any], replaced)
def build_pose_transform(cfg: dict[str, Any]) -> T.Compose:
@@ -175,6 +180,7 @@ def main() -> None:
norm_args=heatmap_cfg["norm_args"],
heatmap_generator_args=heatmap_cfg["heatmap_generator_args"],
align_args=heatmap_cfg["align_args"],
reduction="sum",
)
pose_paths = iter_pose_paths(args.pose_data_path)