Add scoliosis diagnostics and experiment logging
This commit is contained in:
@@ -26,8 +26,9 @@ Use it for:
|
||||
| 2026-03-08 | `ScoNet_skeleton_118_sigma15_bs12x8` | ScoNet-MT-ske control | `Scoliosis1K-drf-pkl-118-sigma15` | Lowered skeleton-map sigma from `8.0` to `1.5` to tighten the pose rasterization | complete | `46.33 Acc / 68.09 Prec / 51.92 Rec / 44.69 F1` |
|
||||
| 2026-03-09 | `ScoNet_skeleton_118_sigma15_joint8_sharedalign_2gpu_bs12x8` | ScoNet-MT-ske control | `Scoliosis1K-drf-pkl-118-sigma15-joint8-sharedalign` | Fixed limb/joint channel misalignment, used mixed sigma `limb=1.5 / joint=8.0`, kept SGD | complete | `50.47 Acc / 69.31 Prec / 54.58 Rec / 48.63 F1` |
|
||||
| 2026-03-09 | `ScoNet_skeleton_118_sigma15_joint8_limb4_adamw_2gpu_bs12x8` | ScoNet-MT-ske control | `Scoliosis1K-drf-pkl-118-sigma15-joint8-sharedalign-limb4` | Rebalanced channel intensity with `limb_gain=4.0`; switched optimizer from `SGD` to `AdamW` | complete | `48.60 Acc / 65.97 Prec / 53.19 Rec / 46.41 F1` |
|
||||
| 2026-03-09 | `ScoNet_skeleton_118_sigma15_joint8_sharedalign_nocut_adamw_1gpu_bs8x8` | ScoNet-MT-ske control | `Scoliosis1K-drf-pkl-118-sigma15-joint8-sharedalign` | Switched runtime transform from `BaseSilCuttingTransform` to `BaseSilTransform` (`no-cut`), kept `AdamW`, reduced `8x8` due to 5070 Ti OOM at `12x8` | training | no eval yet |
|
||||
| 2026-03-09 | `ScoNet_skeleton_118_sigma15_joint8_sharedalign_nocut_adamw_proxy_1gpu` | ScoNet-MT-ske proxy | `Scoliosis1K-drf-pkl-118-sigma15-joint8-sharedalign` | Fast proxy route: `no-cut`, `AdamW`, `8x8`, `total_iter=2000`, `eval_iter=500`, `test_seq_subset_size=128` | training | no eval yet |
|
||||
| 2026-03-09 | `ScoNet_skeleton_118_sigma15_joint8_sharedalign_nocut_adamw_1gpu_bs8x8` | ScoNet-MT-ske control | `Scoliosis1K-drf-pkl-118-sigma15-joint8-sharedalign` | Switched runtime transform from `BaseSilCuttingTransform` to `BaseSilTransform` (`no-cut`), kept `AdamW`, reduced `8x8` due to 5070 Ti OOM at `12x8` | interrupted | superseded by proxy route before eval |
|
||||
| 2026-03-09 | `ScoNet_skeleton_118_sigma15_joint8_sharedalign_nocut_adamw_proxy_1gpu` | ScoNet-MT-ske proxy | `Scoliosis1K-drf-pkl-118-sigma15-joint8-sharedalign` | Fast proxy route: `no-cut`, `AdamW`, `8x8`, `total_iter=2000`, `eval_iter=500`, `test_seq_subset_size=128` | interrupted | superseded by geometry-fixed proxy before completion |
|
||||
| 2026-03-10 | `ScoNet_skeleton_118_sigma15_joint8_geomfix_proxy_1gpu` | ScoNet-MT-ske proxy | `Scoliosis1K-drf-pkl-118-sigma15-joint8-geomfix` | Geometry ablation: aspect-ratio-preserving crop+pad instead of square-warp resize; `AdamW`, `no-cut`, `8x8`, `total_iter=2000`, `eval_iter=500`, fixed test subset seed `118` | complete | proxy subset unstable: `500 24.22/8.07/33.33/13.00`, `1000 60.16/68.05/58.13/55.25`, `1500 26.56/58.33/35.64/17.68`, `2000 27.34/63.96/37.02/20.14` (Acc/Prec/Rec/F1) |
|
||||
|
||||
## Current best skeleton baseline
|
||||
|
||||
@@ -40,3 +41,4 @@ Current best `ScoNet-MT-ske`-style result:
|
||||
- `ckpt/ScoNet-20000-better.pt` is intentionally not listed here because it is a silhouette checkpoint, not a skeleton-map run.
|
||||
- `DRF` runs are included because they are part of the same reproduction/debugging loop, but this log should stay focused on train/eval changes, not broader code refactors.
|
||||
- The long `ScoNet_skeleton_118_sigma15_joint8_sharedalign_nocut_adamw_1gpu_bs8x8` run was intentionally interrupted and superseded by the shorter proxy run once fast-iteration support was added.
|
||||
- The geometry-fixed proxy run fit the train split quickly but did not produce a stable proxy validation curve, so it should not be promoted to a full 20k run.
|
||||
|
||||
@@ -0,0 +1,778 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
import sys
|
||||
import types
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
import click
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from jaxtyping import Float
|
||||
from numpy.typing import NDArray
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
OPENGAIT_ROOT = REPO_ROOT / "opengait"
|
||||
if str(OPENGAIT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(OPENGAIT_ROOT))
|
||||
if str(REPO_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
|
||||
from opengait.data.transform import get_transform
|
||||
from opengait.modeling import models
|
||||
from opengait.modeling.base_model import BaseModel
|
||||
from opengait.modeling.models.drf import DRF
|
||||
from opengait.utils.common import config_loader
|
||||
|
||||
|
||||
FloatArray = NDArray[np.float32]
|
||||
LABEL_NAMES = ("negative", "neutral", "positive")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModelSpec:
|
||||
key: str
|
||||
title: str
|
||||
config_path: Path
|
||||
checkpoint_path: Path
|
||||
data_root: Path
|
||||
is_drf: bool
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class VisualizationResult:
|
||||
title: str
|
||||
predicted_label: str
|
||||
class_scores: tuple[float, float, float]
|
||||
input_image: FloatArray
|
||||
response_map: FloatArray
|
||||
overlay_image: FloatArray
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StageVisualizationResult:
|
||||
title: str
|
||||
predicted_label: str
|
||||
class_scores: tuple[float, float, float]
|
||||
input_image: FloatArray
|
||||
stage_overlays: dict[str, FloatArray]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ForwardDiagnostics:
|
||||
feature_map: Float[torch.Tensor, "batch channels height width"]
|
||||
pooled: Float[torch.Tensor, "batch channels parts"]
|
||||
embeddings: Float[torch.Tensor, "batch channels parts"]
|
||||
logits: Float[torch.Tensor, "batch classes parts"]
|
||||
class_scores: Float[torch.Tensor, "batch classes"]
|
||||
pga_spatial: Float[torch.Tensor, "batch parts"] | None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModelDiagnosticResult:
|
||||
title: str
|
||||
predicted_label: str
|
||||
class_scores: tuple[float, float, float]
|
||||
input_image: FloatArray
|
||||
occlusion_map: FloatArray
|
||||
overlay_image: FloatArray
|
||||
part_scores: FloatArray
|
||||
pga_spatial: FloatArray | None
|
||||
|
||||
|
||||
def apply_transforms(
|
||||
sequence: np.ndarray,
|
||||
transform_cfg: list[dict[str, Any]] | dict[str, Any] | None,
|
||||
) -> np.ndarray:
|
||||
transform = get_transform(transform_cfg)
|
||||
if isinstance(transform, list):
|
||||
result = sequence
|
||||
for item in transform:
|
||||
result = item(result)
|
||||
return np.asarray(result)
|
||||
return np.asarray(transform(sequence))
|
||||
|
||||
|
||||
def build_minimal_model(cfgs: dict[str, Any]) -> nn.Module:
|
||||
model_name = str(cfgs["model_cfg"]["model"])
|
||||
model_cls = getattr(models, model_name)
|
||||
model = model_cls.__new__(model_cls)
|
||||
nn.Module.__init__(model)
|
||||
model.get_backbone = types.MethodType(BaseModel.get_backbone, model)
|
||||
model.build_network(cfgs["model_cfg"])
|
||||
return model
|
||||
|
||||
|
||||
def load_checkpoint_model(spec: ModelSpec) -> tuple[nn.Module, dict[str, Any]]:
|
||||
cfgs = config_loader(str(spec.config_path))
|
||||
model = build_minimal_model(cfgs)
|
||||
checkpoint = torch.load(
|
||||
spec.checkpoint_path,
|
||||
map_location="cpu",
|
||||
weights_only=False,
|
||||
)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.eval()
|
||||
return model, cfgs
|
||||
|
||||
|
||||
def load_pickle(path: Path) -> np.ndarray:
|
||||
with path.open("rb") as handle:
|
||||
return np.asarray(pickle.load(handle))
|
||||
|
||||
|
||||
def load_sample_inputs(
|
||||
spec: ModelSpec,
|
||||
sequence_id: str,
|
||||
label: str,
|
||||
view: str,
|
||||
frame_count: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None, FloatArray]:
|
||||
if spec.is_drf:
|
||||
heatmap_path = spec.data_root / sequence_id / label / view / "0_heatmap.pkl"
|
||||
pav_path = spec.data_root / sequence_id / label / view / "1_pav.pkl"
|
||||
heatmaps = load_pickle(heatmap_path).astype(np.float32)
|
||||
pav = load_pickle(pav_path).astype(np.float32)
|
||||
return prepare_heatmap_inputs(heatmaps, pav, spec.config_path, frame_count)
|
||||
|
||||
if "skeleton" in spec.key:
|
||||
heatmap_path = spec.data_root / sequence_id / label / view / "0_heatmap.pkl"
|
||||
heatmaps = load_pickle(heatmap_path).astype(np.float32)
|
||||
return prepare_heatmap_inputs(heatmaps, None, spec.config_path, frame_count)
|
||||
|
||||
silhouette_path = spec.data_root / sequence_id / label / view / f"{view}.pkl"
|
||||
silhouettes = load_pickle(silhouette_path).astype(np.float32)
|
||||
return prepare_silhouette_inputs(silhouettes, spec.config_path, frame_count)
|
||||
|
||||
|
||||
def prepare_silhouette_inputs(
|
||||
silhouettes: FloatArray,
|
||||
config_path: Path,
|
||||
frame_count: int,
|
||||
) -> tuple[torch.Tensor, None, FloatArray]:
|
||||
cfgs = config_loader(str(config_path))
|
||||
transformed = apply_transforms(silhouettes, cfgs["evaluator_cfg"]["transform"]).astype(np.float32)
|
||||
indices = np.linspace(0, transformed.shape[0] - 1, frame_count, dtype=int)
|
||||
sampled = transformed[indices]
|
||||
input_tensor = torch.from_numpy(sampled).unsqueeze(0).unsqueeze(0).float()
|
||||
projection = normalize_image(sampled.max(axis=0))
|
||||
return input_tensor, None, projection
|
||||
|
||||
|
||||
def prepare_heatmap_inputs(
|
||||
heatmaps: FloatArray,
|
||||
pav: FloatArray | None,
|
||||
config_path: Path,
|
||||
frame_count: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None, FloatArray]:
|
||||
cfgs = config_loader(str(config_path))
|
||||
transform_cfg = cfgs["evaluator_cfg"]["transform"]
|
||||
heatmap_transform = transform_cfg[0] if isinstance(transform_cfg, list) else transform_cfg
|
||||
transformed = apply_transforms(heatmaps, heatmap_transform).astype(np.float32)
|
||||
indices = np.linspace(0, transformed.shape[0] - 1, frame_count, dtype=int)
|
||||
sampled = transformed[indices]
|
||||
input_tensor = torch.from_numpy(sampled).permute(1, 0, 2, 3).unsqueeze(0).float()
|
||||
combined = np.maximum(sampled.max(axis=0)[0], sampled.max(axis=0)[1])
|
||||
projection = normalize_image(combined)
|
||||
pav_tensor = None
|
||||
if pav is not None:
|
||||
pav_tensor = torch.from_numpy(pav.mean(axis=0, keepdims=True)).float()
|
||||
return input_tensor, pav_tensor, projection
|
||||
|
||||
|
||||
AttentionMethod = Literal["sum", "sum_p", "max_p"]
|
||||
ResponseMode = Literal["activation", "cam"]
|
||||
|
||||
|
||||
def compute_classification_scores(
|
||||
model: nn.Module,
|
||||
sequence: torch.Tensor,
|
||||
pav: torch.Tensor | None,
|
||||
) -> tuple[FloatArray, tuple[float, float, float], str]:
|
||||
diagnostics = forward_diagnostics(model, sequence, pav)
|
||||
class_scores = diagnostics.class_scores[0]
|
||||
pred_index = int(class_scores.argmax().detach().cpu())
|
||||
score_tuple = tuple(float(value.detach().cpu()) for value in class_scores)
|
||||
return diagnostics.feature_map.detach()[0], score_tuple, LABEL_NAMES[pred_index]
|
||||
|
||||
|
||||
def forward_diagnostics(
|
||||
model: nn.Module,
|
||||
sequence: torch.Tensor,
|
||||
pav: torch.Tensor | None,
|
||||
) -> ForwardDiagnostics:
|
||||
outputs = model.Backbone(sequence)
|
||||
feature_map = model.TP(outputs, None, options={"dim": 2})[0]
|
||||
pooled = model.HPP(feature_map)
|
||||
embeddings = model.FCs(pooled)
|
||||
|
||||
pga_spatial: Float[torch.Tensor, "batch parts"] | None = None
|
||||
if pav is not None and isinstance(model, DRF):
|
||||
pav_flat = pav.flatten(1)
|
||||
channel_att = model.PGA.channel_att(pav_flat).unsqueeze(-1)
|
||||
spatial_att = model.PGA.spatial_att(pav_flat.unsqueeze(-1)).transpose(1, 2)
|
||||
pga_spatial = spatial_att.squeeze(1)
|
||||
embeddings = embeddings * channel_att * spatial_att
|
||||
|
||||
_bn_embeddings, logits = model.BNNecks(embeddings)
|
||||
class_scores = logits.mean(-1)
|
||||
return ForwardDiagnostics(
|
||||
feature_map=feature_map,
|
||||
pooled=pooled,
|
||||
embeddings=embeddings,
|
||||
logits=logits,
|
||||
class_scores=class_scores,
|
||||
pga_spatial=pga_spatial,
|
||||
)
|
||||
|
||||
|
||||
def compute_response_map(
|
||||
model: nn.Module,
|
||||
sequence: torch.Tensor,
|
||||
pav: torch.Tensor | None,
|
||||
response_mode: ResponseMode,
|
||||
method: AttentionMethod,
|
||||
power: float,
|
||||
) -> tuple[FloatArray, tuple[float, float, float], str]:
|
||||
activation, score_tuple, predicted_label = compute_classification_scores(model, sequence, pav)
|
||||
if response_mode == "cam":
|
||||
response_map = adapted_cam_map(
|
||||
model=model,
|
||||
sequence=sequence,
|
||||
pav=pav,
|
||||
predicted_label=predicted_label,
|
||||
)
|
||||
else:
|
||||
response_map = activation_attention_map(activation, method=method, power=power)
|
||||
response_map = response_map.unsqueeze(0).unsqueeze(0)
|
||||
response_map = F.interpolate(
|
||||
response_map,
|
||||
size=sequence.shape[-2:],
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)[0, 0]
|
||||
response_np = response_map.detach().cpu().numpy().astype(np.float32)
|
||||
response_np = normalize_image(response_np)
|
||||
return response_np, score_tuple, predicted_label
|
||||
|
||||
|
||||
def activation_attention_map(
|
||||
activation: Float[torch.Tensor, "channels height width"],
|
||||
method: AttentionMethod,
|
||||
power: float,
|
||||
) -> Float[torch.Tensor, "height width"]:
|
||||
abs_activation = activation.abs()
|
||||
if method == "sum":
|
||||
return abs_activation.sum(dim=0)
|
||||
powered = abs_activation.pow(power)
|
||||
if method == "sum_p":
|
||||
return powered.sum(dim=0)
|
||||
if method == "max_p":
|
||||
return powered.max(dim=0).values
|
||||
raise ValueError(f"Unsupported attention method: {method}")
|
||||
|
||||
|
||||
def adapted_cam_map(
|
||||
model: nn.Module,
|
||||
sequence: torch.Tensor,
|
||||
pav: torch.Tensor | None,
|
||||
predicted_label: str,
|
||||
) -> Float[torch.Tensor, "height width"]:
|
||||
outputs = model.Backbone(sequence)
|
||||
feature_map = model.TP(outputs, None, options={"dim": 2})[0] # [1, c_in, h, w]
|
||||
pooled = model.HPP(feature_map) # [1, c_in, p]
|
||||
|
||||
pre_attention = model.FCs(pooled) # [1, c_out, p]
|
||||
attention = torch.ones_like(pre_attention)
|
||||
if pav is not None and isinstance(model, DRF):
|
||||
pav_flat = pav.flatten(1)
|
||||
channel_att = model.PGA.channel_att(pav_flat).unsqueeze(-1)
|
||||
spatial_att = model.PGA.spatial_att(pav_flat.unsqueeze(-1)).transpose(1, 2)
|
||||
attention = channel_att * spatial_att
|
||||
embeddings = pre_attention * attention
|
||||
|
||||
bn_scales = bn_feature_scales(model.BNNecks, embeddings.shape[1], embeddings.shape[2], embeddings.device)
|
||||
classifier_weights = F.normalize(model.BNNecks.fc_bin.detach(), dim=1)
|
||||
class_index = LABEL_NAMES.index(predicted_label)
|
||||
class_weights = classifier_weights[:, :, class_index] # [p, c_out]
|
||||
|
||||
normalized_embeddings = batchnorm_forward_eval(model.BNNecks, embeddings)[0].transpose(0, 1) # [p, c_out]
|
||||
part_norms = normalized_embeddings.norm(dim=-1, keepdim=True).clamp_min(1e-6)
|
||||
effective_post_bn = class_weights / part_norms
|
||||
effective_pre_bn = effective_post_bn * bn_scales
|
||||
|
||||
fc_weights = model.FCs.fc_bin.detach() # [p, c_in, c_out]
|
||||
attention_weights = attention.detach()[0].transpose(0, 1) # [p, c_out]
|
||||
effective_channel_weights = torch.einsum(
|
||||
"pkj,pj->pk",
|
||||
fc_weights,
|
||||
attention_weights * effective_pre_bn,
|
||||
) # [p, c_in]
|
||||
|
||||
feature = feature_map.detach()[0] # [c_in, h, w]
|
||||
parts_num = effective_channel_weights.shape[0]
|
||||
height = feature.shape[1]
|
||||
split_size = max(height // parts_num, 1)
|
||||
cam = torch.zeros((height, feature.shape[2]), device=feature.device, dtype=feature.dtype)
|
||||
for part_index in range(parts_num):
|
||||
row_start = part_index * split_size
|
||||
row_end = height if part_index == parts_num - 1 else min((part_index + 1) * split_size, height)
|
||||
if row_start >= height:
|
||||
break
|
||||
contribution = (feature[:, row_start:row_end, :] * effective_channel_weights[part_index][:, None, None]).sum(dim=0)
|
||||
cam[row_start:row_end, :] = contribution
|
||||
return F.relu(cam)
|
||||
|
||||
|
||||
def batchnorm_forward_eval(
|
||||
bn_necks: nn.Module,
|
||||
embeddings: Float[torch.Tensor, "batch channels parts"],
|
||||
) -> Float[torch.Tensor, "batch channels parts"]:
|
||||
batch, channels, parts = embeddings.shape
|
||||
if not getattr(bn_necks, "parallel_BN1d", True):
|
||||
raise ValueError("Adapted CAM currently expects parallel_BN1d=True.")
|
||||
bn = bn_necks.bn1d
|
||||
flattened = embeddings.reshape(batch, -1)
|
||||
normalized = torch.nn.functional.batch_norm(
|
||||
flattened,
|
||||
bn.running_mean,
|
||||
bn.running_var,
|
||||
bn.weight,
|
||||
bn.bias,
|
||||
training=False,
|
||||
momentum=0.0,
|
||||
eps=bn.eps,
|
||||
)
|
||||
return normalized.view(batch, channels, parts)
|
||||
|
||||
|
||||
def bn_feature_scales(
|
||||
bn_necks: nn.Module,
|
||||
channels: int,
|
||||
parts: int,
|
||||
device: torch.device,
|
||||
) -> Float[torch.Tensor, "parts channels"]:
|
||||
if not getattr(bn_necks, "parallel_BN1d", True):
|
||||
raise ValueError("Adapted CAM currently expects parallel_BN1d=True.")
|
||||
bn = bn_necks.bn1d
|
||||
running_var = bn.running_var.detach().to(device).view(channels, parts).transpose(0, 1)
|
||||
base_scale = torch.rsqrt(running_var + bn.eps)
|
||||
if bn.weight is None:
|
||||
return base_scale
|
||||
gamma = bn.weight.detach().to(device).view(channels, parts).transpose(0, 1)
|
||||
return base_scale * gamma
|
||||
|
||||
|
||||
def normalize_image(image: np.ndarray) -> FloatArray:
|
||||
image = image.astype(np.float32)
|
||||
min_value = float(image.min())
|
||||
max_value = float(image.max())
|
||||
if max_value <= min_value:
|
||||
return np.zeros_like(image, dtype=np.float32)
|
||||
return (image - min_value) / (max_value - min_value)
|
||||
|
||||
|
||||
def overlay_response(base_image: FloatArray, response_map: FloatArray) -> FloatArray:
|
||||
cmap = plt.get_cmap("jet")
|
||||
heat = cmap(response_map)[..., :3].astype(np.float32)
|
||||
base_rgb = np.repeat(base_image[..., None], 3, axis=-1)
|
||||
overlay = base_rgb * 0.45 + heat * 0.55
|
||||
return np.clip(overlay, 0.0, 1.0)
|
||||
|
||||
|
||||
def compute_occlusion_sensitivity(
|
||||
model: nn.Module,
|
||||
sequence: torch.Tensor,
|
||||
pav: torch.Tensor | None,
|
||||
patch_size: int,
|
||||
stride: int,
|
||||
) -> tuple[FloatArray, tuple[float, float, float], str]:
|
||||
with torch.no_grad():
|
||||
baseline = forward_diagnostics(model, sequence, pav)
|
||||
class_scores = baseline.class_scores[0]
|
||||
pred_index = int(class_scores.argmax().detach().cpu())
|
||||
baseline_score = float(class_scores[pred_index].detach().cpu())
|
||||
score_tuple = tuple(float(value.detach().cpu()) for value in class_scores)
|
||||
|
||||
_, _, _, height, width = sequence.shape
|
||||
sensitivity = np.zeros((height, width), dtype=np.float32)
|
||||
counts = np.zeros((height, width), dtype=np.float32)
|
||||
row_starts = list(range(0, max(height - patch_size + 1, 1), stride))
|
||||
col_starts = list(range(0, max(width - patch_size + 1, 1), stride))
|
||||
last_row = max(height - patch_size, 0)
|
||||
last_col = max(width - patch_size, 0)
|
||||
if row_starts[-1] != last_row:
|
||||
row_starts.append(last_row)
|
||||
if col_starts[-1] != last_col:
|
||||
col_starts.append(last_col)
|
||||
|
||||
for top in row_starts:
|
||||
for left in col_starts:
|
||||
bottom = min(top + patch_size, height)
|
||||
right = min(left + patch_size, width)
|
||||
masked = sequence.clone()
|
||||
masked[..., top:bottom, left:right] = 0.0
|
||||
masked_scores = forward_diagnostics(model, masked, pav).class_scores[0]
|
||||
score_drop = max(0.0, baseline_score - float(masked_scores[pred_index].detach().cpu()))
|
||||
sensitivity[top:bottom, left:right] += score_drop
|
||||
counts[top:bottom, left:right] += 1.0
|
||||
|
||||
counts = np.maximum(counts, 1.0)
|
||||
sensitivity /= counts
|
||||
return normalize_image(sensitivity), score_tuple, LABEL_NAMES[pred_index]
|
||||
|
||||
|
||||
def compute_part_diagnostics(
|
||||
model: nn.Module,
|
||||
sequence: torch.Tensor,
|
||||
pav: torch.Tensor | None,
|
||||
) -> tuple[FloatArray, FloatArray | None, tuple[float, float, float], str]:
|
||||
with torch.no_grad():
|
||||
diagnostics = forward_diagnostics(model, sequence, pav)
|
||||
class_scores = diagnostics.class_scores[0]
|
||||
pred_index = int(class_scores.argmax().detach().cpu())
|
||||
part_scores = diagnostics.logits[0, pred_index].detach().cpu().numpy().astype(np.float32)
|
||||
pga_spatial = None
|
||||
if diagnostics.pga_spatial is not None:
|
||||
pga_spatial = diagnostics.pga_spatial[0].detach().cpu().numpy().astype(np.float32)
|
||||
score_tuple = tuple(float(value.detach().cpu()) for value in class_scores)
|
||||
return part_scores, pga_spatial, score_tuple, LABEL_NAMES[pred_index]
|
||||
|
||||
|
||||
def extract_temporally_pooled_stages(
|
||||
model: nn.Module,
|
||||
sequence: torch.Tensor,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
if not hasattr(model.Backbone, "forward_block"):
|
||||
raise ValueError("Stage visualization requires SetBlockWrapper-style backbones.")
|
||||
|
||||
backbone = model.Backbone.forward_block
|
||||
batch, channels, frames, height, width = sequence.shape
|
||||
x = sequence.transpose(1, 2).reshape(-1, channels, height, width)
|
||||
|
||||
stage_outputs: dict[str, torch.Tensor] = {}
|
||||
x = backbone.conv1(x)
|
||||
x = backbone.bn1(x)
|
||||
x = backbone.relu(x)
|
||||
if getattr(backbone, "maxpool_flag", False):
|
||||
x = backbone.maxpool(x)
|
||||
stage_outputs["conv1"] = x
|
||||
|
||||
x = backbone.layer1(x)
|
||||
stage_outputs["layer1"] = x
|
||||
x = backbone.layer2(x)
|
||||
stage_outputs["layer2"] = x
|
||||
x = backbone.layer3(x)
|
||||
stage_outputs["layer3"] = x
|
||||
x = backbone.layer4(x)
|
||||
stage_outputs["layer4"] = x
|
||||
|
||||
pooled_outputs: dict[str, torch.Tensor] = {}
|
||||
for stage_name, stage_output in stage_outputs.items():
|
||||
reshaped = stage_output.reshape(batch, frames, *stage_output.shape[1:]).transpose(1, 2).contiguous()
|
||||
pooled_outputs[stage_name] = model.TP(reshaped, None, options={"dim": 2})[0]
|
||||
return pooled_outputs
|
||||
|
||||
|
||||
def render_visualization(results: list[VisualizationResult], output_path: Path) -> None:
|
||||
fig, axes = plt.subplots(len(results), 3, figsize=(10.5, 3.2 * len(results)), constrained_layout=True)
|
||||
if len(results) == 1:
|
||||
axes = np.expand_dims(axes, axis=0)
|
||||
|
||||
for row, result in enumerate(results):
|
||||
input_axis, heat_axis, overlay_axis = axes[row]
|
||||
input_axis.imshow(result.input_image, cmap="gray", vmin=0.0, vmax=1.0)
|
||||
input_axis.set_title(f"{result.title} input")
|
||||
heat_axis.imshow(result.response_map, cmap="jet", vmin=0.0, vmax=1.0)
|
||||
heat_axis.set_title("response")
|
||||
overlay_axis.imshow(result.overlay_image)
|
||||
overlay_axis.set_title(
|
||||
f"overlay | pred={result.predicted_label}\n"
|
||||
f"scores={tuple(round(score, 3) for score in result.class_scores)}"
|
||||
)
|
||||
for axis in (input_axis, heat_axis, overlay_axis):
|
||||
axis.set_xticks([])
|
||||
axis.set_yticks([])
|
||||
|
||||
fig.savefig(output_path, dpi=180)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def render_stage_visualization(results: list[StageVisualizationResult], output_path: Path) -> None:
|
||||
stage_names = ["input", "conv1", "layer1", "layer2", "layer3", "layer4"]
|
||||
fig, axes = plt.subplots(
|
||||
len(results),
|
||||
len(stage_names),
|
||||
figsize=(2.2 * len(stage_names), 3.0 * len(results)),
|
||||
constrained_layout=True,
|
||||
)
|
||||
if len(results) == 1:
|
||||
axes = np.expand_dims(axes, axis=0)
|
||||
|
||||
for row, result in enumerate(results):
|
||||
for col, stage_name in enumerate(stage_names):
|
||||
axis = axes[row, col]
|
||||
if stage_name == "input":
|
||||
axis.imshow(result.input_image, cmap="gray", vmin=0.0, vmax=1.0)
|
||||
axis.set_title(
|
||||
f"{result.title}\npred={result.predicted_label}",
|
||||
fontsize=10,
|
||||
)
|
||||
else:
|
||||
axis.imshow(result.stage_overlays[stage_name])
|
||||
axis.set_title(stage_name, fontsize=10)
|
||||
axis.set_xticks([])
|
||||
axis.set_yticks([])
|
||||
axes[row, 0].set_ylabel(
|
||||
f"scores={tuple(round(score, 3) for score in result.class_scores)}",
|
||||
fontsize=9,
|
||||
rotation=90,
|
||||
)
|
||||
|
||||
fig.savefig(output_path, dpi=180)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def render_diagnostic_visualization(results: list[ModelDiagnosticResult], output_path: Path) -> None:
|
||||
fig, axes = plt.subplots(len(results), 3, figsize=(11.5, 3.3 * len(results)), constrained_layout=True)
|
||||
if len(results) == 1:
|
||||
axes = np.expand_dims(axes, axis=0)
|
||||
|
||||
for row, result in enumerate(results):
|
||||
input_axis, occlusion_axis, part_axis = axes[row]
|
||||
input_axis.imshow(result.input_image, cmap="gray", vmin=0.0, vmax=1.0)
|
||||
input_axis.set_title(f"{result.title} input")
|
||||
input_axis.set_xticks([])
|
||||
input_axis.set_yticks([])
|
||||
|
||||
occlusion_axis.imshow(result.overlay_image)
|
||||
occlusion_axis.set_title(
|
||||
f"occlusion | pred={result.predicted_label}\n"
|
||||
f"scores={tuple(round(score, 3) for score in result.class_scores)}"
|
||||
)
|
||||
occlusion_axis.set_xticks([])
|
||||
occlusion_axis.set_yticks([])
|
||||
|
||||
parts = np.arange(result.part_scores.shape[0], dtype=np.int32)
|
||||
part_axis.bar(parts, result.part_scores, color="#4477AA", alpha=0.85, width=0.8)
|
||||
part_axis.set_title("pred-class part logits")
|
||||
part_axis.set_xlabel("HPP part")
|
||||
part_axis.set_ylabel("logit")
|
||||
part_axis.set_xticks(parts)
|
||||
if result.pga_spatial is not None:
|
||||
part_axis_2 = part_axis.twinx()
|
||||
part_axis_2.plot(parts, result.pga_spatial, color="#CC6677", marker="o", linewidth=1.5)
|
||||
part_axis_2.set_ylabel("PGA spatial")
|
||||
part_axis_2.set_ylim(0.0, max(1.0, float(result.pga_spatial.max()) * 1.05))
|
||||
|
||||
fig.savefig(output_path, dpi=180)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def default_model_specs() -> list[ModelSpec]:
|
||||
return [
|
||||
ModelSpec(
|
||||
key="silhouette",
|
||||
title="ScoNet-MT silhouette",
|
||||
config_path=REPO_ROOT / "configs/sconet/sconet_scoliosis1k_local_eval_2gpu_better_112.yaml",
|
||||
checkpoint_path=REPO_ROOT / "ckpt/ScoNet-20000-better.pt",
|
||||
data_root=Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-sil-pkl"),
|
||||
is_drf=False,
|
||||
),
|
||||
ModelSpec(
|
||||
key="skeleton",
|
||||
title="ScoNet-MT-ske",
|
||||
config_path=REPO_ROOT / "configs/sconet/sconet_scoliosis1k_skeleton_118_sigma15_joint8_2gpu_bs12x8.yaml",
|
||||
checkpoint_path=REPO_ROOT
|
||||
/ "output/Scoliosis1K/ScoNet/ScoNet_skeleton_118_sigma15_joint8_sharedalign_2gpu_bs12x8/checkpoints/ScoNet_skeleton_118_sigma15_joint8_sharedalign_2gpu_bs12x8-20000.pt",
|
||||
data_root=Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-sigma15-joint8-sharedalign"),
|
||||
is_drf=False,
|
||||
),
|
||||
ModelSpec(
|
||||
key="drf",
|
||||
title="DRF",
|
||||
config_path=REPO_ROOT / "configs/drf/drf_scoliosis1k_eval_1gpu.yaml",
|
||||
checkpoint_path=REPO_ROOT / "output/Scoliosis1K/DRF/DRF/checkpoints/DRF-20000.pt",
|
||||
data_root=Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118"),
|
||||
is_drf=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def build_stage_results(
|
||||
spec: ModelSpec,
|
||||
response_mode: ResponseMode,
|
||||
method: AttentionMethod,
|
||||
power: float,
|
||||
sequence_id: str,
|
||||
label: str,
|
||||
view: str,
|
||||
frame_count: int,
|
||||
) -> StageVisualizationResult:
|
||||
if response_mode != "activation":
|
||||
raise ValueError("Stage view currently supports activation maps only.")
|
||||
model, _cfgs = load_checkpoint_model(spec)
|
||||
sequence, pav, base_image = load_sample_inputs(spec, sequence_id, label, view, frame_count)
|
||||
_final_feature, score_tuple, predicted_label = compute_classification_scores(model, sequence, pav)
|
||||
pooled_stages = extract_temporally_pooled_stages(model, sequence)
|
||||
stage_overlays: dict[str, FloatArray] = {}
|
||||
for stage_name, stage_tensor in pooled_stages.items():
|
||||
response_map = activation_attention_map(stage_tensor[0], method=method, power=power)
|
||||
response_map = F.interpolate(
|
||||
response_map.unsqueeze(0).unsqueeze(0),
|
||||
size=sequence.shape[-2:],
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)[0, 0]
|
||||
response_np = normalize_image(response_map.detach().cpu().numpy().astype(np.float32))
|
||||
stage_overlays[stage_name] = overlay_response(base_image, response_np)
|
||||
return StageVisualizationResult(
|
||||
title=spec.title,
|
||||
predicted_label=predicted_label,
|
||||
class_scores=score_tuple,
|
||||
input_image=base_image,
|
||||
stage_overlays=stage_overlays,
|
||||
)
|
||||
|
||||
|
||||
def build_diagnostic_results(
|
||||
spec: ModelSpec,
|
||||
sequence_id: str,
|
||||
label: str,
|
||||
view: str,
|
||||
frame_count: int,
|
||||
patch_size: int,
|
||||
stride: int,
|
||||
) -> ModelDiagnosticResult:
|
||||
model, _cfgs = load_checkpoint_model(spec)
|
||||
sequence, pav, base_image = load_sample_inputs(spec, sequence_id, label, view, frame_count)
|
||||
occlusion_map, score_tuple, predicted_label = compute_occlusion_sensitivity(
|
||||
model=model,
|
||||
sequence=sequence,
|
||||
pav=pav,
|
||||
patch_size=patch_size,
|
||||
stride=stride,
|
||||
)
|
||||
part_scores, pga_spatial, _scores, _predicted = compute_part_diagnostics(model, sequence, pav)
|
||||
return ModelDiagnosticResult(
|
||||
title=spec.title,
|
||||
predicted_label=predicted_label,
|
||||
class_scores=score_tuple,
|
||||
input_image=base_image,
|
||||
occlusion_map=occlusion_map,
|
||||
overlay_image=overlay_response(base_image, occlusion_map),
|
||||
part_scores=part_scores,
|
||||
pga_spatial=pga_spatial,
|
||||
)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--sequence-id", default="00490", show_default=True)
|
||||
@click.option("--label", default="positive", show_default=True)
|
||||
@click.option("--view", default="000_180", show_default=True)
|
||||
@click.option("--frame-count", default=30, show_default=True, type=int)
|
||||
@click.option(
|
||||
"--response-mode",
|
||||
type=click.Choice(["activation", "cam"]),
|
||||
default="activation",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option(
|
||||
"--method",
|
||||
type=click.Choice(["sum", "sum_p", "max_p"]),
|
||||
default="sum",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option("--power", default=2.0, show_default=True, type=float)
|
||||
@click.option(
|
||||
"--view-mode",
|
||||
type=click.Choice(["final", "stages", "diagnostics"]),
|
||||
default="final",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option("--occlusion-patch-size", default=8, show_default=True, type=int)
|
||||
@click.option("--occlusion-stride", default=8, show_default=True, type=int)
|
||||
@click.option(
|
||||
"--output-path",
|
||||
default=str(REPO_ROOT / "research/feature_response_heatmaps/00490_positive_000_180_response_heatmaps.png"),
|
||||
show_default=True,
|
||||
type=click.Path(path_type=Path),
|
||||
)
|
||||
def main(
|
||||
sequence_id: str,
|
||||
label: str,
|
||||
view: str,
|
||||
frame_count: int,
|
||||
response_mode: str,
|
||||
method: str,
|
||||
power: float,
|
||||
view_mode: str,
|
||||
occlusion_patch_size: int,
|
||||
occlusion_stride: int,
|
||||
output_path: Path,
|
||||
) -> None:
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if view_mode == "stages":
|
||||
results = [
|
||||
build_stage_results(
|
||||
spec=spec,
|
||||
response_mode=response_mode,
|
||||
method=method,
|
||||
power=power,
|
||||
sequence_id=sequence_id,
|
||||
label=label,
|
||||
view=view,
|
||||
frame_count=frame_count,
|
||||
)
|
||||
for spec in default_model_specs()
|
||||
]
|
||||
render_stage_visualization(results, output_path)
|
||||
elif view_mode == "diagnostics":
|
||||
results = [
|
||||
build_diagnostic_results(
|
||||
spec=spec,
|
||||
sequence_id=sequence_id,
|
||||
label=label,
|
||||
view=view,
|
||||
frame_count=frame_count,
|
||||
patch_size=occlusion_patch_size,
|
||||
stride=occlusion_stride,
|
||||
)
|
||||
for spec in default_model_specs()
|
||||
]
|
||||
render_diagnostic_visualization(results, output_path)
|
||||
else:
|
||||
results_final: list[VisualizationResult] = []
|
||||
for spec in default_model_specs():
|
||||
model, _cfgs = load_checkpoint_model(spec)
|
||||
sequence, pav, base_image = load_sample_inputs(spec, sequence_id, label, view, frame_count)
|
||||
response_map, scores, predicted_label = compute_response_map(
|
||||
model,
|
||||
sequence,
|
||||
pav,
|
||||
response_mode=response_mode,
|
||||
method=method,
|
||||
power=power,
|
||||
)
|
||||
results_final.append(
|
||||
VisualizationResult(
|
||||
title=spec.title,
|
||||
predicted_label=predicted_label,
|
||||
class_scores=scores,
|
||||
input_image=base_image,
|
||||
response_map=response_map,
|
||||
overlay_image=overlay_response(base_image, response_map),
|
||||
)
|
||||
)
|
||||
render_visualization(results_final, output_path)
|
||||
click.echo(str(output_path))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,173 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import yaml
|
||||
|
||||
matplotlib.use("Agg")
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
if __package__ in {None, ""}:
|
||||
sys.path.append(str(REPO_ROOT))
|
||||
|
||||
from datasets import pretreatment_heatmap as prep
|
||||
|
||||
FloatArray = npt.NDArray[np.float32]
|
||||
|
||||
SAMPLE_ID = "00000"
|
||||
LABEL = "positive"
|
||||
VIEW = "000_180"
|
||||
RESEARCH_ROOT = REPO_ROOT / "research/sigma_comparison"
|
||||
POSE_PATH = Path(
|
||||
f"/mnt/public/data/Scoliosis1K/Scoliosis1K-pose-pkl/{SAMPLE_ID}/{LABEL}/{VIEW}/{VIEW}.pkl"
|
||||
)
|
||||
SIL_PATH = Path(
|
||||
f"/mnt/public/data/Scoliosis1K/Scoliosis1K-sil-pkl/{SAMPLE_ID}/{LABEL}/{VIEW}/{VIEW}.pkl"
|
||||
)
|
||||
CFG_PATH = REPO_ROOT / "configs/drf/pretreatment_heatmap_drf_sigma15_joint8.yaml"
|
||||
|
||||
LIMB_RGB = np.asarray([0.18, 0.85, 0.96], dtype=np.float32)
|
||||
JOINT_RGB = np.asarray([1.00, 0.45, 0.12], dtype=np.float32)
|
||||
SIL_RGB = np.asarray([0.85, 0.97, 0.78], dtype=np.float32)
|
||||
LIMB_ALPHA = 0.8
|
||||
JOINT_ALPHA = 0.8
|
||||
SIL_ALPHA = 0.3
|
||||
|
||||
|
||||
def _normalize(array: npt.NDArray[np.generic]) -> FloatArray:
|
||||
image = np.asarray(array, dtype=np.float32)
|
||||
if image.max() > 1.0:
|
||||
image = image / 255.0
|
||||
return image
|
||||
|
||||
|
||||
def _load_pickle(path: Path) -> FloatArray:
|
||||
with path.open("rb") as handle:
|
||||
return _normalize(pickle.load(handle))
|
||||
|
||||
|
||||
def _load_mixed_sharedalign() -> FloatArray:
|
||||
with CFG_PATH.open("r", encoding="utf-8") as handle:
|
||||
cfg = yaml.safe_load(handle)
|
||||
cfg = prep.replace_variables(cfg, cfg)
|
||||
transform = prep.GenerateHeatmapTransform(
|
||||
coco18tococo17_args=cfg["coco18tococo17_args"],
|
||||
padkeypoints_args=cfg["padkeypoints_args"],
|
||||
norm_args=cfg["norm_args"],
|
||||
heatmap_generator_args=cfg["heatmap_generator_args"],
|
||||
align_args=cfg["align_args"],
|
||||
reduction="upstream",
|
||||
sigma_limb=float(cfg["sigma_limb"]),
|
||||
sigma_joint=float(cfg["sigma_joint"]),
|
||||
)
|
||||
with POSE_PATH.open("rb") as handle:
|
||||
pose = np.asarray(pickle.load(handle), dtype=np.float32)
|
||||
return _normalize(transform(pose))
|
||||
|
||||
|
||||
def _overlay_frame(silhouette: FloatArray, limb: FloatArray, joint: FloatArray) -> FloatArray:
|
||||
canvas = np.zeros((silhouette.shape[0], silhouette.shape[1], 4), dtype=np.float32)
|
||||
|
||||
sil_mask = np.clip(silhouette, 0.0, 1.0)
|
||||
limb_mask = np.clip(limb, 0.0, 1.0)
|
||||
joint_mask = np.clip(joint, 0.0, 1.0)
|
||||
|
||||
canvas[..., :3] += sil_mask[..., None] * SIL_RGB[None, None, :] * SIL_ALPHA
|
||||
canvas[..., 3] = np.maximum(canvas[..., 3], sil_mask * SIL_ALPHA)
|
||||
|
||||
canvas[..., :3] += limb_mask[..., None] * LIMB_RGB[None, None, :] * LIMB_ALPHA
|
||||
canvas[..., 3] = np.maximum(canvas[..., 3], limb_mask * LIMB_ALPHA)
|
||||
|
||||
canvas[..., :3] += joint_mask[..., None] * JOINT_RGB[None, None, :] * JOINT_ALPHA
|
||||
canvas[..., 3] = np.maximum(canvas[..., 3], joint_mask * JOINT_ALPHA)
|
||||
|
||||
canvas[..., :3] = np.clip(canvas[..., :3], 0.0, 1.0)
|
||||
canvas[..., 3] = np.clip(canvas[..., 3], 0.0, 1.0)
|
||||
return canvas
|
||||
|
||||
|
||||
def _bbox(mask: FloatArray, threshold: float = 0.05) -> tuple[int, int, int, int] | None:
|
||||
ys, xs = np.where(mask > threshold)
|
||||
if ys.size == 0:
|
||||
return None
|
||||
return int(ys.min()), int(ys.max()), int(xs.min()), int(xs.max())
|
||||
|
||||
|
||||
def _align_silhouette_to_heatmap(silhouette: FloatArray, heatmap: FloatArray) -> FloatArray:
|
||||
sil_box = _bbox(silhouette)
|
||||
heat_box = _bbox(heatmap)
|
||||
if sil_box is None or heat_box is None:
|
||||
return silhouette
|
||||
|
||||
sy0, sy1, sx0, sx1 = sil_box
|
||||
hy0, hy1, hx0, hx1 = heat_box
|
||||
sil_crop = silhouette[sy0 : sy1 + 1, sx0 : sx1 + 1]
|
||||
sil_h = max(sy1 - sy0 + 1, 1)
|
||||
sil_w = max(sx1 - sx0 + 1, 1)
|
||||
target_h = max(hy1 - hy0 + 1, 1)
|
||||
scale = target_h / sil_h
|
||||
target_w = max(int(round(sil_w * scale)), 1)
|
||||
resized = cv2.resize(sil_crop, (target_w, target_h), interpolation=cv2.INTER_AREA)
|
||||
canvas = np.zeros_like(silhouette)
|
||||
heat_center_x = (hx0 + hx1) / 2.0
|
||||
x0 = int(round(heat_center_x - target_w / 2.0))
|
||||
x1 = x0 + target_w
|
||||
|
||||
src_x0 = max(0, -x0)
|
||||
src_x1 = target_w - max(0, x1 - silhouette.shape[1])
|
||||
dst_x0 = max(0, x0)
|
||||
dst_x1 = min(silhouette.shape[1], x1)
|
||||
|
||||
if src_x0 < src_x1 and dst_x0 < dst_x1:
|
||||
canvas[hy0 : hy1 + 1, dst_x0:dst_x1] = resized[:, src_x0:src_x1]
|
||||
return canvas
|
||||
|
||||
|
||||
def main() -> None:
|
||||
RESEARCH_ROOT.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
silhouette = _load_pickle(SIL_PATH)
|
||||
mixed = _load_mixed_sharedalign()
|
||||
frames = np.linspace(0, silhouette.shape[0] - 1, num=6, dtype=int)
|
||||
|
||||
raw_path = RESEARCH_ROOT / f"{SAMPLE_ID}_{LABEL}_{VIEW}_overlay_sharedalign_preview.png"
|
||||
aligned_path = RESEARCH_ROOT / f"{SAMPLE_ID}_{LABEL}_{VIEW}_overlay_sharedalign_aligned_preview.png"
|
||||
|
||||
for output_path, align_silhouette in ((raw_path, False), (aligned_path, True)):
|
||||
fig, axes = plt.subplots(1, len(frames), figsize=(2.6 * len(frames), 3.6), dpi=220)
|
||||
for axis, frame_idx in zip(axes, frames, strict=True):
|
||||
limb = mixed[frame_idx, 0]
|
||||
joint = mixed[frame_idx, 1]
|
||||
sil_frame = silhouette[frame_idx]
|
||||
if align_silhouette:
|
||||
sil_frame = _align_silhouette_to_heatmap(sil_frame, np.maximum(limb, joint))
|
||||
overlay = _overlay_frame(sil_frame, limb, joint)
|
||||
axis.set_facecolor("black")
|
||||
axis.imshow(overlay)
|
||||
axis.set_xticks([])
|
||||
axis.set_yticks([])
|
||||
axis.set_title(f"Frame {frame_idx}", fontsize=9)
|
||||
|
||||
mode = "height+center aligned silhouette" if align_silhouette else "raw silhouette"
|
||||
fig.suptitle(
|
||||
(
|
||||
f"{SAMPLE_ID}/{LABEL}/{VIEW}: {mode}, "
|
||||
f"silhouette ({SIL_ALPHA}), limb ({LIMB_ALPHA}), joint ({JOINT_ALPHA})"
|
||||
),
|
||||
fontsize=12,
|
||||
)
|
||||
fig.tight_layout()
|
||||
fig.savefig(output_path, bbox_inches="tight", facecolor="black")
|
||||
plt.close(fig)
|
||||
print(output_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,199 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
matplotlib.use("Agg")
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
RESEARCH_ROOT = REPO_ROOT / "research/sigma_comparison"
|
||||
SIGMA8_ROOT = Path("/mnt/public/data/Scoliosis1K/Scoliosis1K_sigma_8.0/pkl")
|
||||
SIGMA15_ROOT = Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-sigma15")
|
||||
MIXED_ROOT = Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-sigma15-joint8-fixed")
|
||||
SIL_ROOT = Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-sil-pkl")
|
||||
|
||||
FloatArray = npt.NDArray[np.float32]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SampleSpec:
|
||||
sequence_id: str
|
||||
label: str
|
||||
view: str
|
||||
|
||||
|
||||
def _normalize_image(array: npt.NDArray[np.generic]) -> FloatArray:
|
||||
image = np.asarray(array, dtype=np.float32)
|
||||
if image.max() > 1.0:
|
||||
image = image / 255.0
|
||||
return image
|
||||
|
||||
|
||||
def _load_pickle(path: Path) -> FloatArray:
|
||||
with path.open("rb") as handle:
|
||||
return _normalize_image(pickle.load(handle))
|
||||
|
||||
|
||||
def _load_sigma8(spec: SampleSpec) -> FloatArray:
|
||||
return _load_pickle(SIGMA8_ROOT / spec.sequence_id / spec.label / spec.view / f"{spec.view}.pkl")
|
||||
|
||||
|
||||
def _load_sigma15(spec: SampleSpec) -> FloatArray:
|
||||
return _load_pickle(SIGMA15_ROOT / spec.sequence_id / spec.label / spec.view / "0_heatmap.pkl")
|
||||
|
||||
|
||||
def _load_silhouette(spec: SampleSpec) -> FloatArray:
|
||||
return _load_pickle(SIL_ROOT / spec.sequence_id / spec.label / spec.view / f"{spec.view}.pkl")
|
||||
|
||||
|
||||
def _load_mixed(spec: SampleSpec) -> FloatArray:
|
||||
return _load_pickle(MIXED_ROOT / spec.sequence_id / spec.label / spec.view / "0_heatmap.pkl")
|
||||
|
||||
|
||||
def _frame_indices(num_frames: int, count: int = 6) -> npt.NDArray[np.int_]:
|
||||
return np.linspace(0, num_frames - 1, num=count, dtype=int)
|
||||
|
||||
|
||||
def _combined_map(heatmaps: FloatArray) -> FloatArray:
|
||||
return np.maximum(heatmaps[:, 0], heatmaps[:, 1])
|
||||
|
||||
|
||||
def _save_combined_figure(
|
||||
spec: SampleSpec,
|
||||
silhouette: FloatArray,
|
||||
sigma8: FloatArray,
|
||||
sigma15: FloatArray,
|
||||
mixed: FloatArray,
|
||||
) -> Path:
|
||||
frames = _frame_indices(silhouette.shape[0], count=6)
|
||||
fig, axes = plt.subplots(4, len(frames), figsize=(2.4 * len(frames), 8.8), dpi=200)
|
||||
row_titles = (
|
||||
"Silhouette",
|
||||
"Sigma 8 Combined",
|
||||
"Sigma 1.5 Combined",
|
||||
"Limb 1.5 / Joint 8 Combined",
|
||||
)
|
||||
images = (silhouette, _combined_map(sigma8), _combined_map(sigma15), _combined_map(mixed))
|
||||
|
||||
for row_idx, (title, row_images) in enumerate(zip(row_titles, images, strict=True)):
|
||||
for col_idx, frame_idx in enumerate(frames):
|
||||
ax = axes[row_idx, col_idx]
|
||||
ax.imshow(row_images[frame_idx], cmap="gray", vmin=0.0, vmax=1.0)
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
if row_idx == 0:
|
||||
ax.set_title(f"Frame {frame_idx}", fontsize=9)
|
||||
if col_idx == 0:
|
||||
ax.set_ylabel(title, fontsize=10)
|
||||
|
||||
fig.suptitle(
|
||||
f"Sample {spec.sequence_id}/{spec.label}/{spec.view}: silhouette vs sigma heatmaps",
|
||||
fontsize=12,
|
||||
)
|
||||
fig.tight_layout()
|
||||
output_path = RESEARCH_ROOT / f"{spec.sequence_id}_{spec.label}_{spec.view}_combined.png"
|
||||
fig.savefig(output_path, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
return output_path
|
||||
|
||||
|
||||
def _save_channel_figure(spec: SampleSpec, sigma8: FloatArray, sigma15: FloatArray) -> Path:
|
||||
frames = _frame_indices(sigma8.shape[0], count=4)
|
||||
fig, axes = plt.subplots(4, len(frames), figsize=(2.6 * len(frames), 8.4), dpi=200)
|
||||
row_titles = (
|
||||
"Sigma 8 Limb",
|
||||
"Sigma 8 Joint",
|
||||
"Sigma 1.5 Limb",
|
||||
"Sigma 1.5 Joint",
|
||||
)
|
||||
images = (sigma8[:, 0], sigma8[:, 1], sigma15[:, 0], sigma15[:, 1])
|
||||
|
||||
for row_idx, (title, row_images) in enumerate(zip(row_titles, images, strict=True)):
|
||||
vmax = float(np.max(row_images)) if np.max(row_images) > 0 else 1.0
|
||||
for col_idx, frame_idx in enumerate(frames):
|
||||
ax = axes[row_idx, col_idx]
|
||||
ax.imshow(row_images[frame_idx], cmap="magma", vmin=0.0, vmax=vmax)
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
if row_idx == 0:
|
||||
ax.set_title(f"Frame {frame_idx}", fontsize=9)
|
||||
if col_idx == 0:
|
||||
ax.set_ylabel(title, fontsize=10)
|
||||
|
||||
fig.suptitle(
|
||||
f"Sample {spec.sequence_id}/{spec.label}/{spec.view}: per-channel sigma comparison",
|
||||
fontsize=12,
|
||||
)
|
||||
fig.tight_layout()
|
||||
output_path = RESEARCH_ROOT / f"{spec.sequence_id}_{spec.label}_{spec.view}_channels.png"
|
||||
fig.savefig(output_path, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
return output_path
|
||||
|
||||
|
||||
def _save_mixed_channel_figure(
|
||||
spec: SampleSpec,
|
||||
sigma8: FloatArray,
|
||||
sigma15: FloatArray,
|
||||
mixed: FloatArray,
|
||||
) -> Path:
|
||||
frames = _frame_indices(sigma8.shape[0], count=4)
|
||||
fig, axes = plt.subplots(6, len(frames), figsize=(2.6 * len(frames), 12.0), dpi=200)
|
||||
row_titles = (
|
||||
"Sigma 8 Limb",
|
||||
"Sigma 8 Joint",
|
||||
"Sigma 1.5 Limb",
|
||||
"Sigma 1.5 Joint",
|
||||
"Mixed Limb 1.5",
|
||||
"Mixed Joint 8",
|
||||
)
|
||||
images = (sigma8[:, 0], sigma8[:, 1], sigma15[:, 0], sigma15[:, 1], mixed[:, 0], mixed[:, 1])
|
||||
|
||||
for row_idx, (title, row_images) in enumerate(zip(row_titles, images, strict=True)):
|
||||
vmax = float(np.max(row_images)) if np.max(row_images) > 0 else 1.0
|
||||
for col_idx, frame_idx in enumerate(frames):
|
||||
ax = axes[row_idx, col_idx]
|
||||
ax.imshow(row_images[frame_idx], cmap="magma", vmin=0.0, vmax=vmax)
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
if row_idx == 0:
|
||||
ax.set_title(f"Frame {frame_idx}", fontsize=9)
|
||||
if col_idx == 0:
|
||||
ax.set_ylabel(title, fontsize=10)
|
||||
|
||||
fig.suptitle(
|
||||
f"Sample {spec.sequence_id}/{spec.label}/{spec.view}: sigma and mixed per-channel comparison",
|
||||
fontsize=12,
|
||||
)
|
||||
fig.tight_layout()
|
||||
output_path = RESEARCH_ROOT / f"{spec.sequence_id}_{spec.label}_{spec.view}_channels_mixed.png"
|
||||
fig.savefig(output_path, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
return output_path
|
||||
|
||||
|
||||
def main() -> None:
|
||||
RESEARCH_ROOT.mkdir(parents=True, exist_ok=True)
|
||||
sample = SampleSpec(sequence_id="00000", label="positive", view="000_180")
|
||||
sigma8 = _load_sigma8(sample)
|
||||
sigma15 = _load_sigma15(sample)
|
||||
mixed = _load_mixed(sample)
|
||||
silhouette = _load_silhouette(sample)
|
||||
|
||||
combined_path = _save_combined_figure(sample, silhouette, sigma8, sigma15, mixed)
|
||||
channels_path = _save_channel_figure(sample, sigma8, sigma15)
|
||||
mixed_channels_path = _save_mixed_channel_figure(sample, sigma8, sigma15, mixed)
|
||||
|
||||
print(combined_path)
|
||||
print(channels_path)
|
||||
print(mixed_channels_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user