Refine DRF preprocessing and body-prior pipeline
This commit is contained in:
@@ -7,7 +7,7 @@ import pickle
|
||||
import sys
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any, TypedDict, cast
|
||||
|
||||
import numpy as np
|
||||
import yaml
|
||||
@@ -63,14 +63,17 @@ def get_args() -> argparse.Namespace:
|
||||
_ = parser.add_argument(
|
||||
"--heatmap_cfg_path",
|
||||
type=str,
|
||||
default="configs/skeletongait/pretreatment_heatmap.yaml",
|
||||
default="configs/drf/pretreatment_heatmap_drf.yaml",
|
||||
help="Heatmap preprocessing config used to build the skeleton map branch.",
|
||||
)
|
||||
_ = parser.add_argument(
|
||||
"--stats_partition",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Optional dataset partition JSON. When set, PAV min/max stats use TRAIN_SET ids only.",
|
||||
help=(
|
||||
"Optional dataset partition JSON. When set, PAV min/max stats use TRAIN_SET ids only. "
|
||||
"Omit it to match the paper's dataset-level min-max normalization."
|
||||
),
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
@@ -79,7 +82,9 @@ def load_heatmap_cfg(cfg_path: str) -> dict[str, Any]:
|
||||
with open(cfg_path, "r", encoding="utf-8") as stream:
|
||||
cfg = yaml.safe_load(stream)
|
||||
replaced = heatmap_prep.replace_variables(cfg, cfg)
|
||||
return dict(replaced)
|
||||
if not isinstance(replaced, dict):
|
||||
raise TypeError(f"Expected heatmap config dict from {cfg_path}, got {type(replaced).__name__}")
|
||||
return cast(dict[str, Any], replaced)
|
||||
|
||||
|
||||
def build_pose_transform(cfg: dict[str, Any]) -> T.Compose:
|
||||
@@ -175,6 +180,7 @@ def main() -> None:
|
||||
norm_args=heatmap_cfg["norm_args"],
|
||||
heatmap_generator_args=heatmap_cfg["heatmap_generator_args"],
|
||||
align_args=heatmap_cfg["align_args"],
|
||||
reduction="sum",
|
||||
)
|
||||
|
||||
pose_paths = iter_pose_paths(args.pose_data_path)
|
||||
|
||||
Reference in New Issue
Block a user