Files
OpenGait/opengait/demo/sconet_demo.py
T
crosstyan b24644f16e feat(demo): implement ScoNet real-time pipeline runtime
Add the full demo runtime stack for single-person scoliosis inference, including input adapters, silhouette preprocessing, temporal windowing, ScoNet wrapper, result publishing, and click-based CLI orchestration. This commit captures the executable pipeline behavior independently from tests and planning artifacts for clearer review and rollback.
2026-02-27 09:59:04 +08:00

318 lines
11 KiB
Python

from __future__ import annotations
from collections.abc import Callable
from pathlib import Path
import sys
from typing import ClassVar, Protocol, cast, override
import torch
import torch.nn as nn
from beartype import beartype
from einops import rearrange
from jaxtyping import Float
import jaxtyping
from torch import Tensor
_OPENGAIT_PACKAGE_ROOT = Path(__file__).resolve().parents[1]
if str(_OPENGAIT_PACKAGE_ROOT) not in sys.path:
sys.path.insert(0, str(_OPENGAIT_PACKAGE_ROOT))
from opengait.modeling.backbones.resnet import ResNet9
from opengait.modeling.modules import (
HorizontalPoolingPyramid,
PackSequenceWrapper as TemporalPool,
SeparateBNNecks,
SeparateFCs,
)
from opengait.utils import common as common_utils
JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
JaxtypedFactory = Callable[..., JaxtypedDecorator]
jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
ConfigLoader = Callable[[str], dict[str, object]]
config_loader = cast(ConfigLoader, common_utils.config_loader)
class TemporalPoolLike(Protocol):
def __call__(
self,
seqs: Tensor,
seqL: object,
dim: int = 2,
options: dict[str, int] | None = None,
) -> object: ...
class HppLike(Protocol):
def __call__(self, x: Tensor) -> Tensor: ...
class FCsLike(Protocol):
def __call__(self, x: Tensor) -> Tensor: ...
class BNNecksLike(Protocol):
def __call__(self, x: Tensor) -> tuple[Tensor, Tensor]: ...
class ScoNetDemo(nn.Module):
LABEL_MAP: ClassVar[dict[int, str]] = {0: "negative", 1: "neutral", 2: "positive"}
cfg_path: str
cfg: dict[str, object]
backbone: ResNet9
temporal_pool: TemporalPoolLike
hpp: HppLike
fcs: FCsLike
bn_necks: BNNecksLike
device: torch.device
@jaxtyped(typechecker=beartype)
def __init__(
self,
cfg_path: str | Path = "configs/sconet/sconet_scoliosis1k.yaml",
checkpoint_path: str | Path | None = None,
device: str | torch.device | None = None,
) -> None:
super().__init__()
resolved_cfg = self._resolve_path(cfg_path)
self.cfg_path = str(resolved_cfg)
self.cfg = config_loader(self.cfg_path)
model_cfg = self._extract_model_cfg(self.cfg)
backbone_cfg = self._extract_dict(model_cfg, "backbone_cfg")
if backbone_cfg.get("type") != "ResNet9":
raise ValueError(
"ScoNetDemo currently supports backbone type ResNet9 only."
)
self.backbone = ResNet9(
block=self._extract_str(backbone_cfg, "block"),
channels=self._extract_int_list(backbone_cfg, "channels"),
in_channel=self._extract_int(backbone_cfg, "in_channel", default=1),
layers=self._extract_int_list(backbone_cfg, "layers"),
strides=self._extract_int_list(backbone_cfg, "strides"),
maxpool=self._extract_bool(backbone_cfg, "maxpool", default=True),
)
fcs_cfg = self._extract_dict(model_cfg, "SeparateFCs")
bn_cfg = self._extract_dict(model_cfg, "SeparateBNNecks")
bin_num = self._extract_int_list(model_cfg, "bin_num")
self.temporal_pool = cast(TemporalPoolLike, TemporalPool(torch.max))
self.hpp = cast(HppLike, HorizontalPoolingPyramid(bin_num=bin_num))
self.fcs = cast(
FCsLike,
SeparateFCs(
parts_num=self._extract_int(fcs_cfg, "parts_num"),
in_channels=self._extract_int(fcs_cfg, "in_channels"),
out_channels=self._extract_int(fcs_cfg, "out_channels"),
norm=self._extract_bool(fcs_cfg, "norm", default=False),
),
)
self.bn_necks = cast(
BNNecksLike,
SeparateBNNecks(
parts_num=self._extract_int(bn_cfg, "parts_num"),
in_channels=self._extract_int(bn_cfg, "in_channels"),
class_num=self._extract_int(bn_cfg, "class_num"),
norm=self._extract_bool(bn_cfg, "norm", default=True),
parallel_BN1d=self._extract_bool(bn_cfg, "parallel_BN1d", default=True),
),
)
self.device = (
torch.device(device) if device is not None else torch.device("cpu")
)
_ = self.to(self.device)
if checkpoint_path is not None:
_ = self.load_checkpoint(checkpoint_path)
_ = self.eval()
@staticmethod
def _resolve_path(path: str | Path) -> Path:
candidate = Path(path)
if candidate.is_file():
return candidate
if candidate.is_absolute():
return candidate
repo_root = Path(__file__).resolve().parents[2]
return repo_root / candidate
@staticmethod
def _extract_model_cfg(cfg: dict[str, object]) -> dict[str, object]:
model_cfg_obj = cfg.get("model_cfg")
if not isinstance(model_cfg_obj, dict):
raise TypeError("model_cfg must be a dictionary.")
return cast(dict[str, object], model_cfg_obj)
@staticmethod
def _extract_dict(container: dict[str, object], key: str) -> dict[str, object]:
value = container.get(key)
if not isinstance(value, dict):
raise TypeError(f"{key} must be a dictionary.")
return cast(dict[str, object], value)
@staticmethod
def _extract_str(container: dict[str, object], key: str) -> str:
value = container.get(key)
if not isinstance(value, str):
raise TypeError(f"{key} must be a string.")
return value
@staticmethod
def _extract_int(
container: dict[str, object], key: str, default: int | None = None
) -> int:
value = container.get(key, default)
if not isinstance(value, int):
raise TypeError(f"{key} must be an int.")
return value
@staticmethod
def _extract_bool(
container: dict[str, object], key: str, default: bool | None = None
) -> bool:
value = container.get(key, default)
if not isinstance(value, bool):
raise TypeError(f"{key} must be a bool.")
return value
@staticmethod
def _extract_int_list(container: dict[str, object], key: str) -> list[int]:
value = container.get(key)
if not isinstance(value, list):
raise TypeError(f"{key} must be a list[int].")
values = cast(list[object], value)
if not all(isinstance(v, int) for v in values):
raise TypeError(f"{key} must be a list[int].")
return cast(list[int], values)
@staticmethod
def _normalize_state_dict(
state_dict_obj: dict[object, object],
) -> dict[str, Tensor]:
prefix_remap: tuple[tuple[str, str], ...] = (
("Backbone.forward_block.", "backbone."),
("FCs.", "fcs."),
("BNNecks.", "bn_necks."),
)
cleaned_state_dict: dict[str, Tensor] = {}
for key_obj, value_obj in state_dict_obj.items():
if not isinstance(key_obj, str):
raise TypeError("Checkpoint state_dict keys must be strings.")
if not isinstance(value_obj, Tensor):
raise TypeError("Checkpoint state_dict values must be torch.Tensor.")
key = key_obj[7:] if key_obj.startswith("module.") else key_obj
for source_prefix, target_prefix in prefix_remap:
if key.startswith(source_prefix):
key = f"{target_prefix}{key[len(source_prefix) :]}"
break
if key in cleaned_state_dict:
raise RuntimeError(
f"Checkpoint key normalization collision detected for key '{key}'."
)
cleaned_state_dict[key] = value_obj
return cleaned_state_dict
@jaxtyped(typechecker=beartype)
def load_checkpoint(
self,
checkpoint_path: str | Path,
map_location: str | torch.device | None = None,
strict: bool = True,
) -> None:
resolved_ckpt = self._resolve_path(checkpoint_path)
checkpoint_obj = cast(
object,
torch.load(
str(resolved_ckpt),
map_location=map_location if map_location is not None else self.device,
),
)
state_dict_obj: object = checkpoint_obj
if isinstance(checkpoint_obj, dict) and "model" in checkpoint_obj:
state_dict_obj = cast(dict[str, object], checkpoint_obj)["model"]
if not isinstance(state_dict_obj, dict):
raise TypeError("Unsupported checkpoint format.")
cleaned_state_dict = self._normalize_state_dict(
cast(dict[object, object], state_dict_obj)
)
try:
_ = self.load_state_dict(cleaned_state_dict, strict=strict)
except RuntimeError as exc:
raise RuntimeError(
f"Failed to load ScoNetDemo checkpoint after key normalization from '{resolved_ckpt}'."
) from exc
_ = self.eval()
def _prepare_sils(self, sils: Tensor) -> Tensor:
if sils.ndim == 4:
sils = sils.unsqueeze(1)
elif sils.ndim == 5 and sils.shape[1] != 1 and sils.shape[2] == 1:
sils = rearrange(sils, "b s c h w -> b c s h w")
if sils.ndim != 5 or sils.shape[1] != 1:
raise ValueError("Expected sils shape [B, 1, S, H, W] or [B, S, H, W].")
return sils.float().to(self.device)
def _forward_backbone(self, sils: Tensor) -> Tensor:
batch, channels, seq, height, width = sils.shape
framewise = sils.transpose(1, 2).reshape(batch * seq, channels, height, width)
frame_feats = cast(Tensor, self.backbone(framewise))
_, out_channels, out_h, out_w = frame_feats.shape
return (
frame_feats.reshape(batch, seq, out_channels, out_h, out_w)
.transpose(1, 2)
.contiguous()
)
@override
@jaxtyped(typechecker=beartype)
def forward(self, sils: Float[Tensor, "batch 1 seq 64 44"]) -> dict[str, Tensor]:
with torch.inference_mode():
prepared_sils = self._prepare_sils(sils)
outs = self._forward_backbone(prepared_sils)
pooled_obj = self.temporal_pool(outs, None, options={"dim": 2})
if (
not isinstance(pooled_obj, tuple)
or not pooled_obj
or not isinstance(pooled_obj[0], Tensor)
):
raise TypeError("TemporalPool output is invalid.")
pooled = pooled_obj[0]
feat = self.hpp(pooled)
embed_1 = self.fcs(feat)
_, logits = self.bn_necks(embed_1)
mean_logits = logits.mean(dim=-1)
pred_ids = torch.argmax(mean_logits, dim=-1)
probs = torch.softmax(mean_logits, dim=-1)
confidence = torch.gather(
probs, dim=-1, index=pred_ids.unsqueeze(-1)
).squeeze(-1)
return {"logits": logits, "label": pred_ids, "confidence": confidence}
@jaxtyped(typechecker=beartype)
def predict(self, sils: Float[Tensor, "batch 1 seq 64 44"]) -> tuple[str, float]:
outputs = cast(dict[str, Tensor], self.forward(sils))
labels = outputs["label"]
confidence = outputs["confidence"]
if labels.numel() != 1:
raise ValueError("predict expects batch size 1.")
label_id = int(labels.item())
return self.LABEL_MAP[label_id], float(confidence.item())