Files
OpenGait/datasets/pretreatment_scoliosis_drf.py
T

259 lines
8.8 KiB
Python

from __future__ import annotations
import argparse
import json
import os
import pickle
import sys
from glob import glob
from pathlib import Path
from typing import Any, Literal, TypedDict, cast
import numpy as np
import yaml
from jaxtyping import Float
from numpy.typing import NDArray
from torchvision import transforms as T
from tqdm import tqdm
if __package__ in {None, ""}:
sys.path.append(str(Path(__file__).resolve().parent.parent))
from datasets import pretreatment_heatmap as heatmap_prep
JOINT_PAIRS = (
(1, 2), # eyes
(3, 4), # ears
(5, 6), # shoulders
(7, 8), # elbows
(9, 10), # wrists
(11, 12), # hips
(13, 14), # knees
(15, 16), # ankles
)
EPS = 1e-6
FloatArray = NDArray[np.float32]
HeatmapReduction = Literal["upstream", "max", "sum"]
class SequenceRecord(TypedDict):
pose_path: str
pid: str
seq_parts: list[str]
frames: int
raw_pav: Float[FloatArray, "pairs metrics"]
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Generate DRF inputs for Scoliosis1K from pose .pkl files."
)
_ = parser.add_argument(
"--pose_data_path",
type=str,
required=True,
help="Root directory containing Scoliosis1K pose .pkl files.",
)
_ = parser.add_argument(
"--output_path",
type=str,
required=True,
help="Output root for the DRF runtime dataset.",
)
_ = parser.add_argument(
"--heatmap_cfg_path",
type=str,
default="configs/drf/pretreatment_heatmap_drf.yaml",
help="Heatmap preprocessing config used to build the skeleton map branch.",
)
_ = parser.add_argument(
"--heatmap_reduction",
type=str,
choices=["upstream", "max", "sum"],
default="upstream",
help=(
"How to collapse joint/limb heatmaps into one channel. "
"'upstream' matches OpenGait at f754f6f..., while 'sum' keeps the paper-literal ablation."
),
)
_ = parser.add_argument(
"--stats_partition",
type=str,
default=None,
help=(
"Optional dataset partition JSON. When set, PAV min/max stats use TRAIN_SET ids only. "
"Omit it to match the paper's dataset-level min-max normalization."
),
)
return parser.parse_args()
def load_heatmap_cfg(cfg_path: str) -> dict[str, Any]:
with open(cfg_path, "r", encoding="utf-8") as stream:
cfg = yaml.safe_load(stream)
replaced = heatmap_prep.replace_variables(cfg, cfg)
if not isinstance(replaced, dict):
raise TypeError(f"Expected heatmap config dict from {cfg_path}, got {type(replaced).__name__}")
return cast(dict[str, Any], replaced)
def optional_cfg_float(cfg: dict[str, Any], key: str) -> float | None:
value = cfg.get(key)
if value is None:
return None
if not isinstance(value, (int, float)):
raise TypeError(f"Expected numeric value for {key}, got {type(value).__name__}")
return float(value)
def build_pose_transform(cfg: dict[str, Any]) -> T.Compose:
return T.Compose([
heatmap_prep.COCO18toCOCO17(**cfg["coco18tococo17_args"]),
heatmap_prep.PadKeypoints(**cfg["padkeypoints_args"]),
heatmap_prep.CenterAndScaleNormalizer(**cfg["norm_args"]),
])
def iter_pose_paths(pose_root: str) -> list[str]:
return sorted(glob(os.path.join(pose_root, "*/*/*/*.pkl")))
def read_pose(pose_path: str) -> Float[FloatArray, "frames joints channels"]:
with open(pose_path, "rb") as handle:
pose = pickle.load(handle)
return np.asarray(pose, dtype=np.float32)
def compute_raw_pav(
normalized_pose: Float[FloatArray, "frames joints channels"],
) -> Float[FloatArray, "pairs metrics"]:
coords = normalized_pose[..., :2]
hip_center_x = coords[:, [11, 12], 0].mean(axis=1)
raw_pav = np.zeros((len(JOINT_PAIRS), 3), dtype=np.float32)
for pair_idx, (left_idx, right_idx) in enumerate(JOINT_PAIRS):
left = coords[:, left_idx]
right = coords[:, right_idx]
dy = left[:, 1] - right[:, 1]
dx = left[:, 0] - right[:, 0]
safe_dx = np.where(np.abs(dx) < EPS, np.where(dx < 0, -EPS, EPS), dx)
vertical = np.abs(dy)
midline = np.abs((left[:, 0] + right[:, 0]) / 2.0 - hip_center_x)
angular = np.abs(np.arctan(dy / safe_dx))
metrics = np.stack([vertical, midline, angular], axis=-1)
raw_pav[pair_idx] = iqr_mean(metrics)
return raw_pav
def iqr_mean(values: Float[FloatArray, "frames metrics"]) -> Float[FloatArray, "metrics"]:
refined = np.zeros(values.shape[1], dtype=np.float32)
for metric_idx in range(values.shape[1]):
metric_values = values[:, metric_idx]
q1, q3 = np.percentile(metric_values, [25, 75])
iqr = q3 - q1
lower = q1 - 1.5 * iqr
upper = q3 + 1.5 * iqr
filtered = metric_values[(metric_values >= lower) & (metric_values <= upper)]
if filtered.size == 0:
filtered = metric_values
refined[metric_idx] = float(filtered.mean())
return refined
def normalize_pav(
raw_pav: Float[FloatArray, "pairs metrics"],
pav_min: Float[FloatArray, "pairs metrics"],
pav_max: Float[FloatArray, "pairs metrics"],
) -> Float[FloatArray, "pairs metrics"]:
denom = np.maximum(pav_max - pav_min, EPS)
return np.clip((raw_pav - pav_min) / denom, 0.0, 1.0).astype(np.float32)
def rel_seq_parts(pose_path: str) -> list[str]:
norm_path = os.path.normpath(pose_path)
return norm_path.split(os.sep)[-4:-1]
def build_stats_mask(records: list[SequenceRecord], stats_partition: str | None) -> NDArray[np.bool_]:
if stats_partition is None:
return np.ones(len(records), dtype=bool)
with open(stats_partition, "r", encoding="utf-8") as handle:
partition = json.load(handle)
train_ids = set(partition["TRAIN_SET"])
return np.asarray([record["pid"] in train_ids for record in records], dtype=bool)
def main() -> None:
args = get_args()
os.makedirs(args.output_path, exist_ok=True)
heatmap_cfg = load_heatmap_cfg(args.heatmap_cfg_path)
pose_transform = build_pose_transform(heatmap_cfg)
heatmap_transform = heatmap_prep.GenerateHeatmapTransform(
coco18tococo17_args=heatmap_cfg["coco18tococo17_args"],
padkeypoints_args=heatmap_cfg["padkeypoints_args"],
norm_args=heatmap_cfg["norm_args"],
heatmap_generator_args=heatmap_cfg["heatmap_generator_args"],
align_args=heatmap_cfg["align_args"],
reduction=cast(HeatmapReduction, args.heatmap_reduction),
sigma_limb=optional_cfg_float(heatmap_cfg, "sigma_limb"),
sigma_joint=optional_cfg_float(heatmap_cfg, "sigma_joint"),
channel_gain_limb=optional_cfg_float(heatmap_cfg, "channel_gain_limb"),
channel_gain_joint=optional_cfg_float(heatmap_cfg, "channel_gain_joint"),
)
pose_paths = iter_pose_paths(args.pose_data_path)
if not pose_paths:
raise FileNotFoundError(f"No pose .pkl files found under {args.pose_data_path}")
records: list[SequenceRecord] = []
for pose_path in tqdm(pose_paths, desc="Pass 1/2: computing raw PAV"):
pose = read_pose(pose_path)
normalized_pose = pose_transform(pose)
records.append({
"pose_path": pose_path,
"pid": rel_seq_parts(pose_path)[0],
"seq_parts": rel_seq_parts(pose_path),
"frames": pose.shape[0],
"raw_pav": compute_raw_pav(normalized_pose),
})
stats_mask = build_stats_mask(records, args.stats_partition)
if not stats_mask.any():
raise ValueError("No sequences matched the requested stats partition.")
pav_stack = np.stack([record["raw_pav"] for record, use in zip(records, stats_mask) if use], axis=0)
pav_min = pav_stack.min(axis=0).astype(np.float32)
pav_max = pav_stack.max(axis=0).astype(np.float32)
stats_path = os.path.join(args.output_path, "pav_stats.pkl")
with open(stats_path, "wb") as handle:
pickle.dump({
"joint_pairs": JOINT_PAIRS,
"pav_min": pav_min,
"pav_max": pav_max,
"stats_partition": args.stats_partition,
}, handle)
for record in tqdm(records, desc="Pass 2/2: writing DRF dataset"):
pose = read_pose(record["pose_path"])
heatmap = heatmap_transform(pose)
pav = normalize_pav(record["raw_pav"], pav_min, pav_max)
pav_seq = np.repeat(pav[np.newaxis, ...], record["frames"], axis=0)
save_dir = os.path.join(args.output_path, *record["seq_parts"])
os.makedirs(save_dir, exist_ok=True)
with open(os.path.join(save_dir, "0_heatmap.pkl"), "wb") as handle:
pickle.dump(heatmap, handle)
with open(os.path.join(save_dir, "1_pav.pkl"), "wb") as handle:
pickle.dump(pav_seq, handle)
if __name__ == "__main__":
main()