00fcda4fe3
Move demo implementation into opengait_studio, retire Sports2D runtime integration, and align packaging with root-level monorepo dependency management.
313 lines
11 KiB
Python
313 lines
11 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import Callable
|
|
from pathlib import Path
|
|
from typing import ClassVar, Protocol, cast, override
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from beartype import beartype
|
|
from einops import rearrange
|
|
from jaxtyping import Float
|
|
import jaxtyping
|
|
from torch import Tensor
|
|
|
|
from opengait.modeling.backbones.resnet import ResNet9
|
|
from opengait.modeling.modules import (
|
|
HorizontalPoolingPyramid,
|
|
PackSequenceWrapper as TemporalPool,
|
|
SeparateBNNecks,
|
|
SeparateFCs,
|
|
)
|
|
from opengait.utils import common as common_utils
|
|
|
|
|
|
JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
|
|
JaxtypedFactory = Callable[..., JaxtypedDecorator]
|
|
jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
|
|
ConfigLoader = Callable[[str], dict[str, object]]
|
|
config_loader = cast(ConfigLoader, common_utils.config_loader)
|
|
|
|
|
|
class TemporalPoolLike(Protocol):
|
|
def __call__(
|
|
self,
|
|
seqs: Tensor,
|
|
seqL: object,
|
|
dim: int = 2,
|
|
options: dict[str, int] | None = None,
|
|
) -> object: ...
|
|
|
|
|
|
class HppLike(Protocol):
|
|
def __call__(self, x: Tensor) -> Tensor: ...
|
|
|
|
|
|
class FCsLike(Protocol):
|
|
def __call__(self, x: Tensor) -> Tensor: ...
|
|
|
|
|
|
class BNNecksLike(Protocol):
|
|
def __call__(self, x: Tensor) -> tuple[Tensor, Tensor]: ...
|
|
|
|
|
|
class ScoNetDemo(nn.Module):
|
|
LABEL_MAP: ClassVar[dict[int, str]] = {0: "negative", 1: "neutral", 2: "positive"}
|
|
cfg_path: str
|
|
cfg: dict[str, object]
|
|
backbone: ResNet9
|
|
temporal_pool: TemporalPoolLike
|
|
hpp: HppLike
|
|
fcs: FCsLike
|
|
bn_necks: BNNecksLike
|
|
device: torch.device
|
|
|
|
@jaxtyped(typechecker=beartype)
|
|
def __init__(
|
|
self,
|
|
cfg_path: str | Path = "configs/sconet/sconet_scoliosis1k.yaml",
|
|
checkpoint_path: str | Path | None = None,
|
|
device: str | torch.device | None = None,
|
|
) -> None:
|
|
super().__init__()
|
|
resolved_cfg = self._resolve_path(cfg_path)
|
|
self.cfg_path = str(resolved_cfg)
|
|
|
|
self.cfg = config_loader(self.cfg_path)
|
|
|
|
model_cfg = self._extract_model_cfg(self.cfg)
|
|
backbone_cfg = self._extract_dict(model_cfg, "backbone_cfg")
|
|
|
|
if backbone_cfg.get("type") != "ResNet9":
|
|
raise ValueError(
|
|
"ScoNetDemo currently supports backbone type ResNet9 only."
|
|
)
|
|
|
|
self.backbone = ResNet9(
|
|
block=self._extract_str(backbone_cfg, "block"),
|
|
channels=self._extract_int_list(backbone_cfg, "channels"),
|
|
in_channel=self._extract_int(backbone_cfg, "in_channel", default=1),
|
|
layers=self._extract_int_list(backbone_cfg, "layers"),
|
|
strides=self._extract_int_list(backbone_cfg, "strides"),
|
|
maxpool=self._extract_bool(backbone_cfg, "maxpool", default=True),
|
|
)
|
|
|
|
fcs_cfg = self._extract_dict(model_cfg, "SeparateFCs")
|
|
bn_cfg = self._extract_dict(model_cfg, "SeparateBNNecks")
|
|
bin_num = self._extract_int_list(model_cfg, "bin_num")
|
|
|
|
self.temporal_pool = cast(TemporalPoolLike, TemporalPool(torch.max))
|
|
self.hpp = cast(HppLike, HorizontalPoolingPyramid(bin_num=bin_num))
|
|
self.fcs = cast(
|
|
FCsLike,
|
|
SeparateFCs(
|
|
parts_num=self._extract_int(fcs_cfg, "parts_num"),
|
|
in_channels=self._extract_int(fcs_cfg, "in_channels"),
|
|
out_channels=self._extract_int(fcs_cfg, "out_channels"),
|
|
norm=self._extract_bool(fcs_cfg, "norm", default=False),
|
|
),
|
|
)
|
|
self.bn_necks = cast(
|
|
BNNecksLike,
|
|
SeparateBNNecks(
|
|
parts_num=self._extract_int(bn_cfg, "parts_num"),
|
|
in_channels=self._extract_int(bn_cfg, "in_channels"),
|
|
class_num=self._extract_int(bn_cfg, "class_num"),
|
|
norm=self._extract_bool(bn_cfg, "norm", default=True),
|
|
parallel_BN1d=self._extract_bool(bn_cfg, "parallel_BN1d", default=True),
|
|
),
|
|
)
|
|
|
|
self.device = (
|
|
torch.device(device) if device is not None else torch.device("cpu")
|
|
)
|
|
_ = self.to(self.device)
|
|
|
|
if checkpoint_path is not None:
|
|
_ = self.load_checkpoint(checkpoint_path)
|
|
|
|
_ = self.eval()
|
|
|
|
@staticmethod
|
|
def _resolve_path(path: str | Path) -> Path:
|
|
candidate = Path(path)
|
|
if candidate.is_file():
|
|
return candidate
|
|
if candidate.is_absolute():
|
|
return candidate
|
|
repo_root = Path(__file__).resolve().parents[2]
|
|
return repo_root / candidate
|
|
|
|
@staticmethod
|
|
def _extract_model_cfg(cfg: dict[str, object]) -> dict[str, object]:
|
|
model_cfg_obj = cfg.get("model_cfg")
|
|
if not isinstance(model_cfg_obj, dict):
|
|
raise TypeError("model_cfg must be a dictionary.")
|
|
return cast(dict[str, object], model_cfg_obj)
|
|
|
|
@staticmethod
|
|
def _extract_dict(container: dict[str, object], key: str) -> dict[str, object]:
|
|
value = container.get(key)
|
|
if not isinstance(value, dict):
|
|
raise TypeError(f"{key} must be a dictionary.")
|
|
return cast(dict[str, object], value)
|
|
|
|
@staticmethod
|
|
def _extract_str(container: dict[str, object], key: str) -> str:
|
|
value = container.get(key)
|
|
if not isinstance(value, str):
|
|
raise TypeError(f"{key} must be a string.")
|
|
return value
|
|
|
|
@staticmethod
|
|
def _extract_int(
|
|
container: dict[str, object], key: str, default: int | None = None
|
|
) -> int:
|
|
value = container.get(key, default)
|
|
if not isinstance(value, int):
|
|
raise TypeError(f"{key} must be an int.")
|
|
return value
|
|
|
|
@staticmethod
|
|
def _extract_bool(
|
|
container: dict[str, object], key: str, default: bool | None = None
|
|
) -> bool:
|
|
value = container.get(key, default)
|
|
if not isinstance(value, bool):
|
|
raise TypeError(f"{key} must be a bool.")
|
|
return value
|
|
|
|
@staticmethod
|
|
def _extract_int_list(container: dict[str, object], key: str) -> list[int]:
|
|
value = container.get(key)
|
|
if not isinstance(value, list):
|
|
raise TypeError(f"{key} must be a list[int].")
|
|
values = cast(list[object], value)
|
|
if not all(isinstance(v, int) for v in values):
|
|
raise TypeError(f"{key} must be a list[int].")
|
|
return cast(list[int], values)
|
|
|
|
@staticmethod
|
|
def _normalize_state_dict(
|
|
state_dict_obj: dict[object, object],
|
|
) -> dict[str, Tensor]:
|
|
prefix_remap: tuple[tuple[str, str], ...] = (
|
|
("Backbone.forward_block.", "backbone."),
|
|
("FCs.", "fcs."),
|
|
("BNNecks.", "bn_necks."),
|
|
)
|
|
cleaned_state_dict: dict[str, Tensor] = {}
|
|
for key_obj, value_obj in state_dict_obj.items():
|
|
if not isinstance(key_obj, str):
|
|
raise TypeError("Checkpoint state_dict keys must be strings.")
|
|
if not isinstance(value_obj, Tensor):
|
|
raise TypeError("Checkpoint state_dict values must be torch.Tensor.")
|
|
key = key_obj[7:] if key_obj.startswith("module.") else key_obj
|
|
for source_prefix, target_prefix in prefix_remap:
|
|
if key.startswith(source_prefix):
|
|
key = f"{target_prefix}{key[len(source_prefix) :]}"
|
|
break
|
|
if key in cleaned_state_dict:
|
|
raise RuntimeError(
|
|
f"Checkpoint key normalization collision detected for key '{key}'."
|
|
)
|
|
cleaned_state_dict[key] = value_obj
|
|
return cleaned_state_dict
|
|
|
|
@jaxtyped(typechecker=beartype)
|
|
def load_checkpoint(
|
|
self,
|
|
checkpoint_path: str | Path,
|
|
map_location: str | torch.device | None = None,
|
|
strict: bool = True,
|
|
) -> None:
|
|
resolved_ckpt = self._resolve_path(checkpoint_path)
|
|
checkpoint_obj = cast(
|
|
object,
|
|
torch.load(
|
|
str(resolved_ckpt),
|
|
map_location=map_location if map_location is not None else self.device,
|
|
),
|
|
)
|
|
|
|
state_dict_obj: object = checkpoint_obj
|
|
if isinstance(checkpoint_obj, dict) and "model" in checkpoint_obj:
|
|
state_dict_obj = cast(dict[str, object], checkpoint_obj)["model"]
|
|
|
|
if not isinstance(state_dict_obj, dict):
|
|
raise TypeError("Unsupported checkpoint format.")
|
|
|
|
cleaned_state_dict = self._normalize_state_dict(
|
|
cast(dict[object, object], state_dict_obj)
|
|
)
|
|
try:
|
|
_ = self.load_state_dict(cleaned_state_dict, strict=strict)
|
|
except RuntimeError as exc:
|
|
raise RuntimeError(
|
|
f"Failed to load ScoNetDemo checkpoint after key normalization from '{resolved_ckpt}'."
|
|
) from exc
|
|
_ = self.eval()
|
|
|
|
def _prepare_sils(self, sils: Tensor) -> Tensor:
|
|
if sils.ndim == 4:
|
|
sils = sils.unsqueeze(1)
|
|
elif sils.ndim == 5 and sils.shape[1] != 1 and sils.shape[2] == 1:
|
|
sils = rearrange(sils, "b s c h w -> b c s h w")
|
|
|
|
if sils.ndim != 5 or sils.shape[1] != 1:
|
|
raise ValueError("Expected sils shape [B, 1, S, H, W] or [B, S, H, W].")
|
|
|
|
return sils.float().to(self.device)
|
|
|
|
def _forward_backbone(self, sils: Tensor) -> Tensor:
|
|
batch, channels, seq, height, width = sils.shape
|
|
framewise = sils.transpose(1, 2).reshape(batch * seq, channels, height, width)
|
|
frame_feats = cast(Tensor, self.backbone(framewise))
|
|
_, out_channels, out_h, out_w = frame_feats.shape
|
|
return (
|
|
frame_feats.reshape(batch, seq, out_channels, out_h, out_w)
|
|
.transpose(1, 2)
|
|
.contiguous()
|
|
)
|
|
|
|
@override
|
|
@jaxtyped(typechecker=beartype)
|
|
def forward(self, sils: Float[Tensor, "batch 1 seq 64 44"]) -> dict[str, Tensor]:
|
|
with torch.inference_mode():
|
|
prepared_sils = self._prepare_sils(sils)
|
|
outs = self._forward_backbone(prepared_sils)
|
|
|
|
pooled_obj = self.temporal_pool(outs, None, options={"dim": 2})
|
|
if (
|
|
not isinstance(pooled_obj, tuple)
|
|
or not pooled_obj
|
|
or not isinstance(pooled_obj[0], Tensor)
|
|
):
|
|
raise TypeError("TemporalPool output is invalid.")
|
|
pooled = pooled_obj[0]
|
|
|
|
feat = self.hpp(pooled)
|
|
embed_1 = self.fcs(feat)
|
|
_, logits = self.bn_necks(embed_1)
|
|
|
|
mean_logits = logits.mean(dim=-1)
|
|
pred_ids = torch.argmax(mean_logits, dim=-1)
|
|
probs = torch.softmax(mean_logits, dim=-1)
|
|
confidence = torch.gather(
|
|
probs, dim=-1, index=pred_ids.unsqueeze(-1)
|
|
).squeeze(-1)
|
|
|
|
return {"logits": logits, "label": pred_ids, "confidence": confidence}
|
|
|
|
@jaxtyped(typechecker=beartype)
|
|
def predict(self, sils: Float[Tensor, "batch 1 seq 64 44"]) -> tuple[str, float]:
|
|
outputs = cast(dict[str, Tensor], self.forward(sils))
|
|
labels = outputs["label"]
|
|
confidence = outputs["confidence"]
|
|
|
|
if labels.numel() != 1:
|
|
raise ValueError("predict expects batch size 1.")
|
|
|
|
label_id = int(labels.item())
|
|
return self.LABEL_MAP[label_id], float(confidence.item())
|