b24644f16e
Add the full demo runtime stack for single-person scoliosis inference, including input adapters, silhouette preprocessing, temporal windowing, ScoNet wrapper, result publishing, and click-based CLI orchestration. This commit captures the executable pipeline behavior independently from tests and planning artifacts for clearer review and rollback.
318 lines
11 KiB
Python
318 lines
11 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import Callable
|
|
from pathlib import Path
|
|
import sys
|
|
from typing import ClassVar, Protocol, cast, override
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from beartype import beartype
|
|
from einops import rearrange
|
|
from jaxtyping import Float
|
|
import jaxtyping
|
|
from torch import Tensor
|
|
|
|
_OPENGAIT_PACKAGE_ROOT = Path(__file__).resolve().parents[1]
|
|
if str(_OPENGAIT_PACKAGE_ROOT) not in sys.path:
|
|
sys.path.insert(0, str(_OPENGAIT_PACKAGE_ROOT))
|
|
|
|
from opengait.modeling.backbones.resnet import ResNet9
|
|
from opengait.modeling.modules import (
|
|
HorizontalPoolingPyramid,
|
|
PackSequenceWrapper as TemporalPool,
|
|
SeparateBNNecks,
|
|
SeparateFCs,
|
|
)
|
|
from opengait.utils import common as common_utils
|
|
|
|
|
|
JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
|
|
JaxtypedFactory = Callable[..., JaxtypedDecorator]
|
|
jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
|
|
ConfigLoader = Callable[[str], dict[str, object]]
|
|
config_loader = cast(ConfigLoader, common_utils.config_loader)
|
|
|
|
|
|
class TemporalPoolLike(Protocol):
|
|
def __call__(
|
|
self,
|
|
seqs: Tensor,
|
|
seqL: object,
|
|
dim: int = 2,
|
|
options: dict[str, int] | None = None,
|
|
) -> object: ...
|
|
|
|
|
|
class HppLike(Protocol):
|
|
def __call__(self, x: Tensor) -> Tensor: ...
|
|
|
|
|
|
class FCsLike(Protocol):
|
|
def __call__(self, x: Tensor) -> Tensor: ...
|
|
|
|
|
|
class BNNecksLike(Protocol):
|
|
def __call__(self, x: Tensor) -> tuple[Tensor, Tensor]: ...
|
|
|
|
|
|
class ScoNetDemo(nn.Module):
|
|
LABEL_MAP: ClassVar[dict[int, str]] = {0: "negative", 1: "neutral", 2: "positive"}
|
|
cfg_path: str
|
|
cfg: dict[str, object]
|
|
backbone: ResNet9
|
|
temporal_pool: TemporalPoolLike
|
|
hpp: HppLike
|
|
fcs: FCsLike
|
|
bn_necks: BNNecksLike
|
|
device: torch.device
|
|
|
|
@jaxtyped(typechecker=beartype)
|
|
def __init__(
|
|
self,
|
|
cfg_path: str | Path = "configs/sconet/sconet_scoliosis1k.yaml",
|
|
checkpoint_path: str | Path | None = None,
|
|
device: str | torch.device | None = None,
|
|
) -> None:
|
|
super().__init__()
|
|
resolved_cfg = self._resolve_path(cfg_path)
|
|
self.cfg_path = str(resolved_cfg)
|
|
|
|
self.cfg = config_loader(self.cfg_path)
|
|
|
|
model_cfg = self._extract_model_cfg(self.cfg)
|
|
backbone_cfg = self._extract_dict(model_cfg, "backbone_cfg")
|
|
|
|
if backbone_cfg.get("type") != "ResNet9":
|
|
raise ValueError(
|
|
"ScoNetDemo currently supports backbone type ResNet9 only."
|
|
)
|
|
|
|
self.backbone = ResNet9(
|
|
block=self._extract_str(backbone_cfg, "block"),
|
|
channels=self._extract_int_list(backbone_cfg, "channels"),
|
|
in_channel=self._extract_int(backbone_cfg, "in_channel", default=1),
|
|
layers=self._extract_int_list(backbone_cfg, "layers"),
|
|
strides=self._extract_int_list(backbone_cfg, "strides"),
|
|
maxpool=self._extract_bool(backbone_cfg, "maxpool", default=True),
|
|
)
|
|
|
|
fcs_cfg = self._extract_dict(model_cfg, "SeparateFCs")
|
|
bn_cfg = self._extract_dict(model_cfg, "SeparateBNNecks")
|
|
bin_num = self._extract_int_list(model_cfg, "bin_num")
|
|
|
|
self.temporal_pool = cast(TemporalPoolLike, TemporalPool(torch.max))
|
|
self.hpp = cast(HppLike, HorizontalPoolingPyramid(bin_num=bin_num))
|
|
self.fcs = cast(
|
|
FCsLike,
|
|
SeparateFCs(
|
|
parts_num=self._extract_int(fcs_cfg, "parts_num"),
|
|
in_channels=self._extract_int(fcs_cfg, "in_channels"),
|
|
out_channels=self._extract_int(fcs_cfg, "out_channels"),
|
|
norm=self._extract_bool(fcs_cfg, "norm", default=False),
|
|
),
|
|
)
|
|
self.bn_necks = cast(
|
|
BNNecksLike,
|
|
SeparateBNNecks(
|
|
parts_num=self._extract_int(bn_cfg, "parts_num"),
|
|
in_channels=self._extract_int(bn_cfg, "in_channels"),
|
|
class_num=self._extract_int(bn_cfg, "class_num"),
|
|
norm=self._extract_bool(bn_cfg, "norm", default=True),
|
|
parallel_BN1d=self._extract_bool(bn_cfg, "parallel_BN1d", default=True),
|
|
),
|
|
)
|
|
|
|
self.device = (
|
|
torch.device(device) if device is not None else torch.device("cpu")
|
|
)
|
|
_ = self.to(self.device)
|
|
|
|
if checkpoint_path is not None:
|
|
_ = self.load_checkpoint(checkpoint_path)
|
|
|
|
_ = self.eval()
|
|
|
|
@staticmethod
|
|
def _resolve_path(path: str | Path) -> Path:
|
|
candidate = Path(path)
|
|
if candidate.is_file():
|
|
return candidate
|
|
if candidate.is_absolute():
|
|
return candidate
|
|
repo_root = Path(__file__).resolve().parents[2]
|
|
return repo_root / candidate
|
|
|
|
@staticmethod
|
|
def _extract_model_cfg(cfg: dict[str, object]) -> dict[str, object]:
|
|
model_cfg_obj = cfg.get("model_cfg")
|
|
if not isinstance(model_cfg_obj, dict):
|
|
raise TypeError("model_cfg must be a dictionary.")
|
|
return cast(dict[str, object], model_cfg_obj)
|
|
|
|
@staticmethod
|
|
def _extract_dict(container: dict[str, object], key: str) -> dict[str, object]:
|
|
value = container.get(key)
|
|
if not isinstance(value, dict):
|
|
raise TypeError(f"{key} must be a dictionary.")
|
|
return cast(dict[str, object], value)
|
|
|
|
@staticmethod
|
|
def _extract_str(container: dict[str, object], key: str) -> str:
|
|
value = container.get(key)
|
|
if not isinstance(value, str):
|
|
raise TypeError(f"{key} must be a string.")
|
|
return value
|
|
|
|
@staticmethod
|
|
def _extract_int(
|
|
container: dict[str, object], key: str, default: int | None = None
|
|
) -> int:
|
|
value = container.get(key, default)
|
|
if not isinstance(value, int):
|
|
raise TypeError(f"{key} must be an int.")
|
|
return value
|
|
|
|
@staticmethod
|
|
def _extract_bool(
|
|
container: dict[str, object], key: str, default: bool | None = None
|
|
) -> bool:
|
|
value = container.get(key, default)
|
|
if not isinstance(value, bool):
|
|
raise TypeError(f"{key} must be a bool.")
|
|
return value
|
|
|
|
@staticmethod
|
|
def _extract_int_list(container: dict[str, object], key: str) -> list[int]:
|
|
value = container.get(key)
|
|
if not isinstance(value, list):
|
|
raise TypeError(f"{key} must be a list[int].")
|
|
values = cast(list[object], value)
|
|
if not all(isinstance(v, int) for v in values):
|
|
raise TypeError(f"{key} must be a list[int].")
|
|
return cast(list[int], values)
|
|
|
|
@staticmethod
|
|
def _normalize_state_dict(
|
|
state_dict_obj: dict[object, object],
|
|
) -> dict[str, Tensor]:
|
|
prefix_remap: tuple[tuple[str, str], ...] = (
|
|
("Backbone.forward_block.", "backbone."),
|
|
("FCs.", "fcs."),
|
|
("BNNecks.", "bn_necks."),
|
|
)
|
|
cleaned_state_dict: dict[str, Tensor] = {}
|
|
for key_obj, value_obj in state_dict_obj.items():
|
|
if not isinstance(key_obj, str):
|
|
raise TypeError("Checkpoint state_dict keys must be strings.")
|
|
if not isinstance(value_obj, Tensor):
|
|
raise TypeError("Checkpoint state_dict values must be torch.Tensor.")
|
|
key = key_obj[7:] if key_obj.startswith("module.") else key_obj
|
|
for source_prefix, target_prefix in prefix_remap:
|
|
if key.startswith(source_prefix):
|
|
key = f"{target_prefix}{key[len(source_prefix) :]}"
|
|
break
|
|
if key in cleaned_state_dict:
|
|
raise RuntimeError(
|
|
f"Checkpoint key normalization collision detected for key '{key}'."
|
|
)
|
|
cleaned_state_dict[key] = value_obj
|
|
return cleaned_state_dict
|
|
|
|
@jaxtyped(typechecker=beartype)
|
|
def load_checkpoint(
|
|
self,
|
|
checkpoint_path: str | Path,
|
|
map_location: str | torch.device | None = None,
|
|
strict: bool = True,
|
|
) -> None:
|
|
resolved_ckpt = self._resolve_path(checkpoint_path)
|
|
checkpoint_obj = cast(
|
|
object,
|
|
torch.load(
|
|
str(resolved_ckpt),
|
|
map_location=map_location if map_location is not None else self.device,
|
|
),
|
|
)
|
|
|
|
state_dict_obj: object = checkpoint_obj
|
|
if isinstance(checkpoint_obj, dict) and "model" in checkpoint_obj:
|
|
state_dict_obj = cast(dict[str, object], checkpoint_obj)["model"]
|
|
|
|
if not isinstance(state_dict_obj, dict):
|
|
raise TypeError("Unsupported checkpoint format.")
|
|
|
|
cleaned_state_dict = self._normalize_state_dict(
|
|
cast(dict[object, object], state_dict_obj)
|
|
)
|
|
try:
|
|
_ = self.load_state_dict(cleaned_state_dict, strict=strict)
|
|
except RuntimeError as exc:
|
|
raise RuntimeError(
|
|
f"Failed to load ScoNetDemo checkpoint after key normalization from '{resolved_ckpt}'."
|
|
) from exc
|
|
_ = self.eval()
|
|
|
|
def _prepare_sils(self, sils: Tensor) -> Tensor:
|
|
if sils.ndim == 4:
|
|
sils = sils.unsqueeze(1)
|
|
elif sils.ndim == 5 and sils.shape[1] != 1 and sils.shape[2] == 1:
|
|
sils = rearrange(sils, "b s c h w -> b c s h w")
|
|
|
|
if sils.ndim != 5 or sils.shape[1] != 1:
|
|
raise ValueError("Expected sils shape [B, 1, S, H, W] or [B, S, H, W].")
|
|
|
|
return sils.float().to(self.device)
|
|
|
|
def _forward_backbone(self, sils: Tensor) -> Tensor:
|
|
batch, channels, seq, height, width = sils.shape
|
|
framewise = sils.transpose(1, 2).reshape(batch * seq, channels, height, width)
|
|
frame_feats = cast(Tensor, self.backbone(framewise))
|
|
_, out_channels, out_h, out_w = frame_feats.shape
|
|
return (
|
|
frame_feats.reshape(batch, seq, out_channels, out_h, out_w)
|
|
.transpose(1, 2)
|
|
.contiguous()
|
|
)
|
|
|
|
@override
|
|
@jaxtyped(typechecker=beartype)
|
|
def forward(self, sils: Float[Tensor, "batch 1 seq 64 44"]) -> dict[str, Tensor]:
|
|
with torch.inference_mode():
|
|
prepared_sils = self._prepare_sils(sils)
|
|
outs = self._forward_backbone(prepared_sils)
|
|
|
|
pooled_obj = self.temporal_pool(outs, None, options={"dim": 2})
|
|
if (
|
|
not isinstance(pooled_obj, tuple)
|
|
or not pooled_obj
|
|
or not isinstance(pooled_obj[0], Tensor)
|
|
):
|
|
raise TypeError("TemporalPool output is invalid.")
|
|
pooled = pooled_obj[0]
|
|
|
|
feat = self.hpp(pooled)
|
|
embed_1 = self.fcs(feat)
|
|
_, logits = self.bn_necks(embed_1)
|
|
|
|
mean_logits = logits.mean(dim=-1)
|
|
pred_ids = torch.argmax(mean_logits, dim=-1)
|
|
probs = torch.softmax(mean_logits, dim=-1)
|
|
confidence = torch.gather(
|
|
probs, dim=-1, index=pred_ids.unsqueeze(-1)
|
|
).squeeze(-1)
|
|
|
|
return {"logits": logits, "label": pred_ids, "confidence": confidence}
|
|
|
|
@jaxtyped(typechecker=beartype)
|
|
def predict(self, sils: Float[Tensor, "batch 1 seq 64 44"]) -> tuple[str, float]:
|
|
outputs = cast(dict[str, Tensor], self.forward(sils))
|
|
labels = outputs["label"]
|
|
confidence = outputs["confidence"]
|
|
|
|
if labels.numel() != 1:
|
|
raise ValueError("predict expects batch size 1.")
|
|
|
|
label_id = int(labels.item())
|
|
return self.LABEL_MAP[label_id], float(confidence.item())
|