Files
OpenGait/scripts/visualize_feature_response_heatmaps.py

779 lines
28 KiB
Python

from __future__ import annotations
import pickle
import sys
import types
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal
import click
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from jaxtyping import Float
from numpy.typing import NDArray
REPO_ROOT = Path(__file__).resolve().parents[1]
OPENGAIT_ROOT = REPO_ROOT / "opengait"
if str(OPENGAIT_ROOT) not in sys.path:
sys.path.insert(0, str(OPENGAIT_ROOT))
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from opengait.data.transform import get_transform
from opengait.modeling import models
from opengait.modeling.base_model import BaseModel
from opengait.modeling.models.drf import DRF
from opengait.utils.common import config_loader
FloatArray = NDArray[np.float32]
LABEL_NAMES = ("negative", "neutral", "positive")
@dataclass(frozen=True)
class ModelSpec:
key: str
title: str
config_path: Path
checkpoint_path: Path
data_root: Path
is_drf: bool
@dataclass(frozen=True)
class VisualizationResult:
title: str
predicted_label: str
class_scores: tuple[float, float, float]
input_image: FloatArray
response_map: FloatArray
overlay_image: FloatArray
@dataclass(frozen=True)
class StageVisualizationResult:
title: str
predicted_label: str
class_scores: tuple[float, float, float]
input_image: FloatArray
stage_overlays: dict[str, FloatArray]
@dataclass(frozen=True)
class ForwardDiagnostics:
feature_map: Float[torch.Tensor, "batch channels height width"]
pooled: Float[torch.Tensor, "batch channels parts"]
embeddings: Float[torch.Tensor, "batch channels parts"]
logits: Float[torch.Tensor, "batch classes parts"]
class_scores: Float[torch.Tensor, "batch classes"]
pga_spatial: Float[torch.Tensor, "batch parts"] | None
@dataclass(frozen=True)
class ModelDiagnosticResult:
title: str
predicted_label: str
class_scores: tuple[float, float, float]
input_image: FloatArray
occlusion_map: FloatArray
overlay_image: FloatArray
part_scores: FloatArray
pga_spatial: FloatArray | None
def apply_transforms(
sequence: np.ndarray,
transform_cfg: list[dict[str, Any]] | dict[str, Any] | None,
) -> np.ndarray:
transform = get_transform(transform_cfg)
if isinstance(transform, list):
result = sequence
for item in transform:
result = item(result)
return np.asarray(result)
return np.asarray(transform(sequence))
def build_minimal_model(cfgs: dict[str, Any]) -> nn.Module:
model_name = str(cfgs["model_cfg"]["model"])
model_cls = getattr(models, model_name)
model = model_cls.__new__(model_cls)
nn.Module.__init__(model)
model.get_backbone = types.MethodType(BaseModel.get_backbone, model)
model.build_network(cfgs["model_cfg"])
return model
def load_checkpoint_model(spec: ModelSpec) -> tuple[nn.Module, dict[str, Any]]:
cfgs = config_loader(str(spec.config_path))
model = build_minimal_model(cfgs)
checkpoint = torch.load(
spec.checkpoint_path,
map_location="cpu",
weights_only=False,
)
model.load_state_dict(checkpoint["model"], strict=False)
model.eval()
return model, cfgs
def load_pickle(path: Path) -> np.ndarray:
with path.open("rb") as handle:
return np.asarray(pickle.load(handle))
def load_sample_inputs(
spec: ModelSpec,
sequence_id: str,
label: str,
view: str,
frame_count: int,
) -> tuple[torch.Tensor, torch.Tensor | None, FloatArray]:
if spec.is_drf:
heatmap_path = spec.data_root / sequence_id / label / view / "0_heatmap.pkl"
pav_path = spec.data_root / sequence_id / label / view / "1_pav.pkl"
heatmaps = load_pickle(heatmap_path).astype(np.float32)
pav = load_pickle(pav_path).astype(np.float32)
return prepare_heatmap_inputs(heatmaps, pav, spec.config_path, frame_count)
if "skeleton" in spec.key:
heatmap_path = spec.data_root / sequence_id / label / view / "0_heatmap.pkl"
heatmaps = load_pickle(heatmap_path).astype(np.float32)
return prepare_heatmap_inputs(heatmaps, None, spec.config_path, frame_count)
silhouette_path = spec.data_root / sequence_id / label / view / f"{view}.pkl"
silhouettes = load_pickle(silhouette_path).astype(np.float32)
return prepare_silhouette_inputs(silhouettes, spec.config_path, frame_count)
def prepare_silhouette_inputs(
silhouettes: FloatArray,
config_path: Path,
frame_count: int,
) -> tuple[torch.Tensor, None, FloatArray]:
cfgs = config_loader(str(config_path))
transformed = apply_transforms(silhouettes, cfgs["evaluator_cfg"]["transform"]).astype(np.float32)
indices = np.linspace(0, transformed.shape[0] - 1, frame_count, dtype=int)
sampled = transformed[indices]
input_tensor = torch.from_numpy(sampled).unsqueeze(0).unsqueeze(0).float()
projection = normalize_image(sampled.max(axis=0))
return input_tensor, None, projection
def prepare_heatmap_inputs(
heatmaps: FloatArray,
pav: FloatArray | None,
config_path: Path,
frame_count: int,
) -> tuple[torch.Tensor, torch.Tensor | None, FloatArray]:
cfgs = config_loader(str(config_path))
transform_cfg = cfgs["evaluator_cfg"]["transform"]
heatmap_transform = transform_cfg[0] if isinstance(transform_cfg, list) else transform_cfg
transformed = apply_transforms(heatmaps, heatmap_transform).astype(np.float32)
indices = np.linspace(0, transformed.shape[0] - 1, frame_count, dtype=int)
sampled = transformed[indices]
input_tensor = torch.from_numpy(sampled).permute(1, 0, 2, 3).unsqueeze(0).float()
combined = np.maximum(sampled.max(axis=0)[0], sampled.max(axis=0)[1])
projection = normalize_image(combined)
pav_tensor = None
if pav is not None:
pav_tensor = torch.from_numpy(pav.mean(axis=0, keepdims=True)).float()
return input_tensor, pav_tensor, projection
AttentionMethod = Literal["sum", "sum_p", "max_p"]
ResponseMode = Literal["activation", "cam"]
def compute_classification_scores(
model: nn.Module,
sequence: torch.Tensor,
pav: torch.Tensor | None,
) -> tuple[FloatArray, tuple[float, float, float], str]:
diagnostics = forward_diagnostics(model, sequence, pav)
class_scores = diagnostics.class_scores[0]
pred_index = int(class_scores.argmax().detach().cpu())
score_tuple = tuple(float(value.detach().cpu()) for value in class_scores)
return diagnostics.feature_map.detach()[0], score_tuple, LABEL_NAMES[pred_index]
def forward_diagnostics(
model: nn.Module,
sequence: torch.Tensor,
pav: torch.Tensor | None,
) -> ForwardDiagnostics:
outputs = model.Backbone(sequence)
feature_map = model.TP(outputs, None, options={"dim": 2})[0]
pooled = model.HPP(feature_map)
embeddings = model.FCs(pooled)
pga_spatial: Float[torch.Tensor, "batch parts"] | None = None
if pav is not None and isinstance(model, DRF):
pav_flat = pav.flatten(1)
channel_att = model.PGA.channel_att(pav_flat).unsqueeze(-1)
spatial_att = model.PGA.spatial_att(pav_flat.unsqueeze(-1)).transpose(1, 2)
pga_spatial = spatial_att.squeeze(1)
embeddings = embeddings * channel_att * spatial_att
_bn_embeddings, logits = model.BNNecks(embeddings)
class_scores = logits.mean(-1)
return ForwardDiagnostics(
feature_map=feature_map,
pooled=pooled,
embeddings=embeddings,
logits=logits,
class_scores=class_scores,
pga_spatial=pga_spatial,
)
def compute_response_map(
model: nn.Module,
sequence: torch.Tensor,
pav: torch.Tensor | None,
response_mode: ResponseMode,
method: AttentionMethod,
power: float,
) -> tuple[FloatArray, tuple[float, float, float], str]:
activation, score_tuple, predicted_label = compute_classification_scores(model, sequence, pav)
if response_mode == "cam":
response_map = adapted_cam_map(
model=model,
sequence=sequence,
pav=pav,
predicted_label=predicted_label,
)
else:
response_map = activation_attention_map(activation, method=method, power=power)
response_map = response_map.unsqueeze(0).unsqueeze(0)
response_map = F.interpolate(
response_map,
size=sequence.shape[-2:],
mode="bilinear",
align_corners=False,
)[0, 0]
response_np = response_map.detach().cpu().numpy().astype(np.float32)
response_np = normalize_image(response_np)
return response_np, score_tuple, predicted_label
def activation_attention_map(
activation: Float[torch.Tensor, "channels height width"],
method: AttentionMethod,
power: float,
) -> Float[torch.Tensor, "height width"]:
abs_activation = activation.abs()
if method == "sum":
return abs_activation.sum(dim=0)
powered = abs_activation.pow(power)
if method == "sum_p":
return powered.sum(dim=0)
if method == "max_p":
return powered.max(dim=0).values
raise ValueError(f"Unsupported attention method: {method}")
def adapted_cam_map(
model: nn.Module,
sequence: torch.Tensor,
pav: torch.Tensor | None,
predicted_label: str,
) -> Float[torch.Tensor, "height width"]:
outputs = model.Backbone(sequence)
feature_map = model.TP(outputs, None, options={"dim": 2})[0] # [1, c_in, h, w]
pooled = model.HPP(feature_map) # [1, c_in, p]
pre_attention = model.FCs(pooled) # [1, c_out, p]
attention = torch.ones_like(pre_attention)
if pav is not None and isinstance(model, DRF):
pav_flat = pav.flatten(1)
channel_att = model.PGA.channel_att(pav_flat).unsqueeze(-1)
spatial_att = model.PGA.spatial_att(pav_flat.unsqueeze(-1)).transpose(1, 2)
attention = channel_att * spatial_att
embeddings = pre_attention * attention
bn_scales = bn_feature_scales(model.BNNecks, embeddings.shape[1], embeddings.shape[2], embeddings.device)
classifier_weights = F.normalize(model.BNNecks.fc_bin.detach(), dim=1)
class_index = LABEL_NAMES.index(predicted_label)
class_weights = classifier_weights[:, :, class_index] # [p, c_out]
normalized_embeddings = batchnorm_forward_eval(model.BNNecks, embeddings)[0].transpose(0, 1) # [p, c_out]
part_norms = normalized_embeddings.norm(dim=-1, keepdim=True).clamp_min(1e-6)
effective_post_bn = class_weights / part_norms
effective_pre_bn = effective_post_bn * bn_scales
fc_weights = model.FCs.fc_bin.detach() # [p, c_in, c_out]
attention_weights = attention.detach()[0].transpose(0, 1) # [p, c_out]
effective_channel_weights = torch.einsum(
"pkj,pj->pk",
fc_weights,
attention_weights * effective_pre_bn,
) # [p, c_in]
feature = feature_map.detach()[0] # [c_in, h, w]
parts_num = effective_channel_weights.shape[0]
height = feature.shape[1]
split_size = max(height // parts_num, 1)
cam = torch.zeros((height, feature.shape[2]), device=feature.device, dtype=feature.dtype)
for part_index in range(parts_num):
row_start = part_index * split_size
row_end = height if part_index == parts_num - 1 else min((part_index + 1) * split_size, height)
if row_start >= height:
break
contribution = (feature[:, row_start:row_end, :] * effective_channel_weights[part_index][:, None, None]).sum(dim=0)
cam[row_start:row_end, :] = contribution
return F.relu(cam)
def batchnorm_forward_eval(
bn_necks: nn.Module,
embeddings: Float[torch.Tensor, "batch channels parts"],
) -> Float[torch.Tensor, "batch channels parts"]:
batch, channels, parts = embeddings.shape
if not getattr(bn_necks, "parallel_BN1d", True):
raise ValueError("Adapted CAM currently expects parallel_BN1d=True.")
bn = bn_necks.bn1d
flattened = embeddings.reshape(batch, -1)
normalized = torch.nn.functional.batch_norm(
flattened,
bn.running_mean,
bn.running_var,
bn.weight,
bn.bias,
training=False,
momentum=0.0,
eps=bn.eps,
)
return normalized.view(batch, channels, parts)
def bn_feature_scales(
bn_necks: nn.Module,
channels: int,
parts: int,
device: torch.device,
) -> Float[torch.Tensor, "parts channels"]:
if not getattr(bn_necks, "parallel_BN1d", True):
raise ValueError("Adapted CAM currently expects parallel_BN1d=True.")
bn = bn_necks.bn1d
running_var = bn.running_var.detach().to(device).view(channels, parts).transpose(0, 1)
base_scale = torch.rsqrt(running_var + bn.eps)
if bn.weight is None:
return base_scale
gamma = bn.weight.detach().to(device).view(channels, parts).transpose(0, 1)
return base_scale * gamma
def normalize_image(image: np.ndarray) -> FloatArray:
image = image.astype(np.float32)
min_value = float(image.min())
max_value = float(image.max())
if max_value <= min_value:
return np.zeros_like(image, dtype=np.float32)
return (image - min_value) / (max_value - min_value)
def overlay_response(base_image: FloatArray, response_map: FloatArray) -> FloatArray:
cmap = plt.get_cmap("jet")
heat = cmap(response_map)[..., :3].astype(np.float32)
base_rgb = np.repeat(base_image[..., None], 3, axis=-1)
overlay = base_rgb * 0.45 + heat * 0.55
return np.clip(overlay, 0.0, 1.0)
def compute_occlusion_sensitivity(
model: nn.Module,
sequence: torch.Tensor,
pav: torch.Tensor | None,
patch_size: int,
stride: int,
) -> tuple[FloatArray, tuple[float, float, float], str]:
with torch.no_grad():
baseline = forward_diagnostics(model, sequence, pav)
class_scores = baseline.class_scores[0]
pred_index = int(class_scores.argmax().detach().cpu())
baseline_score = float(class_scores[pred_index].detach().cpu())
score_tuple = tuple(float(value.detach().cpu()) for value in class_scores)
_, _, _, height, width = sequence.shape
sensitivity = np.zeros((height, width), dtype=np.float32)
counts = np.zeros((height, width), dtype=np.float32)
row_starts = list(range(0, max(height - patch_size + 1, 1), stride))
col_starts = list(range(0, max(width - patch_size + 1, 1), stride))
last_row = max(height - patch_size, 0)
last_col = max(width - patch_size, 0)
if row_starts[-1] != last_row:
row_starts.append(last_row)
if col_starts[-1] != last_col:
col_starts.append(last_col)
for top in row_starts:
for left in col_starts:
bottom = min(top + patch_size, height)
right = min(left + patch_size, width)
masked = sequence.clone()
masked[..., top:bottom, left:right] = 0.0
masked_scores = forward_diagnostics(model, masked, pav).class_scores[0]
score_drop = max(0.0, baseline_score - float(masked_scores[pred_index].detach().cpu()))
sensitivity[top:bottom, left:right] += score_drop
counts[top:bottom, left:right] += 1.0
counts = np.maximum(counts, 1.0)
sensitivity /= counts
return normalize_image(sensitivity), score_tuple, LABEL_NAMES[pred_index]
def compute_part_diagnostics(
model: nn.Module,
sequence: torch.Tensor,
pav: torch.Tensor | None,
) -> tuple[FloatArray, FloatArray | None, tuple[float, float, float], str]:
with torch.no_grad():
diagnostics = forward_diagnostics(model, sequence, pav)
class_scores = diagnostics.class_scores[0]
pred_index = int(class_scores.argmax().detach().cpu())
part_scores = diagnostics.logits[0, pred_index].detach().cpu().numpy().astype(np.float32)
pga_spatial = None
if diagnostics.pga_spatial is not None:
pga_spatial = diagnostics.pga_spatial[0].detach().cpu().numpy().astype(np.float32)
score_tuple = tuple(float(value.detach().cpu()) for value in class_scores)
return part_scores, pga_spatial, score_tuple, LABEL_NAMES[pred_index]
def extract_temporally_pooled_stages(
model: nn.Module,
sequence: torch.Tensor,
) -> dict[str, torch.Tensor]:
if not hasattr(model.Backbone, "forward_block"):
raise ValueError("Stage visualization requires SetBlockWrapper-style backbones.")
backbone = model.Backbone.forward_block
batch, channels, frames, height, width = sequence.shape
x = sequence.transpose(1, 2).reshape(-1, channels, height, width)
stage_outputs: dict[str, torch.Tensor] = {}
x = backbone.conv1(x)
x = backbone.bn1(x)
x = backbone.relu(x)
if getattr(backbone, "maxpool_flag", False):
x = backbone.maxpool(x)
stage_outputs["conv1"] = x
x = backbone.layer1(x)
stage_outputs["layer1"] = x
x = backbone.layer2(x)
stage_outputs["layer2"] = x
x = backbone.layer3(x)
stage_outputs["layer3"] = x
x = backbone.layer4(x)
stage_outputs["layer4"] = x
pooled_outputs: dict[str, torch.Tensor] = {}
for stage_name, stage_output in stage_outputs.items():
reshaped = stage_output.reshape(batch, frames, *stage_output.shape[1:]).transpose(1, 2).contiguous()
pooled_outputs[stage_name] = model.TP(reshaped, None, options={"dim": 2})[0]
return pooled_outputs
def render_visualization(results: list[VisualizationResult], output_path: Path) -> None:
fig, axes = plt.subplots(len(results), 3, figsize=(10.5, 3.2 * len(results)), constrained_layout=True)
if len(results) == 1:
axes = np.expand_dims(axes, axis=0)
for row, result in enumerate(results):
input_axis, heat_axis, overlay_axis = axes[row]
input_axis.imshow(result.input_image, cmap="gray", vmin=0.0, vmax=1.0)
input_axis.set_title(f"{result.title} input")
heat_axis.imshow(result.response_map, cmap="jet", vmin=0.0, vmax=1.0)
heat_axis.set_title("response")
overlay_axis.imshow(result.overlay_image)
overlay_axis.set_title(
f"overlay | pred={result.predicted_label}\n"
f"scores={tuple(round(score, 3) for score in result.class_scores)}"
)
for axis in (input_axis, heat_axis, overlay_axis):
axis.set_xticks([])
axis.set_yticks([])
fig.savefig(output_path, dpi=180)
plt.close(fig)
def render_stage_visualization(results: list[StageVisualizationResult], output_path: Path) -> None:
stage_names = ["input", "conv1", "layer1", "layer2", "layer3", "layer4"]
fig, axes = plt.subplots(
len(results),
len(stage_names),
figsize=(2.2 * len(stage_names), 3.0 * len(results)),
constrained_layout=True,
)
if len(results) == 1:
axes = np.expand_dims(axes, axis=0)
for row, result in enumerate(results):
for col, stage_name in enumerate(stage_names):
axis = axes[row, col]
if stage_name == "input":
axis.imshow(result.input_image, cmap="gray", vmin=0.0, vmax=1.0)
axis.set_title(
f"{result.title}\npred={result.predicted_label}",
fontsize=10,
)
else:
axis.imshow(result.stage_overlays[stage_name])
axis.set_title(stage_name, fontsize=10)
axis.set_xticks([])
axis.set_yticks([])
axes[row, 0].set_ylabel(
f"scores={tuple(round(score, 3) for score in result.class_scores)}",
fontsize=9,
rotation=90,
)
fig.savefig(output_path, dpi=180)
plt.close(fig)
def render_diagnostic_visualization(results: list[ModelDiagnosticResult], output_path: Path) -> None:
fig, axes = plt.subplots(len(results), 3, figsize=(11.5, 3.3 * len(results)), constrained_layout=True)
if len(results) == 1:
axes = np.expand_dims(axes, axis=0)
for row, result in enumerate(results):
input_axis, occlusion_axis, part_axis = axes[row]
input_axis.imshow(result.input_image, cmap="gray", vmin=0.0, vmax=1.0)
input_axis.set_title(f"{result.title} input")
input_axis.set_xticks([])
input_axis.set_yticks([])
occlusion_axis.imshow(result.overlay_image)
occlusion_axis.set_title(
f"occlusion | pred={result.predicted_label}\n"
f"scores={tuple(round(score, 3) for score in result.class_scores)}"
)
occlusion_axis.set_xticks([])
occlusion_axis.set_yticks([])
parts = np.arange(result.part_scores.shape[0], dtype=np.int32)
part_axis.bar(parts, result.part_scores, color="#4477AA", alpha=0.85, width=0.8)
part_axis.set_title("pred-class part logits")
part_axis.set_xlabel("HPP part")
part_axis.set_ylabel("logit")
part_axis.set_xticks(parts)
if result.pga_spatial is not None:
part_axis_2 = part_axis.twinx()
part_axis_2.plot(parts, result.pga_spatial, color="#CC6677", marker="o", linewidth=1.5)
part_axis_2.set_ylabel("PGA spatial")
part_axis_2.set_ylim(0.0, max(1.0, float(result.pga_spatial.max()) * 1.05))
fig.savefig(output_path, dpi=180)
plt.close(fig)
def default_model_specs() -> list[ModelSpec]:
return [
ModelSpec(
key="silhouette",
title="ScoNet-MT silhouette",
config_path=REPO_ROOT / "configs/sconet/sconet_scoliosis1k_local_eval_2gpu_better_112.yaml",
checkpoint_path=REPO_ROOT / "ckpt/ScoNet-20000-better.pt",
data_root=Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-sil-pkl"),
is_drf=False,
),
ModelSpec(
key="skeleton",
title="ScoNet-MT-ske",
config_path=REPO_ROOT / "configs/sconet/sconet_scoliosis1k_skeleton_118_sigma15_joint8_2gpu_bs12x8.yaml",
checkpoint_path=REPO_ROOT
/ "output/Scoliosis1K/ScoNet/ScoNet_skeleton_118_sigma15_joint8_sharedalign_2gpu_bs12x8/checkpoints/ScoNet_skeleton_118_sigma15_joint8_sharedalign_2gpu_bs12x8-20000.pt",
data_root=Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118-sigma15-joint8-sharedalign"),
is_drf=False,
),
ModelSpec(
key="drf",
title="DRF",
config_path=REPO_ROOT / "configs/drf/drf_scoliosis1k_eval_1gpu.yaml",
checkpoint_path=REPO_ROOT / "output/Scoliosis1K/DRF/DRF/checkpoints/DRF-20000.pt",
data_root=Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-drf-pkl-118"),
is_drf=True,
),
]
def build_stage_results(
spec: ModelSpec,
response_mode: ResponseMode,
method: AttentionMethod,
power: float,
sequence_id: str,
label: str,
view: str,
frame_count: int,
) -> StageVisualizationResult:
if response_mode != "activation":
raise ValueError("Stage view currently supports activation maps only.")
model, _cfgs = load_checkpoint_model(spec)
sequence, pav, base_image = load_sample_inputs(spec, sequence_id, label, view, frame_count)
_final_feature, score_tuple, predicted_label = compute_classification_scores(model, sequence, pav)
pooled_stages = extract_temporally_pooled_stages(model, sequence)
stage_overlays: dict[str, FloatArray] = {}
for stage_name, stage_tensor in pooled_stages.items():
response_map = activation_attention_map(stage_tensor[0], method=method, power=power)
response_map = F.interpolate(
response_map.unsqueeze(0).unsqueeze(0),
size=sequence.shape[-2:],
mode="bilinear",
align_corners=False,
)[0, 0]
response_np = normalize_image(response_map.detach().cpu().numpy().astype(np.float32))
stage_overlays[stage_name] = overlay_response(base_image, response_np)
return StageVisualizationResult(
title=spec.title,
predicted_label=predicted_label,
class_scores=score_tuple,
input_image=base_image,
stage_overlays=stage_overlays,
)
def build_diagnostic_results(
spec: ModelSpec,
sequence_id: str,
label: str,
view: str,
frame_count: int,
patch_size: int,
stride: int,
) -> ModelDiagnosticResult:
model, _cfgs = load_checkpoint_model(spec)
sequence, pav, base_image = load_sample_inputs(spec, sequence_id, label, view, frame_count)
occlusion_map, score_tuple, predicted_label = compute_occlusion_sensitivity(
model=model,
sequence=sequence,
pav=pav,
patch_size=patch_size,
stride=stride,
)
part_scores, pga_spatial, _scores, _predicted = compute_part_diagnostics(model, sequence, pav)
return ModelDiagnosticResult(
title=spec.title,
predicted_label=predicted_label,
class_scores=score_tuple,
input_image=base_image,
occlusion_map=occlusion_map,
overlay_image=overlay_response(base_image, occlusion_map),
part_scores=part_scores,
pga_spatial=pga_spatial,
)
@click.command()
@click.option("--sequence-id", default="00490", show_default=True)
@click.option("--label", default="positive", show_default=True)
@click.option("--view", default="000_180", show_default=True)
@click.option("--frame-count", default=30, show_default=True, type=int)
@click.option(
"--response-mode",
type=click.Choice(["activation", "cam"]),
default="activation",
show_default=True,
)
@click.option(
"--method",
type=click.Choice(["sum", "sum_p", "max_p"]),
default="sum",
show_default=True,
)
@click.option("--power", default=2.0, show_default=True, type=float)
@click.option(
"--view-mode",
type=click.Choice(["final", "stages", "diagnostics"]),
default="final",
show_default=True,
)
@click.option("--occlusion-patch-size", default=8, show_default=True, type=int)
@click.option("--occlusion-stride", default=8, show_default=True, type=int)
@click.option(
"--output-path",
default=str(REPO_ROOT / "research/feature_response_heatmaps/00490_positive_000_180_response_heatmaps.png"),
show_default=True,
type=click.Path(path_type=Path),
)
def main(
sequence_id: str,
label: str,
view: str,
frame_count: int,
response_mode: str,
method: str,
power: float,
view_mode: str,
occlusion_patch_size: int,
occlusion_stride: int,
output_path: Path,
) -> None:
output_path.parent.mkdir(parents=True, exist_ok=True)
if view_mode == "stages":
results = [
build_stage_results(
spec=spec,
response_mode=response_mode,
method=method,
power=power,
sequence_id=sequence_id,
label=label,
view=view,
frame_count=frame_count,
)
for spec in default_model_specs()
]
render_stage_visualization(results, output_path)
elif view_mode == "diagnostics":
results = [
build_diagnostic_results(
spec=spec,
sequence_id=sequence_id,
label=label,
view=view,
frame_count=frame_count,
patch_size=occlusion_patch_size,
stride=occlusion_stride,
)
for spec in default_model_specs()
]
render_diagnostic_visualization(results, output_path)
else:
results_final: list[VisualizationResult] = []
for spec in default_model_specs():
model, _cfgs = load_checkpoint_model(spec)
sequence, pav, base_image = load_sample_inputs(spec, sequence_id, label, view, frame_count)
response_map, scores, predicted_label = compute_response_map(
model,
sequence,
pav,
response_mode=response_mode,
method=method,
power=power,
)
results_final.append(
VisualizationResult(
title=spec.title,
predicted_label=predicted_label,
class_scores=scores,
input_image=base_image,
response_map=response_map,
overlay_image=overlay_response(base_image, response_map),
)
)
render_visualization(results_final, output_path)
click.echo(str(output_path))
if __name__ == "__main__":
main()