from __future__ import annotations import argparse import json import os import pickle import sys from glob import glob from pathlib import Path from typing import Any, Literal, TypedDict, cast import numpy as np import yaml from jaxtyping import Float from numpy.typing import NDArray from torchvision import transforms as T from tqdm import tqdm if __package__ in {None, ""}: sys.path.append(str(Path(__file__).resolve().parent.parent)) from datasets import pretreatment_heatmap as heatmap_prep JOINT_PAIRS = ( (1, 2), # eyes (3, 4), # ears (5, 6), # shoulders (7, 8), # elbows (9, 10), # wrists (11, 12), # hips (13, 14), # knees (15, 16), # ankles ) EPS = 1e-6 FloatArray = NDArray[np.float32] HeatmapReduction = Literal["upstream", "max", "sum"] class SequenceRecord(TypedDict): pose_path: str pid: str seq_parts: list[str] frames: int raw_pav: Float[FloatArray, "pairs metrics"] def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Generate DRF inputs for Scoliosis1K from pose .pkl files." ) _ = parser.add_argument( "--pose_data_path", type=str, required=True, help="Root directory containing Scoliosis1K pose .pkl files.", ) _ = parser.add_argument( "--output_path", type=str, required=True, help="Output root for the DRF runtime dataset.", ) _ = parser.add_argument( "--heatmap_cfg_path", type=str, default="configs/drf/pretreatment_heatmap_drf.yaml", help="Heatmap preprocessing config used to build the skeleton map branch.", ) _ = parser.add_argument( "--heatmap_reduction", type=str, choices=["upstream", "max", "sum"], default="upstream", help=( "How to collapse joint/limb heatmaps into one channel. " "'upstream' matches OpenGait at f754f6f..., while 'sum' keeps the paper-literal ablation." ), ) _ = parser.add_argument( "--stats_partition", type=str, default=None, help=( "Optional dataset partition JSON. When set, PAV min/max stats use TRAIN_SET ids only. " "Omit it to match the paper's dataset-level min-max normalization." ), ) return parser.parse_args() def load_heatmap_cfg(cfg_path: str) -> dict[str, Any]: with open(cfg_path, "r", encoding="utf-8") as stream: cfg = yaml.safe_load(stream) replaced = heatmap_prep.replace_variables(cfg, cfg) if not isinstance(replaced, dict): raise TypeError(f"Expected heatmap config dict from {cfg_path}, got {type(replaced).__name__}") return cast(dict[str, Any], replaced) def optional_cfg_float(cfg: dict[str, Any], key: str) -> float | None: value = cfg.get(key) if value is None: return None if not isinstance(value, (int, float)): raise TypeError(f"Expected numeric value for {key}, got {type(value).__name__}") return float(value) def build_pose_transform(cfg: dict[str, Any]) -> T.Compose: return T.Compose([ heatmap_prep.COCO18toCOCO17(**cfg["coco18tococo17_args"]), heatmap_prep.PadKeypoints(**cfg["padkeypoints_args"]), heatmap_prep.CenterAndScaleNormalizer(**cfg["norm_args"]), ]) def iter_pose_paths(pose_root: str) -> list[str]: return sorted(glob(os.path.join(pose_root, "*/*/*/*.pkl"))) def read_pose(pose_path: str) -> Float[FloatArray, "frames joints channels"]: with open(pose_path, "rb") as handle: pose = pickle.load(handle) return np.asarray(pose, dtype=np.float32) def compute_raw_pav( normalized_pose: Float[FloatArray, "frames joints channels"], ) -> Float[FloatArray, "pairs metrics"]: coords = normalized_pose[..., :2] hip_center_x = coords[:, [11, 12], 0].mean(axis=1) raw_pav = np.zeros((len(JOINT_PAIRS), 3), dtype=np.float32) for pair_idx, (left_idx, right_idx) in enumerate(JOINT_PAIRS): left = coords[:, left_idx] right = coords[:, right_idx] dy = left[:, 1] - right[:, 1] dx = left[:, 0] - right[:, 0] safe_dx = np.where(np.abs(dx) < EPS, np.where(dx < 0, -EPS, EPS), dx) vertical = np.abs(dy) midline = np.abs((left[:, 0] + right[:, 0]) / 2.0 - hip_center_x) angular = np.abs(np.arctan(dy / safe_dx)) metrics = np.stack([vertical, midline, angular], axis=-1) raw_pav[pair_idx] = iqr_mean(metrics) return raw_pav def iqr_mean(values: Float[FloatArray, "frames metrics"]) -> Float[FloatArray, "metrics"]: refined = np.zeros(values.shape[1], dtype=np.float32) for metric_idx in range(values.shape[1]): metric_values = values[:, metric_idx] q1, q3 = np.percentile(metric_values, [25, 75]) iqr = q3 - q1 lower = q1 - 1.5 * iqr upper = q3 + 1.5 * iqr filtered = metric_values[(metric_values >= lower) & (metric_values <= upper)] if filtered.size == 0: filtered = metric_values refined[metric_idx] = float(filtered.mean()) return refined def normalize_pav( raw_pav: Float[FloatArray, "pairs metrics"], pav_min: Float[FloatArray, "pairs metrics"], pav_max: Float[FloatArray, "pairs metrics"], ) -> Float[FloatArray, "pairs metrics"]: denom = np.maximum(pav_max - pav_min, EPS) return np.clip((raw_pav - pav_min) / denom, 0.0, 1.0).astype(np.float32) def rel_seq_parts(pose_path: str) -> list[str]: norm_path = os.path.normpath(pose_path) return norm_path.split(os.sep)[-4:-1] def build_stats_mask(records: list[SequenceRecord], stats_partition: str | None) -> NDArray[np.bool_]: if stats_partition is None: return np.ones(len(records), dtype=bool) with open(stats_partition, "r", encoding="utf-8") as handle: partition = json.load(handle) train_ids = set(partition["TRAIN_SET"]) return np.asarray([record["pid"] in train_ids for record in records], dtype=bool) def main() -> None: args = get_args() os.makedirs(args.output_path, exist_ok=True) heatmap_cfg = load_heatmap_cfg(args.heatmap_cfg_path) pose_transform = build_pose_transform(heatmap_cfg) heatmap_transform = heatmap_prep.GenerateHeatmapTransform( coco18tococo17_args=heatmap_cfg["coco18tococo17_args"], padkeypoints_args=heatmap_cfg["padkeypoints_args"], norm_args=heatmap_cfg["norm_args"], heatmap_generator_args=heatmap_cfg["heatmap_generator_args"], align_args=heatmap_cfg["align_args"], reduction=cast(HeatmapReduction, args.heatmap_reduction), sigma_limb=optional_cfg_float(heatmap_cfg, "sigma_limb"), sigma_joint=optional_cfg_float(heatmap_cfg, "sigma_joint"), channel_gain_limb=optional_cfg_float(heatmap_cfg, "channel_gain_limb"), channel_gain_joint=optional_cfg_float(heatmap_cfg, "channel_gain_joint"), ) pose_paths = iter_pose_paths(args.pose_data_path) if not pose_paths: raise FileNotFoundError(f"No pose .pkl files found under {args.pose_data_path}") records: list[SequenceRecord] = [] for pose_path in tqdm(pose_paths, desc="Pass 1/2: computing raw PAV"): pose = read_pose(pose_path) normalized_pose = pose_transform(pose) records.append({ "pose_path": pose_path, "pid": rel_seq_parts(pose_path)[0], "seq_parts": rel_seq_parts(pose_path), "frames": pose.shape[0], "raw_pav": compute_raw_pav(normalized_pose), }) stats_mask = build_stats_mask(records, args.stats_partition) if not stats_mask.any(): raise ValueError("No sequences matched the requested stats partition.") pav_stack = np.stack([record["raw_pav"] for record, use in zip(records, stats_mask) if use], axis=0) pav_min = pav_stack.min(axis=0).astype(np.float32) pav_max = pav_stack.max(axis=0).astype(np.float32) stats_path = os.path.join(args.output_path, "pav_stats.pkl") with open(stats_path, "wb") as handle: pickle.dump({ "joint_pairs": JOINT_PAIRS, "pav_min": pav_min, "pav_max": pav_max, "stats_partition": args.stats_partition, }, handle) for record in tqdm(records, desc="Pass 2/2: writing DRF dataset"): pose = read_pose(record["pose_path"]) heatmap = heatmap_transform(pose) pav = normalize_pav(record["raw_pav"], pav_min, pav_max) pav_seq = np.repeat(pav[np.newaxis, ...], record["frames"], axis=0) save_dir = os.path.join(args.output_path, *record["seq_parts"]) os.makedirs(save_dir, exist_ok=True) with open(os.path.join(save_dir, "0_heatmap.pkl"), "wb") as handle: pickle.dump(heatmap, handle) with open(os.path.join(save_dir, "1_pav.pkl"), "wb") as handle: pickle.dump(pav_seq, handle) if __name__ == "__main__": main()