200 lines
6.7 KiB
Python
200 lines
6.7 KiB
Python
from __future__ import annotations
|
|
|
|
import pickle
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
import matplotlib
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import numpy.typing as npt
|
|
|
|
matplotlib.use("Agg")
|
|
|
|
REPO_ROOT = Path(__file__).resolve().parents[1]
|
|
RESEARCH_ROOT = REPO_ROOT / "research/sigma_comparison"
|
|
SIGMA8_ROOT = Path("/mnt/public/data/Scoliosis1K/Scoliosis1K_sigma_8.0/pkl")
|
|
SIGMA15_ROOT = Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-sigma15")
|
|
MIXED_ROOT = Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-sigma15-joint8-fixed")
|
|
SIL_ROOT = Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-sil-pkl")
|
|
|
|
FloatArray = npt.NDArray[np.float32]
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class SampleSpec:
|
|
sequence_id: str
|
|
label: str
|
|
view: str
|
|
|
|
|
|
def _normalize_image(array: npt.NDArray[np.generic]) -> FloatArray:
|
|
image = np.asarray(array, dtype=np.float32)
|
|
if image.max() > 1.0:
|
|
image = image / 255.0
|
|
return image
|
|
|
|
|
|
def _load_pickle(path: Path) -> FloatArray:
|
|
with path.open("rb") as handle:
|
|
return _normalize_image(pickle.load(handle))
|
|
|
|
|
|
def _load_sigma8(spec: SampleSpec) -> FloatArray:
|
|
return _load_pickle(SIGMA8_ROOT / spec.sequence_id / spec.label / spec.view / f"{spec.view}.pkl")
|
|
|
|
|
|
def _load_sigma15(spec: SampleSpec) -> FloatArray:
|
|
return _load_pickle(SIGMA15_ROOT / spec.sequence_id / spec.label / spec.view / "0_heatmap.pkl")
|
|
|
|
|
|
def _load_silhouette(spec: SampleSpec) -> FloatArray:
|
|
return _load_pickle(SIL_ROOT / spec.sequence_id / spec.label / spec.view / f"{spec.view}.pkl")
|
|
|
|
|
|
def _load_mixed(spec: SampleSpec) -> FloatArray:
|
|
return _load_pickle(MIXED_ROOT / spec.sequence_id / spec.label / spec.view / "0_heatmap.pkl")
|
|
|
|
|
|
def _frame_indices(num_frames: int, count: int = 6) -> npt.NDArray[np.int_]:
|
|
return np.linspace(0, num_frames - 1, num=count, dtype=int)
|
|
|
|
|
|
def _combined_map(heatmaps: FloatArray) -> FloatArray:
|
|
return np.maximum(heatmaps[:, 0], heatmaps[:, 1])
|
|
|
|
|
|
def _save_combined_figure(
|
|
spec: SampleSpec,
|
|
silhouette: FloatArray,
|
|
sigma8: FloatArray,
|
|
sigma15: FloatArray,
|
|
mixed: FloatArray,
|
|
) -> Path:
|
|
frames = _frame_indices(silhouette.shape[0], count=6)
|
|
fig, axes = plt.subplots(4, len(frames), figsize=(2.4 * len(frames), 8.8), dpi=200)
|
|
row_titles = (
|
|
"Silhouette",
|
|
"Sigma 8 Combined",
|
|
"Sigma 1.5 Combined",
|
|
"Limb 1.5 / Joint 8 Combined",
|
|
)
|
|
images = (silhouette, _combined_map(sigma8), _combined_map(sigma15), _combined_map(mixed))
|
|
|
|
for row_idx, (title, row_images) in enumerate(zip(row_titles, images, strict=True)):
|
|
for col_idx, frame_idx in enumerate(frames):
|
|
ax = axes[row_idx, col_idx]
|
|
ax.imshow(row_images[frame_idx], cmap="gray", vmin=0.0, vmax=1.0)
|
|
ax.set_xticks([])
|
|
ax.set_yticks([])
|
|
if row_idx == 0:
|
|
ax.set_title(f"Frame {frame_idx}", fontsize=9)
|
|
if col_idx == 0:
|
|
ax.set_ylabel(title, fontsize=10)
|
|
|
|
fig.suptitle(
|
|
f"Sample {spec.sequence_id}/{spec.label}/{spec.view}: silhouette vs sigma heatmaps",
|
|
fontsize=12,
|
|
)
|
|
fig.tight_layout()
|
|
output_path = RESEARCH_ROOT / f"{spec.sequence_id}_{spec.label}_{spec.view}_combined.png"
|
|
fig.savefig(output_path, bbox_inches="tight")
|
|
plt.close(fig)
|
|
return output_path
|
|
|
|
|
|
def _save_channel_figure(spec: SampleSpec, sigma8: FloatArray, sigma15: FloatArray) -> Path:
|
|
frames = _frame_indices(sigma8.shape[0], count=4)
|
|
fig, axes = plt.subplots(4, len(frames), figsize=(2.6 * len(frames), 8.4), dpi=200)
|
|
row_titles = (
|
|
"Sigma 8 Limb",
|
|
"Sigma 8 Joint",
|
|
"Sigma 1.5 Limb",
|
|
"Sigma 1.5 Joint",
|
|
)
|
|
images = (sigma8[:, 0], sigma8[:, 1], sigma15[:, 0], sigma15[:, 1])
|
|
|
|
for row_idx, (title, row_images) in enumerate(zip(row_titles, images, strict=True)):
|
|
vmax = float(np.max(row_images)) if np.max(row_images) > 0 else 1.0
|
|
for col_idx, frame_idx in enumerate(frames):
|
|
ax = axes[row_idx, col_idx]
|
|
ax.imshow(row_images[frame_idx], cmap="magma", vmin=0.0, vmax=vmax)
|
|
ax.set_xticks([])
|
|
ax.set_yticks([])
|
|
if row_idx == 0:
|
|
ax.set_title(f"Frame {frame_idx}", fontsize=9)
|
|
if col_idx == 0:
|
|
ax.set_ylabel(title, fontsize=10)
|
|
|
|
fig.suptitle(
|
|
f"Sample {spec.sequence_id}/{spec.label}/{spec.view}: per-channel sigma comparison",
|
|
fontsize=12,
|
|
)
|
|
fig.tight_layout()
|
|
output_path = RESEARCH_ROOT / f"{spec.sequence_id}_{spec.label}_{spec.view}_channels.png"
|
|
fig.savefig(output_path, bbox_inches="tight")
|
|
plt.close(fig)
|
|
return output_path
|
|
|
|
|
|
def _save_mixed_channel_figure(
|
|
spec: SampleSpec,
|
|
sigma8: FloatArray,
|
|
sigma15: FloatArray,
|
|
mixed: FloatArray,
|
|
) -> Path:
|
|
frames = _frame_indices(sigma8.shape[0], count=4)
|
|
fig, axes = plt.subplots(6, len(frames), figsize=(2.6 * len(frames), 12.0), dpi=200)
|
|
row_titles = (
|
|
"Sigma 8 Limb",
|
|
"Sigma 8 Joint",
|
|
"Sigma 1.5 Limb",
|
|
"Sigma 1.5 Joint",
|
|
"Mixed Limb 1.5",
|
|
"Mixed Joint 8",
|
|
)
|
|
images = (sigma8[:, 0], sigma8[:, 1], sigma15[:, 0], sigma15[:, 1], mixed[:, 0], mixed[:, 1])
|
|
|
|
for row_idx, (title, row_images) in enumerate(zip(row_titles, images, strict=True)):
|
|
vmax = float(np.max(row_images)) if np.max(row_images) > 0 else 1.0
|
|
for col_idx, frame_idx in enumerate(frames):
|
|
ax = axes[row_idx, col_idx]
|
|
ax.imshow(row_images[frame_idx], cmap="magma", vmin=0.0, vmax=vmax)
|
|
ax.set_xticks([])
|
|
ax.set_yticks([])
|
|
if row_idx == 0:
|
|
ax.set_title(f"Frame {frame_idx}", fontsize=9)
|
|
if col_idx == 0:
|
|
ax.set_ylabel(title, fontsize=10)
|
|
|
|
fig.suptitle(
|
|
f"Sample {spec.sequence_id}/{spec.label}/{spec.view}: sigma and mixed per-channel comparison",
|
|
fontsize=12,
|
|
)
|
|
fig.tight_layout()
|
|
output_path = RESEARCH_ROOT / f"{spec.sequence_id}_{spec.label}_{spec.view}_channels_mixed.png"
|
|
fig.savefig(output_path, bbox_inches="tight")
|
|
plt.close(fig)
|
|
return output_path
|
|
|
|
|
|
def main() -> None:
|
|
RESEARCH_ROOT.mkdir(parents=True, exist_ok=True)
|
|
sample = SampleSpec(sequence_id="00000", label="positive", view="000_180")
|
|
sigma8 = _load_sigma8(sample)
|
|
sigma15 = _load_sigma15(sample)
|
|
mixed = _load_mixed(sample)
|
|
silhouette = _load_silhouette(sample)
|
|
|
|
combined_path = _save_combined_figure(sample, silhouette, sigma8, sigma15, mixed)
|
|
channels_path = _save_channel_figure(sample, sigma8, sigma15)
|
|
mixed_channels_path = _save_mixed_channel_figure(sample, sigma8, sigma15, mixed)
|
|
|
|
print(combined_path)
|
|
print(channels_path)
|
|
print(mixed_channels_path)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|