Files
OpenGait/scripts/visualize_sigma_samples.py

200 lines
6.7 KiB
Python

from __future__ import annotations
import pickle
from dataclasses import dataclass
from pathlib import Path
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
matplotlib.use("Agg")
REPO_ROOT = Path(__file__).resolve().parents[1]
RESEARCH_ROOT = REPO_ROOT / "research/sigma_comparison"
SIGMA8_ROOT = Path("/mnt/public/data/Scoliosis1K/Scoliosis1K_sigma_8.0/pkl")
SIGMA15_ROOT = Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-sigma15")
MIXED_ROOT = Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-sigma15-joint8-fixed")
SIL_ROOT = Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-sil-pkl")
FloatArray = npt.NDArray[np.float32]
@dataclass(frozen=True)
class SampleSpec:
sequence_id: str
label: str
view: str
def _normalize_image(array: npt.NDArray[np.generic]) -> FloatArray:
image = np.asarray(array, dtype=np.float32)
if image.max() > 1.0:
image = image / 255.0
return image
def _load_pickle(path: Path) -> FloatArray:
with path.open("rb") as handle:
return _normalize_image(pickle.load(handle))
def _load_sigma8(spec: SampleSpec) -> FloatArray:
return _load_pickle(SIGMA8_ROOT / spec.sequence_id / spec.label / spec.view / f"{spec.view}.pkl")
def _load_sigma15(spec: SampleSpec) -> FloatArray:
return _load_pickle(SIGMA15_ROOT / spec.sequence_id / spec.label / spec.view / "0_heatmap.pkl")
def _load_silhouette(spec: SampleSpec) -> FloatArray:
return _load_pickle(SIL_ROOT / spec.sequence_id / spec.label / spec.view / f"{spec.view}.pkl")
def _load_mixed(spec: SampleSpec) -> FloatArray:
return _load_pickle(MIXED_ROOT / spec.sequence_id / spec.label / spec.view / "0_heatmap.pkl")
def _frame_indices(num_frames: int, count: int = 6) -> npt.NDArray[np.int_]:
return np.linspace(0, num_frames - 1, num=count, dtype=int)
def _combined_map(heatmaps: FloatArray) -> FloatArray:
return np.maximum(heatmaps[:, 0], heatmaps[:, 1])
def _save_combined_figure(
spec: SampleSpec,
silhouette: FloatArray,
sigma8: FloatArray,
sigma15: FloatArray,
mixed: FloatArray,
) -> Path:
frames = _frame_indices(silhouette.shape[0], count=6)
fig, axes = plt.subplots(4, len(frames), figsize=(2.4 * len(frames), 8.8), dpi=200)
row_titles = (
"Silhouette",
"Sigma 8 Combined",
"Sigma 1.5 Combined",
"Limb 1.5 / Joint 8 Combined",
)
images = (silhouette, _combined_map(sigma8), _combined_map(sigma15), _combined_map(mixed))
for row_idx, (title, row_images) in enumerate(zip(row_titles, images, strict=True)):
for col_idx, frame_idx in enumerate(frames):
ax = axes[row_idx, col_idx]
ax.imshow(row_images[frame_idx], cmap="gray", vmin=0.0, vmax=1.0)
ax.set_xticks([])
ax.set_yticks([])
if row_idx == 0:
ax.set_title(f"Frame {frame_idx}", fontsize=9)
if col_idx == 0:
ax.set_ylabel(title, fontsize=10)
fig.suptitle(
f"Sample {spec.sequence_id}/{spec.label}/{spec.view}: silhouette vs sigma heatmaps",
fontsize=12,
)
fig.tight_layout()
output_path = RESEARCH_ROOT / f"{spec.sequence_id}_{spec.label}_{spec.view}_combined.png"
fig.savefig(output_path, bbox_inches="tight")
plt.close(fig)
return output_path
def _save_channel_figure(spec: SampleSpec, sigma8: FloatArray, sigma15: FloatArray) -> Path:
frames = _frame_indices(sigma8.shape[0], count=4)
fig, axes = plt.subplots(4, len(frames), figsize=(2.6 * len(frames), 8.4), dpi=200)
row_titles = (
"Sigma 8 Limb",
"Sigma 8 Joint",
"Sigma 1.5 Limb",
"Sigma 1.5 Joint",
)
images = (sigma8[:, 0], sigma8[:, 1], sigma15[:, 0], sigma15[:, 1])
for row_idx, (title, row_images) in enumerate(zip(row_titles, images, strict=True)):
vmax = float(np.max(row_images)) if np.max(row_images) > 0 else 1.0
for col_idx, frame_idx in enumerate(frames):
ax = axes[row_idx, col_idx]
ax.imshow(row_images[frame_idx], cmap="magma", vmin=0.0, vmax=vmax)
ax.set_xticks([])
ax.set_yticks([])
if row_idx == 0:
ax.set_title(f"Frame {frame_idx}", fontsize=9)
if col_idx == 0:
ax.set_ylabel(title, fontsize=10)
fig.suptitle(
f"Sample {spec.sequence_id}/{spec.label}/{spec.view}: per-channel sigma comparison",
fontsize=12,
)
fig.tight_layout()
output_path = RESEARCH_ROOT / f"{spec.sequence_id}_{spec.label}_{spec.view}_channels.png"
fig.savefig(output_path, bbox_inches="tight")
plt.close(fig)
return output_path
def _save_mixed_channel_figure(
spec: SampleSpec,
sigma8: FloatArray,
sigma15: FloatArray,
mixed: FloatArray,
) -> Path:
frames = _frame_indices(sigma8.shape[0], count=4)
fig, axes = plt.subplots(6, len(frames), figsize=(2.6 * len(frames), 12.0), dpi=200)
row_titles = (
"Sigma 8 Limb",
"Sigma 8 Joint",
"Sigma 1.5 Limb",
"Sigma 1.5 Joint",
"Mixed Limb 1.5",
"Mixed Joint 8",
)
images = (sigma8[:, 0], sigma8[:, 1], sigma15[:, 0], sigma15[:, 1], mixed[:, 0], mixed[:, 1])
for row_idx, (title, row_images) in enumerate(zip(row_titles, images, strict=True)):
vmax = float(np.max(row_images)) if np.max(row_images) > 0 else 1.0
for col_idx, frame_idx in enumerate(frames):
ax = axes[row_idx, col_idx]
ax.imshow(row_images[frame_idx], cmap="magma", vmin=0.0, vmax=vmax)
ax.set_xticks([])
ax.set_yticks([])
if row_idx == 0:
ax.set_title(f"Frame {frame_idx}", fontsize=9)
if col_idx == 0:
ax.set_ylabel(title, fontsize=10)
fig.suptitle(
f"Sample {spec.sequence_id}/{spec.label}/{spec.view}: sigma and mixed per-channel comparison",
fontsize=12,
)
fig.tight_layout()
output_path = RESEARCH_ROOT / f"{spec.sequence_id}_{spec.label}_{spec.view}_channels_mixed.png"
fig.savefig(output_path, bbox_inches="tight")
plt.close(fig)
return output_path
def main() -> None:
RESEARCH_ROOT.mkdir(parents=True, exist_ok=True)
sample = SampleSpec(sequence_id="00000", label="positive", view="000_180")
sigma8 = _load_sigma8(sample)
sigma15 = _load_sigma15(sample)
mixed = _load_mixed(sample)
silhouette = _load_silhouette(sample)
combined_path = _save_combined_figure(sample, silhouette, sigma8, sigma15, mixed)
channels_path = _save_channel_figure(sample, sigma8, sigma15)
mixed_channels_path = _save_mixed_channel_figure(sample, sigma8, sigma15, mixed)
print(combined_path)
print(channels_path)
print(mixed_channels_path)
if __name__ == "__main__":
main()