from __future__ import annotations from collections.abc import Callable from pathlib import Path from typing import ClassVar, Protocol, cast, override import torch import torch.nn as nn from beartype import beartype from einops import rearrange from jaxtyping import Float import jaxtyping from torch import Tensor from opengait.modeling.backbones.resnet import ResNet9 from opengait.modeling.modules import ( HorizontalPoolingPyramid, PackSequenceWrapper as TemporalPool, SeparateBNNecks, SeparateFCs, ) from opengait.utils import common as common_utils JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]] JaxtypedFactory = Callable[..., JaxtypedDecorator] jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped) ConfigLoader = Callable[[str], dict[str, object]] config_loader = cast(ConfigLoader, common_utils.config_loader) class TemporalPoolLike(Protocol): def __call__( self, seqs: Tensor, seqL: object, dim: int = 2, options: dict[str, int] | None = None, ) -> object: ... class HppLike(Protocol): def __call__(self, x: Tensor) -> Tensor: ... class FCsLike(Protocol): def __call__(self, x: Tensor) -> Tensor: ... class BNNecksLike(Protocol): def __call__(self, x: Tensor) -> tuple[Tensor, Tensor]: ... class ScoNetDemo(nn.Module): LABEL_MAP: ClassVar[dict[int, str]] = {0: "negative", 1: "neutral", 2: "positive"} cfg_path: str cfg: dict[str, object] backbone: ResNet9 temporal_pool: TemporalPoolLike hpp: HppLike fcs: FCsLike bn_necks: BNNecksLike device: torch.device @jaxtyped(typechecker=beartype) def __init__( self, cfg_path: str | Path = "configs/sconet/sconet_scoliosis1k.yaml", checkpoint_path: str | Path | None = None, device: str | torch.device | None = None, ) -> None: super().__init__() resolved_cfg = self._resolve_path(cfg_path) self.cfg_path = str(resolved_cfg) self.cfg = config_loader(self.cfg_path) model_cfg = self._extract_model_cfg(self.cfg) backbone_cfg = self._extract_dict(model_cfg, "backbone_cfg") if backbone_cfg.get("type") != "ResNet9": raise ValueError( "ScoNetDemo currently supports backbone type ResNet9 only." ) self.backbone = ResNet9( block=self._extract_str(backbone_cfg, "block"), channels=self._extract_int_list(backbone_cfg, "channels"), in_channel=self._extract_int(backbone_cfg, "in_channel", default=1), layers=self._extract_int_list(backbone_cfg, "layers"), strides=self._extract_int_list(backbone_cfg, "strides"), maxpool=self._extract_bool(backbone_cfg, "maxpool", default=True), ) fcs_cfg = self._extract_dict(model_cfg, "SeparateFCs") bn_cfg = self._extract_dict(model_cfg, "SeparateBNNecks") bin_num = self._extract_int_list(model_cfg, "bin_num") self.temporal_pool = cast(TemporalPoolLike, TemporalPool(torch.max)) self.hpp = cast(HppLike, HorizontalPoolingPyramid(bin_num=bin_num)) self.fcs = cast( FCsLike, SeparateFCs( parts_num=self._extract_int(fcs_cfg, "parts_num"), in_channels=self._extract_int(fcs_cfg, "in_channels"), out_channels=self._extract_int(fcs_cfg, "out_channels"), norm=self._extract_bool(fcs_cfg, "norm", default=False), ), ) self.bn_necks = cast( BNNecksLike, SeparateBNNecks( parts_num=self._extract_int(bn_cfg, "parts_num"), in_channels=self._extract_int(bn_cfg, "in_channels"), class_num=self._extract_int(bn_cfg, "class_num"), norm=self._extract_bool(bn_cfg, "norm", default=True), parallel_BN1d=self._extract_bool(bn_cfg, "parallel_BN1d", default=True), ), ) self.device = ( torch.device(device) if device is not None else torch.device("cpu") ) _ = self.to(self.device) if checkpoint_path is not None: _ = self.load_checkpoint(checkpoint_path) _ = self.eval() @staticmethod def _resolve_path(path: str | Path) -> Path: candidate = Path(path) if candidate.is_file(): return candidate if candidate.is_absolute(): return candidate repo_root = Path(__file__).resolve().parents[2] return repo_root / candidate @staticmethod def _extract_model_cfg(cfg: dict[str, object]) -> dict[str, object]: model_cfg_obj = cfg.get("model_cfg") if not isinstance(model_cfg_obj, dict): raise TypeError("model_cfg must be a dictionary.") return cast(dict[str, object], model_cfg_obj) @staticmethod def _extract_dict(container: dict[str, object], key: str) -> dict[str, object]: value = container.get(key) if not isinstance(value, dict): raise TypeError(f"{key} must be a dictionary.") return cast(dict[str, object], value) @staticmethod def _extract_str(container: dict[str, object], key: str) -> str: value = container.get(key) if not isinstance(value, str): raise TypeError(f"{key} must be a string.") return value @staticmethod def _extract_int( container: dict[str, object], key: str, default: int | None = None ) -> int: value = container.get(key, default) if not isinstance(value, int): raise TypeError(f"{key} must be an int.") return value @staticmethod def _extract_bool( container: dict[str, object], key: str, default: bool | None = None ) -> bool: value = container.get(key, default) if not isinstance(value, bool): raise TypeError(f"{key} must be a bool.") return value @staticmethod def _extract_int_list(container: dict[str, object], key: str) -> list[int]: value = container.get(key) if not isinstance(value, list): raise TypeError(f"{key} must be a list[int].") values = cast(list[object], value) if not all(isinstance(v, int) for v in values): raise TypeError(f"{key} must be a list[int].") return cast(list[int], values) @staticmethod def _normalize_state_dict( state_dict_obj: dict[object, object], ) -> dict[str, Tensor]: prefix_remap: tuple[tuple[str, str], ...] = ( ("Backbone.forward_block.", "backbone."), ("FCs.", "fcs."), ("BNNecks.", "bn_necks."), ) cleaned_state_dict: dict[str, Tensor] = {} for key_obj, value_obj in state_dict_obj.items(): if not isinstance(key_obj, str): raise TypeError("Checkpoint state_dict keys must be strings.") if not isinstance(value_obj, Tensor): raise TypeError("Checkpoint state_dict values must be torch.Tensor.") key = key_obj[7:] if key_obj.startswith("module.") else key_obj for source_prefix, target_prefix in prefix_remap: if key.startswith(source_prefix): key = f"{target_prefix}{key[len(source_prefix) :]}" break if key in cleaned_state_dict: raise RuntimeError( f"Checkpoint key normalization collision detected for key '{key}'." ) cleaned_state_dict[key] = value_obj return cleaned_state_dict @jaxtyped(typechecker=beartype) def load_checkpoint( self, checkpoint_path: str | Path, map_location: str | torch.device | None = None, strict: bool = True, ) -> None: resolved_ckpt = self._resolve_path(checkpoint_path) checkpoint_obj = cast( object, torch.load( str(resolved_ckpt), map_location=map_location if map_location is not None else self.device, ), ) state_dict_obj: object = checkpoint_obj if isinstance(checkpoint_obj, dict) and "model" in checkpoint_obj: state_dict_obj = cast(dict[str, object], checkpoint_obj)["model"] if not isinstance(state_dict_obj, dict): raise TypeError("Unsupported checkpoint format.") cleaned_state_dict = self._normalize_state_dict( cast(dict[object, object], state_dict_obj) ) try: _ = self.load_state_dict(cleaned_state_dict, strict=strict) except RuntimeError as exc: raise RuntimeError( f"Failed to load ScoNetDemo checkpoint after key normalization from '{resolved_ckpt}'." ) from exc _ = self.eval() def _prepare_sils(self, sils: Tensor) -> Tensor: if sils.ndim == 4: sils = sils.unsqueeze(1) elif sils.ndim == 5 and sils.shape[1] != 1 and sils.shape[2] == 1: sils = rearrange(sils, "b s c h w -> b c s h w") if sils.ndim != 5 or sils.shape[1] != 1: raise ValueError("Expected sils shape [B, 1, S, H, W] or [B, S, H, W].") return sils.float().to(self.device) def _forward_backbone(self, sils: Tensor) -> Tensor: batch, channels, seq, height, width = sils.shape framewise = sils.transpose(1, 2).reshape(batch * seq, channels, height, width) frame_feats = cast(Tensor, self.backbone(framewise)) _, out_channels, out_h, out_w = frame_feats.shape return ( frame_feats.reshape(batch, seq, out_channels, out_h, out_w) .transpose(1, 2) .contiguous() ) @override @jaxtyped(typechecker=beartype) def forward(self, sils: Float[Tensor, "batch 1 seq 64 44"]) -> dict[str, Tensor]: with torch.inference_mode(): prepared_sils = self._prepare_sils(sils) outs = self._forward_backbone(prepared_sils) pooled_obj = self.temporal_pool(outs, None, options={"dim": 2}) if ( not isinstance(pooled_obj, tuple) or not pooled_obj or not isinstance(pooled_obj[0], Tensor) ): raise TypeError("TemporalPool output is invalid.") pooled = pooled_obj[0] feat = self.hpp(pooled) embed_1 = self.fcs(feat) _, logits = self.bn_necks(embed_1) mean_logits = logits.mean(dim=-1) pred_ids = torch.argmax(mean_logits, dim=-1) probs = torch.softmax(mean_logits, dim=-1) confidence = torch.gather( probs, dim=-1, index=pred_ids.unsqueeze(-1) ).squeeze(-1) return {"logits": logits, "label": pred_ids, "confidence": confidence} @jaxtyped(typechecker=beartype) def predict(self, sils: Float[Tensor, "batch 1 seq 64 44"]) -> tuple[str, float]: outputs = cast(dict[str, Tensor], self.forward(sils)) labels = outputs["label"] confidence = outputs["confidence"] if labels.numel() != 1: raise ValueError("predict expects batch size 1.") label_id = int(labels.item()) return self.LABEL_MAP[label_id], float(confidence.item())