Add DRF Scoliosis1K pipeline and optional wandb logging
This commit is contained in:
@@ -0,0 +1,228 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import numpy as np
|
||||
import yaml
|
||||
from jaxtyping import Float
|
||||
from numpy.typing import NDArray
|
||||
from torchvision import transforms as T
|
||||
from tqdm import tqdm
|
||||
|
||||
if __package__ in {None, ""}:
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
from datasets import pretreatment_heatmap as heatmap_prep
|
||||
|
||||
|
||||
JOINT_PAIRS = (
|
||||
(1, 2), # eyes
|
||||
(3, 4), # ears
|
||||
(5, 6), # shoulders
|
||||
(7, 8), # elbows
|
||||
(9, 10), # wrists
|
||||
(11, 12), # hips
|
||||
(13, 14), # knees
|
||||
(15, 16), # ankles
|
||||
)
|
||||
EPS = 1e-6
|
||||
FloatArray = NDArray[np.float32]
|
||||
|
||||
|
||||
class SequenceRecord(TypedDict):
|
||||
pose_path: str
|
||||
pid: str
|
||||
seq_parts: list[str]
|
||||
frames: int
|
||||
raw_pav: Float[FloatArray, "pairs metrics"]
|
||||
|
||||
|
||||
def get_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate DRF inputs for Scoliosis1K from pose .pkl files."
|
||||
)
|
||||
_ = parser.add_argument(
|
||||
"--pose_data_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Root directory containing Scoliosis1K pose .pkl files.",
|
||||
)
|
||||
_ = parser.add_argument(
|
||||
"--output_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Output root for the DRF runtime dataset.",
|
||||
)
|
||||
_ = parser.add_argument(
|
||||
"--heatmap_cfg_path",
|
||||
type=str,
|
||||
default="configs/skeletongait/pretreatment_heatmap.yaml",
|
||||
help="Heatmap preprocessing config used to build the skeleton map branch.",
|
||||
)
|
||||
_ = parser.add_argument(
|
||||
"--stats_partition",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Optional dataset partition JSON. When set, PAV min/max stats use TRAIN_SET ids only.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_heatmap_cfg(cfg_path: str) -> dict[str, Any]:
|
||||
with open(cfg_path, "r", encoding="utf-8") as stream:
|
||||
cfg = yaml.safe_load(stream)
|
||||
replaced = heatmap_prep.replace_variables(cfg, cfg)
|
||||
return dict(replaced)
|
||||
|
||||
|
||||
def build_pose_transform(cfg: dict[str, Any]) -> T.Compose:
|
||||
return T.Compose([
|
||||
heatmap_prep.COCO18toCOCO17(**cfg["coco18tococo17_args"]),
|
||||
heatmap_prep.PadKeypoints(**cfg["padkeypoints_args"]),
|
||||
heatmap_prep.CenterAndScaleNormalizer(**cfg["norm_args"]),
|
||||
])
|
||||
|
||||
|
||||
def iter_pose_paths(pose_root: str) -> list[str]:
|
||||
return sorted(glob(os.path.join(pose_root, "*/*/*/*.pkl")))
|
||||
|
||||
|
||||
def read_pose(pose_path: str) -> Float[FloatArray, "frames joints channels"]:
|
||||
with open(pose_path, "rb") as handle:
|
||||
pose = pickle.load(handle)
|
||||
return np.asarray(pose, dtype=np.float32)
|
||||
|
||||
|
||||
def compute_raw_pav(
|
||||
normalized_pose: Float[FloatArray, "frames joints channels"],
|
||||
) -> Float[FloatArray, "pairs metrics"]:
|
||||
coords = normalized_pose[..., :2]
|
||||
hip_center_x = coords[:, [11, 12], 0].mean(axis=1)
|
||||
raw_pav = np.zeros((len(JOINT_PAIRS), 3), dtype=np.float32)
|
||||
|
||||
for pair_idx, (left_idx, right_idx) in enumerate(JOINT_PAIRS):
|
||||
left = coords[:, left_idx]
|
||||
right = coords[:, right_idx]
|
||||
dy = left[:, 1] - right[:, 1]
|
||||
dx = left[:, 0] - right[:, 0]
|
||||
safe_dx = np.where(np.abs(dx) < EPS, np.where(dx < 0, -EPS, EPS), dx)
|
||||
|
||||
vertical = np.abs(dy)
|
||||
midline = np.abs((left[:, 0] + right[:, 0]) / 2.0 - hip_center_x)
|
||||
angular = np.abs(np.arctan(dy / safe_dx))
|
||||
|
||||
metrics = np.stack([vertical, midline, angular], axis=-1)
|
||||
raw_pav[pair_idx] = iqr_mean(metrics)
|
||||
|
||||
return raw_pav
|
||||
|
||||
|
||||
def iqr_mean(values: Float[FloatArray, "frames metrics"]) -> Float[FloatArray, "metrics"]:
|
||||
refined = np.zeros(values.shape[1], dtype=np.float32)
|
||||
for metric_idx in range(values.shape[1]):
|
||||
metric_values = values[:, metric_idx]
|
||||
q1, q3 = np.percentile(metric_values, [25, 75])
|
||||
iqr = q3 - q1
|
||||
lower = q1 - 1.5 * iqr
|
||||
upper = q3 + 1.5 * iqr
|
||||
filtered = metric_values[(metric_values >= lower) & (metric_values <= upper)]
|
||||
if filtered.size == 0:
|
||||
filtered = metric_values
|
||||
refined[metric_idx] = float(filtered.mean())
|
||||
return refined
|
||||
|
||||
|
||||
def normalize_pav(
|
||||
raw_pav: Float[FloatArray, "pairs metrics"],
|
||||
pav_min: Float[FloatArray, "pairs metrics"],
|
||||
pav_max: Float[FloatArray, "pairs metrics"],
|
||||
) -> Float[FloatArray, "pairs metrics"]:
|
||||
denom = np.maximum(pav_max - pav_min, EPS)
|
||||
return np.clip((raw_pav - pav_min) / denom, 0.0, 1.0).astype(np.float32)
|
||||
|
||||
|
||||
def rel_seq_parts(pose_path: str) -> list[str]:
|
||||
norm_path = os.path.normpath(pose_path)
|
||||
return norm_path.split(os.sep)[-4:-1]
|
||||
|
||||
|
||||
def build_stats_mask(records: list[SequenceRecord], stats_partition: str | None) -> NDArray[np.bool_]:
|
||||
if stats_partition is None:
|
||||
return np.ones(len(records), dtype=bool)
|
||||
|
||||
with open(stats_partition, "r", encoding="utf-8") as handle:
|
||||
partition = json.load(handle)
|
||||
train_ids = set(partition["TRAIN_SET"])
|
||||
return np.asarray([record["pid"] in train_ids for record in records], dtype=bool)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = get_args()
|
||||
os.makedirs(args.output_path, exist_ok=True)
|
||||
|
||||
heatmap_cfg = load_heatmap_cfg(args.heatmap_cfg_path)
|
||||
pose_transform = build_pose_transform(heatmap_cfg)
|
||||
heatmap_transform = heatmap_prep.GenerateHeatmapTransform(
|
||||
coco18tococo17_args=heatmap_cfg["coco18tococo17_args"],
|
||||
padkeypoints_args=heatmap_cfg["padkeypoints_args"],
|
||||
norm_args=heatmap_cfg["norm_args"],
|
||||
heatmap_generator_args=heatmap_cfg["heatmap_generator_args"],
|
||||
align_args=heatmap_cfg["align_args"],
|
||||
)
|
||||
|
||||
pose_paths = iter_pose_paths(args.pose_data_path)
|
||||
if not pose_paths:
|
||||
raise FileNotFoundError(f"No pose .pkl files found under {args.pose_data_path}")
|
||||
|
||||
records: list[SequenceRecord] = []
|
||||
for pose_path in tqdm(pose_paths, desc="Pass 1/2: computing raw PAV"):
|
||||
pose = read_pose(pose_path)
|
||||
normalized_pose = pose_transform(pose)
|
||||
records.append({
|
||||
"pose_path": pose_path,
|
||||
"pid": rel_seq_parts(pose_path)[0],
|
||||
"seq_parts": rel_seq_parts(pose_path),
|
||||
"frames": pose.shape[0],
|
||||
"raw_pav": compute_raw_pav(normalized_pose),
|
||||
})
|
||||
|
||||
stats_mask = build_stats_mask(records, args.stats_partition)
|
||||
if not stats_mask.any():
|
||||
raise ValueError("No sequences matched the requested stats partition.")
|
||||
|
||||
pav_stack = np.stack([record["raw_pav"] for record, use in zip(records, stats_mask) if use], axis=0)
|
||||
pav_min = pav_stack.min(axis=0).astype(np.float32)
|
||||
pav_max = pav_stack.max(axis=0).astype(np.float32)
|
||||
|
||||
stats_path = os.path.join(args.output_path, "pav_stats.pkl")
|
||||
with open(stats_path, "wb") as handle:
|
||||
pickle.dump({
|
||||
"joint_pairs": JOINT_PAIRS,
|
||||
"pav_min": pav_min,
|
||||
"pav_max": pav_max,
|
||||
"stats_partition": args.stats_partition,
|
||||
}, handle)
|
||||
|
||||
for record in tqdm(records, desc="Pass 2/2: writing DRF dataset"):
|
||||
pose = read_pose(record["pose_path"])
|
||||
heatmap = heatmap_transform(pose)
|
||||
pav = normalize_pav(record["raw_pav"], pav_min, pav_max)
|
||||
pav_seq = np.repeat(pav[np.newaxis, ...], record["frames"], axis=0)
|
||||
|
||||
save_dir = os.path.join(args.output_path, *record["seq_parts"])
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
with open(os.path.join(save_dir, "0_heatmap.pkl"), "wb") as handle:
|
||||
pickle.dump(heatmap, handle)
|
||||
with open(os.path.join(save_dir, "1_pav.pkl"), "wb") as handle:
|
||||
pickle.dump(pav_seq, handle)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user