feat: extract opengait_studio monorepo module
Move demo implementation into opengait_studio, retire Sports2D runtime integration, and align packaging with root-level monorepo dependency management.
This commit is contained in:
@@ -0,0 +1,312 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Protocol, cast, override
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from beartype import beartype
|
||||
from einops import rearrange
|
||||
from jaxtyping import Float
|
||||
import jaxtyping
|
||||
from torch import Tensor
|
||||
|
||||
from opengait.modeling.backbones.resnet import ResNet9
|
||||
from opengait.modeling.modules import (
|
||||
HorizontalPoolingPyramid,
|
||||
PackSequenceWrapper as TemporalPool,
|
||||
SeparateBNNecks,
|
||||
SeparateFCs,
|
||||
)
|
||||
from opengait.utils import common as common_utils
|
||||
|
||||
|
||||
JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
|
||||
JaxtypedFactory = Callable[..., JaxtypedDecorator]
|
||||
jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
|
||||
ConfigLoader = Callable[[str], dict[str, object]]
|
||||
config_loader = cast(ConfigLoader, common_utils.config_loader)
|
||||
|
||||
|
||||
class TemporalPoolLike(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
seqs: Tensor,
|
||||
seqL: object,
|
||||
dim: int = 2,
|
||||
options: dict[str, int] | None = None,
|
||||
) -> object: ...
|
||||
|
||||
|
||||
class HppLike(Protocol):
|
||||
def __call__(self, x: Tensor) -> Tensor: ...
|
||||
|
||||
|
||||
class FCsLike(Protocol):
|
||||
def __call__(self, x: Tensor) -> Tensor: ...
|
||||
|
||||
|
||||
class BNNecksLike(Protocol):
|
||||
def __call__(self, x: Tensor) -> tuple[Tensor, Tensor]: ...
|
||||
|
||||
|
||||
class ScoNetDemo(nn.Module):
|
||||
LABEL_MAP: ClassVar[dict[int, str]] = {0: "negative", 1: "neutral", 2: "positive"}
|
||||
cfg_path: str
|
||||
cfg: dict[str, object]
|
||||
backbone: ResNet9
|
||||
temporal_pool: TemporalPoolLike
|
||||
hpp: HppLike
|
||||
fcs: FCsLike
|
||||
bn_necks: BNNecksLike
|
||||
device: torch.device
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def __init__(
|
||||
self,
|
||||
cfg_path: str | Path = "configs/sconet/sconet_scoliosis1k.yaml",
|
||||
checkpoint_path: str | Path | None = None,
|
||||
device: str | torch.device | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
resolved_cfg = self._resolve_path(cfg_path)
|
||||
self.cfg_path = str(resolved_cfg)
|
||||
|
||||
self.cfg = config_loader(self.cfg_path)
|
||||
|
||||
model_cfg = self._extract_model_cfg(self.cfg)
|
||||
backbone_cfg = self._extract_dict(model_cfg, "backbone_cfg")
|
||||
|
||||
if backbone_cfg.get("type") != "ResNet9":
|
||||
raise ValueError(
|
||||
"ScoNetDemo currently supports backbone type ResNet9 only."
|
||||
)
|
||||
|
||||
self.backbone = ResNet9(
|
||||
block=self._extract_str(backbone_cfg, "block"),
|
||||
channels=self._extract_int_list(backbone_cfg, "channels"),
|
||||
in_channel=self._extract_int(backbone_cfg, "in_channel", default=1),
|
||||
layers=self._extract_int_list(backbone_cfg, "layers"),
|
||||
strides=self._extract_int_list(backbone_cfg, "strides"),
|
||||
maxpool=self._extract_bool(backbone_cfg, "maxpool", default=True),
|
||||
)
|
||||
|
||||
fcs_cfg = self._extract_dict(model_cfg, "SeparateFCs")
|
||||
bn_cfg = self._extract_dict(model_cfg, "SeparateBNNecks")
|
||||
bin_num = self._extract_int_list(model_cfg, "bin_num")
|
||||
|
||||
self.temporal_pool = cast(TemporalPoolLike, TemporalPool(torch.max))
|
||||
self.hpp = cast(HppLike, HorizontalPoolingPyramid(bin_num=bin_num))
|
||||
self.fcs = cast(
|
||||
FCsLike,
|
||||
SeparateFCs(
|
||||
parts_num=self._extract_int(fcs_cfg, "parts_num"),
|
||||
in_channels=self._extract_int(fcs_cfg, "in_channels"),
|
||||
out_channels=self._extract_int(fcs_cfg, "out_channels"),
|
||||
norm=self._extract_bool(fcs_cfg, "norm", default=False),
|
||||
),
|
||||
)
|
||||
self.bn_necks = cast(
|
||||
BNNecksLike,
|
||||
SeparateBNNecks(
|
||||
parts_num=self._extract_int(bn_cfg, "parts_num"),
|
||||
in_channels=self._extract_int(bn_cfg, "in_channels"),
|
||||
class_num=self._extract_int(bn_cfg, "class_num"),
|
||||
norm=self._extract_bool(bn_cfg, "norm", default=True),
|
||||
parallel_BN1d=self._extract_bool(bn_cfg, "parallel_BN1d", default=True),
|
||||
),
|
||||
)
|
||||
|
||||
self.device = (
|
||||
torch.device(device) if device is not None else torch.device("cpu")
|
||||
)
|
||||
_ = self.to(self.device)
|
||||
|
||||
if checkpoint_path is not None:
|
||||
_ = self.load_checkpoint(checkpoint_path)
|
||||
|
||||
_ = self.eval()
|
||||
|
||||
@staticmethod
|
||||
def _resolve_path(path: str | Path) -> Path:
|
||||
candidate = Path(path)
|
||||
if candidate.is_file():
|
||||
return candidate
|
||||
if candidate.is_absolute():
|
||||
return candidate
|
||||
repo_root = Path(__file__).resolve().parents[2]
|
||||
return repo_root / candidate
|
||||
|
||||
@staticmethod
|
||||
def _extract_model_cfg(cfg: dict[str, object]) -> dict[str, object]:
|
||||
model_cfg_obj = cfg.get("model_cfg")
|
||||
if not isinstance(model_cfg_obj, dict):
|
||||
raise TypeError("model_cfg must be a dictionary.")
|
||||
return cast(dict[str, object], model_cfg_obj)
|
||||
|
||||
@staticmethod
|
||||
def _extract_dict(container: dict[str, object], key: str) -> dict[str, object]:
|
||||
value = container.get(key)
|
||||
if not isinstance(value, dict):
|
||||
raise TypeError(f"{key} must be a dictionary.")
|
||||
return cast(dict[str, object], value)
|
||||
|
||||
@staticmethod
|
||||
def _extract_str(container: dict[str, object], key: str) -> str:
|
||||
value = container.get(key)
|
||||
if not isinstance(value, str):
|
||||
raise TypeError(f"{key} must be a string.")
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _extract_int(
|
||||
container: dict[str, object], key: str, default: int | None = None
|
||||
) -> int:
|
||||
value = container.get(key, default)
|
||||
if not isinstance(value, int):
|
||||
raise TypeError(f"{key} must be an int.")
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _extract_bool(
|
||||
container: dict[str, object], key: str, default: bool | None = None
|
||||
) -> bool:
|
||||
value = container.get(key, default)
|
||||
if not isinstance(value, bool):
|
||||
raise TypeError(f"{key} must be a bool.")
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _extract_int_list(container: dict[str, object], key: str) -> list[int]:
|
||||
value = container.get(key)
|
||||
if not isinstance(value, list):
|
||||
raise TypeError(f"{key} must be a list[int].")
|
||||
values = cast(list[object], value)
|
||||
if not all(isinstance(v, int) for v in values):
|
||||
raise TypeError(f"{key} must be a list[int].")
|
||||
return cast(list[int], values)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_state_dict(
|
||||
state_dict_obj: dict[object, object],
|
||||
) -> dict[str, Tensor]:
|
||||
prefix_remap: tuple[tuple[str, str], ...] = (
|
||||
("Backbone.forward_block.", "backbone."),
|
||||
("FCs.", "fcs."),
|
||||
("BNNecks.", "bn_necks."),
|
||||
)
|
||||
cleaned_state_dict: dict[str, Tensor] = {}
|
||||
for key_obj, value_obj in state_dict_obj.items():
|
||||
if not isinstance(key_obj, str):
|
||||
raise TypeError("Checkpoint state_dict keys must be strings.")
|
||||
if not isinstance(value_obj, Tensor):
|
||||
raise TypeError("Checkpoint state_dict values must be torch.Tensor.")
|
||||
key = key_obj[7:] if key_obj.startswith("module.") else key_obj
|
||||
for source_prefix, target_prefix in prefix_remap:
|
||||
if key.startswith(source_prefix):
|
||||
key = f"{target_prefix}{key[len(source_prefix) :]}"
|
||||
break
|
||||
if key in cleaned_state_dict:
|
||||
raise RuntimeError(
|
||||
f"Checkpoint key normalization collision detected for key '{key}'."
|
||||
)
|
||||
cleaned_state_dict[key] = value_obj
|
||||
return cleaned_state_dict
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def load_checkpoint(
|
||||
self,
|
||||
checkpoint_path: str | Path,
|
||||
map_location: str | torch.device | None = None,
|
||||
strict: bool = True,
|
||||
) -> None:
|
||||
resolved_ckpt = self._resolve_path(checkpoint_path)
|
||||
checkpoint_obj = cast(
|
||||
object,
|
||||
torch.load(
|
||||
str(resolved_ckpt),
|
||||
map_location=map_location if map_location is not None else self.device,
|
||||
),
|
||||
)
|
||||
|
||||
state_dict_obj: object = checkpoint_obj
|
||||
if isinstance(checkpoint_obj, dict) and "model" in checkpoint_obj:
|
||||
state_dict_obj = cast(dict[str, object], checkpoint_obj)["model"]
|
||||
|
||||
if not isinstance(state_dict_obj, dict):
|
||||
raise TypeError("Unsupported checkpoint format.")
|
||||
|
||||
cleaned_state_dict = self._normalize_state_dict(
|
||||
cast(dict[object, object], state_dict_obj)
|
||||
)
|
||||
try:
|
||||
_ = self.load_state_dict(cleaned_state_dict, strict=strict)
|
||||
except RuntimeError as exc:
|
||||
raise RuntimeError(
|
||||
f"Failed to load ScoNetDemo checkpoint after key normalization from '{resolved_ckpt}'."
|
||||
) from exc
|
||||
_ = self.eval()
|
||||
|
||||
def _prepare_sils(self, sils: Tensor) -> Tensor:
|
||||
if sils.ndim == 4:
|
||||
sils = sils.unsqueeze(1)
|
||||
elif sils.ndim == 5 and sils.shape[1] != 1 and sils.shape[2] == 1:
|
||||
sils = rearrange(sils, "b s c h w -> b c s h w")
|
||||
|
||||
if sils.ndim != 5 or sils.shape[1] != 1:
|
||||
raise ValueError("Expected sils shape [B, 1, S, H, W] or [B, S, H, W].")
|
||||
|
||||
return sils.float().to(self.device)
|
||||
|
||||
def _forward_backbone(self, sils: Tensor) -> Tensor:
|
||||
batch, channels, seq, height, width = sils.shape
|
||||
framewise = sils.transpose(1, 2).reshape(batch * seq, channels, height, width)
|
||||
frame_feats = cast(Tensor, self.backbone(framewise))
|
||||
_, out_channels, out_h, out_w = frame_feats.shape
|
||||
return (
|
||||
frame_feats.reshape(batch, seq, out_channels, out_h, out_w)
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
@override
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def forward(self, sils: Float[Tensor, "batch 1 seq 64 44"]) -> dict[str, Tensor]:
|
||||
with torch.inference_mode():
|
||||
prepared_sils = self._prepare_sils(sils)
|
||||
outs = self._forward_backbone(prepared_sils)
|
||||
|
||||
pooled_obj = self.temporal_pool(outs, None, options={"dim": 2})
|
||||
if (
|
||||
not isinstance(pooled_obj, tuple)
|
||||
or not pooled_obj
|
||||
or not isinstance(pooled_obj[0], Tensor)
|
||||
):
|
||||
raise TypeError("TemporalPool output is invalid.")
|
||||
pooled = pooled_obj[0]
|
||||
|
||||
feat = self.hpp(pooled)
|
||||
embed_1 = self.fcs(feat)
|
||||
_, logits = self.bn_necks(embed_1)
|
||||
|
||||
mean_logits = logits.mean(dim=-1)
|
||||
pred_ids = torch.argmax(mean_logits, dim=-1)
|
||||
probs = torch.softmax(mean_logits, dim=-1)
|
||||
confidence = torch.gather(
|
||||
probs, dim=-1, index=pred_ids.unsqueeze(-1)
|
||||
).squeeze(-1)
|
||||
|
||||
return {"logits": logits, "label": pred_ids, "confidence": confidence}
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def predict(self, sils: Float[Tensor, "batch 1 seq 64 44"]) -> tuple[str, float]:
|
||||
outputs = cast(dict[str, Tensor], self.forward(sils))
|
||||
labels = outputs["label"]
|
||||
confidence = outputs["confidence"]
|
||||
|
||||
if labels.numel() != 1:
|
||||
raise ValueError("predict expects batch size 1.")
|
||||
|
||||
label_id = int(labels.item())
|
||||
return self.LABEL_MAP[label_id], float(confidence.item())
|
||||
Reference in New Issue
Block a user