from __future__ import annotations import json import pickle import sys from collections import defaultdict from dataclasses import dataclass from pathlib import Path from typing import Iterable import click import numpy as np from jaxtyping import Float from numpy.typing import NDArray REPO_ROOT = Path(__file__).resolve().parent.parent if str(REPO_ROOT) not in sys.path: sys.path.append(str(REPO_ROOT)) from datasets import pretreatment_scoliosis_drf as drf_prep DEFAULT_POSE_ROOT = Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-pose-pkl") DEFAULT_HEATMAP_ROOT = Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-sigma15-joint8-sharedalign") DEFAULT_PARTITION_PATH = REPO_ROOT / "datasets/Scoliosis1K/Scoliosis1K_118.json" DEFAULT_HEATMAP_CFG_PATH = REPO_ROOT / "configs/drf/pretreatment_heatmap_drf_sigma15_joint8.yaml" DEFAULT_REPORT_PATH = REPO_ROOT / "research/scoliosis_dataset_analysis_118_sharedalign.md" DEFAULT_JSON_PATH = REPO_ROOT / "research/scoliosis_dataset_analysis_118_sharedalign.json" EPS = 1e-6 THRESHOLD = 13.0 SIDE_CUT = 10 LABEL_TO_INT = {"negative": 0, "neutral": 1, "positive": 2} FloatArray = NDArray[np.float32] @dataclass(frozen=True) class SequenceKey: pid: str label: str seq: str @dataclass class RunningStats: total: float = 0.0 count: int = 0 def update(self, value: float, n: int = 1) -> None: self.total += value * n self.count += n @property def mean(self) -> float: return self.total / max(self.count, 1) @dataclass(frozen=True) class AnalysisArgs: pose_root: Path heatmap_root: Path partition_path: Path heatmap_cfg_path: Path report_path: Path json_path: Path report_title: str def load_partition_ids(partition_path: Path) -> tuple[set[str], set[str]]: with partition_path.open("r", encoding="utf-8") as handle: partition = json.load(handle) return set(partition["TRAIN_SET"]), set(partition["TEST_SET"]) def sequence_key_from_path(path: Path) -> SequenceKey: parts = path.parts return SequenceKey(pid=parts[-4], label=parts[-3], seq=parts[-2]) def iter_pose_paths(pose_root: Path) -> list[Path]: return sorted(pose_root.glob("*/*/*/*.pkl")) def iter_heatmap_paths(heatmap_root: Path) -> list[Path]: return sorted(heatmap_root.glob("*/*/*/0_heatmap.pkl")) def read_pickle(path: Path) -> object: with path.open("rb") as handle: return pickle.load(handle) def bbox_from_mask(mask: NDArray[np.bool_]) -> tuple[float, float, float, float] | None: rows = np.flatnonzero(mask.any(axis=1)) cols = np.flatnonzero(mask.any(axis=0)) if rows.size == 0 or cols.size == 0: return None y0 = int(rows[0]) y1 = int(rows[-1]) x0 = int(cols[0]) x1 = int(cols[-1]) width = float(x1 - x0 + 1) height = float(y1 - y0 + 1) center_x = float((x0 + x1) / 2.0) center_y = float((y0 + y1) / 2.0) return width, height, center_x, center_y def sequence_bbox_metrics( heatmap: Float[FloatArray, "frames channels height width"], threshold: float = THRESHOLD, ) -> dict[str, float]: support = heatmap.max(axis=1) bone = heatmap[:, 0] joint = heatmap[:, 1] widths: list[float] = [] heights: list[float] = [] centers_x: list[float] = [] centers_y: list[float] = [] active_fractions: list[float] = [] cut_mass_ratios: list[float] = [] bone_joint_dx: list[float] = [] bone_joint_dy: list[float] = [] for frame_idx in range(support.shape[0]): frame = support[frame_idx] mask = frame > threshold bbox = bbox_from_mask(mask) if bbox is not None: width, height, center_x, center_y = bbox widths.append(width) heights.append(height) centers_x.append(center_x) centers_y.append(center_y) active_fractions.append(float(mask.mean())) total_mass = float(frame.sum()) if total_mass > EPS: clipped_mass = float(frame[:, :SIDE_CUT].sum() + frame[:, -SIDE_CUT:].sum()) cut_mass_ratios.append(clipped_mass / total_mass) bone_bbox = bbox_from_mask(bone[frame_idx] > threshold) joint_bbox = bbox_from_mask(joint[frame_idx] > threshold) if bone_bbox is not None and joint_bbox is not None: bone_joint_dx.append(abs(bone_bbox[2] - joint_bbox[2])) bone_joint_dy.append(abs(bone_bbox[3] - joint_bbox[3])) def safe_mean(values: Iterable[float]) -> float: array = np.asarray(list(values), dtype=np.float32) return float(array.mean()) if array.size else 0.0 def safe_std(values: Iterable[float]) -> float: array = np.asarray(list(values), dtype=np.float32) return float(array.std()) if array.size else 0.0 return { "width_mean": safe_mean(widths), "height_mean": safe_mean(heights), "center_x_std": safe_std(centers_x), "center_y_std": safe_std(centers_y), "width_std": safe_std(widths), "height_std": safe_std(heights), "active_fraction_mean": safe_mean(active_fractions), "cut_mass_ratio_mean": safe_mean(cut_mass_ratios), "bone_joint_dx_mean": safe_mean(bone_joint_dx), "bone_joint_dy_mean": safe_mean(bone_joint_dy), } def softmax_rows(logits: NDArray[np.float64]) -> NDArray[np.float64]: shifted = logits - logits.max(axis=1, keepdims=True) exp = np.exp(shifted) return exp / exp.sum(axis=1, keepdims=True) def fit_softmax_regression( x: NDArray[np.float64], y: NDArray[np.int64], num_classes: int, steps: int = 4000, lr: float = 0.05, reg: float = 1e-4, ) -> tuple[NDArray[np.float64], NDArray[np.float64]]: weights = np.zeros((x.shape[1], num_classes), dtype=np.float64) bias = np.zeros(num_classes, dtype=np.float64) one_hot = np.eye(num_classes, dtype=np.float64)[y] for _ in range(steps): logits = x @ weights + bias probs = softmax_rows(logits) error = probs - one_hot grad_w = (x.T @ error) / x.shape[0] + reg * weights grad_b = error.mean(axis=0) weights -= lr * grad_w bias -= lr * grad_b return weights, bias def evaluate_predictions( y_true: NDArray[np.int64], y_pred: NDArray[np.int64], num_classes: int, ) -> dict[str, float]: accuracy = float((y_true == y_pred).mean()) precisions: list[float] = [] recalls: list[float] = [] f1s: list[float] = [] for class_id in range(num_classes): tp = int(((y_true == class_id) & (y_pred == class_id)).sum()) fp = int(((y_true != class_id) & (y_pred == class_id)).sum()) fn = int(((y_true == class_id) & (y_pred != class_id)).sum()) precision = tp / max(tp + fp, 1) recall = tp / max(tp + fn, 1) f1 = 2 * precision * recall / max(precision + recall, EPS) precisions.append(precision) recalls.append(recall) f1s.append(f1) return { "accuracy": 100.0 * accuracy, "macro_precision": 100.0 * float(np.mean(precisions)), "macro_recall": 100.0 * float(np.mean(recalls)), "macro_f1": 100.0 * float(np.mean(f1s)), } def analyze(args: AnalysisArgs) -> dict[str, object]: train_ids, test_ids = load_partition_ids(args.partition_path) heatmap_cfg = drf_prep.load_heatmap_cfg(str(args.heatmap_cfg_path)) pose_transform = drf_prep.build_pose_transform(heatmap_cfg) split_label_counts: dict[str, dict[str, int]] = { "train": defaultdict(int), "test": defaultdict(int), } pose_quality: dict[str, dict[str, RunningStats]] = { "train": defaultdict(RunningStats), "test": defaultdict(RunningStats), } valid_ratio: dict[str, dict[str, RunningStats]] = { "train": defaultdict(RunningStats), "test": defaultdict(RunningStats), } for pose_path in iter_pose_paths(args.pose_root): key = sequence_key_from_path(pose_path) split = "train" if key.pid in train_ids else "test" split_label_counts[split][key.label] += 1 pose = drf_prep.read_pose(str(pose_path)) conf = pose[..., 2] if pose.shape[-1] >= 3 else np.ones(pose.shape[:-1], dtype=np.float32) pose_quality[split][key.label].update(float(conf.mean())) valid_ratio[split][key.label].update(float((conf > 0.05).mean())) heatmap_metrics: dict[str, list[float]] = defaultdict(list) pav_vectors_train: list[NDArray[np.float64]] = [] pav_vectors_test: list[NDArray[np.float64]] = [] labels_train: list[int] = [] labels_test: list[int] = [] pav_means: dict[str, list[float]] = defaultdict(list) for heatmap_path in iter_heatmap_paths(args.heatmap_root): key = sequence_key_from_path(heatmap_path) split = "train" if key.pid in train_ids else "test" heatmap = np.asarray(read_pickle(heatmap_path), dtype=np.float32) metrics = sequence_bbox_metrics(heatmap) for metric_name, metric_value in metrics.items(): heatmap_metrics[f"{split}.{metric_name}"].append(metric_value) heatmap_metrics[f"all.{metric_name}"].append(metric_value) pav_path = heatmap_path.with_name("1_pav.pkl") pav_seq = np.asarray(read_pickle(pav_path), dtype=np.float32) pav_vector = pav_seq[0].reshape(-1).astype(np.float64) pav_means[key.label].append(float(pav_vector.mean())) if split == "train": pav_vectors_train.append(pav_vector) labels_train.append(LABEL_TO_INT[key.label]) else: pav_vectors_test.append(pav_vector) labels_test.append(LABEL_TO_INT[key.label]) x_train = np.stack(pav_vectors_train, axis=0) x_test = np.stack(pav_vectors_test, axis=0) y_train = np.asarray(labels_train, dtype=np.int64) y_test = np.asarray(labels_test, dtype=np.int64) mean = x_train.mean(axis=0, keepdims=True) std = np.maximum(x_train.std(axis=0, keepdims=True), EPS) x_train_std = (x_train - mean) / std x_test_std = (x_test - mean) / std weights, bias = fit_softmax_regression(x_train_std, y_train, num_classes=3) y_pred = np.argmax(x_test_std @ weights + bias, axis=1).astype(np.int64) pav_classifier = evaluate_predictions(y_test, y_pred, num_classes=3) results: dict[str, object] = { "report_title": args.report_title, "pose_root": str(args.pose_root), "heatmap_root": str(args.heatmap_root), "partition_path": str(args.partition_path), "heatmap_cfg_path": str(args.heatmap_cfg_path), "split_label_counts": split_label_counts, "pose_confidence_mean": { split: {label: stats.mean for label, stats in per_label.items()} for split, per_label in pose_quality.items() }, "pose_valid_ratio_mean": { split: {label: stats.mean for label, stats in per_label.items()} for split, per_label in valid_ratio.items() }, "pav_label_means": { label: float(np.mean(values)) for label, values in pav_means.items() }, "pav_softmax_probe": pav_classifier, "heatmap_metrics": { key: { "mean": float(np.mean(values)), "p95": float(np.percentile(values, 95)), } for key, values in heatmap_metrics.items() }, } return results def format_report(results: dict[str, object]) -> str: report_title = str(results["report_title"]) pose_root = str(results["pose_root"]) heatmap_root = str(results["heatmap_root"]) partition_path = str(results["partition_path"]) heatmap_cfg_path = str(results["heatmap_cfg_path"]) split_counts = results["split_label_counts"] pose_conf = results["pose_confidence_mean"] pose_valid = results["pose_valid_ratio_mean"] heat = results["heatmap_metrics"] pav_probe = results["pav_softmax_probe"] pav_means = results["pav_label_means"] def heat_stat(name: str) -> tuple[float, float]: entry = heat[f"all.{name}"] return entry["mean"], entry["p95"] center_x_std_mean, center_x_std_p95 = heat_stat("center_x_std") center_y_std_mean, center_y_std_p95 = heat_stat("center_y_std") width_std_mean, width_std_p95 = heat_stat("width_std") height_std_mean, height_std_p95 = heat_stat("height_std") cut_ratio_mean, cut_ratio_p95 = heat_stat("cut_mass_ratio_mean") bone_joint_dx_mean, bone_joint_dx_p95 = heat_stat("bone_joint_dx_mean") bone_joint_dy_mean, bone_joint_dy_p95 = heat_stat("bone_joint_dy_mean") width_mean, width_p95 = heat_stat("width_mean") height_mean, height_p95 = heat_stat("height_mean") active_fraction_mean, active_fraction_p95 = heat_stat("active_fraction_mean") return f"""# {report_title} Inputs: - pose root: `{pose_root}` - heatmap root: `{heatmap_root}` - partition: `{partition_path}` - heatmap cfg: `{heatmap_cfg_path}` ## Split Train counts: - negative: {split_counts["train"]["negative"]} - neutral: {split_counts["train"]["neutral"]} - positive: {split_counts["train"]["positive"]} Test counts: - negative: {split_counts["test"]["negative"]} - neutral: {split_counts["test"]["neutral"]} - positive: {split_counts["test"]["positive"]} ## Raw pose quality Mean keypoint confidence by split/class: - train negative: {pose_conf["train"]["negative"]:.4f} - train neutral: {pose_conf["train"]["neutral"]:.4f} - train positive: {pose_conf["train"]["positive"]:.4f} - test negative: {pose_conf["test"]["negative"]:.4f} - test neutral: {pose_conf["test"]["neutral"]:.4f} - test positive: {pose_conf["test"]["positive"]:.4f} Mean valid-joint ratio (`conf > 0.05`) by split/class: - train negative: {pose_valid["train"]["negative"]:.4f} - train neutral: {pose_valid["train"]["neutral"]:.4f} - train positive: {pose_valid["train"]["positive"]:.4f} - test negative: {pose_valid["test"]["negative"]:.4f} - test neutral: {pose_valid["test"]["neutral"]:.4f} - test positive: {pose_valid["test"]["positive"]:.4f} ## PAV signal Mean normalized PAV value by label: - negative: {pav_means["negative"]:.4f} - neutral: {pav_means["neutral"]:.4f} - positive: {pav_means["positive"]:.4f} Train-on-train / test-on-test linear softmax probe over sequence-level PAV: - accuracy: {pav_probe["accuracy"]:.2f}% - macro precision: {pav_probe["macro_precision"]:.2f}% - macro recall: {pav_probe["macro_recall"]:.2f}% - macro F1: {pav_probe["macro_f1"]:.2f}% ## Heatmap geometry Combined support bbox stats over all sequences: - width mean / p95: {width_mean:.2f} / {width_p95:.2f} - height mean / p95: {height_mean:.2f} / {height_p95:.2f} - active fraction mean / p95: {active_fraction_mean:.4f} / {active_fraction_p95:.4f} Per-sequence temporal jitter (std over frames): - center-x std mean / p95: {center_x_std_mean:.3f} / {center_x_std_p95:.3f} - center-y std mean / p95: {center_y_std_mean:.3f} / {center_y_std_p95:.3f} - width std mean / p95: {width_std_mean:.3f} / {width_std_p95:.3f} - height std mean / p95: {height_std_mean:.3f} / {height_std_p95:.3f} Residual limb-vs-joint bbox-center mismatch after shared alignment: - dx mean / p95: {bone_joint_dx_mean:.3f} / {bone_joint_dx_p95:.3f} - dy mean / p95: {bone_joint_dy_mean:.3f} / {bone_joint_dy_p95:.3f} Estimated intensity mass in the columns removed by `BaseSilCuttingTransform`: - mean clipped-mass ratio: {cut_ratio_mean:.4f} - p95 clipped-mass ratio: {cut_ratio_p95:.4f} ## Reading - The raw pose data does not look broken. Confidence and valid-joint ratios are high and similar across classes. - The sequence-level PAV still carries useful label signal, so the dataset is not devoid of scoliosis information. - The limb/joint alignment fix removed the old registration bug; residual channel-center mismatch is now small. - The remaining suspicious area is the visual branch: the skeleton map still has frame-to-frame bbox jitter, and the support bbox is almost full-height (`~61.5 / 64`) and fairly dense (`~36%` active pixels), which may be washing out subtle asymmetry cues. - `BaseSilCuttingTransform` does not appear to be the main failure source for this export; the measured mass in the removed side margins is near zero. - The dataset itself looks usable; the bigger issue still appears to be how the current skeleton-map preprocessing/runtime path presents that data to ScoNet. """ @click.command() @click.option( "--pose-root", type=click.Path(path_type=Path, file_okay=False), default=DEFAULT_POSE_ROOT, show_default=True, ) @click.option( "--heatmap-root", type=click.Path(path_type=Path, file_okay=False), default=DEFAULT_HEATMAP_ROOT, show_default=True, ) @click.option( "--partition-path", type=click.Path(path_type=Path, dir_okay=False), default=DEFAULT_PARTITION_PATH, show_default=True, ) @click.option( "--heatmap-cfg-path", type=click.Path(path_type=Path, dir_okay=False), default=DEFAULT_HEATMAP_CFG_PATH, show_default=True, ) @click.option( "--report-path", type=click.Path(path_type=Path, dir_okay=False), default=DEFAULT_REPORT_PATH, show_default=True, ) @click.option( "--json-path", type=click.Path(path_type=Path, dir_okay=False), default=DEFAULT_JSON_PATH, show_default=True, ) @click.option( "--report-title", type=str, default="Scoliosis1K Dataset Analysis (1:1:8, shared-align skeleton maps)", show_default=True, ) def main( pose_root: Path, heatmap_root: Path, partition_path: Path, heatmap_cfg_path: Path, report_path: Path, json_path: Path, report_title: str, ) -> None: args = AnalysisArgs( pose_root=pose_root, heatmap_root=heatmap_root, partition_path=partition_path, heatmap_cfg_path=heatmap_cfg_path, report_path=report_path, json_path=json_path, report_title=report_title, ) results = analyze(args) args.report_path.write_text(format_report(results), encoding="utf-8") args.json_path.write_text(json.dumps(results, indent=2, sort_keys=True), encoding="utf-8") print(f"Wrote {args.report_path}") print(f"Wrote {args.json_path}") if __name__ == "__main__": main()