422 lines
15 KiB
Python
422 lines
15 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import pickle
|
|
import sys
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Iterable
|
|
|
|
import numpy as np
|
|
from jaxtyping import Float
|
|
from numpy.typing import NDArray
|
|
|
|
REPO_ROOT = Path(__file__).resolve().parent.parent
|
|
if str(REPO_ROOT) not in sys.path:
|
|
sys.path.append(str(REPO_ROOT))
|
|
|
|
from datasets import pretreatment_scoliosis_drf as drf_prep
|
|
|
|
POSE_ROOT = Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-pose-pkl")
|
|
HEATMAP_ROOT = Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-sigma15-joint8-sharedalign")
|
|
PARTITION_PATH = REPO_ROOT / "datasets/Scoliosis1K/Scoliosis1K_118.json"
|
|
HEATMAP_CFG_PATH = REPO_ROOT / "configs/drf/pretreatment_heatmap_drf_sigma15_joint8.yaml"
|
|
REPORT_PATH = REPO_ROOT / "research/scoliosis_dataset_analysis_118_sharedalign.md"
|
|
JSON_PATH = REPO_ROOT / "research/scoliosis_dataset_analysis_118_sharedalign.json"
|
|
|
|
EPS = 1e-6
|
|
THRESHOLD = 13.0
|
|
SIDE_CUT = 10
|
|
LABEL_TO_INT = {"negative": 0, "neutral": 1, "positive": 2}
|
|
FloatArray = NDArray[np.float32]
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class SequenceKey:
|
|
pid: str
|
|
label: str
|
|
seq: str
|
|
|
|
|
|
@dataclass
|
|
class RunningStats:
|
|
total: float = 0.0
|
|
count: int = 0
|
|
|
|
def update(self, value: float, n: int = 1) -> None:
|
|
self.total += value * n
|
|
self.count += n
|
|
|
|
@property
|
|
def mean(self) -> float:
|
|
return self.total / max(self.count, 1)
|
|
|
|
|
|
def load_partition_ids() -> tuple[set[str], set[str]]:
|
|
with PARTITION_PATH.open("r", encoding="utf-8") as handle:
|
|
partition = json.load(handle)
|
|
return set(partition["TRAIN_SET"]), set(partition["TEST_SET"])
|
|
|
|
|
|
def sequence_key_from_path(path: Path) -> SequenceKey:
|
|
parts = path.parts
|
|
return SequenceKey(pid=parts[-4], label=parts[-3], seq=parts[-2])
|
|
|
|
|
|
def iter_pose_paths() -> list[Path]:
|
|
return sorted(POSE_ROOT.glob("*/*/*/*.pkl"))
|
|
|
|
|
|
def iter_heatmap_paths() -> list[Path]:
|
|
return sorted(HEATMAP_ROOT.glob("*/*/*/0_heatmap.pkl"))
|
|
|
|
|
|
def read_pickle(path: Path) -> object:
|
|
with path.open("rb") as handle:
|
|
return pickle.load(handle)
|
|
|
|
|
|
def bbox_from_mask(mask: NDArray[np.bool_]) -> tuple[float, float, float, float] | None:
|
|
rows = np.flatnonzero(mask.any(axis=1))
|
|
cols = np.flatnonzero(mask.any(axis=0))
|
|
if rows.size == 0 or cols.size == 0:
|
|
return None
|
|
y0 = int(rows[0])
|
|
y1 = int(rows[-1])
|
|
x0 = int(cols[0])
|
|
x1 = int(cols[-1])
|
|
width = float(x1 - x0 + 1)
|
|
height = float(y1 - y0 + 1)
|
|
center_x = float((x0 + x1) / 2.0)
|
|
center_y = float((y0 + y1) / 2.0)
|
|
return width, height, center_x, center_y
|
|
|
|
|
|
def sequence_bbox_metrics(
|
|
heatmap: Float[FloatArray, "frames channels height width"],
|
|
threshold: float = THRESHOLD,
|
|
) -> dict[str, float]:
|
|
support = heatmap.max(axis=1)
|
|
bone = heatmap[:, 0]
|
|
joint = heatmap[:, 1]
|
|
|
|
widths: list[float] = []
|
|
heights: list[float] = []
|
|
centers_x: list[float] = []
|
|
centers_y: list[float] = []
|
|
active_fractions: list[float] = []
|
|
cut_mass_ratios: list[float] = []
|
|
bone_joint_dx: list[float] = []
|
|
bone_joint_dy: list[float] = []
|
|
|
|
for frame_idx in range(support.shape[0]):
|
|
frame = support[frame_idx]
|
|
mask = frame > threshold
|
|
bbox = bbox_from_mask(mask)
|
|
if bbox is not None:
|
|
width, height, center_x, center_y = bbox
|
|
widths.append(width)
|
|
heights.append(height)
|
|
centers_x.append(center_x)
|
|
centers_y.append(center_y)
|
|
active_fractions.append(float(mask.mean()))
|
|
|
|
total_mass = float(frame.sum())
|
|
if total_mass > EPS:
|
|
clipped_mass = float(frame[:, :SIDE_CUT].sum() + frame[:, -SIDE_CUT:].sum())
|
|
cut_mass_ratios.append(clipped_mass / total_mass)
|
|
|
|
bone_bbox = bbox_from_mask(bone[frame_idx] > threshold)
|
|
joint_bbox = bbox_from_mask(joint[frame_idx] > threshold)
|
|
if bone_bbox is not None and joint_bbox is not None:
|
|
bone_joint_dx.append(abs(bone_bbox[2] - joint_bbox[2]))
|
|
bone_joint_dy.append(abs(bone_bbox[3] - joint_bbox[3]))
|
|
|
|
def safe_mean(values: Iterable[float]) -> float:
|
|
array = np.asarray(list(values), dtype=np.float32)
|
|
return float(array.mean()) if array.size else 0.0
|
|
|
|
def safe_std(values: Iterable[float]) -> float:
|
|
array = np.asarray(list(values), dtype=np.float32)
|
|
return float(array.std()) if array.size else 0.0
|
|
|
|
return {
|
|
"width_mean": safe_mean(widths),
|
|
"height_mean": safe_mean(heights),
|
|
"center_x_std": safe_std(centers_x),
|
|
"center_y_std": safe_std(centers_y),
|
|
"width_std": safe_std(widths),
|
|
"height_std": safe_std(heights),
|
|
"active_fraction_mean": safe_mean(active_fractions),
|
|
"cut_mass_ratio_mean": safe_mean(cut_mass_ratios),
|
|
"bone_joint_dx_mean": safe_mean(bone_joint_dx),
|
|
"bone_joint_dy_mean": safe_mean(bone_joint_dy),
|
|
}
|
|
|
|
|
|
def softmax_rows(logits: NDArray[np.float64]) -> NDArray[np.float64]:
|
|
shifted = logits - logits.max(axis=1, keepdims=True)
|
|
exp = np.exp(shifted)
|
|
return exp / exp.sum(axis=1, keepdims=True)
|
|
|
|
|
|
def fit_softmax_regression(
|
|
x: NDArray[np.float64],
|
|
y: NDArray[np.int64],
|
|
num_classes: int,
|
|
steps: int = 4000,
|
|
lr: float = 0.05,
|
|
reg: float = 1e-4,
|
|
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
|
|
weights = np.zeros((x.shape[1], num_classes), dtype=np.float64)
|
|
bias = np.zeros(num_classes, dtype=np.float64)
|
|
one_hot = np.eye(num_classes, dtype=np.float64)[y]
|
|
|
|
for _ in range(steps):
|
|
logits = x @ weights + bias
|
|
probs = softmax_rows(logits)
|
|
error = probs - one_hot
|
|
grad_w = (x.T @ error) / x.shape[0] + reg * weights
|
|
grad_b = error.mean(axis=0)
|
|
weights -= lr * grad_w
|
|
bias -= lr * grad_b
|
|
|
|
return weights, bias
|
|
|
|
|
|
def evaluate_predictions(
|
|
y_true: NDArray[np.int64],
|
|
y_pred: NDArray[np.int64],
|
|
num_classes: int,
|
|
) -> dict[str, float]:
|
|
accuracy = float((y_true == y_pred).mean())
|
|
precisions: list[float] = []
|
|
recalls: list[float] = []
|
|
f1s: list[float] = []
|
|
|
|
for class_id in range(num_classes):
|
|
tp = int(((y_true == class_id) & (y_pred == class_id)).sum())
|
|
fp = int(((y_true != class_id) & (y_pred == class_id)).sum())
|
|
fn = int(((y_true == class_id) & (y_pred != class_id)).sum())
|
|
precision = tp / max(tp + fp, 1)
|
|
recall = tp / max(tp + fn, 1)
|
|
f1 = 2 * precision * recall / max(precision + recall, EPS)
|
|
precisions.append(precision)
|
|
recalls.append(recall)
|
|
f1s.append(f1)
|
|
|
|
return {
|
|
"accuracy": 100.0 * accuracy,
|
|
"macro_precision": 100.0 * float(np.mean(precisions)),
|
|
"macro_recall": 100.0 * float(np.mean(recalls)),
|
|
"macro_f1": 100.0 * float(np.mean(f1s)),
|
|
}
|
|
|
|
|
|
def analyze() -> dict[str, object]:
|
|
train_ids, test_ids = load_partition_ids()
|
|
|
|
heatmap_cfg = drf_prep.load_heatmap_cfg(str(HEATMAP_CFG_PATH))
|
|
pose_transform = drf_prep.build_pose_transform(heatmap_cfg)
|
|
|
|
split_label_counts: dict[str, dict[str, int]] = {
|
|
"train": defaultdict(int),
|
|
"test": defaultdict(int),
|
|
}
|
|
pose_quality: dict[str, dict[str, RunningStats]] = {
|
|
"train": defaultdict(RunningStats),
|
|
"test": defaultdict(RunningStats),
|
|
}
|
|
valid_ratio: dict[str, dict[str, RunningStats]] = {
|
|
"train": defaultdict(RunningStats),
|
|
"test": defaultdict(RunningStats),
|
|
}
|
|
|
|
for pose_path in iter_pose_paths():
|
|
key = sequence_key_from_path(pose_path)
|
|
split = "train" if key.pid in train_ids else "test"
|
|
split_label_counts[split][key.label] += 1
|
|
|
|
pose = drf_prep.read_pose(str(pose_path))
|
|
conf = pose[..., 2] if pose.shape[-1] >= 3 else np.ones(pose.shape[:-1], dtype=np.float32)
|
|
pose_quality[split][key.label].update(float(conf.mean()))
|
|
valid_ratio[split][key.label].update(float((conf > 0.05).mean()))
|
|
|
|
heatmap_metrics: dict[str, list[float]] = defaultdict(list)
|
|
pav_vectors_train: list[NDArray[np.float64]] = []
|
|
pav_vectors_test: list[NDArray[np.float64]] = []
|
|
labels_train: list[int] = []
|
|
labels_test: list[int] = []
|
|
pav_means: dict[str, list[float]] = defaultdict(list)
|
|
|
|
for heatmap_path in iter_heatmap_paths():
|
|
key = sequence_key_from_path(heatmap_path)
|
|
split = "train" if key.pid in train_ids else "test"
|
|
heatmap = np.asarray(read_pickle(heatmap_path), dtype=np.float32)
|
|
metrics = sequence_bbox_metrics(heatmap)
|
|
for metric_name, metric_value in metrics.items():
|
|
heatmap_metrics[f"{split}.{metric_name}"].append(metric_value)
|
|
heatmap_metrics[f"all.{metric_name}"].append(metric_value)
|
|
|
|
pav_path = heatmap_path.with_name("1_pav.pkl")
|
|
pav_seq = np.asarray(read_pickle(pav_path), dtype=np.float32)
|
|
pav_vector = pav_seq[0].reshape(-1).astype(np.float64)
|
|
pav_means[key.label].append(float(pav_vector.mean()))
|
|
if split == "train":
|
|
pav_vectors_train.append(pav_vector)
|
|
labels_train.append(LABEL_TO_INT[key.label])
|
|
else:
|
|
pav_vectors_test.append(pav_vector)
|
|
labels_test.append(LABEL_TO_INT[key.label])
|
|
|
|
x_train = np.stack(pav_vectors_train, axis=0)
|
|
x_test = np.stack(pav_vectors_test, axis=0)
|
|
y_train = np.asarray(labels_train, dtype=np.int64)
|
|
y_test = np.asarray(labels_test, dtype=np.int64)
|
|
|
|
mean = x_train.mean(axis=0, keepdims=True)
|
|
std = np.maximum(x_train.std(axis=0, keepdims=True), EPS)
|
|
x_train_std = (x_train - mean) / std
|
|
x_test_std = (x_test - mean) / std
|
|
weights, bias = fit_softmax_regression(x_train_std, y_train, num_classes=3)
|
|
y_pred = np.argmax(x_test_std @ weights + bias, axis=1).astype(np.int64)
|
|
pav_classifier = evaluate_predictions(y_test, y_pred, num_classes=3)
|
|
|
|
results: dict[str, object] = {
|
|
"split_label_counts": split_label_counts,
|
|
"pose_confidence_mean": {
|
|
split: {label: stats.mean for label, stats in per_label.items()}
|
|
for split, per_label in pose_quality.items()
|
|
},
|
|
"pose_valid_ratio_mean": {
|
|
split: {label: stats.mean for label, stats in per_label.items()}
|
|
for split, per_label in valid_ratio.items()
|
|
},
|
|
"pav_label_means": {
|
|
label: float(np.mean(values))
|
|
for label, values in pav_means.items()
|
|
},
|
|
"pav_softmax_probe": pav_classifier,
|
|
"heatmap_metrics": {
|
|
key: {
|
|
"mean": float(np.mean(values)),
|
|
"p95": float(np.percentile(values, 95)),
|
|
}
|
|
for key, values in heatmap_metrics.items()
|
|
},
|
|
}
|
|
return results
|
|
|
|
|
|
def format_report(results: dict[str, object]) -> str:
|
|
split_counts = results["split_label_counts"]
|
|
pose_conf = results["pose_confidence_mean"]
|
|
pose_valid = results["pose_valid_ratio_mean"]
|
|
heat = results["heatmap_metrics"]
|
|
pav_probe = results["pav_softmax_probe"]
|
|
pav_means = results["pav_label_means"]
|
|
|
|
def heat_stat(name: str) -> tuple[float, float]:
|
|
entry = heat[f"all.{name}"]
|
|
return entry["mean"], entry["p95"]
|
|
|
|
center_x_std_mean, center_x_std_p95 = heat_stat("center_x_std")
|
|
center_y_std_mean, center_y_std_p95 = heat_stat("center_y_std")
|
|
width_std_mean, width_std_p95 = heat_stat("width_std")
|
|
height_std_mean, height_std_p95 = heat_stat("height_std")
|
|
cut_ratio_mean, cut_ratio_p95 = heat_stat("cut_mass_ratio_mean")
|
|
bone_joint_dx_mean, bone_joint_dx_p95 = heat_stat("bone_joint_dx_mean")
|
|
bone_joint_dy_mean, bone_joint_dy_p95 = heat_stat("bone_joint_dy_mean")
|
|
width_mean, width_p95 = heat_stat("width_mean")
|
|
height_mean, height_p95 = heat_stat("height_mean")
|
|
active_fraction_mean, active_fraction_p95 = heat_stat("active_fraction_mean")
|
|
|
|
return f"""# Scoliosis1K Dataset Analysis (1:1:8, shared-align skeleton maps)
|
|
|
|
## Split
|
|
|
|
Train counts:
|
|
- negative: {split_counts["train"]["negative"]}
|
|
- neutral: {split_counts["train"]["neutral"]}
|
|
- positive: {split_counts["train"]["positive"]}
|
|
|
|
Test counts:
|
|
- negative: {split_counts["test"]["negative"]}
|
|
- neutral: {split_counts["test"]["neutral"]}
|
|
- positive: {split_counts["test"]["positive"]}
|
|
|
|
## Raw pose quality
|
|
|
|
Mean keypoint confidence by split/class:
|
|
- train negative: {pose_conf["train"]["negative"]:.4f}
|
|
- train neutral: {pose_conf["train"]["neutral"]:.4f}
|
|
- train positive: {pose_conf["train"]["positive"]:.4f}
|
|
- test negative: {pose_conf["test"]["negative"]:.4f}
|
|
- test neutral: {pose_conf["test"]["neutral"]:.4f}
|
|
- test positive: {pose_conf["test"]["positive"]:.4f}
|
|
|
|
Mean valid-joint ratio (`conf > 0.05`) by split/class:
|
|
- train negative: {pose_valid["train"]["negative"]:.4f}
|
|
- train neutral: {pose_valid["train"]["neutral"]:.4f}
|
|
- train positive: {pose_valid["train"]["positive"]:.4f}
|
|
- test negative: {pose_valid["test"]["negative"]:.4f}
|
|
- test neutral: {pose_valid["test"]["neutral"]:.4f}
|
|
- test positive: {pose_valid["test"]["positive"]:.4f}
|
|
|
|
## PAV signal
|
|
|
|
Mean normalized PAV value by label:
|
|
- negative: {pav_means["negative"]:.4f}
|
|
- neutral: {pav_means["neutral"]:.4f}
|
|
- positive: {pav_means["positive"]:.4f}
|
|
|
|
Train-on-train / test-on-test linear softmax probe over sequence-level PAV:
|
|
- accuracy: {pav_probe["accuracy"]:.2f}%
|
|
- macro precision: {pav_probe["macro_precision"]:.2f}%
|
|
- macro recall: {pav_probe["macro_recall"]:.2f}%
|
|
- macro F1: {pav_probe["macro_f1"]:.2f}%
|
|
|
|
## Shared-align heatmap geometry
|
|
|
|
Combined support bbox stats over all sequences:
|
|
- width mean / p95: {width_mean:.2f} / {width_p95:.2f}
|
|
- height mean / p95: {height_mean:.2f} / {height_p95:.2f}
|
|
- active fraction mean / p95: {active_fraction_mean:.4f} / {active_fraction_p95:.4f}
|
|
|
|
Per-sequence temporal jitter (std over frames):
|
|
- center-x std mean / p95: {center_x_std_mean:.3f} / {center_x_std_p95:.3f}
|
|
- center-y std mean / p95: {center_y_std_mean:.3f} / {center_y_std_p95:.3f}
|
|
- width std mean / p95: {width_std_mean:.3f} / {width_std_p95:.3f}
|
|
- height std mean / p95: {height_std_mean:.3f} / {height_std_p95:.3f}
|
|
|
|
Residual limb-vs-joint bbox-center mismatch after shared alignment:
|
|
- dx mean / p95: {bone_joint_dx_mean:.3f} / {bone_joint_dx_p95:.3f}
|
|
- dy mean / p95: {bone_joint_dy_mean:.3f} / {bone_joint_dy_p95:.3f}
|
|
|
|
Estimated intensity mass in the columns removed by `BaseSilCuttingTransform`:
|
|
- mean clipped-mass ratio: {cut_ratio_mean:.4f}
|
|
- p95 clipped-mass ratio: {cut_ratio_p95:.4f}
|
|
|
|
## Reading
|
|
|
|
- The raw pose data does not look broken. Confidence and valid-joint ratios are high and similar across classes.
|
|
- The sequence-level PAV still carries useful label signal, so the dataset is not devoid of scoliosis information.
|
|
- Shared alignment removed the old limb-vs-joint registration bug; residual channel-center mismatch is now small.
|
|
- The remaining suspicious area is the visual branch: the skeleton map still has frame-to-frame bbox jitter, and the support bbox is almost full-height (`~61.5 / 64`) and fairly dense (`~36%` active pixels), which may be washing out subtle asymmetry cues.
|
|
- `BaseSilCuttingTransform` does not appear to be the main failure source for this shared-align export; the measured mass in the removed side margins is near zero.
|
|
- The dataset itself looks usable; the bigger issue still appears to be how the current skeleton-map preprocessing/runtime path presents that data to ScoNet.
|
|
"""
|
|
|
|
|
|
def main() -> None:
|
|
results = analyze()
|
|
REPORT_PATH.write_text(format_report(results), encoding="utf-8")
|
|
JSON_PATH.write_text(json.dumps(results, indent=2, sort_keys=True), encoding="utf-8")
|
|
print(f"Wrote {REPORT_PATH}")
|
|
print(f"Wrote {JSON_PATH}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|