Files
OpenGait/scripts/analyze_scoliosis_dataset.py
T

510 lines
18 KiB
Python

from __future__ import annotations
import json
import pickle
import sys
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable
import click
import numpy as np
from jaxtyping import Float
from numpy.typing import NDArray
REPO_ROOT = Path(__file__).resolve().parent.parent
if str(REPO_ROOT) not in sys.path:
sys.path.append(str(REPO_ROOT))
from datasets import pretreatment_scoliosis_drf as drf_prep
DEFAULT_POSE_ROOT = Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-pose-pkl")
DEFAULT_HEATMAP_ROOT = Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-sigma15-joint8-sharedalign")
DEFAULT_PARTITION_PATH = REPO_ROOT / "datasets/Scoliosis1K/Scoliosis1K_118.json"
DEFAULT_HEATMAP_CFG_PATH = REPO_ROOT / "configs/drf/pretreatment_heatmap_drf_sigma15_joint8.yaml"
DEFAULT_REPORT_PATH = REPO_ROOT / "research/scoliosis_dataset_analysis_118_sharedalign.md"
DEFAULT_JSON_PATH = REPO_ROOT / "research/scoliosis_dataset_analysis_118_sharedalign.json"
EPS = 1e-6
THRESHOLD = 13.0
SIDE_CUT = 10
LABEL_TO_INT = {"negative": 0, "neutral": 1, "positive": 2}
FloatArray = NDArray[np.float32]
@dataclass(frozen=True)
class SequenceKey:
pid: str
label: str
seq: str
@dataclass
class RunningStats:
total: float = 0.0
count: int = 0
def update(self, value: float, n: int = 1) -> None:
self.total += value * n
self.count += n
@property
def mean(self) -> float:
return self.total / max(self.count, 1)
@dataclass(frozen=True)
class AnalysisArgs:
pose_root: Path
heatmap_root: Path
partition_path: Path
heatmap_cfg_path: Path
report_path: Path
json_path: Path
report_title: str
def load_partition_ids(partition_path: Path) -> tuple[set[str], set[str]]:
with partition_path.open("r", encoding="utf-8") as handle:
partition = json.load(handle)
return set(partition["TRAIN_SET"]), set(partition["TEST_SET"])
def sequence_key_from_path(path: Path) -> SequenceKey:
parts = path.parts
return SequenceKey(pid=parts[-4], label=parts[-3], seq=parts[-2])
def iter_pose_paths(pose_root: Path) -> list[Path]:
return sorted(pose_root.glob("*/*/*/*.pkl"))
def iter_heatmap_paths(heatmap_root: Path) -> list[Path]:
return sorted(heatmap_root.glob("*/*/*/0_heatmap.pkl"))
def read_pickle(path: Path) -> object:
with path.open("rb") as handle:
return pickle.load(handle)
def bbox_from_mask(mask: NDArray[np.bool_]) -> tuple[float, float, float, float] | None:
rows = np.flatnonzero(mask.any(axis=1))
cols = np.flatnonzero(mask.any(axis=0))
if rows.size == 0 or cols.size == 0:
return None
y0 = int(rows[0])
y1 = int(rows[-1])
x0 = int(cols[0])
x1 = int(cols[-1])
width = float(x1 - x0 + 1)
height = float(y1 - y0 + 1)
center_x = float((x0 + x1) / 2.0)
center_y = float((y0 + y1) / 2.0)
return width, height, center_x, center_y
def sequence_bbox_metrics(
heatmap: Float[FloatArray, "frames channels height width"],
threshold: float = THRESHOLD,
) -> dict[str, float]:
support = heatmap.max(axis=1)
bone = heatmap[:, 0]
joint = heatmap[:, 1]
widths: list[float] = []
heights: list[float] = []
centers_x: list[float] = []
centers_y: list[float] = []
active_fractions: list[float] = []
cut_mass_ratios: list[float] = []
bone_joint_dx: list[float] = []
bone_joint_dy: list[float] = []
for frame_idx in range(support.shape[0]):
frame = support[frame_idx]
mask = frame > threshold
bbox = bbox_from_mask(mask)
if bbox is not None:
width, height, center_x, center_y = bbox
widths.append(width)
heights.append(height)
centers_x.append(center_x)
centers_y.append(center_y)
active_fractions.append(float(mask.mean()))
total_mass = float(frame.sum())
if total_mass > EPS:
clipped_mass = float(frame[:, :SIDE_CUT].sum() + frame[:, -SIDE_CUT:].sum())
cut_mass_ratios.append(clipped_mass / total_mass)
bone_bbox = bbox_from_mask(bone[frame_idx] > threshold)
joint_bbox = bbox_from_mask(joint[frame_idx] > threshold)
if bone_bbox is not None and joint_bbox is not None:
bone_joint_dx.append(abs(bone_bbox[2] - joint_bbox[2]))
bone_joint_dy.append(abs(bone_bbox[3] - joint_bbox[3]))
def safe_mean(values: Iterable[float]) -> float:
array = np.asarray(list(values), dtype=np.float32)
return float(array.mean()) if array.size else 0.0
def safe_std(values: Iterable[float]) -> float:
array = np.asarray(list(values), dtype=np.float32)
return float(array.std()) if array.size else 0.0
return {
"width_mean": safe_mean(widths),
"height_mean": safe_mean(heights),
"center_x_std": safe_std(centers_x),
"center_y_std": safe_std(centers_y),
"width_std": safe_std(widths),
"height_std": safe_std(heights),
"active_fraction_mean": safe_mean(active_fractions),
"cut_mass_ratio_mean": safe_mean(cut_mass_ratios),
"bone_joint_dx_mean": safe_mean(bone_joint_dx),
"bone_joint_dy_mean": safe_mean(bone_joint_dy),
}
def softmax_rows(logits: NDArray[np.float64]) -> NDArray[np.float64]:
shifted = logits - logits.max(axis=1, keepdims=True)
exp = np.exp(shifted)
return exp / exp.sum(axis=1, keepdims=True)
def fit_softmax_regression(
x: NDArray[np.float64],
y: NDArray[np.int64],
num_classes: int,
steps: int = 4000,
lr: float = 0.05,
reg: float = 1e-4,
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
weights = np.zeros((x.shape[1], num_classes), dtype=np.float64)
bias = np.zeros(num_classes, dtype=np.float64)
one_hot = np.eye(num_classes, dtype=np.float64)[y]
for _ in range(steps):
logits = x @ weights + bias
probs = softmax_rows(logits)
error = probs - one_hot
grad_w = (x.T @ error) / x.shape[0] + reg * weights
grad_b = error.mean(axis=0)
weights -= lr * grad_w
bias -= lr * grad_b
return weights, bias
def evaluate_predictions(
y_true: NDArray[np.int64],
y_pred: NDArray[np.int64],
num_classes: int,
) -> dict[str, float]:
accuracy = float((y_true == y_pred).mean())
precisions: list[float] = []
recalls: list[float] = []
f1s: list[float] = []
for class_id in range(num_classes):
tp = int(((y_true == class_id) & (y_pred == class_id)).sum())
fp = int(((y_true != class_id) & (y_pred == class_id)).sum())
fn = int(((y_true == class_id) & (y_pred != class_id)).sum())
precision = tp / max(tp + fp, 1)
recall = tp / max(tp + fn, 1)
f1 = 2 * precision * recall / max(precision + recall, EPS)
precisions.append(precision)
recalls.append(recall)
f1s.append(f1)
return {
"accuracy": 100.0 * accuracy,
"macro_precision": 100.0 * float(np.mean(precisions)),
"macro_recall": 100.0 * float(np.mean(recalls)),
"macro_f1": 100.0 * float(np.mean(f1s)),
}
def analyze(args: AnalysisArgs) -> dict[str, object]:
train_ids, test_ids = load_partition_ids(args.partition_path)
heatmap_cfg = drf_prep.load_heatmap_cfg(str(args.heatmap_cfg_path))
pose_transform = drf_prep.build_pose_transform(heatmap_cfg)
split_label_counts: dict[str, dict[str, int]] = {
"train": defaultdict(int),
"test": defaultdict(int),
}
pose_quality: dict[str, dict[str, RunningStats]] = {
"train": defaultdict(RunningStats),
"test": defaultdict(RunningStats),
}
valid_ratio: dict[str, dict[str, RunningStats]] = {
"train": defaultdict(RunningStats),
"test": defaultdict(RunningStats),
}
for pose_path in iter_pose_paths(args.pose_root):
key = sequence_key_from_path(pose_path)
split = "train" if key.pid in train_ids else "test"
split_label_counts[split][key.label] += 1
pose = drf_prep.read_pose(str(pose_path))
conf = pose[..., 2] if pose.shape[-1] >= 3 else np.ones(pose.shape[:-1], dtype=np.float32)
pose_quality[split][key.label].update(float(conf.mean()))
valid_ratio[split][key.label].update(float((conf > 0.05).mean()))
heatmap_metrics: dict[str, list[float]] = defaultdict(list)
pav_vectors_train: list[NDArray[np.float64]] = []
pav_vectors_test: list[NDArray[np.float64]] = []
labels_train: list[int] = []
labels_test: list[int] = []
pav_means: dict[str, list[float]] = defaultdict(list)
for heatmap_path in iter_heatmap_paths(args.heatmap_root):
key = sequence_key_from_path(heatmap_path)
split = "train" if key.pid in train_ids else "test"
heatmap = np.asarray(read_pickle(heatmap_path), dtype=np.float32)
metrics = sequence_bbox_metrics(heatmap)
for metric_name, metric_value in metrics.items():
heatmap_metrics[f"{split}.{metric_name}"].append(metric_value)
heatmap_metrics[f"all.{metric_name}"].append(metric_value)
pav_path = heatmap_path.with_name("1_pav.pkl")
pav_seq = np.asarray(read_pickle(pav_path), dtype=np.float32)
pav_vector = pav_seq[0].reshape(-1).astype(np.float64)
pav_means[key.label].append(float(pav_vector.mean()))
if split == "train":
pav_vectors_train.append(pav_vector)
labels_train.append(LABEL_TO_INT[key.label])
else:
pav_vectors_test.append(pav_vector)
labels_test.append(LABEL_TO_INT[key.label])
x_train = np.stack(pav_vectors_train, axis=0)
x_test = np.stack(pav_vectors_test, axis=0)
y_train = np.asarray(labels_train, dtype=np.int64)
y_test = np.asarray(labels_test, dtype=np.int64)
mean = x_train.mean(axis=0, keepdims=True)
std = np.maximum(x_train.std(axis=0, keepdims=True), EPS)
x_train_std = (x_train - mean) / std
x_test_std = (x_test - mean) / std
weights, bias = fit_softmax_regression(x_train_std, y_train, num_classes=3)
y_pred = np.argmax(x_test_std @ weights + bias, axis=1).astype(np.int64)
pav_classifier = evaluate_predictions(y_test, y_pred, num_classes=3)
results: dict[str, object] = {
"report_title": args.report_title,
"pose_root": str(args.pose_root),
"heatmap_root": str(args.heatmap_root),
"partition_path": str(args.partition_path),
"heatmap_cfg_path": str(args.heatmap_cfg_path),
"split_label_counts": split_label_counts,
"pose_confidence_mean": {
split: {label: stats.mean for label, stats in per_label.items()}
for split, per_label in pose_quality.items()
},
"pose_valid_ratio_mean": {
split: {label: stats.mean for label, stats in per_label.items()}
for split, per_label in valid_ratio.items()
},
"pav_label_means": {
label: float(np.mean(values))
for label, values in pav_means.items()
},
"pav_softmax_probe": pav_classifier,
"heatmap_metrics": {
key: {
"mean": float(np.mean(values)),
"p95": float(np.percentile(values, 95)),
}
for key, values in heatmap_metrics.items()
},
}
return results
def format_report(results: dict[str, object]) -> str:
report_title = str(results["report_title"])
pose_root = str(results["pose_root"])
heatmap_root = str(results["heatmap_root"])
partition_path = str(results["partition_path"])
heatmap_cfg_path = str(results["heatmap_cfg_path"])
split_counts = results["split_label_counts"]
pose_conf = results["pose_confidence_mean"]
pose_valid = results["pose_valid_ratio_mean"]
heat = results["heatmap_metrics"]
pav_probe = results["pav_softmax_probe"]
pav_means = results["pav_label_means"]
def heat_stat(name: str) -> tuple[float, float]:
entry = heat[f"all.{name}"]
return entry["mean"], entry["p95"]
center_x_std_mean, center_x_std_p95 = heat_stat("center_x_std")
center_y_std_mean, center_y_std_p95 = heat_stat("center_y_std")
width_std_mean, width_std_p95 = heat_stat("width_std")
height_std_mean, height_std_p95 = heat_stat("height_std")
cut_ratio_mean, cut_ratio_p95 = heat_stat("cut_mass_ratio_mean")
bone_joint_dx_mean, bone_joint_dx_p95 = heat_stat("bone_joint_dx_mean")
bone_joint_dy_mean, bone_joint_dy_p95 = heat_stat("bone_joint_dy_mean")
width_mean, width_p95 = heat_stat("width_mean")
height_mean, height_p95 = heat_stat("height_mean")
active_fraction_mean, active_fraction_p95 = heat_stat("active_fraction_mean")
return f"""# {report_title}
Inputs:
- pose root: `{pose_root}`
- heatmap root: `{heatmap_root}`
- partition: `{partition_path}`
- heatmap cfg: `{heatmap_cfg_path}`
## Split
Train counts:
- negative: {split_counts["train"]["negative"]}
- neutral: {split_counts["train"]["neutral"]}
- positive: {split_counts["train"]["positive"]}
Test counts:
- negative: {split_counts["test"]["negative"]}
- neutral: {split_counts["test"]["neutral"]}
- positive: {split_counts["test"]["positive"]}
## Raw pose quality
Mean keypoint confidence by split/class:
- train negative: {pose_conf["train"]["negative"]:.4f}
- train neutral: {pose_conf["train"]["neutral"]:.4f}
- train positive: {pose_conf["train"]["positive"]:.4f}
- test negative: {pose_conf["test"]["negative"]:.4f}
- test neutral: {pose_conf["test"]["neutral"]:.4f}
- test positive: {pose_conf["test"]["positive"]:.4f}
Mean valid-joint ratio (`conf > 0.05`) by split/class:
- train negative: {pose_valid["train"]["negative"]:.4f}
- train neutral: {pose_valid["train"]["neutral"]:.4f}
- train positive: {pose_valid["train"]["positive"]:.4f}
- test negative: {pose_valid["test"]["negative"]:.4f}
- test neutral: {pose_valid["test"]["neutral"]:.4f}
- test positive: {pose_valid["test"]["positive"]:.4f}
## PAV signal
Mean normalized PAV value by label:
- negative: {pav_means["negative"]:.4f}
- neutral: {pav_means["neutral"]:.4f}
- positive: {pav_means["positive"]:.4f}
Train-on-train / test-on-test linear softmax probe over sequence-level PAV:
- accuracy: {pav_probe["accuracy"]:.2f}%
- macro precision: {pav_probe["macro_precision"]:.2f}%
- macro recall: {pav_probe["macro_recall"]:.2f}%
- macro F1: {pav_probe["macro_f1"]:.2f}%
## Heatmap geometry
Combined support bbox stats over all sequences:
- width mean / p95: {width_mean:.2f} / {width_p95:.2f}
- height mean / p95: {height_mean:.2f} / {height_p95:.2f}
- active fraction mean / p95: {active_fraction_mean:.4f} / {active_fraction_p95:.4f}
Per-sequence temporal jitter (std over frames):
- center-x std mean / p95: {center_x_std_mean:.3f} / {center_x_std_p95:.3f}
- center-y std mean / p95: {center_y_std_mean:.3f} / {center_y_std_p95:.3f}
- width std mean / p95: {width_std_mean:.3f} / {width_std_p95:.3f}
- height std mean / p95: {height_std_mean:.3f} / {height_std_p95:.3f}
Residual limb-vs-joint bbox-center mismatch after shared alignment:
- dx mean / p95: {bone_joint_dx_mean:.3f} / {bone_joint_dx_p95:.3f}
- dy mean / p95: {bone_joint_dy_mean:.3f} / {bone_joint_dy_p95:.3f}
Estimated intensity mass in the columns removed by `BaseSilCuttingTransform`:
- mean clipped-mass ratio: {cut_ratio_mean:.4f}
- p95 clipped-mass ratio: {cut_ratio_p95:.4f}
## Reading
- The raw pose data does not look broken. Confidence and valid-joint ratios are high and similar across classes.
- The sequence-level PAV still carries useful label signal, so the dataset is not devoid of scoliosis information.
- The limb/joint alignment fix removed the old registration bug; residual channel-center mismatch is now small.
- The remaining suspicious area is the visual branch: the skeleton map still has frame-to-frame bbox jitter, and the support bbox is almost full-height (`~61.5 / 64`) and fairly dense (`~36%` active pixels), which may be washing out subtle asymmetry cues.
- `BaseSilCuttingTransform` does not appear to be the main failure source for this export; the measured mass in the removed side margins is near zero.
- The dataset itself looks usable; the bigger issue still appears to be how the current skeleton-map preprocessing/runtime path presents that data to ScoNet.
"""
@click.command()
@click.option(
"--pose-root",
type=click.Path(path_type=Path, file_okay=False),
default=DEFAULT_POSE_ROOT,
show_default=True,
)
@click.option(
"--heatmap-root",
type=click.Path(path_type=Path, file_okay=False),
default=DEFAULT_HEATMAP_ROOT,
show_default=True,
)
@click.option(
"--partition-path",
type=click.Path(path_type=Path, dir_okay=False),
default=DEFAULT_PARTITION_PATH,
show_default=True,
)
@click.option(
"--heatmap-cfg-path",
type=click.Path(path_type=Path, dir_okay=False),
default=DEFAULT_HEATMAP_CFG_PATH,
show_default=True,
)
@click.option(
"--report-path",
type=click.Path(path_type=Path, dir_okay=False),
default=DEFAULT_REPORT_PATH,
show_default=True,
)
@click.option(
"--json-path",
type=click.Path(path_type=Path, dir_okay=False),
default=DEFAULT_JSON_PATH,
show_default=True,
)
@click.option(
"--report-title",
type=str,
default="Scoliosis1K Dataset Analysis (1:1:8, shared-align skeleton maps)",
show_default=True,
)
def main(
pose_root: Path,
heatmap_root: Path,
partition_path: Path,
heatmap_cfg_path: Path,
report_path: Path,
json_path: Path,
report_title: str,
) -> None:
args = AnalysisArgs(
pose_root=pose_root,
heatmap_root=heatmap_root,
partition_path=partition_path,
heatmap_cfg_path=heatmap_cfg_path,
report_path=report_path,
json_path=json_path,
report_title=report_title,
)
results = analyze(args)
args.report_path.write_text(format_report(results), encoding="utf-8")
args.json_path.write_text(json.dumps(results, indent=2, sort_keys=True), encoding="utf-8")
print(f"Wrote {args.report_path}")
print(f"Wrote {args.json_path}")
if __name__ == "__main__":
main()