from __future__ import annotations import pickle from dataclasses import dataclass from pathlib import Path import matplotlib import matplotlib.pyplot as plt import numpy as np import numpy.typing as npt matplotlib.use("Agg") REPO_ROOT = Path(__file__).resolve().parents[1] RESEARCH_ROOT = REPO_ROOT / "research/sigma_comparison" SIGMA8_ROOT = Path("/mnt/public/data/Scoliosis1K/Scoliosis1K_sigma_8.0/pkl") SIGMA15_ROOT = Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-sigma15") MIXED_ROOT = Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-sigma15-joint8-fixed") SIL_ROOT = Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-sil-pkl") FloatArray = npt.NDArray[np.float32] @dataclass(frozen=True) class SampleSpec: sequence_id: str label: str view: str def _normalize_image(array: npt.NDArray[np.generic]) -> FloatArray: image = np.asarray(array, dtype=np.float32) if image.max() > 1.0: image = image / 255.0 return image def _load_pickle(path: Path) -> FloatArray: with path.open("rb") as handle: return _normalize_image(pickle.load(handle)) def _load_sigma8(spec: SampleSpec) -> FloatArray: return _load_pickle(SIGMA8_ROOT / spec.sequence_id / spec.label / spec.view / f"{spec.view}.pkl") def _load_sigma15(spec: SampleSpec) -> FloatArray: return _load_pickle(SIGMA15_ROOT / spec.sequence_id / spec.label / spec.view / "0_heatmap.pkl") def _load_silhouette(spec: SampleSpec) -> FloatArray: return _load_pickle(SIL_ROOT / spec.sequence_id / spec.label / spec.view / f"{spec.view}.pkl") def _load_mixed(spec: SampleSpec) -> FloatArray: return _load_pickle(MIXED_ROOT / spec.sequence_id / spec.label / spec.view / "0_heatmap.pkl") def _frame_indices(num_frames: int, count: int = 6) -> npt.NDArray[np.int_]: return np.linspace(0, num_frames - 1, num=count, dtype=int) def _combined_map(heatmaps: FloatArray) -> FloatArray: return np.maximum(heatmaps[:, 0], heatmaps[:, 1]) def _save_combined_figure( spec: SampleSpec, silhouette: FloatArray, sigma8: FloatArray, sigma15: FloatArray, mixed: FloatArray, ) -> Path: frames = _frame_indices(silhouette.shape[0], count=6) fig, axes = plt.subplots(4, len(frames), figsize=(2.4 * len(frames), 8.8), dpi=200) row_titles = ( "Silhouette", "Sigma 8 Combined", "Sigma 1.5 Combined", "Limb 1.5 / Joint 8 Combined", ) images = (silhouette, _combined_map(sigma8), _combined_map(sigma15), _combined_map(mixed)) for row_idx, (title, row_images) in enumerate(zip(row_titles, images, strict=True)): for col_idx, frame_idx in enumerate(frames): ax = axes[row_idx, col_idx] ax.imshow(row_images[frame_idx], cmap="gray", vmin=0.0, vmax=1.0) ax.set_xticks([]) ax.set_yticks([]) if row_idx == 0: ax.set_title(f"Frame {frame_idx}", fontsize=9) if col_idx == 0: ax.set_ylabel(title, fontsize=10) fig.suptitle( f"Sample {spec.sequence_id}/{spec.label}/{spec.view}: silhouette vs sigma heatmaps", fontsize=12, ) fig.tight_layout() output_path = RESEARCH_ROOT / f"{spec.sequence_id}_{spec.label}_{spec.view}_combined.png" fig.savefig(output_path, bbox_inches="tight") plt.close(fig) return output_path def _save_channel_figure(spec: SampleSpec, sigma8: FloatArray, sigma15: FloatArray) -> Path: frames = _frame_indices(sigma8.shape[0], count=4) fig, axes = plt.subplots(4, len(frames), figsize=(2.6 * len(frames), 8.4), dpi=200) row_titles = ( "Sigma 8 Limb", "Sigma 8 Joint", "Sigma 1.5 Limb", "Sigma 1.5 Joint", ) images = (sigma8[:, 0], sigma8[:, 1], sigma15[:, 0], sigma15[:, 1]) for row_idx, (title, row_images) in enumerate(zip(row_titles, images, strict=True)): vmax = float(np.max(row_images)) if np.max(row_images) > 0 else 1.0 for col_idx, frame_idx in enumerate(frames): ax = axes[row_idx, col_idx] ax.imshow(row_images[frame_idx], cmap="magma", vmin=0.0, vmax=vmax) ax.set_xticks([]) ax.set_yticks([]) if row_idx == 0: ax.set_title(f"Frame {frame_idx}", fontsize=9) if col_idx == 0: ax.set_ylabel(title, fontsize=10) fig.suptitle( f"Sample {spec.sequence_id}/{spec.label}/{spec.view}: per-channel sigma comparison", fontsize=12, ) fig.tight_layout() output_path = RESEARCH_ROOT / f"{spec.sequence_id}_{spec.label}_{spec.view}_channels.png" fig.savefig(output_path, bbox_inches="tight") plt.close(fig) return output_path def _save_mixed_channel_figure( spec: SampleSpec, sigma8: FloatArray, sigma15: FloatArray, mixed: FloatArray, ) -> Path: frames = _frame_indices(sigma8.shape[0], count=4) fig, axes = plt.subplots(6, len(frames), figsize=(2.6 * len(frames), 12.0), dpi=200) row_titles = ( "Sigma 8 Limb", "Sigma 8 Joint", "Sigma 1.5 Limb", "Sigma 1.5 Joint", "Mixed Limb 1.5", "Mixed Joint 8", ) images = (sigma8[:, 0], sigma8[:, 1], sigma15[:, 0], sigma15[:, 1], mixed[:, 0], mixed[:, 1]) for row_idx, (title, row_images) in enumerate(zip(row_titles, images, strict=True)): vmax = float(np.max(row_images)) if np.max(row_images) > 0 else 1.0 for col_idx, frame_idx in enumerate(frames): ax = axes[row_idx, col_idx] ax.imshow(row_images[frame_idx], cmap="magma", vmin=0.0, vmax=vmax) ax.set_xticks([]) ax.set_yticks([]) if row_idx == 0: ax.set_title(f"Frame {frame_idx}", fontsize=9) if col_idx == 0: ax.set_ylabel(title, fontsize=10) fig.suptitle( f"Sample {spec.sequence_id}/{spec.label}/{spec.view}: sigma and mixed per-channel comparison", fontsize=12, ) fig.tight_layout() output_path = RESEARCH_ROOT / f"{spec.sequence_id}_{spec.label}_{spec.view}_channels_mixed.png" fig.savefig(output_path, bbox_inches="tight") plt.close(fig) return output_path def main() -> None: RESEARCH_ROOT.mkdir(parents=True, exist_ok=True) sample = SampleSpec(sequence_id="00000", label="positive", view="000_180") sigma8 = _load_sigma8(sample) sigma15 = _load_sigma15(sample) mixed = _load_mixed(sample) silhouette = _load_silhouette(sample) combined_path = _save_combined_figure(sample, silhouette, sigma8, sigma15, mixed) channels_path = _save_channel_figure(sample, sigma8, sigma15) mixed_channels_path = _save_mixed_channel_figure(sample, sigma8, sigma15, mixed) print(combined_path) print(channels_path) print(mixed_channels_path) if __name__ == "__main__": main()