Add scoliosis diagnostics and experiment logging

This commit is contained in:
2026-03-10 00:32:39 +08:00
parent 5cf628669e
commit 24381551f4
4 changed files with 1154 additions and 2 deletions
+4 -2
View File
@@ -26,8 +26,9 @@ Use it for:
| 2026-03-08 | `ScoNet_skeleton_118_sigma15_bs12x8` | ScoNet-MT-ske control | `Scoliosis1K-drf-pkl-118-sigma15` | Lowered skeleton-map sigma from `8.0` to `1.5` to tighten the pose rasterization | complete | `46.33 Acc / 68.09 Prec / 51.92 Rec / 44.69 F1` |
| 2026-03-09 | `ScoNet_skeleton_118_sigma15_joint8_sharedalign_2gpu_bs12x8` | ScoNet-MT-ske control | `Scoliosis1K-drf-pkl-118-sigma15-joint8-sharedalign` | Fixed limb/joint channel misalignment, used mixed sigma `limb=1.5 / joint=8.0`, kept SGD | complete | `50.47 Acc / 69.31 Prec / 54.58 Rec / 48.63 F1` |
| 2026-03-09 | `ScoNet_skeleton_118_sigma15_joint8_limb4_adamw_2gpu_bs12x8` | ScoNet-MT-ske control | `Scoliosis1K-drf-pkl-118-sigma15-joint8-sharedalign-limb4` | Rebalanced channel intensity with `limb_gain=4.0`; switched optimizer from `SGD` to `AdamW` | complete | `48.60 Acc / 65.97 Prec / 53.19 Rec / 46.41 F1` |
| 2026-03-09 | `ScoNet_skeleton_118_sigma15_joint8_sharedalign_nocut_adamw_1gpu_bs8x8` | ScoNet-MT-ske control | `Scoliosis1K-drf-pkl-118-sigma15-joint8-sharedalign` | Switched runtime transform from `BaseSilCuttingTransform` to `BaseSilTransform` (`no-cut`), kept `AdamW`, reduced `8x8` due to 5070 Ti OOM at `12x8` | training | no eval yet |
| 2026-03-09 | `ScoNet_skeleton_118_sigma15_joint8_sharedalign_nocut_adamw_proxy_1gpu` | ScoNet-MT-ske proxy | `Scoliosis1K-drf-pkl-118-sigma15-joint8-sharedalign` | Fast proxy route: `no-cut`, `AdamW`, `8x8`, `total_iter=2000`, `eval_iter=500`, `test_seq_subset_size=128` | training | no eval yet |
| 2026-03-09 | `ScoNet_skeleton_118_sigma15_joint8_sharedalign_nocut_adamw_1gpu_bs8x8` | ScoNet-MT-ske control | `Scoliosis1K-drf-pkl-118-sigma15-joint8-sharedalign` | Switched runtime transform from `BaseSilCuttingTransform` to `BaseSilTransform` (`no-cut`), kept `AdamW`, reduced `8x8` due to 5070 Ti OOM at `12x8` | interrupted | superseded by proxy route before eval |
| 2026-03-09 | `ScoNet_skeleton_118_sigma15_joint8_sharedalign_nocut_adamw_proxy_1gpu` | ScoNet-MT-ske proxy | `Scoliosis1K-drf-pkl-118-sigma15-joint8-sharedalign` | Fast proxy route: `no-cut`, `AdamW`, `8x8`, `total_iter=2000`, `eval_iter=500`, `test_seq_subset_size=128` | interrupted | superseded by geometry-fixed proxy before completion |
| 2026-03-10 | `ScoNet_skeleton_118_sigma15_joint8_geomfix_proxy_1gpu` | ScoNet-MT-ske proxy | `Scoliosis1K-drf-pkl-118-sigma15-joint8-geomfix` | Geometry ablation: aspect-ratio-preserving crop+pad instead of square-warp resize; `AdamW`, `no-cut`, `8x8`, `total_iter=2000`, `eval_iter=500`, fixed test subset seed `118` | complete | proxy subset unstable: `500 24.22/8.07/33.33/13.00`, `1000 60.16/68.05/58.13/55.25`, `1500 26.56/58.33/35.64/17.68`, `2000 27.34/63.96/37.02/20.14` (Acc/Prec/Rec/F1) |
## Current best skeleton baseline
@@ -40,3 +41,4 @@ Current best `ScoNet-MT-ske`-style result:
- `ckpt/ScoNet-20000-better.pt` is intentionally not listed here because it is a silhouette checkpoint, not a skeleton-map run.
- `DRF` runs are included because they are part of the same reproduction/debugging loop, but this log should stay focused on train/eval changes, not broader code refactors.
- The long `ScoNet_skeleton_118_sigma15_joint8_sharedalign_nocut_adamw_1gpu_bs8x8` run was intentionally interrupted and superseded by the shorter proxy run once fast-iteration support was added.
- The geometry-fixed proxy run fit the train split quickly but did not produce a stable proxy validation curve, so it should not be promoted to a full 20k run.
@@ -0,0 +1,778 @@
from __future__ import annotations
import pickle
import sys
import types
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal
import click
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from jaxtyping import Float
from numpy.typing import NDArray
REPO_ROOT = Path(__file__).resolve().parents[1]
OPENGAIT_ROOT = REPO_ROOT / "opengait"
if str(OPENGAIT_ROOT) not in sys.path:
sys.path.insert(0, str(OPENGAIT_ROOT))
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from opengait.data.transform import get_transform
from opengait.modeling import models
from opengait.modeling.base_model import BaseModel
from opengait.modeling.models.drf import DRF
from opengait.utils.common import config_loader
FloatArray = NDArray[np.float32]
LABEL_NAMES = ("negative", "neutral", "positive")
@dataclass(frozen=True)
class ModelSpec:
key: str
title: str
config_path: Path
checkpoint_path: Path
data_root: Path
is_drf: bool
@dataclass(frozen=True)
class VisualizationResult:
title: str
predicted_label: str
class_scores: tuple[float, float, float]
input_image: FloatArray
response_map: FloatArray
overlay_image: FloatArray
@dataclass(frozen=True)
class StageVisualizationResult:
title: str
predicted_label: str
class_scores: tuple[float, float, float]
input_image: FloatArray
stage_overlays: dict[str, FloatArray]
@dataclass(frozen=True)
class ForwardDiagnostics:
feature_map: Float[torch.Tensor, "batch channels height width"]
pooled: Float[torch.Tensor, "batch channels parts"]
embeddings: Float[torch.Tensor, "batch channels parts"]
logits: Float[torch.Tensor, "batch classes parts"]
class_scores: Float[torch.Tensor, "batch classes"]
pga_spatial: Float[torch.Tensor, "batch parts"] | None
@dataclass(frozen=True)
class ModelDiagnosticResult:
title: str
predicted_label: str
class_scores: tuple[float, float, float]
input_image: FloatArray
occlusion_map: FloatArray
overlay_image: FloatArray
part_scores: FloatArray
pga_spatial: FloatArray | None
def apply_transforms(
sequence: np.ndarray,
transform_cfg: list[dict[str, Any]] | dict[str, Any] | None,
) -> np.ndarray:
transform = get_transform(transform_cfg)
if isinstance(transform, list):
result = sequence
for item in transform:
result = item(result)
return np.asarray(result)
return np.asarray(transform(sequence))
def build_minimal_model(cfgs: dict[str, Any]) -> nn.Module:
model_name = str(cfgs["model_cfg"]["model"])
model_cls = getattr(models, model_name)
model = model_cls.__new__(model_cls)
nn.Module.__init__(model)
model.get_backbone = types.MethodType(BaseModel.get_backbone, model)
model.build_network(cfgs["model_cfg"])
return model
def load_checkpoint_model(spec: ModelSpec) -> tuple[nn.Module, dict[str, Any]]:
cfgs = config_loader(str(spec.config_path))
model = build_minimal_model(cfgs)
checkpoint = torch.load(
spec.checkpoint_path,
map_location="cpu",
weights_only=False,
)
model.load_state_dict(checkpoint["model"], strict=False)
model.eval()
return model, cfgs
def load_pickle(path: Path) -> np.ndarray:
with path.open("rb") as handle:
return np.asarray(pickle.load(handle))
def load_sample_inputs(
spec: ModelSpec,
sequence_id: str,
label: str,
view: str,
frame_count: int,
) -> tuple[torch.Tensor, torch.Tensor | None, FloatArray]:
if spec.is_drf:
heatmap_path = spec.data_root / sequence_id / label / view / "0_heatmap.pkl"
pav_path = spec.data_root / sequence_id / label / view / "1_pav.pkl"
heatmaps = load_pickle(heatmap_path).astype(np.float32)
pav = load_pickle(pav_path).astype(np.float32)
return prepare_heatmap_inputs(heatmaps, pav, spec.config_path, frame_count)
if "skeleton" in spec.key:
heatmap_path = spec.data_root / sequence_id / label / view / "0_heatmap.pkl"
heatmaps = load_pickle(heatmap_path).astype(np.float32)
return prepare_heatmap_inputs(heatmaps, None, spec.config_path, frame_count)
silhouette_path = spec.data_root / sequence_id / label / view / f"{view}.pkl"
silhouettes = load_pickle(silhouette_path).astype(np.float32)
return prepare_silhouette_inputs(silhouettes, spec.config_path, frame_count)
def prepare_silhouette_inputs(
silhouettes: FloatArray,
config_path: Path,
frame_count: int,
) -> tuple[torch.Tensor, None, FloatArray]:
cfgs = config_loader(str(config_path))
transformed = apply_transforms(silhouettes, cfgs["evaluator_cfg"]["transform"]).astype(np.float32)
indices = np.linspace(0, transformed.shape[0] - 1, frame_count, dtype=int)
sampled = transformed[indices]
input_tensor = torch.from_numpy(sampled).unsqueeze(0).unsqueeze(0).float()
projection = normalize_image(sampled.max(axis=0))
return input_tensor, None, projection
def prepare_heatmap_inputs(
heatmaps: FloatArray,
pav: FloatArray | None,
config_path: Path,
frame_count: int,
) -> tuple[torch.Tensor, torch.Tensor | None, FloatArray]:
cfgs = config_loader(str(config_path))
transform_cfg = cfgs["evaluator_cfg"]["transform"]
heatmap_transform = transform_cfg[0] if isinstance(transform_cfg, list) else transform_cfg
transformed = apply_transforms(heatmaps, heatmap_transform).astype(np.float32)
indices = np.linspace(0, transformed.shape[0] - 1, frame_count, dtype=int)
sampled = transformed[indices]
input_tensor = torch.from_numpy(sampled).permute(1, 0, 2, 3).unsqueeze(0).float()
combined = np.maximum(sampled.max(axis=0)[0], sampled.max(axis=0)[1])
projection = normalize_image(combined)
pav_tensor = None
if pav is not None:
pav_tensor = torch.from_numpy(pav.mean(axis=0, keepdims=True)).float()
return input_tensor, pav_tensor, projection
AttentionMethod = Literal["sum", "sum_p", "max_p"]
ResponseMode = Literal["activation", "cam"]
def compute_classification_scores(
model: nn.Module,
sequence: torch.Tensor,
pav: torch.Tensor | None,
) -> tuple[FloatArray, tuple[float, float, float], str]:
diagnostics = forward_diagnostics(model, sequence, pav)
class_scores = diagnostics.class_scores[0]
pred_index = int(class_scores.argmax().detach().cpu())
score_tuple = tuple(float(value.detach().cpu()) for value in class_scores)
return diagnostics.feature_map.detach()[0], score_tuple, LABEL_NAMES[pred_index]
def forward_diagnostics(
model: nn.Module,
sequence: torch.Tensor,
pav: torch.Tensor | None,
) -> ForwardDiagnostics:
outputs = model.Backbone(sequence)
feature_map = model.TP(outputs, None, options={"dim": 2})[0]
pooled = model.HPP(feature_map)
embeddings = model.FCs(pooled)
pga_spatial: Float[torch.Tensor, "batch parts"] | None = None
if pav is not None and isinstance(model, DRF):
pav_flat = pav.flatten(1)
channel_att = model.PGA.channel_att(pav_flat).unsqueeze(-1)
spatial_att = model.PGA.spatial_att(pav_flat.unsqueeze(-1)).transpose(1, 2)
pga_spatial = spatial_att.squeeze(1)
embeddings = embeddings * channel_att * spatial_att
_bn_embeddings, logits = model.BNNecks(embeddings)
class_scores = logits.mean(-1)
return ForwardDiagnostics(
feature_map=feature_map,
pooled=pooled,
embeddings=embeddings,
logits=logits,
class_scores=class_scores,
pga_spatial=pga_spatial,
)
def compute_response_map(
model: nn.Module,
sequence: torch.Tensor,
pav: torch.Tensor | None,
response_mode: ResponseMode,
method: AttentionMethod,
power: float,
) -> tuple[FloatArray, tuple[float, float, float], str]:
activation, score_tuple, predicted_label = compute_classification_scores(model, sequence, pav)
if response_mode == "cam":
response_map = adapted_cam_map(
model=model,
sequence=sequence,
pav=pav,
predicted_label=predicted_label,
)
else:
response_map = activation_attention_map(activation, method=method, power=power)
response_map = response_map.unsqueeze(0).unsqueeze(0)
response_map = F.interpolate(
response_map,
size=sequence.shape[-2:],
mode="bilinear",
align_corners=False,
)[0, 0]
response_np = response_map.detach().cpu().numpy().astype(np.float32)
response_np = normalize_image(response_np)
return response_np, score_tuple, predicted_label
def activation_attention_map(
activation: Float[torch.Tensor, "channels height width"],
method: AttentionMethod,
power: float,
) -> Float[torch.Tensor, "height width"]:
abs_activation = activation.abs()
if method == "sum":
return abs_activation.sum(dim=0)
powered = abs_activation.pow(power)
if method == "sum_p":
return powered.sum(dim=0)
if method == "max_p":
return powered.max(dim=0).values
raise ValueError(f"Unsupported attention method: {method}")
def adapted_cam_map(
model: nn.Module,
sequence: torch.Tensor,
pav: torch.Tensor | None,
predicted_label: str,
) -> Float[torch.Tensor, "height width"]:
outputs = model.Backbone(sequence)
feature_map = model.TP(outputs, None, options={"dim": 2})[0] # [1, c_in, h, w]
pooled = model.HPP(feature_map) # [1, c_in, p]
pre_attention = model.FCs(pooled) # [1, c_out, p]
attention = torch.ones_like(pre_attention)
if pav is not None and isinstance(model, DRF):
pav_flat = pav.flatten(1)
channel_att = model.PGA.channel_att(pav_flat).unsqueeze(-1)
spatial_att = model.PGA.spatial_att(pav_flat.unsqueeze(-1)).transpose(1, 2)
attention = channel_att * spatial_att
embeddings = pre_attention * attention
bn_scales = bn_feature_scales(model.BNNecks, embeddings.shape[1], embeddings.shape[2], embeddings.device)
classifier_weights = F.normalize(model.BNNecks.fc_bin.detach(), dim=1)
class_index = LABEL_NAMES.index(predicted_label)
class_weights = classifier_weights[:, :, class_index] # [p, c_out]
normalized_embeddings = batchnorm_forward_eval(model.BNNecks, embeddings)[0].transpose(0, 1) # [p, c_out]
part_norms = normalized_embeddings.norm(dim=-1, keepdim=True).clamp_min(1e-6)
effective_post_bn = class_weights / part_norms
effective_pre_bn = effective_post_bn * bn_scales
fc_weights = model.FCs.fc_bin.detach() # [p, c_in, c_out]
attention_weights = attention.detach()[0].transpose(0, 1) # [p, c_out]
effective_channel_weights = torch.einsum(
"pkj,pj->pk",
fc_weights,
attention_weights * effective_pre_bn,
) # [p, c_in]
feature = feature_map.detach()[0] # [c_in, h, w]
parts_num = effective_channel_weights.shape[0]
height = feature.shape[1]
split_size = max(height // parts_num, 1)
cam = torch.zeros((height, feature.shape[2]), device=feature.device, dtype=feature.dtype)
for part_index in range(parts_num):
row_start = part_index * split_size
row_end = height if part_index == parts_num - 1 else min((part_index + 1) * split_size, height)
if row_start >= height:
break
contribution = (feature[:, row_start:row_end, :] * effective_channel_weights[part_index][:, None, None]).sum(dim=0)
cam[row_start:row_end, :] = contribution
return F.relu(cam)
def batchnorm_forward_eval(
bn_necks: nn.Module,
embeddings: Float[torch.Tensor, "batch channels parts"],
) -> Float[torch.Tensor, "batch channels parts"]:
batch, channels, parts = embeddings.shape
if not getattr(bn_necks, "parallel_BN1d", True):
raise ValueError("Adapted CAM currently expects parallel_BN1d=True.")
bn = bn_necks.bn1d
flattened = embeddings.reshape(batch, -1)
normalized = torch.nn.functional.batch_norm(
flattened,
bn.running_mean,
bn.running_var,
bn.weight,
bn.bias,
training=False,
momentum=0.0,
eps=bn.eps,
)
return normalized.view(batch, channels, parts)
def bn_feature_scales(
bn_necks: nn.Module,
channels: int,
parts: int,
device: torch.device,
) -> Float[torch.Tensor, "parts channels"]:
if not getattr(bn_necks, "parallel_BN1d", True):
raise ValueError("Adapted CAM currently expects parallel_BN1d=True.")
bn = bn_necks.bn1d
running_var = bn.running_var.detach().to(device).view(channels, parts).transpose(0, 1)
base_scale = torch.rsqrt(running_var + bn.eps)
if bn.weight is None:
return base_scale
gamma = bn.weight.detach().to(device).view(channels, parts).transpose(0, 1)
return base_scale * gamma
def normalize_image(image: np.ndarray) -> FloatArray:
image = image.astype(np.float32)
min_value = float(image.min())
max_value = float(image.max())
if max_value <= min_value:
return np.zeros_like(image, dtype=np.float32)
return (image - min_value) / (max_value - min_value)
def overlay_response(base_image: FloatArray, response_map: FloatArray) -> FloatArray:
cmap = plt.get_cmap("jet")
heat = cmap(response_map)[..., :3].astype(np.float32)
base_rgb = np.repeat(base_image[..., None], 3, axis=-1)
overlay = base_rgb * 0.45 + heat * 0.55
return np.clip(overlay, 0.0, 1.0)
def compute_occlusion_sensitivity(
model: nn.Module,
sequence: torch.Tensor,
pav: torch.Tensor | None,
patch_size: int,
stride: int,
) -> tuple[FloatArray, tuple[float, float, float], str]:
with torch.no_grad():
baseline = forward_diagnostics(model, sequence, pav)
class_scores = baseline.class_scores[0]
pred_index = int(class_scores.argmax().detach().cpu())
baseline_score = float(class_scores[pred_index].detach().cpu())
score_tuple = tuple(float(value.detach().cpu()) for value in class_scores)
_, _, _, height, width = sequence.shape
sensitivity = np.zeros((height, width), dtype=np.float32)
counts = np.zeros((height, width), dtype=np.float32)
row_starts = list(range(0, max(height - patch_size + 1, 1), stride))
col_starts = list(range(0, max(width - patch_size + 1, 1), stride))
last_row = max(height - patch_size, 0)
last_col = max(width - patch_size, 0)
if row_starts[-1] != last_row:
row_starts.append(last_row)
if col_starts[-1] != last_col:
col_starts.append(last_col)
for top in row_starts:
for left in col_starts:
bottom = min(top + patch_size, height)
right = min(left + patch_size, width)
masked = sequence.clone()
masked[..., top:bottom, left:right] = 0.0
masked_scores = forward_diagnostics(model, masked, pav).class_scores[0]
score_drop = max(0.0, baseline_score - float(masked_scores[pred_index].detach().cpu()))
sensitivity[top:bottom, left:right] += score_drop
counts[top:bottom, left:right] += 1.0
counts = np.maximum(counts, 1.0)
sensitivity /= counts
return normalize_image(sensitivity), score_tuple, LABEL_NAMES[pred_index]
def compute_part_diagnostics(
model: nn.Module,
sequence: torch.Tensor,
pav: torch.Tensor | None,
) -> tuple[FloatArray, FloatArray | None, tuple[float, float, float], str]:
with torch.no_grad():
diagnostics = forward_diagnostics(model, sequence, pav)
class_scores = diagnostics.class_scores[0]
pred_index = int(class_scores.argmax().detach().cpu())
part_scores = diagnostics.logits[0, pred_index].detach().cpu().numpy().astype(np.float32)
pga_spatial = None
if diagnostics.pga_spatial is not None:
pga_spatial = diagnostics.pga_spatial[0].detach().cpu().numpy().astype(np.float32)
score_tuple = tuple(float(value.detach().cpu()) for value in class_scores)
return part_scores, pga_spatial, score_tuple, LABEL_NAMES[pred_index]
def extract_temporally_pooled_stages(
model: nn.Module,
sequence: torch.Tensor,
) -> dict[str, torch.Tensor]:
if not hasattr(model.Backbone, "forward_block"):
raise ValueError("Stage visualization requires SetBlockWrapper-style backbones.")
backbone = model.Backbone.forward_block
batch, channels, frames, height, width = sequence.shape
x = sequence.transpose(1, 2).reshape(-1, channels, height, width)
stage_outputs: dict[str, torch.Tensor] = {}
x = backbone.conv1(x)
x = backbone.bn1(x)
x = backbone.relu(x)
if getattr(backbone, "maxpool_flag", False):
x = backbone.maxpool(x)
stage_outputs["conv1"] = x
x = backbone.layer1(x)
stage_outputs["layer1"] = x
x = backbone.layer2(x)
stage_outputs["layer2"] = x
x = backbone.layer3(x)
stage_outputs["layer3"] = x
x = backbone.layer4(x)
stage_outputs["layer4"] = x
pooled_outputs: dict[str, torch.Tensor] = {}
for stage_name, stage_output in stage_outputs.items():
reshaped = stage_output.reshape(batch, frames, *stage_output.shape[1:]).transpose(1, 2).contiguous()
pooled_outputs[stage_name] = model.TP(reshaped, None, options={"dim": 2})[0]
return pooled_outputs
def render_visualization(results: list[VisualizationResult], output_path: Path) -> None:
fig, axes = plt.subplots(len(results), 3, figsize=(10.5, 3.2 * len(results)), constrained_layout=True)
if len(results) == 1:
axes = np.expand_dims(axes, axis=0)
for row, result in enumerate(results):
input_axis, heat_axis, overlay_axis = axes[row]
input_axis.imshow(result.input_image, cmap="gray", vmin=0.0, vmax=1.0)
input_axis.set_title(f"{result.title} input")
heat_axis.imshow(result.response_map, cmap="jet", vmin=0.0, vmax=1.0)
heat_axis.set_title("response")
overlay_axis.imshow(result.overlay_image)
overlay_axis.set_title(
f"overlay | pred={result.predicted_label}\n"
f"scores={tuple(round(score, 3) for score in result.class_scores)}"
)
for axis in (input_axis, heat_axis, overlay_axis):
axis.set_xticks([])
axis.set_yticks([])
fig.savefig(output_path, dpi=180)
plt.close(fig)
def render_stage_visualization(results: list[StageVisualizationResult], output_path: Path) -> None:
stage_names = ["input", "conv1", "layer1", "layer2", "layer3", "layer4"]
fig, axes = plt.subplots(
len(results),
len(stage_names),
figsize=(2.2 * len(stage_names), 3.0 * len(results)),
constrained_layout=True,
)
if len(results) == 1:
axes = np.expand_dims(axes, axis=0)
for row, result in enumerate(results):
for col, stage_name in enumerate(stage_names):
axis = axes[row, col]
if stage_name == "input":
axis.imshow(result.input_image, cmap="gray", vmin=0.0, vmax=1.0)
axis.set_title(
f"{result.title}\npred={result.predicted_label}",
fontsize=10,
)
else:
axis.imshow(result.stage_overlays[stage_name])
axis.set_title(stage_name, fontsize=10)
axis.set_xticks([])
axis.set_yticks([])
axes[row, 0].set_ylabel(
f"scores={tuple(round(score, 3) for score in result.class_scores)}",
fontsize=9,
rotation=90,
)
fig.savefig(output_path, dpi=180)
plt.close(fig)
def render_diagnostic_visualization(results: list[ModelDiagnosticResult], output_path: Path) -> None:
fig, axes = plt.subplots(len(results), 3, figsize=(11.5, 3.3 * len(results)), constrained_layout=True)
if len(results) == 1:
axes = np.expand_dims(axes, axis=0)
for row, result in enumerate(results):
input_axis, occlusion_axis, part_axis = axes[row]
input_axis.imshow(result.input_image, cmap="gray", vmin=0.0, vmax=1.0)
input_axis.set_title(f"{result.title} input")
input_axis.set_xticks([])
input_axis.set_yticks([])
occlusion_axis.imshow(result.overlay_image)
occlusion_axis.set_title(
f"occlusion | pred={result.predicted_label}\n"
f"scores={tuple(round(score, 3) for score in result.class_scores)}"
)
occlusion_axis.set_xticks([])
occlusion_axis.set_yticks([])
parts = np.arange(result.part_scores.shape[0], dtype=np.int32)
part_axis.bar(parts, result.part_scores, color="#4477AA", alpha=0.85, width=0.8)
part_axis.set_title("pred-class part logits")
part_axis.set_xlabel("HPP part")
part_axis.set_ylabel("logit")
part_axis.set_xticks(parts)
if result.pga_spatial is not None:
part_axis_2 = part_axis.twinx()
part_axis_2.plot(parts, result.pga_spatial, color="#CC6677", marker="o", linewidth=1.5)
part_axis_2.set_ylabel("PGA spatial")
part_axis_2.set_ylim(0.0, max(1.0, float(result.pga_spatial.max()) * 1.05))
fig.savefig(output_path, dpi=180)
plt.close(fig)
def default_model_specs() -> list[ModelSpec]:
return [
ModelSpec(
key="silhouette",
title="ScoNet-MT silhouette",
config_path=REPO_ROOT / "configs/sconet/sconet_scoliosis1k_local_eval_2gpu_better_112.yaml",
checkpoint_path=REPO_ROOT / "ckpt/ScoNet-20000-better.pt",
data_root=Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-sil-pkl"),
is_drf=False,
),
ModelSpec(
key="skeleton",
title="ScoNet-MT-ske",
config_path=REPO_ROOT / "configs/sconet/sconet_scoliosis1k_skeleton_118_sigma15_joint8_2gpu_bs12x8.yaml",
checkpoint_path=REPO_ROOT
/ "output/Scoliosis1K/ScoNet/ScoNet_skeleton_118_sigma15_joint8_sharedalign_2gpu_bs12x8/checkpoints/ScoNet_skeleton_118_sigma15_joint8_sharedalign_2gpu_bs12x8-20000.pt",
data_root=Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-sigma15-joint8-sharedalign"),
is_drf=False,
),
ModelSpec(
key="drf",
title="DRF",
config_path=REPO_ROOT / "configs/drf/drf_scoliosis1k_eval_1gpu.yaml",
checkpoint_path=REPO_ROOT / "output/Scoliosis1K/DRF/DRF/checkpoints/DRF-20000.pt",
data_root=Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118"),
is_drf=True,
),
]
def build_stage_results(
spec: ModelSpec,
response_mode: ResponseMode,
method: AttentionMethod,
power: float,
sequence_id: str,
label: str,
view: str,
frame_count: int,
) -> StageVisualizationResult:
if response_mode != "activation":
raise ValueError("Stage view currently supports activation maps only.")
model, _cfgs = load_checkpoint_model(spec)
sequence, pav, base_image = load_sample_inputs(spec, sequence_id, label, view, frame_count)
_final_feature, score_tuple, predicted_label = compute_classification_scores(model, sequence, pav)
pooled_stages = extract_temporally_pooled_stages(model, sequence)
stage_overlays: dict[str, FloatArray] = {}
for stage_name, stage_tensor in pooled_stages.items():
response_map = activation_attention_map(stage_tensor[0], method=method, power=power)
response_map = F.interpolate(
response_map.unsqueeze(0).unsqueeze(0),
size=sequence.shape[-2:],
mode="bilinear",
align_corners=False,
)[0, 0]
response_np = normalize_image(response_map.detach().cpu().numpy().astype(np.float32))
stage_overlays[stage_name] = overlay_response(base_image, response_np)
return StageVisualizationResult(
title=spec.title,
predicted_label=predicted_label,
class_scores=score_tuple,
input_image=base_image,
stage_overlays=stage_overlays,
)
def build_diagnostic_results(
spec: ModelSpec,
sequence_id: str,
label: str,
view: str,
frame_count: int,
patch_size: int,
stride: int,
) -> ModelDiagnosticResult:
model, _cfgs = load_checkpoint_model(spec)
sequence, pav, base_image = load_sample_inputs(spec, sequence_id, label, view, frame_count)
occlusion_map, score_tuple, predicted_label = compute_occlusion_sensitivity(
model=model,
sequence=sequence,
pav=pav,
patch_size=patch_size,
stride=stride,
)
part_scores, pga_spatial, _scores, _predicted = compute_part_diagnostics(model, sequence, pav)
return ModelDiagnosticResult(
title=spec.title,
predicted_label=predicted_label,
class_scores=score_tuple,
input_image=base_image,
occlusion_map=occlusion_map,
overlay_image=overlay_response(base_image, occlusion_map),
part_scores=part_scores,
pga_spatial=pga_spatial,
)
@click.command()
@click.option("--sequence-id", default="00490", show_default=True)
@click.option("--label", default="positive", show_default=True)
@click.option("--view", default="000_180", show_default=True)
@click.option("--frame-count", default=30, show_default=True, type=int)
@click.option(
"--response-mode",
type=click.Choice(["activation", "cam"]),
default="activation",
show_default=True,
)
@click.option(
"--method",
type=click.Choice(["sum", "sum_p", "max_p"]),
default="sum",
show_default=True,
)
@click.option("--power", default=2.0, show_default=True, type=float)
@click.option(
"--view-mode",
type=click.Choice(["final", "stages", "diagnostics"]),
default="final",
show_default=True,
)
@click.option("--occlusion-patch-size", default=8, show_default=True, type=int)
@click.option("--occlusion-stride", default=8, show_default=True, type=int)
@click.option(
"--output-path",
default=str(REPO_ROOT / "research/feature_response_heatmaps/00490_positive_000_180_response_heatmaps.png"),
show_default=True,
type=click.Path(path_type=Path),
)
def main(
sequence_id: str,
label: str,
view: str,
frame_count: int,
response_mode: str,
method: str,
power: float,
view_mode: str,
occlusion_patch_size: int,
occlusion_stride: int,
output_path: Path,
) -> None:
output_path.parent.mkdir(parents=True, exist_ok=True)
if view_mode == "stages":
results = [
build_stage_results(
spec=spec,
response_mode=response_mode,
method=method,
power=power,
sequence_id=sequence_id,
label=label,
view=view,
frame_count=frame_count,
)
for spec in default_model_specs()
]
render_stage_visualization(results, output_path)
elif view_mode == "diagnostics":
results = [
build_diagnostic_results(
spec=spec,
sequence_id=sequence_id,
label=label,
view=view,
frame_count=frame_count,
patch_size=occlusion_patch_size,
stride=occlusion_stride,
)
for spec in default_model_specs()
]
render_diagnostic_visualization(results, output_path)
else:
results_final: list[VisualizationResult] = []
for spec in default_model_specs():
model, _cfgs = load_checkpoint_model(spec)
sequence, pav, base_image = load_sample_inputs(spec, sequence_id, label, view, frame_count)
response_map, scores, predicted_label = compute_response_map(
model,
sequence,
pav,
response_mode=response_mode,
method=method,
power=power,
)
results_final.append(
VisualizationResult(
title=spec.title,
predicted_label=predicted_label,
class_scores=scores,
input_image=base_image,
response_map=response_map,
overlay_image=overlay_response(base_image, response_map),
)
)
render_visualization(results_final, output_path)
click.echo(str(output_path))
if __name__ == "__main__":
main()
+173
View File
@@ -0,0 +1,173 @@
from __future__ import annotations
import pickle
import sys
from pathlib import Path
import cv2
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import yaml
matplotlib.use("Agg")
REPO_ROOT = Path(__file__).resolve().parents[1]
if __package__ in {None, ""}:
sys.path.append(str(REPO_ROOT))
from datasets import pretreatment_heatmap as prep
FloatArray = npt.NDArray[np.float32]
SAMPLE_ID = "00000"
LABEL = "positive"
VIEW = "000_180"
RESEARCH_ROOT = REPO_ROOT / "research/sigma_comparison"
POSE_PATH = Path(
f"/mnt/public/data/Scoliosis1K/Scoliosis1K-pose-pkl/{SAMPLE_ID}/{LABEL}/{VIEW}/{VIEW}.pkl"
)
SIL_PATH = Path(
f"/mnt/public/data/Scoliosis1K/Scoliosis1K-sil-pkl/{SAMPLE_ID}/{LABEL}/{VIEW}/{VIEW}.pkl"
)
CFG_PATH = REPO_ROOT / "configs/drf/pretreatment_heatmap_drf_sigma15_joint8.yaml"
LIMB_RGB = np.asarray([0.18, 0.85, 0.96], dtype=np.float32)
JOINT_RGB = np.asarray([1.00, 0.45, 0.12], dtype=np.float32)
SIL_RGB = np.asarray([0.85, 0.97, 0.78], dtype=np.float32)
LIMB_ALPHA = 0.8
JOINT_ALPHA = 0.8
SIL_ALPHA = 0.3
def _normalize(array: npt.NDArray[np.generic]) -> FloatArray:
image = np.asarray(array, dtype=np.float32)
if image.max() > 1.0:
image = image / 255.0
return image
def _load_pickle(path: Path) -> FloatArray:
with path.open("rb") as handle:
return _normalize(pickle.load(handle))
def _load_mixed_sharedalign() -> FloatArray:
with CFG_PATH.open("r", encoding="utf-8") as handle:
cfg = yaml.safe_load(handle)
cfg = prep.replace_variables(cfg, cfg)
transform = prep.GenerateHeatmapTransform(
coco18tococo17_args=cfg["coco18tococo17_args"],
padkeypoints_args=cfg["padkeypoints_args"],
norm_args=cfg["norm_args"],
heatmap_generator_args=cfg["heatmap_generator_args"],
align_args=cfg["align_args"],
reduction="upstream",
sigma_limb=float(cfg["sigma_limb"]),
sigma_joint=float(cfg["sigma_joint"]),
)
with POSE_PATH.open("rb") as handle:
pose = np.asarray(pickle.load(handle), dtype=np.float32)
return _normalize(transform(pose))
def _overlay_frame(silhouette: FloatArray, limb: FloatArray, joint: FloatArray) -> FloatArray:
canvas = np.zeros((silhouette.shape[0], silhouette.shape[1], 4), dtype=np.float32)
sil_mask = np.clip(silhouette, 0.0, 1.0)
limb_mask = np.clip(limb, 0.0, 1.0)
joint_mask = np.clip(joint, 0.0, 1.0)
canvas[..., :3] += sil_mask[..., None] * SIL_RGB[None, None, :] * SIL_ALPHA
canvas[..., 3] = np.maximum(canvas[..., 3], sil_mask * SIL_ALPHA)
canvas[..., :3] += limb_mask[..., None] * LIMB_RGB[None, None, :] * LIMB_ALPHA
canvas[..., 3] = np.maximum(canvas[..., 3], limb_mask * LIMB_ALPHA)
canvas[..., :3] += joint_mask[..., None] * JOINT_RGB[None, None, :] * JOINT_ALPHA
canvas[..., 3] = np.maximum(canvas[..., 3], joint_mask * JOINT_ALPHA)
canvas[..., :3] = np.clip(canvas[..., :3], 0.0, 1.0)
canvas[..., 3] = np.clip(canvas[..., 3], 0.0, 1.0)
return canvas
def _bbox(mask: FloatArray, threshold: float = 0.05) -> tuple[int, int, int, int] | None:
ys, xs = np.where(mask > threshold)
if ys.size == 0:
return None
return int(ys.min()), int(ys.max()), int(xs.min()), int(xs.max())
def _align_silhouette_to_heatmap(silhouette: FloatArray, heatmap: FloatArray) -> FloatArray:
sil_box = _bbox(silhouette)
heat_box = _bbox(heatmap)
if sil_box is None or heat_box is None:
return silhouette
sy0, sy1, sx0, sx1 = sil_box
hy0, hy1, hx0, hx1 = heat_box
sil_crop = silhouette[sy0 : sy1 + 1, sx0 : sx1 + 1]
sil_h = max(sy1 - sy0 + 1, 1)
sil_w = max(sx1 - sx0 + 1, 1)
target_h = max(hy1 - hy0 + 1, 1)
scale = target_h / sil_h
target_w = max(int(round(sil_w * scale)), 1)
resized = cv2.resize(sil_crop, (target_w, target_h), interpolation=cv2.INTER_AREA)
canvas = np.zeros_like(silhouette)
heat_center_x = (hx0 + hx1) / 2.0
x0 = int(round(heat_center_x - target_w / 2.0))
x1 = x0 + target_w
src_x0 = max(0, -x0)
src_x1 = target_w - max(0, x1 - silhouette.shape[1])
dst_x0 = max(0, x0)
dst_x1 = min(silhouette.shape[1], x1)
if src_x0 < src_x1 and dst_x0 < dst_x1:
canvas[hy0 : hy1 + 1, dst_x0:dst_x1] = resized[:, src_x0:src_x1]
return canvas
def main() -> None:
RESEARCH_ROOT.mkdir(parents=True, exist_ok=True)
silhouette = _load_pickle(SIL_PATH)
mixed = _load_mixed_sharedalign()
frames = np.linspace(0, silhouette.shape[0] - 1, num=6, dtype=int)
raw_path = RESEARCH_ROOT / f"{SAMPLE_ID}_{LABEL}_{VIEW}_overlay_sharedalign_preview.png"
aligned_path = RESEARCH_ROOT / f"{SAMPLE_ID}_{LABEL}_{VIEW}_overlay_sharedalign_aligned_preview.png"
for output_path, align_silhouette in ((raw_path, False), (aligned_path, True)):
fig, axes = plt.subplots(1, len(frames), figsize=(2.6 * len(frames), 3.6), dpi=220)
for axis, frame_idx in zip(axes, frames, strict=True):
limb = mixed[frame_idx, 0]
joint = mixed[frame_idx, 1]
sil_frame = silhouette[frame_idx]
if align_silhouette:
sil_frame = _align_silhouette_to_heatmap(sil_frame, np.maximum(limb, joint))
overlay = _overlay_frame(sil_frame, limb, joint)
axis.set_facecolor("black")
axis.imshow(overlay)
axis.set_xticks([])
axis.set_yticks([])
axis.set_title(f"Frame {frame_idx}", fontsize=9)
mode = "height+center aligned silhouette" if align_silhouette else "raw silhouette"
fig.suptitle(
(
f"{SAMPLE_ID}/{LABEL}/{VIEW}: {mode}, "
f"silhouette ({SIL_ALPHA}), limb ({LIMB_ALPHA}), joint ({JOINT_ALPHA})"
),
fontsize=12,
)
fig.tight_layout()
fig.savefig(output_path, bbox_inches="tight", facecolor="black")
plt.close(fig)
print(output_path)
if __name__ == "__main__":
main()
+199
View File
@@ -0,0 +1,199 @@
from __future__ import annotations
import pickle
from dataclasses import dataclass
from pathlib import Path
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
matplotlib.use("Agg")
REPO_ROOT = Path(__file__).resolve().parents[1]
RESEARCH_ROOT = REPO_ROOT / "research/sigma_comparison"
SIGMA8_ROOT = Path("/mnt/public/data/Scoliosis1K/Scoliosis1K_sigma_8.0/pkl")
SIGMA15_ROOT = Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-sigma15")
MIXED_ROOT = Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-sigma15-joint8-fixed")
SIL_ROOT = Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-sil-pkl")
FloatArray = npt.NDArray[np.float32]
@dataclass(frozen=True)
class SampleSpec:
sequence_id: str
label: str
view: str
def _normalize_image(array: npt.NDArray[np.generic]) -> FloatArray:
image = np.asarray(array, dtype=np.float32)
if image.max() > 1.0:
image = image / 255.0
return image
def _load_pickle(path: Path) -> FloatArray:
with path.open("rb") as handle:
return _normalize_image(pickle.load(handle))
def _load_sigma8(spec: SampleSpec) -> FloatArray:
return _load_pickle(SIGMA8_ROOT / spec.sequence_id / spec.label / spec.view / f"{spec.view}.pkl")
def _load_sigma15(spec: SampleSpec) -> FloatArray:
return _load_pickle(SIGMA15_ROOT / spec.sequence_id / spec.label / spec.view / "0_heatmap.pkl")
def _load_silhouette(spec: SampleSpec) -> FloatArray:
return _load_pickle(SIL_ROOT / spec.sequence_id / spec.label / spec.view / f"{spec.view}.pkl")
def _load_mixed(spec: SampleSpec) -> FloatArray:
return _load_pickle(MIXED_ROOT / spec.sequence_id / spec.label / spec.view / "0_heatmap.pkl")
def _frame_indices(num_frames: int, count: int = 6) -> npt.NDArray[np.int_]:
return np.linspace(0, num_frames - 1, num=count, dtype=int)
def _combined_map(heatmaps: FloatArray) -> FloatArray:
return np.maximum(heatmaps[:, 0], heatmaps[:, 1])
def _save_combined_figure(
spec: SampleSpec,
silhouette: FloatArray,
sigma8: FloatArray,
sigma15: FloatArray,
mixed: FloatArray,
) -> Path:
frames = _frame_indices(silhouette.shape[0], count=6)
fig, axes = plt.subplots(4, len(frames), figsize=(2.4 * len(frames), 8.8), dpi=200)
row_titles = (
"Silhouette",
"Sigma 8 Combined",
"Sigma 1.5 Combined",
"Limb 1.5 / Joint 8 Combined",
)
images = (silhouette, _combined_map(sigma8), _combined_map(sigma15), _combined_map(mixed))
for row_idx, (title, row_images) in enumerate(zip(row_titles, images, strict=True)):
for col_idx, frame_idx in enumerate(frames):
ax = axes[row_idx, col_idx]
ax.imshow(row_images[frame_idx], cmap="gray", vmin=0.0, vmax=1.0)
ax.set_xticks([])
ax.set_yticks([])
if row_idx == 0:
ax.set_title(f"Frame {frame_idx}", fontsize=9)
if col_idx == 0:
ax.set_ylabel(title, fontsize=10)
fig.suptitle(
f"Sample {spec.sequence_id}/{spec.label}/{spec.view}: silhouette vs sigma heatmaps",
fontsize=12,
)
fig.tight_layout()
output_path = RESEARCH_ROOT / f"{spec.sequence_id}_{spec.label}_{spec.view}_combined.png"
fig.savefig(output_path, bbox_inches="tight")
plt.close(fig)
return output_path
def _save_channel_figure(spec: SampleSpec, sigma8: FloatArray, sigma15: FloatArray) -> Path:
frames = _frame_indices(sigma8.shape[0], count=4)
fig, axes = plt.subplots(4, len(frames), figsize=(2.6 * len(frames), 8.4), dpi=200)
row_titles = (
"Sigma 8 Limb",
"Sigma 8 Joint",
"Sigma 1.5 Limb",
"Sigma 1.5 Joint",
)
images = (sigma8[:, 0], sigma8[:, 1], sigma15[:, 0], sigma15[:, 1])
for row_idx, (title, row_images) in enumerate(zip(row_titles, images, strict=True)):
vmax = float(np.max(row_images)) if np.max(row_images) > 0 else 1.0
for col_idx, frame_idx in enumerate(frames):
ax = axes[row_idx, col_idx]
ax.imshow(row_images[frame_idx], cmap="magma", vmin=0.0, vmax=vmax)
ax.set_xticks([])
ax.set_yticks([])
if row_idx == 0:
ax.set_title(f"Frame {frame_idx}", fontsize=9)
if col_idx == 0:
ax.set_ylabel(title, fontsize=10)
fig.suptitle(
f"Sample {spec.sequence_id}/{spec.label}/{spec.view}: per-channel sigma comparison",
fontsize=12,
)
fig.tight_layout()
output_path = RESEARCH_ROOT / f"{spec.sequence_id}_{spec.label}_{spec.view}_channels.png"
fig.savefig(output_path, bbox_inches="tight")
plt.close(fig)
return output_path
def _save_mixed_channel_figure(
spec: SampleSpec,
sigma8: FloatArray,
sigma15: FloatArray,
mixed: FloatArray,
) -> Path:
frames = _frame_indices(sigma8.shape[0], count=4)
fig, axes = plt.subplots(6, len(frames), figsize=(2.6 * len(frames), 12.0), dpi=200)
row_titles = (
"Sigma 8 Limb",
"Sigma 8 Joint",
"Sigma 1.5 Limb",
"Sigma 1.5 Joint",
"Mixed Limb 1.5",
"Mixed Joint 8",
)
images = (sigma8[:, 0], sigma8[:, 1], sigma15[:, 0], sigma15[:, 1], mixed[:, 0], mixed[:, 1])
for row_idx, (title, row_images) in enumerate(zip(row_titles, images, strict=True)):
vmax = float(np.max(row_images)) if np.max(row_images) > 0 else 1.0
for col_idx, frame_idx in enumerate(frames):
ax = axes[row_idx, col_idx]
ax.imshow(row_images[frame_idx], cmap="magma", vmin=0.0, vmax=vmax)
ax.set_xticks([])
ax.set_yticks([])
if row_idx == 0:
ax.set_title(f"Frame {frame_idx}", fontsize=9)
if col_idx == 0:
ax.set_ylabel(title, fontsize=10)
fig.suptitle(
f"Sample {spec.sequence_id}/{spec.label}/{spec.view}: sigma and mixed per-channel comparison",
fontsize=12,
)
fig.tight_layout()
output_path = RESEARCH_ROOT / f"{spec.sequence_id}_{spec.label}_{spec.view}_channels_mixed.png"
fig.savefig(output_path, bbox_inches="tight")
plt.close(fig)
return output_path
def main() -> None:
RESEARCH_ROOT.mkdir(parents=True, exist_ok=True)
sample = SampleSpec(sequence_id="00000", label="positive", view="000_180")
sigma8 = _load_sigma8(sample)
sigma15 = _load_sigma15(sample)
mixed = _load_mixed(sample)
silhouette = _load_silhouette(sample)
combined_path = _save_combined_figure(sample, silhouette, sigma8, sigma15, mixed)
channels_path = _save_channel_figure(sample, sigma8, sigma15)
mixed_channels_path = _save_mixed_channel_figure(sample, sigma8, sigma15, mixed)
print(combined_path)
print(channels_path)
print(mixed_channels_path)
if __name__ == "__main__":
main()