from __future__ import annotations import pickle import sys import types from dataclasses import dataclass from pathlib import Path from typing import Any, Literal import click import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from jaxtyping import Float from numpy.typing import NDArray REPO_ROOT = Path(__file__).resolve().parents[1] OPENGAIT_ROOT = REPO_ROOT / "opengait" if str(OPENGAIT_ROOT) not in sys.path: sys.path.insert(0, str(OPENGAIT_ROOT)) if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) from opengait.data.transform import get_transform from opengait.modeling import models from opengait.modeling.base_model import BaseModel from opengait.modeling.models.drf import DRF from opengait.utils.common import config_loader FloatArray = NDArray[np.float32] LABEL_NAMES = ("negative", "neutral", "positive") @dataclass(frozen=True) class ModelSpec: key: str title: str config_path: Path checkpoint_path: Path data_root: Path is_drf: bool @dataclass(frozen=True) class VisualizationResult: title: str predicted_label: str class_scores: tuple[float, float, float] input_image: FloatArray response_map: FloatArray overlay_image: FloatArray @dataclass(frozen=True) class StageVisualizationResult: title: str predicted_label: str class_scores: tuple[float, float, float] input_image: FloatArray stage_overlays: dict[str, FloatArray] @dataclass(frozen=True) class ForwardDiagnostics: feature_map: Float[torch.Tensor, "batch channels height width"] pooled: Float[torch.Tensor, "batch channels parts"] embeddings: Float[torch.Tensor, "batch channels parts"] logits: Float[torch.Tensor, "batch classes parts"] class_scores: Float[torch.Tensor, "batch classes"] pga_spatial: Float[torch.Tensor, "batch parts"] | None @dataclass(frozen=True) class ModelDiagnosticResult: title: str predicted_label: str class_scores: tuple[float, float, float] input_image: FloatArray occlusion_map: FloatArray overlay_image: FloatArray part_scores: FloatArray pga_spatial: FloatArray | None def apply_transforms( sequence: np.ndarray, transform_cfg: list[dict[str, Any]] | dict[str, Any] | None, ) -> np.ndarray: transform = get_transform(transform_cfg) if isinstance(transform, list): result = sequence for item in transform: result = item(result) return np.asarray(result) return np.asarray(transform(sequence)) def build_minimal_model(cfgs: dict[str, Any]) -> nn.Module: model_name = str(cfgs["model_cfg"]["model"]) model_cls = getattr(models, model_name) model = model_cls.__new__(model_cls) nn.Module.__init__(model) model.get_backbone = types.MethodType(BaseModel.get_backbone, model) model.build_network(cfgs["model_cfg"]) return model def load_checkpoint_model(spec: ModelSpec) -> tuple[nn.Module, dict[str, Any]]: cfgs = config_loader(str(spec.config_path)) model = build_minimal_model(cfgs) checkpoint = torch.load( spec.checkpoint_path, map_location="cpu", weights_only=False, ) model.load_state_dict(checkpoint["model"], strict=False) model.eval() return model, cfgs def load_pickle(path: Path) -> np.ndarray: with path.open("rb") as handle: return np.asarray(pickle.load(handle)) def load_sample_inputs( spec: ModelSpec, sequence_id: str, label: str, view: str, frame_count: int, ) -> tuple[torch.Tensor, torch.Tensor | None, FloatArray]: if spec.is_drf: heatmap_path = spec.data_root / sequence_id / label / view / "0_heatmap.pkl" pav_path = spec.data_root / sequence_id / label / view / "1_pav.pkl" heatmaps = load_pickle(heatmap_path).astype(np.float32) pav = load_pickle(pav_path).astype(np.float32) return prepare_heatmap_inputs(heatmaps, pav, spec.config_path, frame_count) if "skeleton" in spec.key: heatmap_path = spec.data_root / sequence_id / label / view / "0_heatmap.pkl" heatmaps = load_pickle(heatmap_path).astype(np.float32) return prepare_heatmap_inputs(heatmaps, None, spec.config_path, frame_count) silhouette_path = spec.data_root / sequence_id / label / view / f"{view}.pkl" silhouettes = load_pickle(silhouette_path).astype(np.float32) return prepare_silhouette_inputs(silhouettes, spec.config_path, frame_count) def prepare_silhouette_inputs( silhouettes: FloatArray, config_path: Path, frame_count: int, ) -> tuple[torch.Tensor, None, FloatArray]: cfgs = config_loader(str(config_path)) transformed = apply_transforms(silhouettes, cfgs["evaluator_cfg"]["transform"]).astype(np.float32) indices = np.linspace(0, transformed.shape[0] - 1, frame_count, dtype=int) sampled = transformed[indices] input_tensor = torch.from_numpy(sampled).unsqueeze(0).unsqueeze(0).float() projection = normalize_image(sampled.max(axis=0)) return input_tensor, None, projection def prepare_heatmap_inputs( heatmaps: FloatArray, pav: FloatArray | None, config_path: Path, frame_count: int, ) -> tuple[torch.Tensor, torch.Tensor | None, FloatArray]: cfgs = config_loader(str(config_path)) transform_cfg = cfgs["evaluator_cfg"]["transform"] heatmap_transform = transform_cfg[0] if isinstance(transform_cfg, list) else transform_cfg transformed = apply_transforms(heatmaps, heatmap_transform).astype(np.float32) indices = np.linspace(0, transformed.shape[0] - 1, frame_count, dtype=int) sampled = transformed[indices] input_tensor = torch.from_numpy(sampled).permute(1, 0, 2, 3).unsqueeze(0).float() combined = np.maximum(sampled.max(axis=0)[0], sampled.max(axis=0)[1]) projection = normalize_image(combined) pav_tensor = None if pav is not None: pav_tensor = torch.from_numpy(pav.mean(axis=0, keepdims=True)).float() return input_tensor, pav_tensor, projection AttentionMethod = Literal["sum", "sum_p", "max_p"] ResponseMode = Literal["activation", "cam"] def compute_classification_scores( model: nn.Module, sequence: torch.Tensor, pav: torch.Tensor | None, ) -> tuple[FloatArray, tuple[float, float, float], str]: diagnostics = forward_diagnostics(model, sequence, pav) class_scores = diagnostics.class_scores[0] pred_index = int(class_scores.argmax().detach().cpu()) score_tuple = tuple(float(value.detach().cpu()) for value in class_scores) return diagnostics.feature_map.detach()[0], score_tuple, LABEL_NAMES[pred_index] def forward_diagnostics( model: nn.Module, sequence: torch.Tensor, pav: torch.Tensor | None, ) -> ForwardDiagnostics: outputs = model.Backbone(sequence) feature_map = model.TP(outputs, None, options={"dim": 2})[0] pooled = model.HPP(feature_map) embeddings = model.FCs(pooled) pga_spatial: Float[torch.Tensor, "batch parts"] | None = None if pav is not None and isinstance(model, DRF): pav_flat = pav.flatten(1) channel_att = model.PGA.channel_att(pav_flat).unsqueeze(-1) spatial_att = model.PGA.spatial_att(pav_flat.unsqueeze(-1)).transpose(1, 2) pga_spatial = spatial_att.squeeze(1) embeddings = embeddings * channel_att * spatial_att _bn_embeddings, logits = model.BNNecks(embeddings) class_scores = logits.mean(-1) return ForwardDiagnostics( feature_map=feature_map, pooled=pooled, embeddings=embeddings, logits=logits, class_scores=class_scores, pga_spatial=pga_spatial, ) def compute_response_map( model: nn.Module, sequence: torch.Tensor, pav: torch.Tensor | None, response_mode: ResponseMode, method: AttentionMethod, power: float, ) -> tuple[FloatArray, tuple[float, float, float], str]: activation, score_tuple, predicted_label = compute_classification_scores(model, sequence, pav) if response_mode == "cam": response_map = adapted_cam_map( model=model, sequence=sequence, pav=pav, predicted_label=predicted_label, ) else: response_map = activation_attention_map(activation, method=method, power=power) response_map = response_map.unsqueeze(0).unsqueeze(0) response_map = F.interpolate( response_map, size=sequence.shape[-2:], mode="bilinear", align_corners=False, )[0, 0] response_np = response_map.detach().cpu().numpy().astype(np.float32) response_np = normalize_image(response_np) return response_np, score_tuple, predicted_label def activation_attention_map( activation: Float[torch.Tensor, "channels height width"], method: AttentionMethod, power: float, ) -> Float[torch.Tensor, "height width"]: abs_activation = activation.abs() if method == "sum": return abs_activation.sum(dim=0) powered = abs_activation.pow(power) if method == "sum_p": return powered.sum(dim=0) if method == "max_p": return powered.max(dim=0).values raise ValueError(f"Unsupported attention method: {method}") def adapted_cam_map( model: nn.Module, sequence: torch.Tensor, pav: torch.Tensor | None, predicted_label: str, ) -> Float[torch.Tensor, "height width"]: outputs = model.Backbone(sequence) feature_map = model.TP(outputs, None, options={"dim": 2})[0] # [1, c_in, h, w] pooled = model.HPP(feature_map) # [1, c_in, p] pre_attention = model.FCs(pooled) # [1, c_out, p] attention = torch.ones_like(pre_attention) if pav is not None and isinstance(model, DRF): pav_flat = pav.flatten(1) channel_att = model.PGA.channel_att(pav_flat).unsqueeze(-1) spatial_att = model.PGA.spatial_att(pav_flat.unsqueeze(-1)).transpose(1, 2) attention = channel_att * spatial_att embeddings = pre_attention * attention bn_scales = bn_feature_scales(model.BNNecks, embeddings.shape[1], embeddings.shape[2], embeddings.device) classifier_weights = F.normalize(model.BNNecks.fc_bin.detach(), dim=1) class_index = LABEL_NAMES.index(predicted_label) class_weights = classifier_weights[:, :, class_index] # [p, c_out] normalized_embeddings = batchnorm_forward_eval(model.BNNecks, embeddings)[0].transpose(0, 1) # [p, c_out] part_norms = normalized_embeddings.norm(dim=-1, keepdim=True).clamp_min(1e-6) effective_post_bn = class_weights / part_norms effective_pre_bn = effective_post_bn * bn_scales fc_weights = model.FCs.fc_bin.detach() # [p, c_in, c_out] attention_weights = attention.detach()[0].transpose(0, 1) # [p, c_out] effective_channel_weights = torch.einsum( "pkj,pj->pk", fc_weights, attention_weights * effective_pre_bn, ) # [p, c_in] feature = feature_map.detach()[0] # [c_in, h, w] parts_num = effective_channel_weights.shape[0] height = feature.shape[1] split_size = max(height // parts_num, 1) cam = torch.zeros((height, feature.shape[2]), device=feature.device, dtype=feature.dtype) for part_index in range(parts_num): row_start = part_index * split_size row_end = height if part_index == parts_num - 1 else min((part_index + 1) * split_size, height) if row_start >= height: break contribution = (feature[:, row_start:row_end, :] * effective_channel_weights[part_index][:, None, None]).sum(dim=0) cam[row_start:row_end, :] = contribution return F.relu(cam) def batchnorm_forward_eval( bn_necks: nn.Module, embeddings: Float[torch.Tensor, "batch channels parts"], ) -> Float[torch.Tensor, "batch channels parts"]: batch, channels, parts = embeddings.shape if not getattr(bn_necks, "parallel_BN1d", True): raise ValueError("Adapted CAM currently expects parallel_BN1d=True.") bn = bn_necks.bn1d flattened = embeddings.reshape(batch, -1) normalized = torch.nn.functional.batch_norm( flattened, bn.running_mean, bn.running_var, bn.weight, bn.bias, training=False, momentum=0.0, eps=bn.eps, ) return normalized.view(batch, channels, parts) def bn_feature_scales( bn_necks: nn.Module, channels: int, parts: int, device: torch.device, ) -> Float[torch.Tensor, "parts channels"]: if not getattr(bn_necks, "parallel_BN1d", True): raise ValueError("Adapted CAM currently expects parallel_BN1d=True.") bn = bn_necks.bn1d running_var = bn.running_var.detach().to(device).view(channels, parts).transpose(0, 1) base_scale = torch.rsqrt(running_var + bn.eps) if bn.weight is None: return base_scale gamma = bn.weight.detach().to(device).view(channels, parts).transpose(0, 1) return base_scale * gamma def normalize_image(image: np.ndarray) -> FloatArray: image = image.astype(np.float32) min_value = float(image.min()) max_value = float(image.max()) if max_value <= min_value: return np.zeros_like(image, dtype=np.float32) return (image - min_value) / (max_value - min_value) def overlay_response(base_image: FloatArray, response_map: FloatArray) -> FloatArray: cmap = plt.get_cmap("jet") heat = cmap(response_map)[..., :3].astype(np.float32) base_rgb = np.repeat(base_image[..., None], 3, axis=-1) overlay = base_rgb * 0.45 + heat * 0.55 return np.clip(overlay, 0.0, 1.0) def compute_occlusion_sensitivity( model: nn.Module, sequence: torch.Tensor, pav: torch.Tensor | None, patch_size: int, stride: int, ) -> tuple[FloatArray, tuple[float, float, float], str]: with torch.no_grad(): baseline = forward_diagnostics(model, sequence, pav) class_scores = baseline.class_scores[0] pred_index = int(class_scores.argmax().detach().cpu()) baseline_score = float(class_scores[pred_index].detach().cpu()) score_tuple = tuple(float(value.detach().cpu()) for value in class_scores) _, _, _, height, width = sequence.shape sensitivity = np.zeros((height, width), dtype=np.float32) counts = np.zeros((height, width), dtype=np.float32) row_starts = list(range(0, max(height - patch_size + 1, 1), stride)) col_starts = list(range(0, max(width - patch_size + 1, 1), stride)) last_row = max(height - patch_size, 0) last_col = max(width - patch_size, 0) if row_starts[-1] != last_row: row_starts.append(last_row) if col_starts[-1] != last_col: col_starts.append(last_col) for top in row_starts: for left in col_starts: bottom = min(top + patch_size, height) right = min(left + patch_size, width) masked = sequence.clone() masked[..., top:bottom, left:right] = 0.0 masked_scores = forward_diagnostics(model, masked, pav).class_scores[0] score_drop = max(0.0, baseline_score - float(masked_scores[pred_index].detach().cpu())) sensitivity[top:bottom, left:right] += score_drop counts[top:bottom, left:right] += 1.0 counts = np.maximum(counts, 1.0) sensitivity /= counts return normalize_image(sensitivity), score_tuple, LABEL_NAMES[pred_index] def compute_part_diagnostics( model: nn.Module, sequence: torch.Tensor, pav: torch.Tensor | None, ) -> tuple[FloatArray, FloatArray | None, tuple[float, float, float], str]: with torch.no_grad(): diagnostics = forward_diagnostics(model, sequence, pav) class_scores = diagnostics.class_scores[0] pred_index = int(class_scores.argmax().detach().cpu()) part_scores = diagnostics.logits[0, pred_index].detach().cpu().numpy().astype(np.float32) pga_spatial = None if diagnostics.pga_spatial is not None: pga_spatial = diagnostics.pga_spatial[0].detach().cpu().numpy().astype(np.float32) score_tuple = tuple(float(value.detach().cpu()) for value in class_scores) return part_scores, pga_spatial, score_tuple, LABEL_NAMES[pred_index] def extract_temporally_pooled_stages( model: nn.Module, sequence: torch.Tensor, ) -> dict[str, torch.Tensor]: if not hasattr(model.Backbone, "forward_block"): raise ValueError("Stage visualization requires SetBlockWrapper-style backbones.") backbone = model.Backbone.forward_block batch, channels, frames, height, width = sequence.shape x = sequence.transpose(1, 2).reshape(-1, channels, height, width) stage_outputs: dict[str, torch.Tensor] = {} x = backbone.conv1(x) x = backbone.bn1(x) x = backbone.relu(x) if getattr(backbone, "maxpool_flag", False): x = backbone.maxpool(x) stage_outputs["conv1"] = x x = backbone.layer1(x) stage_outputs["layer1"] = x x = backbone.layer2(x) stage_outputs["layer2"] = x x = backbone.layer3(x) stage_outputs["layer3"] = x x = backbone.layer4(x) stage_outputs["layer4"] = x pooled_outputs: dict[str, torch.Tensor] = {} for stage_name, stage_output in stage_outputs.items(): reshaped = stage_output.reshape(batch, frames, *stage_output.shape[1:]).transpose(1, 2).contiguous() pooled_outputs[stage_name] = model.TP(reshaped, None, options={"dim": 2})[0] return pooled_outputs def render_visualization(results: list[VisualizationResult], output_path: Path) -> None: fig, axes = plt.subplots(len(results), 3, figsize=(10.5, 3.2 * len(results)), constrained_layout=True) if len(results) == 1: axes = np.expand_dims(axes, axis=0) for row, result in enumerate(results): input_axis, heat_axis, overlay_axis = axes[row] input_axis.imshow(result.input_image, cmap="gray", vmin=0.0, vmax=1.0) input_axis.set_title(f"{result.title} input") heat_axis.imshow(result.response_map, cmap="jet", vmin=0.0, vmax=1.0) heat_axis.set_title("response") overlay_axis.imshow(result.overlay_image) overlay_axis.set_title( f"overlay | pred={result.predicted_label}\n" f"scores={tuple(round(score, 3) for score in result.class_scores)}" ) for axis in (input_axis, heat_axis, overlay_axis): axis.set_xticks([]) axis.set_yticks([]) fig.savefig(output_path, dpi=180) plt.close(fig) def render_stage_visualization(results: list[StageVisualizationResult], output_path: Path) -> None: stage_names = ["input", "conv1", "layer1", "layer2", "layer3", "layer4"] fig, axes = plt.subplots( len(results), len(stage_names), figsize=(2.2 * len(stage_names), 3.0 * len(results)), constrained_layout=True, ) if len(results) == 1: axes = np.expand_dims(axes, axis=0) for row, result in enumerate(results): for col, stage_name in enumerate(stage_names): axis = axes[row, col] if stage_name == "input": axis.imshow(result.input_image, cmap="gray", vmin=0.0, vmax=1.0) axis.set_title( f"{result.title}\npred={result.predicted_label}", fontsize=10, ) else: axis.imshow(result.stage_overlays[stage_name]) axis.set_title(stage_name, fontsize=10) axis.set_xticks([]) axis.set_yticks([]) axes[row, 0].set_ylabel( f"scores={tuple(round(score, 3) for score in result.class_scores)}", fontsize=9, rotation=90, ) fig.savefig(output_path, dpi=180) plt.close(fig) def render_diagnostic_visualization(results: list[ModelDiagnosticResult], output_path: Path) -> None: fig, axes = plt.subplots(len(results), 3, figsize=(11.5, 3.3 * len(results)), constrained_layout=True) if len(results) == 1: axes = np.expand_dims(axes, axis=0) for row, result in enumerate(results): input_axis, occlusion_axis, part_axis = axes[row] input_axis.imshow(result.input_image, cmap="gray", vmin=0.0, vmax=1.0) input_axis.set_title(f"{result.title} input") input_axis.set_xticks([]) input_axis.set_yticks([]) occlusion_axis.imshow(result.overlay_image) occlusion_axis.set_title( f"occlusion | pred={result.predicted_label}\n" f"scores={tuple(round(score, 3) for score in result.class_scores)}" ) occlusion_axis.set_xticks([]) occlusion_axis.set_yticks([]) parts = np.arange(result.part_scores.shape[0], dtype=np.int32) part_axis.bar(parts, result.part_scores, color="#4477AA", alpha=0.85, width=0.8) part_axis.set_title("pred-class part logits") part_axis.set_xlabel("HPP part") part_axis.set_ylabel("logit") part_axis.set_xticks(parts) if result.pga_spatial is not None: part_axis_2 = part_axis.twinx() part_axis_2.plot(parts, result.pga_spatial, color="#CC6677", marker="o", linewidth=1.5) part_axis_2.set_ylabel("PGA spatial") part_axis_2.set_ylim(0.0, max(1.0, float(result.pga_spatial.max()) * 1.05)) fig.savefig(output_path, dpi=180) plt.close(fig) def default_model_specs() -> list[ModelSpec]: return [ ModelSpec( key="silhouette", title="ScoNet-MT silhouette", config_path=REPO_ROOT / "configs/sconet/sconet_scoliosis1k_local_eval_2gpu_better_112.yaml", checkpoint_path=REPO_ROOT / "ckpt/ScoNet-20000-better.pt", data_root=Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-sil-pkl"), is_drf=False, ), ModelSpec( key="skeleton", title="ScoNet-MT-ske", config_path=REPO_ROOT / "configs/sconet/sconet_scoliosis1k_skeleton_118_sigma15_joint8_2gpu_bs12x8.yaml", checkpoint_path=REPO_ROOT / "output/Scoliosis1K/ScoNet/ScoNet_skeleton_118_sigma15_joint8_sharedalign_2gpu_bs12x8/checkpoints/ScoNet_skeleton_118_sigma15_joint8_sharedalign_2gpu_bs12x8-20000.pt", data_root=Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-sigma15-joint8-sharedalign"), is_drf=False, ), ModelSpec( key="drf", title="DRF", config_path=REPO_ROOT / "configs/drf/drf_scoliosis1k_eval_1gpu.yaml", checkpoint_path=REPO_ROOT / "output/Scoliosis1K/DRF/DRF/checkpoints/DRF-20000.pt", data_root=Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118"), is_drf=True, ), ] def build_stage_results( spec: ModelSpec, response_mode: ResponseMode, method: AttentionMethod, power: float, sequence_id: str, label: str, view: str, frame_count: int, ) -> StageVisualizationResult: if response_mode != "activation": raise ValueError("Stage view currently supports activation maps only.") model, _cfgs = load_checkpoint_model(spec) sequence, pav, base_image = load_sample_inputs(spec, sequence_id, label, view, frame_count) _final_feature, score_tuple, predicted_label = compute_classification_scores(model, sequence, pav) pooled_stages = extract_temporally_pooled_stages(model, sequence) stage_overlays: dict[str, FloatArray] = {} for stage_name, stage_tensor in pooled_stages.items(): response_map = activation_attention_map(stage_tensor[0], method=method, power=power) response_map = F.interpolate( response_map.unsqueeze(0).unsqueeze(0), size=sequence.shape[-2:], mode="bilinear", align_corners=False, )[0, 0] response_np = normalize_image(response_map.detach().cpu().numpy().astype(np.float32)) stage_overlays[stage_name] = overlay_response(base_image, response_np) return StageVisualizationResult( title=spec.title, predicted_label=predicted_label, class_scores=score_tuple, input_image=base_image, stage_overlays=stage_overlays, ) def build_diagnostic_results( spec: ModelSpec, sequence_id: str, label: str, view: str, frame_count: int, patch_size: int, stride: int, ) -> ModelDiagnosticResult: model, _cfgs = load_checkpoint_model(spec) sequence, pav, base_image = load_sample_inputs(spec, sequence_id, label, view, frame_count) occlusion_map, score_tuple, predicted_label = compute_occlusion_sensitivity( model=model, sequence=sequence, pav=pav, patch_size=patch_size, stride=stride, ) part_scores, pga_spatial, _scores, _predicted = compute_part_diagnostics(model, sequence, pav) return ModelDiagnosticResult( title=spec.title, predicted_label=predicted_label, class_scores=score_tuple, input_image=base_image, occlusion_map=occlusion_map, overlay_image=overlay_response(base_image, occlusion_map), part_scores=part_scores, pga_spatial=pga_spatial, ) @click.command() @click.option("--sequence-id", default="00490", show_default=True) @click.option("--label", default="positive", show_default=True) @click.option("--view", default="000_180", show_default=True) @click.option("--frame-count", default=30, show_default=True, type=int) @click.option( "--response-mode", type=click.Choice(["activation", "cam"]), default="activation", show_default=True, ) @click.option( "--method", type=click.Choice(["sum", "sum_p", "max_p"]), default="sum", show_default=True, ) @click.option("--power", default=2.0, show_default=True, type=float) @click.option( "--view-mode", type=click.Choice(["final", "stages", "diagnostics"]), default="final", show_default=True, ) @click.option("--occlusion-patch-size", default=8, show_default=True, type=int) @click.option("--occlusion-stride", default=8, show_default=True, type=int) @click.option( "--output-path", default=str(REPO_ROOT / "research/feature_response_heatmaps/00490_positive_000_180_response_heatmaps.png"), show_default=True, type=click.Path(path_type=Path), ) def main( sequence_id: str, label: str, view: str, frame_count: int, response_mode: str, method: str, power: float, view_mode: str, occlusion_patch_size: int, occlusion_stride: int, output_path: Path, ) -> None: output_path.parent.mkdir(parents=True, exist_ok=True) if view_mode == "stages": results = [ build_stage_results( spec=spec, response_mode=response_mode, method=method, power=power, sequence_id=sequence_id, label=label, view=view, frame_count=frame_count, ) for spec in default_model_specs() ] render_stage_visualization(results, output_path) elif view_mode == "diagnostics": results = [ build_diagnostic_results( spec=spec, sequence_id=sequence_id, label=label, view=view, frame_count=frame_count, patch_size=occlusion_patch_size, stride=occlusion_stride, ) for spec in default_model_specs() ] render_diagnostic_visualization(results, output_path) else: results_final: list[VisualizationResult] = [] for spec in default_model_specs(): model, _cfgs = load_checkpoint_model(spec) sequence, pav, base_image = load_sample_inputs(spec, sequence_id, label, view, frame_count) response_map, scores, predicted_label = compute_response_map( model, sequence, pav, response_mode=response_mode, method=method, power=power, ) results_final.append( VisualizationResult( title=spec.title, predicted_label=predicted_label, class_scores=scores, input_image=base_image, response_map=response_map, overlay_image=overlay_response(base_image, response_map), ) ) render_visualization(results_final, output_path) click.echo(str(output_path)) if __name__ == "__main__": main()