Add scoliosis diagnostics and experiment logging
This commit is contained in:
@@ -0,0 +1,199 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
matplotlib.use("Agg")
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
RESEARCH_ROOT = REPO_ROOT / "research/sigma_comparison"
|
||||
SIGMA8_ROOT = Path("/mnt/public/data/Scoliosis1K/Scoliosis1K_sigma_8.0/pkl")
|
||||
SIGMA15_ROOT = Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-sigma15")
|
||||
MIXED_ROOT = Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-sigma15-joint8-fixed")
|
||||
SIL_ROOT = Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-sil-pkl")
|
||||
|
||||
FloatArray = npt.NDArray[np.float32]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SampleSpec:
|
||||
sequence_id: str
|
||||
label: str
|
||||
view: str
|
||||
|
||||
|
||||
def _normalize_image(array: npt.NDArray[np.generic]) -> FloatArray:
|
||||
image = np.asarray(array, dtype=np.float32)
|
||||
if image.max() > 1.0:
|
||||
image = image / 255.0
|
||||
return image
|
||||
|
||||
|
||||
def _load_pickle(path: Path) -> FloatArray:
|
||||
with path.open("rb") as handle:
|
||||
return _normalize_image(pickle.load(handle))
|
||||
|
||||
|
||||
def _load_sigma8(spec: SampleSpec) -> FloatArray:
|
||||
return _load_pickle(SIGMA8_ROOT / spec.sequence_id / spec.label / spec.view / f"{spec.view}.pkl")
|
||||
|
||||
|
||||
def _load_sigma15(spec: SampleSpec) -> FloatArray:
|
||||
return _load_pickle(SIGMA15_ROOT / spec.sequence_id / spec.label / spec.view / "0_heatmap.pkl")
|
||||
|
||||
|
||||
def _load_silhouette(spec: SampleSpec) -> FloatArray:
|
||||
return _load_pickle(SIL_ROOT / spec.sequence_id / spec.label / spec.view / f"{spec.view}.pkl")
|
||||
|
||||
|
||||
def _load_mixed(spec: SampleSpec) -> FloatArray:
|
||||
return _load_pickle(MIXED_ROOT / spec.sequence_id / spec.label / spec.view / "0_heatmap.pkl")
|
||||
|
||||
|
||||
def _frame_indices(num_frames: int, count: int = 6) -> npt.NDArray[np.int_]:
|
||||
return np.linspace(0, num_frames - 1, num=count, dtype=int)
|
||||
|
||||
|
||||
def _combined_map(heatmaps: FloatArray) -> FloatArray:
|
||||
return np.maximum(heatmaps[:, 0], heatmaps[:, 1])
|
||||
|
||||
|
||||
def _save_combined_figure(
|
||||
spec: SampleSpec,
|
||||
silhouette: FloatArray,
|
||||
sigma8: FloatArray,
|
||||
sigma15: FloatArray,
|
||||
mixed: FloatArray,
|
||||
) -> Path:
|
||||
frames = _frame_indices(silhouette.shape[0], count=6)
|
||||
fig, axes = plt.subplots(4, len(frames), figsize=(2.4 * len(frames), 8.8), dpi=200)
|
||||
row_titles = (
|
||||
"Silhouette",
|
||||
"Sigma 8 Combined",
|
||||
"Sigma 1.5 Combined",
|
||||
"Limb 1.5 / Joint 8 Combined",
|
||||
)
|
||||
images = (silhouette, _combined_map(sigma8), _combined_map(sigma15), _combined_map(mixed))
|
||||
|
||||
for row_idx, (title, row_images) in enumerate(zip(row_titles, images, strict=True)):
|
||||
for col_idx, frame_idx in enumerate(frames):
|
||||
ax = axes[row_idx, col_idx]
|
||||
ax.imshow(row_images[frame_idx], cmap="gray", vmin=0.0, vmax=1.0)
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
if row_idx == 0:
|
||||
ax.set_title(f"Frame {frame_idx}", fontsize=9)
|
||||
if col_idx == 0:
|
||||
ax.set_ylabel(title, fontsize=10)
|
||||
|
||||
fig.suptitle(
|
||||
f"Sample {spec.sequence_id}/{spec.label}/{spec.view}: silhouette vs sigma heatmaps",
|
||||
fontsize=12,
|
||||
)
|
||||
fig.tight_layout()
|
||||
output_path = RESEARCH_ROOT / f"{spec.sequence_id}_{spec.label}_{spec.view}_combined.png"
|
||||
fig.savefig(output_path, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
return output_path
|
||||
|
||||
|
||||
def _save_channel_figure(spec: SampleSpec, sigma8: FloatArray, sigma15: FloatArray) -> Path:
|
||||
frames = _frame_indices(sigma8.shape[0], count=4)
|
||||
fig, axes = plt.subplots(4, len(frames), figsize=(2.6 * len(frames), 8.4), dpi=200)
|
||||
row_titles = (
|
||||
"Sigma 8 Limb",
|
||||
"Sigma 8 Joint",
|
||||
"Sigma 1.5 Limb",
|
||||
"Sigma 1.5 Joint",
|
||||
)
|
||||
images = (sigma8[:, 0], sigma8[:, 1], sigma15[:, 0], sigma15[:, 1])
|
||||
|
||||
for row_idx, (title, row_images) in enumerate(zip(row_titles, images, strict=True)):
|
||||
vmax = float(np.max(row_images)) if np.max(row_images) > 0 else 1.0
|
||||
for col_idx, frame_idx in enumerate(frames):
|
||||
ax = axes[row_idx, col_idx]
|
||||
ax.imshow(row_images[frame_idx], cmap="magma", vmin=0.0, vmax=vmax)
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
if row_idx == 0:
|
||||
ax.set_title(f"Frame {frame_idx}", fontsize=9)
|
||||
if col_idx == 0:
|
||||
ax.set_ylabel(title, fontsize=10)
|
||||
|
||||
fig.suptitle(
|
||||
f"Sample {spec.sequence_id}/{spec.label}/{spec.view}: per-channel sigma comparison",
|
||||
fontsize=12,
|
||||
)
|
||||
fig.tight_layout()
|
||||
output_path = RESEARCH_ROOT / f"{spec.sequence_id}_{spec.label}_{spec.view}_channels.png"
|
||||
fig.savefig(output_path, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
return output_path
|
||||
|
||||
|
||||
def _save_mixed_channel_figure(
|
||||
spec: SampleSpec,
|
||||
sigma8: FloatArray,
|
||||
sigma15: FloatArray,
|
||||
mixed: FloatArray,
|
||||
) -> Path:
|
||||
frames = _frame_indices(sigma8.shape[0], count=4)
|
||||
fig, axes = plt.subplots(6, len(frames), figsize=(2.6 * len(frames), 12.0), dpi=200)
|
||||
row_titles = (
|
||||
"Sigma 8 Limb",
|
||||
"Sigma 8 Joint",
|
||||
"Sigma 1.5 Limb",
|
||||
"Sigma 1.5 Joint",
|
||||
"Mixed Limb 1.5",
|
||||
"Mixed Joint 8",
|
||||
)
|
||||
images = (sigma8[:, 0], sigma8[:, 1], sigma15[:, 0], sigma15[:, 1], mixed[:, 0], mixed[:, 1])
|
||||
|
||||
for row_idx, (title, row_images) in enumerate(zip(row_titles, images, strict=True)):
|
||||
vmax = float(np.max(row_images)) if np.max(row_images) > 0 else 1.0
|
||||
for col_idx, frame_idx in enumerate(frames):
|
||||
ax = axes[row_idx, col_idx]
|
||||
ax.imshow(row_images[frame_idx], cmap="magma", vmin=0.0, vmax=vmax)
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
if row_idx == 0:
|
||||
ax.set_title(f"Frame {frame_idx}", fontsize=9)
|
||||
if col_idx == 0:
|
||||
ax.set_ylabel(title, fontsize=10)
|
||||
|
||||
fig.suptitle(
|
||||
f"Sample {spec.sequence_id}/{spec.label}/{spec.view}: sigma and mixed per-channel comparison",
|
||||
fontsize=12,
|
||||
)
|
||||
fig.tight_layout()
|
||||
output_path = RESEARCH_ROOT / f"{spec.sequence_id}_{spec.label}_{spec.view}_channels_mixed.png"
|
||||
fig.savefig(output_path, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
return output_path
|
||||
|
||||
|
||||
def main() -> None:
|
||||
RESEARCH_ROOT.mkdir(parents=True, exist_ok=True)
|
||||
sample = SampleSpec(sequence_id="00000", label="positive", view="000_180")
|
||||
sigma8 = _load_sigma8(sample)
|
||||
sigma15 = _load_sigma15(sample)
|
||||
mixed = _load_mixed(sample)
|
||||
silhouette = _load_silhouette(sample)
|
||||
|
||||
combined_path = _save_combined_figure(sample, silhouette, sigma8, sigma15, mixed)
|
||||
channels_path = _save_channel_figure(sample, sigma8, sigma15)
|
||||
mixed_channels_path = _save_mixed_channel_figure(sample, sigma8, sigma15, mixed)
|
||||
|
||||
print(combined_path)
|
||||
print(channels_path)
|
||||
print(mixed_channels_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user