259 lines
8.8 KiB
Python
259 lines
8.8 KiB
Python
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import pickle
|
|
import sys
|
|
from glob import glob
|
|
from pathlib import Path
|
|
from typing import Any, Literal, TypedDict, cast
|
|
|
|
import numpy as np
|
|
import yaml
|
|
from jaxtyping import Float
|
|
from numpy.typing import NDArray
|
|
from torchvision import transforms as T
|
|
from tqdm import tqdm
|
|
|
|
if __package__ in {None, ""}:
|
|
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
|
|
|
from datasets import pretreatment_heatmap as heatmap_prep
|
|
|
|
|
|
JOINT_PAIRS = (
|
|
(1, 2), # eyes
|
|
(3, 4), # ears
|
|
(5, 6), # shoulders
|
|
(7, 8), # elbows
|
|
(9, 10), # wrists
|
|
(11, 12), # hips
|
|
(13, 14), # knees
|
|
(15, 16), # ankles
|
|
)
|
|
EPS = 1e-6
|
|
FloatArray = NDArray[np.float32]
|
|
HeatmapReduction = Literal["upstream", "max", "sum"]
|
|
|
|
|
|
class SequenceRecord(TypedDict):
|
|
pose_path: str
|
|
pid: str
|
|
seq_parts: list[str]
|
|
frames: int
|
|
raw_pav: Float[FloatArray, "pairs metrics"]
|
|
|
|
|
|
def get_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(
|
|
description="Generate DRF inputs for Scoliosis1K from pose .pkl files."
|
|
)
|
|
_ = parser.add_argument(
|
|
"--pose_data_path",
|
|
type=str,
|
|
required=True,
|
|
help="Root directory containing Scoliosis1K pose .pkl files.",
|
|
)
|
|
_ = parser.add_argument(
|
|
"--output_path",
|
|
type=str,
|
|
required=True,
|
|
help="Output root for the DRF runtime dataset.",
|
|
)
|
|
_ = parser.add_argument(
|
|
"--heatmap_cfg_path",
|
|
type=str,
|
|
default="configs/drf/pretreatment_heatmap_drf.yaml",
|
|
help="Heatmap preprocessing config used to build the skeleton map branch.",
|
|
)
|
|
_ = parser.add_argument(
|
|
"--heatmap_reduction",
|
|
type=str,
|
|
choices=["upstream", "max", "sum"],
|
|
default="upstream",
|
|
help=(
|
|
"How to collapse joint/limb heatmaps into one channel. "
|
|
"'upstream' matches OpenGait at f754f6f..., while 'sum' keeps the paper-literal ablation."
|
|
),
|
|
)
|
|
_ = parser.add_argument(
|
|
"--stats_partition",
|
|
type=str,
|
|
default=None,
|
|
help=(
|
|
"Optional dataset partition JSON. When set, PAV min/max stats use TRAIN_SET ids only. "
|
|
"Omit it to match the paper's dataset-level min-max normalization."
|
|
),
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def load_heatmap_cfg(cfg_path: str) -> dict[str, Any]:
|
|
with open(cfg_path, "r", encoding="utf-8") as stream:
|
|
cfg = yaml.safe_load(stream)
|
|
replaced = heatmap_prep.replace_variables(cfg, cfg)
|
|
if not isinstance(replaced, dict):
|
|
raise TypeError(f"Expected heatmap config dict from {cfg_path}, got {type(replaced).__name__}")
|
|
return cast(dict[str, Any], replaced)
|
|
|
|
|
|
def optional_cfg_float(cfg: dict[str, Any], key: str) -> float | None:
|
|
value = cfg.get(key)
|
|
if value is None:
|
|
return None
|
|
if not isinstance(value, (int, float)):
|
|
raise TypeError(f"Expected numeric value for {key}, got {type(value).__name__}")
|
|
return float(value)
|
|
|
|
|
|
def build_pose_transform(cfg: dict[str, Any]) -> T.Compose:
|
|
return T.Compose([
|
|
heatmap_prep.COCO18toCOCO17(**cfg["coco18tococo17_args"]),
|
|
heatmap_prep.PadKeypoints(**cfg["padkeypoints_args"]),
|
|
heatmap_prep.CenterAndScaleNormalizer(**cfg["norm_args"]),
|
|
])
|
|
|
|
|
|
def iter_pose_paths(pose_root: str) -> list[str]:
|
|
return sorted(glob(os.path.join(pose_root, "*/*/*/*.pkl")))
|
|
|
|
|
|
def read_pose(pose_path: str) -> Float[FloatArray, "frames joints channels"]:
|
|
with open(pose_path, "rb") as handle:
|
|
pose = pickle.load(handle)
|
|
return np.asarray(pose, dtype=np.float32)
|
|
|
|
|
|
def compute_raw_pav(
|
|
normalized_pose: Float[FloatArray, "frames joints channels"],
|
|
) -> Float[FloatArray, "pairs metrics"]:
|
|
coords = normalized_pose[..., :2]
|
|
hip_center_x = coords[:, [11, 12], 0].mean(axis=1)
|
|
raw_pav = np.zeros((len(JOINT_PAIRS), 3), dtype=np.float32)
|
|
|
|
for pair_idx, (left_idx, right_idx) in enumerate(JOINT_PAIRS):
|
|
left = coords[:, left_idx]
|
|
right = coords[:, right_idx]
|
|
dy = left[:, 1] - right[:, 1]
|
|
dx = left[:, 0] - right[:, 0]
|
|
safe_dx = np.where(np.abs(dx) < EPS, np.where(dx < 0, -EPS, EPS), dx)
|
|
|
|
vertical = np.abs(dy)
|
|
midline = np.abs((left[:, 0] + right[:, 0]) / 2.0 - hip_center_x)
|
|
angular = np.abs(np.arctan(dy / safe_dx))
|
|
|
|
metrics = np.stack([vertical, midline, angular], axis=-1)
|
|
raw_pav[pair_idx] = iqr_mean(metrics)
|
|
|
|
return raw_pav
|
|
|
|
|
|
def iqr_mean(values: Float[FloatArray, "frames metrics"]) -> Float[FloatArray, "metrics"]:
|
|
refined = np.zeros(values.shape[1], dtype=np.float32)
|
|
for metric_idx in range(values.shape[1]):
|
|
metric_values = values[:, metric_idx]
|
|
q1, q3 = np.percentile(metric_values, [25, 75])
|
|
iqr = q3 - q1
|
|
lower = q1 - 1.5 * iqr
|
|
upper = q3 + 1.5 * iqr
|
|
filtered = metric_values[(metric_values >= lower) & (metric_values <= upper)]
|
|
if filtered.size == 0:
|
|
filtered = metric_values
|
|
refined[metric_idx] = float(filtered.mean())
|
|
return refined
|
|
|
|
|
|
def normalize_pav(
|
|
raw_pav: Float[FloatArray, "pairs metrics"],
|
|
pav_min: Float[FloatArray, "pairs metrics"],
|
|
pav_max: Float[FloatArray, "pairs metrics"],
|
|
) -> Float[FloatArray, "pairs metrics"]:
|
|
denom = np.maximum(pav_max - pav_min, EPS)
|
|
return np.clip((raw_pav - pav_min) / denom, 0.0, 1.0).astype(np.float32)
|
|
|
|
|
|
def rel_seq_parts(pose_path: str) -> list[str]:
|
|
norm_path = os.path.normpath(pose_path)
|
|
return norm_path.split(os.sep)[-4:-1]
|
|
|
|
|
|
def build_stats_mask(records: list[SequenceRecord], stats_partition: str | None) -> NDArray[np.bool_]:
|
|
if stats_partition is None:
|
|
return np.ones(len(records), dtype=bool)
|
|
|
|
with open(stats_partition, "r", encoding="utf-8") as handle:
|
|
partition = json.load(handle)
|
|
train_ids = set(partition["TRAIN_SET"])
|
|
return np.asarray([record["pid"] in train_ids for record in records], dtype=bool)
|
|
|
|
|
|
def main() -> None:
|
|
args = get_args()
|
|
os.makedirs(args.output_path, exist_ok=True)
|
|
|
|
heatmap_cfg = load_heatmap_cfg(args.heatmap_cfg_path)
|
|
pose_transform = build_pose_transform(heatmap_cfg)
|
|
heatmap_transform = heatmap_prep.GenerateHeatmapTransform(
|
|
coco18tococo17_args=heatmap_cfg["coco18tococo17_args"],
|
|
padkeypoints_args=heatmap_cfg["padkeypoints_args"],
|
|
norm_args=heatmap_cfg["norm_args"],
|
|
heatmap_generator_args=heatmap_cfg["heatmap_generator_args"],
|
|
align_args=heatmap_cfg["align_args"],
|
|
reduction=cast(HeatmapReduction, args.heatmap_reduction),
|
|
sigma_limb=optional_cfg_float(heatmap_cfg, "sigma_limb"),
|
|
sigma_joint=optional_cfg_float(heatmap_cfg, "sigma_joint"),
|
|
channel_gain_limb=optional_cfg_float(heatmap_cfg, "channel_gain_limb"),
|
|
channel_gain_joint=optional_cfg_float(heatmap_cfg, "channel_gain_joint"),
|
|
)
|
|
|
|
pose_paths = iter_pose_paths(args.pose_data_path)
|
|
if not pose_paths:
|
|
raise FileNotFoundError(f"No pose .pkl files found under {args.pose_data_path}")
|
|
|
|
records: list[SequenceRecord] = []
|
|
for pose_path in tqdm(pose_paths, desc="Pass 1/2: computing raw PAV"):
|
|
pose = read_pose(pose_path)
|
|
normalized_pose = pose_transform(pose)
|
|
records.append({
|
|
"pose_path": pose_path,
|
|
"pid": rel_seq_parts(pose_path)[0],
|
|
"seq_parts": rel_seq_parts(pose_path),
|
|
"frames": pose.shape[0],
|
|
"raw_pav": compute_raw_pav(normalized_pose),
|
|
})
|
|
|
|
stats_mask = build_stats_mask(records, args.stats_partition)
|
|
if not stats_mask.any():
|
|
raise ValueError("No sequences matched the requested stats partition.")
|
|
|
|
pav_stack = np.stack([record["raw_pav"] for record, use in zip(records, stats_mask) if use], axis=0)
|
|
pav_min = pav_stack.min(axis=0).astype(np.float32)
|
|
pav_max = pav_stack.max(axis=0).astype(np.float32)
|
|
|
|
stats_path = os.path.join(args.output_path, "pav_stats.pkl")
|
|
with open(stats_path, "wb") as handle:
|
|
pickle.dump({
|
|
"joint_pairs": JOINT_PAIRS,
|
|
"pav_min": pav_min,
|
|
"pav_max": pav_max,
|
|
"stats_partition": args.stats_partition,
|
|
}, handle)
|
|
|
|
for record in tqdm(records, desc="Pass 2/2: writing DRF dataset"):
|
|
pose = read_pose(record["pose_path"])
|
|
heatmap = heatmap_transform(pose)
|
|
pav = normalize_pav(record["raw_pav"], pav_min, pav_max)
|
|
pav_seq = np.repeat(pav[np.newaxis, ...], record["frames"], axis=0)
|
|
|
|
save_dir = os.path.join(args.output_path, *record["seq_parts"])
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
with open(os.path.join(save_dir, "0_heatmap.pkl"), "wb") as handle:
|
|
pickle.dump(heatmap, handle)
|
|
with open(os.path.join(save_dir, "1_pav.pkl"), "wb") as handle:
|
|
pickle.dump(pav_seq, handle)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|