diff --git a/docs/scoliosis_training_change_log.md b/docs/scoliosis_training_change_log.md index 5568234..62cfa30 100644 --- a/docs/scoliosis_training_change_log.md +++ b/docs/scoliosis_training_change_log.md @@ -26,8 +26,9 @@ Use it for: | 2026-03-08 | `ScoNet_skeleton_118_sigma15_bs12x8` | ScoNet-MT-ske control | `Scoliosis1K-drf-pkl-118-sigma15` | Lowered skeleton-map sigma from `8.0` to `1.5` to tighten the pose rasterization | complete | `46.33 Acc / 68.09 Prec / 51.92 Rec / 44.69 F1` | | 2026-03-09 | `ScoNet_skeleton_118_sigma15_joint8_sharedalign_2gpu_bs12x8` | ScoNet-MT-ske control | `Scoliosis1K-drf-pkl-118-sigma15-joint8-sharedalign` | Fixed limb/joint channel misalignment, used mixed sigma `limb=1.5 / joint=8.0`, kept SGD | complete | `50.47 Acc / 69.31 Prec / 54.58 Rec / 48.63 F1` | | 2026-03-09 | `ScoNet_skeleton_118_sigma15_joint8_limb4_adamw_2gpu_bs12x8` | ScoNet-MT-ske control | `Scoliosis1K-drf-pkl-118-sigma15-joint8-sharedalign-limb4` | Rebalanced channel intensity with `limb_gain=4.0`; switched optimizer from `SGD` to `AdamW` | complete | `48.60 Acc / 65.97 Prec / 53.19 Rec / 46.41 F1` | -| 2026-03-09 | `ScoNet_skeleton_118_sigma15_joint8_sharedalign_nocut_adamw_1gpu_bs8x8` | ScoNet-MT-ske control | `Scoliosis1K-drf-pkl-118-sigma15-joint8-sharedalign` | Switched runtime transform from `BaseSilCuttingTransform` to `BaseSilTransform` (`no-cut`), kept `AdamW`, reduced `8x8` due to 5070 Ti OOM at `12x8` | training | no eval yet | -| 2026-03-09 | `ScoNet_skeleton_118_sigma15_joint8_sharedalign_nocut_adamw_proxy_1gpu` | ScoNet-MT-ske proxy | `Scoliosis1K-drf-pkl-118-sigma15-joint8-sharedalign` | Fast proxy route: `no-cut`, `AdamW`, `8x8`, `total_iter=2000`, `eval_iter=500`, `test_seq_subset_size=128` | training | no eval yet | +| 2026-03-09 | `ScoNet_skeleton_118_sigma15_joint8_sharedalign_nocut_adamw_1gpu_bs8x8` | ScoNet-MT-ske control | `Scoliosis1K-drf-pkl-118-sigma15-joint8-sharedalign` | Switched runtime transform from `BaseSilCuttingTransform` to `BaseSilTransform` (`no-cut`), kept `AdamW`, reduced `8x8` due to 5070 Ti OOM at `12x8` | interrupted | superseded by proxy route before eval | +| 2026-03-09 | `ScoNet_skeleton_118_sigma15_joint8_sharedalign_nocut_adamw_proxy_1gpu` | ScoNet-MT-ske proxy | `Scoliosis1K-drf-pkl-118-sigma15-joint8-sharedalign` | Fast proxy route: `no-cut`, `AdamW`, `8x8`, `total_iter=2000`, `eval_iter=500`, `test_seq_subset_size=128` | interrupted | superseded by geometry-fixed proxy before completion | +| 2026-03-10 | `ScoNet_skeleton_118_sigma15_joint8_geomfix_proxy_1gpu` | ScoNet-MT-ske proxy | `Scoliosis1K-drf-pkl-118-sigma15-joint8-geomfix` | Geometry ablation: aspect-ratio-preserving crop+pad instead of square-warp resize; `AdamW`, `no-cut`, `8x8`, `total_iter=2000`, `eval_iter=500`, fixed test subset seed `118` | complete | proxy subset unstable: `500 24.22/8.07/33.33/13.00`, `1000 60.16/68.05/58.13/55.25`, `1500 26.56/58.33/35.64/17.68`, `2000 27.34/63.96/37.02/20.14` (Acc/Prec/Rec/F1) | ## Current best skeleton baseline @@ -40,3 +41,4 @@ Current best `ScoNet-MT-ske`-style result: - `ckpt/ScoNet-20000-better.pt` is intentionally not listed here because it is a silhouette checkpoint, not a skeleton-map run. - `DRF` runs are included because they are part of the same reproduction/debugging loop, but this log should stay focused on train/eval changes, not broader code refactors. - The long `ScoNet_skeleton_118_sigma15_joint8_sharedalign_nocut_adamw_1gpu_bs8x8` run was intentionally interrupted and superseded by the shorter proxy run once fast-iteration support was added. +- The geometry-fixed proxy run fit the train split quickly but did not produce a stable proxy validation curve, so it should not be promoted to a full 20k run. diff --git a/scripts/visualize_feature_response_heatmaps.py b/scripts/visualize_feature_response_heatmaps.py new file mode 100644 index 0000000..fa12abe --- /dev/null +++ b/scripts/visualize_feature_response_heatmaps.py @@ -0,0 +1,778 @@ +from __future__ import annotations + +import pickle +import sys +import types +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Literal + +import click +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from jaxtyping import Float +from numpy.typing import NDArray + + +REPO_ROOT = Path(__file__).resolve().parents[1] +OPENGAIT_ROOT = REPO_ROOT / "opengait" +if str(OPENGAIT_ROOT) not in sys.path: + sys.path.insert(0, str(OPENGAIT_ROOT)) +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from opengait.data.transform import get_transform +from opengait.modeling import models +from opengait.modeling.base_model import BaseModel +from opengait.modeling.models.drf import DRF +from opengait.utils.common import config_loader + + +FloatArray = NDArray[np.float32] +LABEL_NAMES = ("negative", "neutral", "positive") + + +@dataclass(frozen=True) +class ModelSpec: + key: str + title: str + config_path: Path + checkpoint_path: Path + data_root: Path + is_drf: bool + + +@dataclass(frozen=True) +class VisualizationResult: + title: str + predicted_label: str + class_scores: tuple[float, float, float] + input_image: FloatArray + response_map: FloatArray + overlay_image: FloatArray + + +@dataclass(frozen=True) +class StageVisualizationResult: + title: str + predicted_label: str + class_scores: tuple[float, float, float] + input_image: FloatArray + stage_overlays: dict[str, FloatArray] + + +@dataclass(frozen=True) +class ForwardDiagnostics: + feature_map: Float[torch.Tensor, "batch channels height width"] + pooled: Float[torch.Tensor, "batch channels parts"] + embeddings: Float[torch.Tensor, "batch channels parts"] + logits: Float[torch.Tensor, "batch classes parts"] + class_scores: Float[torch.Tensor, "batch classes"] + pga_spatial: Float[torch.Tensor, "batch parts"] | None + + +@dataclass(frozen=True) +class ModelDiagnosticResult: + title: str + predicted_label: str + class_scores: tuple[float, float, float] + input_image: FloatArray + occlusion_map: FloatArray + overlay_image: FloatArray + part_scores: FloatArray + pga_spatial: FloatArray | None + + +def apply_transforms( + sequence: np.ndarray, + transform_cfg: list[dict[str, Any]] | dict[str, Any] | None, +) -> np.ndarray: + transform = get_transform(transform_cfg) + if isinstance(transform, list): + result = sequence + for item in transform: + result = item(result) + return np.asarray(result) + return np.asarray(transform(sequence)) + + +def build_minimal_model(cfgs: dict[str, Any]) -> nn.Module: + model_name = str(cfgs["model_cfg"]["model"]) + model_cls = getattr(models, model_name) + model = model_cls.__new__(model_cls) + nn.Module.__init__(model) + model.get_backbone = types.MethodType(BaseModel.get_backbone, model) + model.build_network(cfgs["model_cfg"]) + return model + + +def load_checkpoint_model(spec: ModelSpec) -> tuple[nn.Module, dict[str, Any]]: + cfgs = config_loader(str(spec.config_path)) + model = build_minimal_model(cfgs) + checkpoint = torch.load( + spec.checkpoint_path, + map_location="cpu", + weights_only=False, + ) + model.load_state_dict(checkpoint["model"], strict=False) + model.eval() + return model, cfgs + + +def load_pickle(path: Path) -> np.ndarray: + with path.open("rb") as handle: + return np.asarray(pickle.load(handle)) + + +def load_sample_inputs( + spec: ModelSpec, + sequence_id: str, + label: str, + view: str, + frame_count: int, +) -> tuple[torch.Tensor, torch.Tensor | None, FloatArray]: + if spec.is_drf: + heatmap_path = spec.data_root / sequence_id / label / view / "0_heatmap.pkl" + pav_path = spec.data_root / sequence_id / label / view / "1_pav.pkl" + heatmaps = load_pickle(heatmap_path).astype(np.float32) + pav = load_pickle(pav_path).astype(np.float32) + return prepare_heatmap_inputs(heatmaps, pav, spec.config_path, frame_count) + + if "skeleton" in spec.key: + heatmap_path = spec.data_root / sequence_id / label / view / "0_heatmap.pkl" + heatmaps = load_pickle(heatmap_path).astype(np.float32) + return prepare_heatmap_inputs(heatmaps, None, spec.config_path, frame_count) + + silhouette_path = spec.data_root / sequence_id / label / view / f"{view}.pkl" + silhouettes = load_pickle(silhouette_path).astype(np.float32) + return prepare_silhouette_inputs(silhouettes, spec.config_path, frame_count) + + +def prepare_silhouette_inputs( + silhouettes: FloatArray, + config_path: Path, + frame_count: int, +) -> tuple[torch.Tensor, None, FloatArray]: + cfgs = config_loader(str(config_path)) + transformed = apply_transforms(silhouettes, cfgs["evaluator_cfg"]["transform"]).astype(np.float32) + indices = np.linspace(0, transformed.shape[0] - 1, frame_count, dtype=int) + sampled = transformed[indices] + input_tensor = torch.from_numpy(sampled).unsqueeze(0).unsqueeze(0).float() + projection = normalize_image(sampled.max(axis=0)) + return input_tensor, None, projection + + +def prepare_heatmap_inputs( + heatmaps: FloatArray, + pav: FloatArray | None, + config_path: Path, + frame_count: int, +) -> tuple[torch.Tensor, torch.Tensor | None, FloatArray]: + cfgs = config_loader(str(config_path)) + transform_cfg = cfgs["evaluator_cfg"]["transform"] + heatmap_transform = transform_cfg[0] if isinstance(transform_cfg, list) else transform_cfg + transformed = apply_transforms(heatmaps, heatmap_transform).astype(np.float32) + indices = np.linspace(0, transformed.shape[0] - 1, frame_count, dtype=int) + sampled = transformed[indices] + input_tensor = torch.from_numpy(sampled).permute(1, 0, 2, 3).unsqueeze(0).float() + combined = np.maximum(sampled.max(axis=0)[0], sampled.max(axis=0)[1]) + projection = normalize_image(combined) + pav_tensor = None + if pav is not None: + pav_tensor = torch.from_numpy(pav.mean(axis=0, keepdims=True)).float() + return input_tensor, pav_tensor, projection + + +AttentionMethod = Literal["sum", "sum_p", "max_p"] +ResponseMode = Literal["activation", "cam"] + + +def compute_classification_scores( + model: nn.Module, + sequence: torch.Tensor, + pav: torch.Tensor | None, +) -> tuple[FloatArray, tuple[float, float, float], str]: + diagnostics = forward_diagnostics(model, sequence, pav) + class_scores = diagnostics.class_scores[0] + pred_index = int(class_scores.argmax().detach().cpu()) + score_tuple = tuple(float(value.detach().cpu()) for value in class_scores) + return diagnostics.feature_map.detach()[0], score_tuple, LABEL_NAMES[pred_index] + + +def forward_diagnostics( + model: nn.Module, + sequence: torch.Tensor, + pav: torch.Tensor | None, +) -> ForwardDiagnostics: + outputs = model.Backbone(sequence) + feature_map = model.TP(outputs, None, options={"dim": 2})[0] + pooled = model.HPP(feature_map) + embeddings = model.FCs(pooled) + + pga_spatial: Float[torch.Tensor, "batch parts"] | None = None + if pav is not None and isinstance(model, DRF): + pav_flat = pav.flatten(1) + channel_att = model.PGA.channel_att(pav_flat).unsqueeze(-1) + spatial_att = model.PGA.spatial_att(pav_flat.unsqueeze(-1)).transpose(1, 2) + pga_spatial = spatial_att.squeeze(1) + embeddings = embeddings * channel_att * spatial_att + + _bn_embeddings, logits = model.BNNecks(embeddings) + class_scores = logits.mean(-1) + return ForwardDiagnostics( + feature_map=feature_map, + pooled=pooled, + embeddings=embeddings, + logits=logits, + class_scores=class_scores, + pga_spatial=pga_spatial, + ) + + +def compute_response_map( + model: nn.Module, + sequence: torch.Tensor, + pav: torch.Tensor | None, + response_mode: ResponseMode, + method: AttentionMethod, + power: float, +) -> tuple[FloatArray, tuple[float, float, float], str]: + activation, score_tuple, predicted_label = compute_classification_scores(model, sequence, pav) + if response_mode == "cam": + response_map = adapted_cam_map( + model=model, + sequence=sequence, + pav=pav, + predicted_label=predicted_label, + ) + else: + response_map = activation_attention_map(activation, method=method, power=power) + response_map = response_map.unsqueeze(0).unsqueeze(0) + response_map = F.interpolate( + response_map, + size=sequence.shape[-2:], + mode="bilinear", + align_corners=False, + )[0, 0] + response_np = response_map.detach().cpu().numpy().astype(np.float32) + response_np = normalize_image(response_np) + return response_np, score_tuple, predicted_label + + +def activation_attention_map( + activation: Float[torch.Tensor, "channels height width"], + method: AttentionMethod, + power: float, +) -> Float[torch.Tensor, "height width"]: + abs_activation = activation.abs() + if method == "sum": + return abs_activation.sum(dim=0) + powered = abs_activation.pow(power) + if method == "sum_p": + return powered.sum(dim=0) + if method == "max_p": + return powered.max(dim=0).values + raise ValueError(f"Unsupported attention method: {method}") + + +def adapted_cam_map( + model: nn.Module, + sequence: torch.Tensor, + pav: torch.Tensor | None, + predicted_label: str, +) -> Float[torch.Tensor, "height width"]: + outputs = model.Backbone(sequence) + feature_map = model.TP(outputs, None, options={"dim": 2})[0] # [1, c_in, h, w] + pooled = model.HPP(feature_map) # [1, c_in, p] + + pre_attention = model.FCs(pooled) # [1, c_out, p] + attention = torch.ones_like(pre_attention) + if pav is not None and isinstance(model, DRF): + pav_flat = pav.flatten(1) + channel_att = model.PGA.channel_att(pav_flat).unsqueeze(-1) + spatial_att = model.PGA.spatial_att(pav_flat.unsqueeze(-1)).transpose(1, 2) + attention = channel_att * spatial_att + embeddings = pre_attention * attention + + bn_scales = bn_feature_scales(model.BNNecks, embeddings.shape[1], embeddings.shape[2], embeddings.device) + classifier_weights = F.normalize(model.BNNecks.fc_bin.detach(), dim=1) + class_index = LABEL_NAMES.index(predicted_label) + class_weights = classifier_weights[:, :, class_index] # [p, c_out] + + normalized_embeddings = batchnorm_forward_eval(model.BNNecks, embeddings)[0].transpose(0, 1) # [p, c_out] + part_norms = normalized_embeddings.norm(dim=-1, keepdim=True).clamp_min(1e-6) + effective_post_bn = class_weights / part_norms + effective_pre_bn = effective_post_bn * bn_scales + + fc_weights = model.FCs.fc_bin.detach() # [p, c_in, c_out] + attention_weights = attention.detach()[0].transpose(0, 1) # [p, c_out] + effective_channel_weights = torch.einsum( + "pkj,pj->pk", + fc_weights, + attention_weights * effective_pre_bn, + ) # [p, c_in] + + feature = feature_map.detach()[0] # [c_in, h, w] + parts_num = effective_channel_weights.shape[0] + height = feature.shape[1] + split_size = max(height // parts_num, 1) + cam = torch.zeros((height, feature.shape[2]), device=feature.device, dtype=feature.dtype) + for part_index in range(parts_num): + row_start = part_index * split_size + row_end = height if part_index == parts_num - 1 else min((part_index + 1) * split_size, height) + if row_start >= height: + break + contribution = (feature[:, row_start:row_end, :] * effective_channel_weights[part_index][:, None, None]).sum(dim=0) + cam[row_start:row_end, :] = contribution + return F.relu(cam) + + +def batchnorm_forward_eval( + bn_necks: nn.Module, + embeddings: Float[torch.Tensor, "batch channels parts"], +) -> Float[torch.Tensor, "batch channels parts"]: + batch, channels, parts = embeddings.shape + if not getattr(bn_necks, "parallel_BN1d", True): + raise ValueError("Adapted CAM currently expects parallel_BN1d=True.") + bn = bn_necks.bn1d + flattened = embeddings.reshape(batch, -1) + normalized = torch.nn.functional.batch_norm( + flattened, + bn.running_mean, + bn.running_var, + bn.weight, + bn.bias, + training=False, + momentum=0.0, + eps=bn.eps, + ) + return normalized.view(batch, channels, parts) + + +def bn_feature_scales( + bn_necks: nn.Module, + channels: int, + parts: int, + device: torch.device, +) -> Float[torch.Tensor, "parts channels"]: + if not getattr(bn_necks, "parallel_BN1d", True): + raise ValueError("Adapted CAM currently expects parallel_BN1d=True.") + bn = bn_necks.bn1d + running_var = bn.running_var.detach().to(device).view(channels, parts).transpose(0, 1) + base_scale = torch.rsqrt(running_var + bn.eps) + if bn.weight is None: + return base_scale + gamma = bn.weight.detach().to(device).view(channels, parts).transpose(0, 1) + return base_scale * gamma + + +def normalize_image(image: np.ndarray) -> FloatArray: + image = image.astype(np.float32) + min_value = float(image.min()) + max_value = float(image.max()) + if max_value <= min_value: + return np.zeros_like(image, dtype=np.float32) + return (image - min_value) / (max_value - min_value) + + +def overlay_response(base_image: FloatArray, response_map: FloatArray) -> FloatArray: + cmap = plt.get_cmap("jet") + heat = cmap(response_map)[..., :3].astype(np.float32) + base_rgb = np.repeat(base_image[..., None], 3, axis=-1) + overlay = base_rgb * 0.45 + heat * 0.55 + return np.clip(overlay, 0.0, 1.0) + + +def compute_occlusion_sensitivity( + model: nn.Module, + sequence: torch.Tensor, + pav: torch.Tensor | None, + patch_size: int, + stride: int, +) -> tuple[FloatArray, tuple[float, float, float], str]: + with torch.no_grad(): + baseline = forward_diagnostics(model, sequence, pav) + class_scores = baseline.class_scores[0] + pred_index = int(class_scores.argmax().detach().cpu()) + baseline_score = float(class_scores[pred_index].detach().cpu()) + score_tuple = tuple(float(value.detach().cpu()) for value in class_scores) + + _, _, _, height, width = sequence.shape + sensitivity = np.zeros((height, width), dtype=np.float32) + counts = np.zeros((height, width), dtype=np.float32) + row_starts = list(range(0, max(height - patch_size + 1, 1), stride)) + col_starts = list(range(0, max(width - patch_size + 1, 1), stride)) + last_row = max(height - patch_size, 0) + last_col = max(width - patch_size, 0) + if row_starts[-1] != last_row: + row_starts.append(last_row) + if col_starts[-1] != last_col: + col_starts.append(last_col) + + for top in row_starts: + for left in col_starts: + bottom = min(top + patch_size, height) + right = min(left + patch_size, width) + masked = sequence.clone() + masked[..., top:bottom, left:right] = 0.0 + masked_scores = forward_diagnostics(model, masked, pav).class_scores[0] + score_drop = max(0.0, baseline_score - float(masked_scores[pred_index].detach().cpu())) + sensitivity[top:bottom, left:right] += score_drop + counts[top:bottom, left:right] += 1.0 + + counts = np.maximum(counts, 1.0) + sensitivity /= counts + return normalize_image(sensitivity), score_tuple, LABEL_NAMES[pred_index] + + +def compute_part_diagnostics( + model: nn.Module, + sequence: torch.Tensor, + pav: torch.Tensor | None, +) -> tuple[FloatArray, FloatArray | None, tuple[float, float, float], str]: + with torch.no_grad(): + diagnostics = forward_diagnostics(model, sequence, pav) + class_scores = diagnostics.class_scores[0] + pred_index = int(class_scores.argmax().detach().cpu()) + part_scores = diagnostics.logits[0, pred_index].detach().cpu().numpy().astype(np.float32) + pga_spatial = None + if diagnostics.pga_spatial is not None: + pga_spatial = diagnostics.pga_spatial[0].detach().cpu().numpy().astype(np.float32) + score_tuple = tuple(float(value.detach().cpu()) for value in class_scores) + return part_scores, pga_spatial, score_tuple, LABEL_NAMES[pred_index] + + +def extract_temporally_pooled_stages( + model: nn.Module, + sequence: torch.Tensor, +) -> dict[str, torch.Tensor]: + if not hasattr(model.Backbone, "forward_block"): + raise ValueError("Stage visualization requires SetBlockWrapper-style backbones.") + + backbone = model.Backbone.forward_block + batch, channels, frames, height, width = sequence.shape + x = sequence.transpose(1, 2).reshape(-1, channels, height, width) + + stage_outputs: dict[str, torch.Tensor] = {} + x = backbone.conv1(x) + x = backbone.bn1(x) + x = backbone.relu(x) + if getattr(backbone, "maxpool_flag", False): + x = backbone.maxpool(x) + stage_outputs["conv1"] = x + + x = backbone.layer1(x) + stage_outputs["layer1"] = x + x = backbone.layer2(x) + stage_outputs["layer2"] = x + x = backbone.layer3(x) + stage_outputs["layer3"] = x + x = backbone.layer4(x) + stage_outputs["layer4"] = x + + pooled_outputs: dict[str, torch.Tensor] = {} + for stage_name, stage_output in stage_outputs.items(): + reshaped = stage_output.reshape(batch, frames, *stage_output.shape[1:]).transpose(1, 2).contiguous() + pooled_outputs[stage_name] = model.TP(reshaped, None, options={"dim": 2})[0] + return pooled_outputs + + +def render_visualization(results: list[VisualizationResult], output_path: Path) -> None: + fig, axes = plt.subplots(len(results), 3, figsize=(10.5, 3.2 * len(results)), constrained_layout=True) + if len(results) == 1: + axes = np.expand_dims(axes, axis=0) + + for row, result in enumerate(results): + input_axis, heat_axis, overlay_axis = axes[row] + input_axis.imshow(result.input_image, cmap="gray", vmin=0.0, vmax=1.0) + input_axis.set_title(f"{result.title} input") + heat_axis.imshow(result.response_map, cmap="jet", vmin=0.0, vmax=1.0) + heat_axis.set_title("response") + overlay_axis.imshow(result.overlay_image) + overlay_axis.set_title( + f"overlay | pred={result.predicted_label}\n" + f"scores={tuple(round(score, 3) for score in result.class_scores)}" + ) + for axis in (input_axis, heat_axis, overlay_axis): + axis.set_xticks([]) + axis.set_yticks([]) + + fig.savefig(output_path, dpi=180) + plt.close(fig) + + +def render_stage_visualization(results: list[StageVisualizationResult], output_path: Path) -> None: + stage_names = ["input", "conv1", "layer1", "layer2", "layer3", "layer4"] + fig, axes = plt.subplots( + len(results), + len(stage_names), + figsize=(2.2 * len(stage_names), 3.0 * len(results)), + constrained_layout=True, + ) + if len(results) == 1: + axes = np.expand_dims(axes, axis=0) + + for row, result in enumerate(results): + for col, stage_name in enumerate(stage_names): + axis = axes[row, col] + if stage_name == "input": + axis.imshow(result.input_image, cmap="gray", vmin=0.0, vmax=1.0) + axis.set_title( + f"{result.title}\npred={result.predicted_label}", + fontsize=10, + ) + else: + axis.imshow(result.stage_overlays[stage_name]) + axis.set_title(stage_name, fontsize=10) + axis.set_xticks([]) + axis.set_yticks([]) + axes[row, 0].set_ylabel( + f"scores={tuple(round(score, 3) for score in result.class_scores)}", + fontsize=9, + rotation=90, + ) + + fig.savefig(output_path, dpi=180) + plt.close(fig) + + +def render_diagnostic_visualization(results: list[ModelDiagnosticResult], output_path: Path) -> None: + fig, axes = plt.subplots(len(results), 3, figsize=(11.5, 3.3 * len(results)), constrained_layout=True) + if len(results) == 1: + axes = np.expand_dims(axes, axis=0) + + for row, result in enumerate(results): + input_axis, occlusion_axis, part_axis = axes[row] + input_axis.imshow(result.input_image, cmap="gray", vmin=0.0, vmax=1.0) + input_axis.set_title(f"{result.title} input") + input_axis.set_xticks([]) + input_axis.set_yticks([]) + + occlusion_axis.imshow(result.overlay_image) + occlusion_axis.set_title( + f"occlusion | pred={result.predicted_label}\n" + f"scores={tuple(round(score, 3) for score in result.class_scores)}" + ) + occlusion_axis.set_xticks([]) + occlusion_axis.set_yticks([]) + + parts = np.arange(result.part_scores.shape[0], dtype=np.int32) + part_axis.bar(parts, result.part_scores, color="#4477AA", alpha=0.85, width=0.8) + part_axis.set_title("pred-class part logits") + part_axis.set_xlabel("HPP part") + part_axis.set_ylabel("logit") + part_axis.set_xticks(parts) + if result.pga_spatial is not None: + part_axis_2 = part_axis.twinx() + part_axis_2.plot(parts, result.pga_spatial, color="#CC6677", marker="o", linewidth=1.5) + part_axis_2.set_ylabel("PGA spatial") + part_axis_2.set_ylim(0.0, max(1.0, float(result.pga_spatial.max()) * 1.05)) + + fig.savefig(output_path, dpi=180) + plt.close(fig) + + +def default_model_specs() -> list[ModelSpec]: + return [ + ModelSpec( + key="silhouette", + title="ScoNet-MT silhouette", + config_path=REPO_ROOT / "configs/sconet/sconet_scoliosis1k_local_eval_2gpu_better_112.yaml", + checkpoint_path=REPO_ROOT / "ckpt/ScoNet-20000-better.pt", + data_root=Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-sil-pkl"), + is_drf=False, + ), + ModelSpec( + key="skeleton", + title="ScoNet-MT-ske", + config_path=REPO_ROOT / "configs/sconet/sconet_scoliosis1k_skeleton_118_sigma15_joint8_2gpu_bs12x8.yaml", + checkpoint_path=REPO_ROOT + / "output/Scoliosis1K/ScoNet/ScoNet_skeleton_118_sigma15_joint8_sharedalign_2gpu_bs12x8/checkpoints/ScoNet_skeleton_118_sigma15_joint8_sharedalign_2gpu_bs12x8-20000.pt", + data_root=Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-sigma15-joint8-sharedalign"), + is_drf=False, + ), + ModelSpec( + key="drf", + title="DRF", + config_path=REPO_ROOT / "configs/drf/drf_scoliosis1k_eval_1gpu.yaml", + checkpoint_path=REPO_ROOT / "output/Scoliosis1K/DRF/DRF/checkpoints/DRF-20000.pt", + data_root=Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118"), + is_drf=True, + ), + ] + + +def build_stage_results( + spec: ModelSpec, + response_mode: ResponseMode, + method: AttentionMethod, + power: float, + sequence_id: str, + label: str, + view: str, + frame_count: int, +) -> StageVisualizationResult: + if response_mode != "activation": + raise ValueError("Stage view currently supports activation maps only.") + model, _cfgs = load_checkpoint_model(spec) + sequence, pav, base_image = load_sample_inputs(spec, sequence_id, label, view, frame_count) + _final_feature, score_tuple, predicted_label = compute_classification_scores(model, sequence, pav) + pooled_stages = extract_temporally_pooled_stages(model, sequence) + stage_overlays: dict[str, FloatArray] = {} + for stage_name, stage_tensor in pooled_stages.items(): + response_map = activation_attention_map(stage_tensor[0], method=method, power=power) + response_map = F.interpolate( + response_map.unsqueeze(0).unsqueeze(0), + size=sequence.shape[-2:], + mode="bilinear", + align_corners=False, + )[0, 0] + response_np = normalize_image(response_map.detach().cpu().numpy().astype(np.float32)) + stage_overlays[stage_name] = overlay_response(base_image, response_np) + return StageVisualizationResult( + title=spec.title, + predicted_label=predicted_label, + class_scores=score_tuple, + input_image=base_image, + stage_overlays=stage_overlays, + ) + + +def build_diagnostic_results( + spec: ModelSpec, + sequence_id: str, + label: str, + view: str, + frame_count: int, + patch_size: int, + stride: int, +) -> ModelDiagnosticResult: + model, _cfgs = load_checkpoint_model(spec) + sequence, pav, base_image = load_sample_inputs(spec, sequence_id, label, view, frame_count) + occlusion_map, score_tuple, predicted_label = compute_occlusion_sensitivity( + model=model, + sequence=sequence, + pav=pav, + patch_size=patch_size, + stride=stride, + ) + part_scores, pga_spatial, _scores, _predicted = compute_part_diagnostics(model, sequence, pav) + return ModelDiagnosticResult( + title=spec.title, + predicted_label=predicted_label, + class_scores=score_tuple, + input_image=base_image, + occlusion_map=occlusion_map, + overlay_image=overlay_response(base_image, occlusion_map), + part_scores=part_scores, + pga_spatial=pga_spatial, + ) + + +@click.command() +@click.option("--sequence-id", default="00490", show_default=True) +@click.option("--label", default="positive", show_default=True) +@click.option("--view", default="000_180", show_default=True) +@click.option("--frame-count", default=30, show_default=True, type=int) +@click.option( + "--response-mode", + type=click.Choice(["activation", "cam"]), + default="activation", + show_default=True, +) +@click.option( + "--method", + type=click.Choice(["sum", "sum_p", "max_p"]), + default="sum", + show_default=True, +) +@click.option("--power", default=2.0, show_default=True, type=float) +@click.option( + "--view-mode", + type=click.Choice(["final", "stages", "diagnostics"]), + default="final", + show_default=True, +) +@click.option("--occlusion-patch-size", default=8, show_default=True, type=int) +@click.option("--occlusion-stride", default=8, show_default=True, type=int) +@click.option( + "--output-path", + default=str(REPO_ROOT / "research/feature_response_heatmaps/00490_positive_000_180_response_heatmaps.png"), + show_default=True, + type=click.Path(path_type=Path), +) +def main( + sequence_id: str, + label: str, + view: str, + frame_count: int, + response_mode: str, + method: str, + power: float, + view_mode: str, + occlusion_patch_size: int, + occlusion_stride: int, + output_path: Path, +) -> None: + output_path.parent.mkdir(parents=True, exist_ok=True) + if view_mode == "stages": + results = [ + build_stage_results( + spec=spec, + response_mode=response_mode, + method=method, + power=power, + sequence_id=sequence_id, + label=label, + view=view, + frame_count=frame_count, + ) + for spec in default_model_specs() + ] + render_stage_visualization(results, output_path) + elif view_mode == "diagnostics": + results = [ + build_diagnostic_results( + spec=spec, + sequence_id=sequence_id, + label=label, + view=view, + frame_count=frame_count, + patch_size=occlusion_patch_size, + stride=occlusion_stride, + ) + for spec in default_model_specs() + ] + render_diagnostic_visualization(results, output_path) + else: + results_final: list[VisualizationResult] = [] + for spec in default_model_specs(): + model, _cfgs = load_checkpoint_model(spec) + sequence, pav, base_image = load_sample_inputs(spec, sequence_id, label, view, frame_count) + response_map, scores, predicted_label = compute_response_map( + model, + sequence, + pav, + response_mode=response_mode, + method=method, + power=power, + ) + results_final.append( + VisualizationResult( + title=spec.title, + predicted_label=predicted_label, + class_scores=scores, + input_image=base_image, + response_map=response_map, + overlay_image=overlay_response(base_image, response_map), + ) + ) + render_visualization(results_final, output_path) + click.echo(str(output_path)) + + +if __name__ == "__main__": + main() diff --git a/scripts/visualize_sharedalign_overlay.py b/scripts/visualize_sharedalign_overlay.py new file mode 100644 index 0000000..2811562 --- /dev/null +++ b/scripts/visualize_sharedalign_overlay.py @@ -0,0 +1,173 @@ +from __future__ import annotations + +import pickle +import sys +from pathlib import Path + +import cv2 +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import numpy.typing as npt +import yaml + +matplotlib.use("Agg") + +REPO_ROOT = Path(__file__).resolve().parents[1] +if __package__ in {None, ""}: + sys.path.append(str(REPO_ROOT)) + +from datasets import pretreatment_heatmap as prep + +FloatArray = npt.NDArray[np.float32] + +SAMPLE_ID = "00000" +LABEL = "positive" +VIEW = "000_180" +RESEARCH_ROOT = REPO_ROOT / "research/sigma_comparison" +POSE_PATH = Path( + f"/mnt/public/data/Scoliosis1K/Scoliosis1K-pose-pkl/{SAMPLE_ID}/{LABEL}/{VIEW}/{VIEW}.pkl" +) +SIL_PATH = Path( + f"/mnt/public/data/Scoliosis1K/Scoliosis1K-sil-pkl/{SAMPLE_ID}/{LABEL}/{VIEW}/{VIEW}.pkl" +) +CFG_PATH = REPO_ROOT / "configs/drf/pretreatment_heatmap_drf_sigma15_joint8.yaml" + +LIMB_RGB = np.asarray([0.18, 0.85, 0.96], dtype=np.float32) +JOINT_RGB = np.asarray([1.00, 0.45, 0.12], dtype=np.float32) +SIL_RGB = np.asarray([0.85, 0.97, 0.78], dtype=np.float32) +LIMB_ALPHA = 0.8 +JOINT_ALPHA = 0.8 +SIL_ALPHA = 0.3 + + +def _normalize(array: npt.NDArray[np.generic]) -> FloatArray: + image = np.asarray(array, dtype=np.float32) + if image.max() > 1.0: + image = image / 255.0 + return image + + +def _load_pickle(path: Path) -> FloatArray: + with path.open("rb") as handle: + return _normalize(pickle.load(handle)) + + +def _load_mixed_sharedalign() -> FloatArray: + with CFG_PATH.open("r", encoding="utf-8") as handle: + cfg = yaml.safe_load(handle) + cfg = prep.replace_variables(cfg, cfg) + transform = prep.GenerateHeatmapTransform( + coco18tococo17_args=cfg["coco18tococo17_args"], + padkeypoints_args=cfg["padkeypoints_args"], + norm_args=cfg["norm_args"], + heatmap_generator_args=cfg["heatmap_generator_args"], + align_args=cfg["align_args"], + reduction="upstream", + sigma_limb=float(cfg["sigma_limb"]), + sigma_joint=float(cfg["sigma_joint"]), + ) + with POSE_PATH.open("rb") as handle: + pose = np.asarray(pickle.load(handle), dtype=np.float32) + return _normalize(transform(pose)) + + +def _overlay_frame(silhouette: FloatArray, limb: FloatArray, joint: FloatArray) -> FloatArray: + canvas = np.zeros((silhouette.shape[0], silhouette.shape[1], 4), dtype=np.float32) + + sil_mask = np.clip(silhouette, 0.0, 1.0) + limb_mask = np.clip(limb, 0.0, 1.0) + joint_mask = np.clip(joint, 0.0, 1.0) + + canvas[..., :3] += sil_mask[..., None] * SIL_RGB[None, None, :] * SIL_ALPHA + canvas[..., 3] = np.maximum(canvas[..., 3], sil_mask * SIL_ALPHA) + + canvas[..., :3] += limb_mask[..., None] * LIMB_RGB[None, None, :] * LIMB_ALPHA + canvas[..., 3] = np.maximum(canvas[..., 3], limb_mask * LIMB_ALPHA) + + canvas[..., :3] += joint_mask[..., None] * JOINT_RGB[None, None, :] * JOINT_ALPHA + canvas[..., 3] = np.maximum(canvas[..., 3], joint_mask * JOINT_ALPHA) + + canvas[..., :3] = np.clip(canvas[..., :3], 0.0, 1.0) + canvas[..., 3] = np.clip(canvas[..., 3], 0.0, 1.0) + return canvas + + +def _bbox(mask: FloatArray, threshold: float = 0.05) -> tuple[int, int, int, int] | None: + ys, xs = np.where(mask > threshold) + if ys.size == 0: + return None + return int(ys.min()), int(ys.max()), int(xs.min()), int(xs.max()) + + +def _align_silhouette_to_heatmap(silhouette: FloatArray, heatmap: FloatArray) -> FloatArray: + sil_box = _bbox(silhouette) + heat_box = _bbox(heatmap) + if sil_box is None or heat_box is None: + return silhouette + + sy0, sy1, sx0, sx1 = sil_box + hy0, hy1, hx0, hx1 = heat_box + sil_crop = silhouette[sy0 : sy1 + 1, sx0 : sx1 + 1] + sil_h = max(sy1 - sy0 + 1, 1) + sil_w = max(sx1 - sx0 + 1, 1) + target_h = max(hy1 - hy0 + 1, 1) + scale = target_h / sil_h + target_w = max(int(round(sil_w * scale)), 1) + resized = cv2.resize(sil_crop, (target_w, target_h), interpolation=cv2.INTER_AREA) + canvas = np.zeros_like(silhouette) + heat_center_x = (hx0 + hx1) / 2.0 + x0 = int(round(heat_center_x - target_w / 2.0)) + x1 = x0 + target_w + + src_x0 = max(0, -x0) + src_x1 = target_w - max(0, x1 - silhouette.shape[1]) + dst_x0 = max(0, x0) + dst_x1 = min(silhouette.shape[1], x1) + + if src_x0 < src_x1 and dst_x0 < dst_x1: + canvas[hy0 : hy1 + 1, dst_x0:dst_x1] = resized[:, src_x0:src_x1] + return canvas + + +def main() -> None: + RESEARCH_ROOT.mkdir(parents=True, exist_ok=True) + + silhouette = _load_pickle(SIL_PATH) + mixed = _load_mixed_sharedalign() + frames = np.linspace(0, silhouette.shape[0] - 1, num=6, dtype=int) + + raw_path = RESEARCH_ROOT / f"{SAMPLE_ID}_{LABEL}_{VIEW}_overlay_sharedalign_preview.png" + aligned_path = RESEARCH_ROOT / f"{SAMPLE_ID}_{LABEL}_{VIEW}_overlay_sharedalign_aligned_preview.png" + + for output_path, align_silhouette in ((raw_path, False), (aligned_path, True)): + fig, axes = plt.subplots(1, len(frames), figsize=(2.6 * len(frames), 3.6), dpi=220) + for axis, frame_idx in zip(axes, frames, strict=True): + limb = mixed[frame_idx, 0] + joint = mixed[frame_idx, 1] + sil_frame = silhouette[frame_idx] + if align_silhouette: + sil_frame = _align_silhouette_to_heatmap(sil_frame, np.maximum(limb, joint)) + overlay = _overlay_frame(sil_frame, limb, joint) + axis.set_facecolor("black") + axis.imshow(overlay) + axis.set_xticks([]) + axis.set_yticks([]) + axis.set_title(f"Frame {frame_idx}", fontsize=9) + + mode = "height+center aligned silhouette" if align_silhouette else "raw silhouette" + fig.suptitle( + ( + f"{SAMPLE_ID}/{LABEL}/{VIEW}: {mode}, " + f"silhouette ({SIL_ALPHA}), limb ({LIMB_ALPHA}), joint ({JOINT_ALPHA})" + ), + fontsize=12, + ) + fig.tight_layout() + fig.savefig(output_path, bbox_inches="tight", facecolor="black") + plt.close(fig) + print(output_path) + + +if __name__ == "__main__": + main() diff --git a/scripts/visualize_sigma_samples.py b/scripts/visualize_sigma_samples.py new file mode 100644 index 0000000..74b16b1 --- /dev/null +++ b/scripts/visualize_sigma_samples.py @@ -0,0 +1,199 @@ +from __future__ import annotations + +import pickle +from dataclasses import dataclass +from pathlib import Path + +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import numpy.typing as npt + +matplotlib.use("Agg") + +REPO_ROOT = Path(__file__).resolve().parents[1] +RESEARCH_ROOT = REPO_ROOT / "research/sigma_comparison" +SIGMA8_ROOT = Path("/mnt/public/data/Scoliosis1K/Scoliosis1K_sigma_8.0/pkl") +SIGMA15_ROOT = Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-sigma15") +MIXED_ROOT = Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-sigma15-joint8-fixed") +SIL_ROOT = Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-sil-pkl") + +FloatArray = npt.NDArray[np.float32] + + +@dataclass(frozen=True) +class SampleSpec: + sequence_id: str + label: str + view: str + + +def _normalize_image(array: npt.NDArray[np.generic]) -> FloatArray: + image = np.asarray(array, dtype=np.float32) + if image.max() > 1.0: + image = image / 255.0 + return image + + +def _load_pickle(path: Path) -> FloatArray: + with path.open("rb") as handle: + return _normalize_image(pickle.load(handle)) + + +def _load_sigma8(spec: SampleSpec) -> FloatArray: + return _load_pickle(SIGMA8_ROOT / spec.sequence_id / spec.label / spec.view / f"{spec.view}.pkl") + + +def _load_sigma15(spec: SampleSpec) -> FloatArray: + return _load_pickle(SIGMA15_ROOT / spec.sequence_id / spec.label / spec.view / "0_heatmap.pkl") + + +def _load_silhouette(spec: SampleSpec) -> FloatArray: + return _load_pickle(SIL_ROOT / spec.sequence_id / spec.label / spec.view / f"{spec.view}.pkl") + + +def _load_mixed(spec: SampleSpec) -> FloatArray: + return _load_pickle(MIXED_ROOT / spec.sequence_id / spec.label / spec.view / "0_heatmap.pkl") + + +def _frame_indices(num_frames: int, count: int = 6) -> npt.NDArray[np.int_]: + return np.linspace(0, num_frames - 1, num=count, dtype=int) + + +def _combined_map(heatmaps: FloatArray) -> FloatArray: + return np.maximum(heatmaps[:, 0], heatmaps[:, 1]) + + +def _save_combined_figure( + spec: SampleSpec, + silhouette: FloatArray, + sigma8: FloatArray, + sigma15: FloatArray, + mixed: FloatArray, +) -> Path: + frames = _frame_indices(silhouette.shape[0], count=6) + fig, axes = plt.subplots(4, len(frames), figsize=(2.4 * len(frames), 8.8), dpi=200) + row_titles = ( + "Silhouette", + "Sigma 8 Combined", + "Sigma 1.5 Combined", + "Limb 1.5 / Joint 8 Combined", + ) + images = (silhouette, _combined_map(sigma8), _combined_map(sigma15), _combined_map(mixed)) + + for row_idx, (title, row_images) in enumerate(zip(row_titles, images, strict=True)): + for col_idx, frame_idx in enumerate(frames): + ax = axes[row_idx, col_idx] + ax.imshow(row_images[frame_idx], cmap="gray", vmin=0.0, vmax=1.0) + ax.set_xticks([]) + ax.set_yticks([]) + if row_idx == 0: + ax.set_title(f"Frame {frame_idx}", fontsize=9) + if col_idx == 0: + ax.set_ylabel(title, fontsize=10) + + fig.suptitle( + f"Sample {spec.sequence_id}/{spec.label}/{spec.view}: silhouette vs sigma heatmaps", + fontsize=12, + ) + fig.tight_layout() + output_path = RESEARCH_ROOT / f"{spec.sequence_id}_{spec.label}_{spec.view}_combined.png" + fig.savefig(output_path, bbox_inches="tight") + plt.close(fig) + return output_path + + +def _save_channel_figure(spec: SampleSpec, sigma8: FloatArray, sigma15: FloatArray) -> Path: + frames = _frame_indices(sigma8.shape[0], count=4) + fig, axes = plt.subplots(4, len(frames), figsize=(2.6 * len(frames), 8.4), dpi=200) + row_titles = ( + "Sigma 8 Limb", + "Sigma 8 Joint", + "Sigma 1.5 Limb", + "Sigma 1.5 Joint", + ) + images = (sigma8[:, 0], sigma8[:, 1], sigma15[:, 0], sigma15[:, 1]) + + for row_idx, (title, row_images) in enumerate(zip(row_titles, images, strict=True)): + vmax = float(np.max(row_images)) if np.max(row_images) > 0 else 1.0 + for col_idx, frame_idx in enumerate(frames): + ax = axes[row_idx, col_idx] + ax.imshow(row_images[frame_idx], cmap="magma", vmin=0.0, vmax=vmax) + ax.set_xticks([]) + ax.set_yticks([]) + if row_idx == 0: + ax.set_title(f"Frame {frame_idx}", fontsize=9) + if col_idx == 0: + ax.set_ylabel(title, fontsize=10) + + fig.suptitle( + f"Sample {spec.sequence_id}/{spec.label}/{spec.view}: per-channel sigma comparison", + fontsize=12, + ) + fig.tight_layout() + output_path = RESEARCH_ROOT / f"{spec.sequence_id}_{spec.label}_{spec.view}_channels.png" + fig.savefig(output_path, bbox_inches="tight") + plt.close(fig) + return output_path + + +def _save_mixed_channel_figure( + spec: SampleSpec, + sigma8: FloatArray, + sigma15: FloatArray, + mixed: FloatArray, +) -> Path: + frames = _frame_indices(sigma8.shape[0], count=4) + fig, axes = plt.subplots(6, len(frames), figsize=(2.6 * len(frames), 12.0), dpi=200) + row_titles = ( + "Sigma 8 Limb", + "Sigma 8 Joint", + "Sigma 1.5 Limb", + "Sigma 1.5 Joint", + "Mixed Limb 1.5", + "Mixed Joint 8", + ) + images = (sigma8[:, 0], sigma8[:, 1], sigma15[:, 0], sigma15[:, 1], mixed[:, 0], mixed[:, 1]) + + for row_idx, (title, row_images) in enumerate(zip(row_titles, images, strict=True)): + vmax = float(np.max(row_images)) if np.max(row_images) > 0 else 1.0 + for col_idx, frame_idx in enumerate(frames): + ax = axes[row_idx, col_idx] + ax.imshow(row_images[frame_idx], cmap="magma", vmin=0.0, vmax=vmax) + ax.set_xticks([]) + ax.set_yticks([]) + if row_idx == 0: + ax.set_title(f"Frame {frame_idx}", fontsize=9) + if col_idx == 0: + ax.set_ylabel(title, fontsize=10) + + fig.suptitle( + f"Sample {spec.sequence_id}/{spec.label}/{spec.view}: sigma and mixed per-channel comparison", + fontsize=12, + ) + fig.tight_layout() + output_path = RESEARCH_ROOT / f"{spec.sequence_id}_{spec.label}_{spec.view}_channels_mixed.png" + fig.savefig(output_path, bbox_inches="tight") + plt.close(fig) + return output_path + + +def main() -> None: + RESEARCH_ROOT.mkdir(parents=True, exist_ok=True) + sample = SampleSpec(sequence_id="00000", label="positive", view="000_180") + sigma8 = _load_sigma8(sample) + sigma15 = _load_sigma15(sample) + mixed = _load_mixed(sample) + silhouette = _load_silhouette(sample) + + combined_path = _save_combined_figure(sample, silhouette, sigma8, sigma15, mixed) + channels_path = _save_channel_figure(sample, sigma8, sigma15) + mixed_channels_path = _save_mixed_channel_figure(sample, sigma8, sigma15, mixed) + + print(combined_path) + print(channels_path) + print(mixed_channels_path) + + +if __name__ == "__main__": + main()