Files
crosstyan 00fcda4fe3 feat: extract opengait_studio monorepo module
Move demo implementation into opengait_studio, retire Sports2D runtime integration, and align packaging with root-level monorepo dependency management.
2026-03-07 18:14:13 +08:00

342 lines
12 KiB
Python

"""Unit tests for ScoNetDemo forward pass.
Tests cover:
- Construction from config/checkpoint path handling
- Forward output shape (N, 3, 16) and dtype float
- Predict output (label_str, confidence_float) with valid label/range
- No-DDP leakage check (no torch.distributed calls in unit behavior)
"""
from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING, cast
from unittest.mock import patch
import pytest
import torch
from torch import Tensor
if TYPE_CHECKING:
from opengait_studio.sconet_demo import ScoNetDemo
# Constants for test configuration
CONFIG_PATH = Path("configs/sconet/sconet_scoliosis1k.yaml")
@pytest.fixture
def demo() -> "ScoNetDemo":
"""Create ScoNetDemo without loading checkpoint (CPU-only)."""
from opengait_studio.sconet_demo import ScoNetDemo
return ScoNetDemo(
cfg_path=str(CONFIG_PATH),
checkpoint_path=None,
device="cpu",
)
@pytest.fixture
def dummy_sils_batch() -> Tensor:
"""Return dummy silhouette tensor of shape (N, 1, S, 64, 44)."""
return torch.randn(2, 1, 30, 64, 44)
@pytest.fixture
def dummy_sils_single() -> Tensor:
"""Return dummy silhouette tensor of shape (1, 1, S, 64, 44) for predict."""
return torch.randn(1, 1, 30, 64, 44)
@pytest.fixture
def synthetic_state_dict() -> dict[str, Tensor]:
"""Return a synthetic state dict compatible with ScoNetDemo structure."""
return {
"backbone.conv1.conv.weight": torch.randn(64, 1, 3, 3),
"backbone.conv1.bn.weight": torch.ones(64),
"backbone.conv1.bn.bias": torch.zeros(64),
"backbone.conv1.bn.running_mean": torch.zeros(64),
"backbone.conv1.bn.running_var": torch.ones(64),
"fcs.fc_bin": torch.randn(16, 512, 256),
"bn_necks.fc_bin": torch.randn(16, 256, 3),
"bn_necks.bn1d.weight": torch.ones(4096),
"bn_necks.bn1d.bias": torch.zeros(4096),
"bn_necks.bn1d.running_mean": torch.zeros(4096),
"bn_necks.bn1d.running_var": torch.ones(4096),
}
class TestScoNetDemoConstruction:
"""Tests for ScoNetDemo construction and path handling."""
def test_construction_from_config_no_checkpoint(self) -> None:
"""Test construction with config only, no checkpoint."""
from opengait_studio.sconet_demo import ScoNetDemo
demo = ScoNetDemo(
cfg_path=str(CONFIG_PATH),
checkpoint_path=None,
device="cpu",
)
assert demo.cfg_path.endswith("sconet_scoliosis1k.yaml")
assert demo.device == torch.device("cpu")
assert demo.cfg is not None
assert "model_cfg" in demo.cfg
assert demo.training is False # eval mode
def test_construction_with_relative_path(self) -> None:
"""Test construction handles relative config path correctly."""
from opengait_studio.sconet_demo import ScoNetDemo
demo = ScoNetDemo(
cfg_path="configs/sconet/sconet_scoliosis1k.yaml",
checkpoint_path=None,
device="cpu",
)
assert demo.cfg is not None
assert demo.backbone is not None
def test_construction_invalid_config_raises(self) -> None:
"""Test construction raises with invalid config path."""
from opengait_studio.sconet_demo import ScoNetDemo
with pytest.raises((FileNotFoundError, TypeError)):
_ = ScoNetDemo(
cfg_path="/nonexistent/path/config.yaml",
checkpoint_path=None,
device="cpu",
)
class TestScoNetDemoForward:
"""Tests for ScoNetDemo forward pass."""
def test_forward_output_shape_and_dtype(
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
) -> None:
"""Test forward returns logits with shape (N, 3, 16) and correct dtypes."""
outputs_raw = demo.forward(dummy_sils_batch)
outputs = cast(dict[str, Tensor], outputs_raw)
assert "logits" in outputs
logits = outputs["logits"]
# Expected shape: (batch_size, num_classes, parts_num) = (N, 3, 16)
assert logits.shape == (2, 3, 16)
assert logits.dtype == torch.float32
assert outputs["label"].dtype == torch.int64
assert outputs["confidence"].dtype == torch.float32
def test_forward_returns_required_keys(
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
) -> None:
"""Test forward returns required output keys."""
outputs_raw = demo.forward(dummy_sils_batch)
outputs = cast(dict[str, Tensor], outputs_raw)
required_keys = {"logits", "label", "confidence"}
assert set(outputs.keys()) >= required_keys
def test_forward_batch_size_one(
self, demo: "ScoNetDemo", dummy_sils_single: Tensor
) -> None:
"""Test forward works with batch size 1."""
outputs_raw = demo.forward(dummy_sils_single)
outputs = cast(dict[str, Tensor], outputs_raw)
assert outputs["logits"].shape == (1, 3, 16)
assert outputs["label"].shape == (1,)
assert outputs["confidence"].shape == (1,)
def test_forward_label_range(
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
) -> None:
"""Test forward returns valid label indices (0, 1, or 2)."""
outputs_raw = demo.forward(dummy_sils_batch)
outputs = cast(dict[str, Tensor], outputs_raw)
labels = outputs["label"]
assert torch.all(labels >= 0)
assert torch.all(labels <= 2)
def test_forward_confidence_range(
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
) -> None:
"""Test forward returns confidence in [0, 1]."""
outputs_raw = demo.forward(dummy_sils_batch)
outputs = cast(dict[str, Tensor], outputs_raw)
confidence = outputs["confidence"]
assert torch.all(confidence >= 0.0)
assert torch.all(confidence <= 1.0)
class TestScoNetDemoPredict:
"""Tests for ScoNetDemo predict method."""
def test_predict_returns_tuple_with_valid_types(
self, demo: "ScoNetDemo", dummy_sils_single: Tensor
) -> None:
"""Test predict returns (str, float) tuple with valid label."""
from opengait_studio.sconet_demo import ScoNetDemo
result_raw = demo.predict(dummy_sils_single)
result = cast(tuple[str, float], result_raw)
assert isinstance(result, tuple)
assert len(result) == 2
label, confidence = result
assert isinstance(label, str)
assert isinstance(confidence, float)
valid_labels = set(ScoNetDemo.LABEL_MAP.values())
assert label in valid_labels
def test_predict_confidence_range(
self, demo: "ScoNetDemo", dummy_sils_single: Tensor
) -> None:
"""Test predict returns confidence in valid range [0, 1]."""
result_raw = demo.predict(dummy_sils_single)
result = cast(tuple[str, float], result_raw)
confidence = result[1]
assert 0.0 <= confidence <= 1.0
def test_predict_rejects_batch_size_greater_than_one(
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
) -> None:
"""Test predict raises ValueError for batch size > 1."""
with pytest.raises(ValueError, match="batch size 1"):
_ = demo.predict(dummy_sils_batch)
class TestScoNetDemoNoDDP:
"""Tests to verify no DDP leakage in unit behavior."""
def test_no_distributed_init_in_construction(self) -> None:
"""Test that construction does not call torch.distributed."""
from opengait_studio.sconet_demo import ScoNetDemo
with patch("torch.distributed.is_initialized") as mock_is_init:
with patch("torch.distributed.init_process_group") as mock_init_pg:
_ = ScoNetDemo(
cfg_path=str(CONFIG_PATH),
checkpoint_path=None,
device="cpu",
)
mock_init_pg.assert_not_called()
mock_is_init.assert_not_called()
def test_forward_no_distributed_calls(
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
) -> None:
"""Test forward pass does not call torch.distributed."""
with patch("torch.distributed.all_reduce") as mock_all_reduce:
with patch("torch.distributed.broadcast") as mock_broadcast:
_ = demo.forward(dummy_sils_batch)
mock_all_reduce.assert_not_called()
mock_broadcast.assert_not_called()
def test_predict_no_distributed_calls(
self, demo: "ScoNetDemo", dummy_sils_single: Tensor
) -> None:
"""Test predict does not call torch.distributed."""
with patch("torch.distributed.all_reduce") as mock_all_reduce:
with patch("torch.distributed.broadcast") as mock_broadcast:
_ = demo.predict(dummy_sils_single)
mock_all_reduce.assert_not_called()
mock_broadcast.assert_not_called()
def test_model_not_wrapped_in_ddp(self, demo: "ScoNetDemo") -> None:
"""Test model is not wrapped in DistributedDataParallel."""
from torch.nn.parallel import DistributedDataParallel as DDP
assert not isinstance(demo, DDP)
assert not isinstance(demo.backbone, DDP)
def test_device_is_cpu(self, demo: "ScoNetDemo") -> None:
"""Test model stays on CPU when device="cpu" specified."""
assert demo.device.type == "cpu"
for param in demo.parameters():
assert param.device.type == "cpu"
class TestScoNetDemoCheckpointLoading:
"""Tests for checkpoint loading behavior using synthetic state dict."""
def test_load_checkpoint_changes_weights(
self,
demo: "ScoNetDemo",
synthetic_state_dict: dict[str, Tensor],
) -> None:
"""Test loading checkpoint actually changes model weights."""
import tempfile
import os
# Get initial weight
initial_weight = next(iter(demo.parameters())).clone()
# Create temp checkpoint file with synthetic state dict
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f:
torch.save(synthetic_state_dict, f.name)
temp_path = f.name
try:
_ = demo.load_checkpoint(temp_path, strict=False)
new_weight = next(iter(demo.parameters()))
assert not torch.equal(initial_weight, new_weight)
finally:
os.unlink(temp_path)
def test_load_checkpoint_sets_eval_mode(
self,
demo: "ScoNetDemo",
synthetic_state_dict: dict[str, Tensor],
) -> None:
"""Test loading checkpoint sets model to eval mode."""
import tempfile
import os
_ = demo.train()
assert demo.training is True
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f:
torch.save(synthetic_state_dict, f.name)
temp_path = f.name
try:
_ = demo.load_checkpoint(temp_path, strict=False)
assert demo.training is False
finally:
os.unlink(temp_path)
def test_load_checkpoint_invalid_path_raises(self, demo: "ScoNetDemo") -> None:
"""Test loading from invalid checkpoint path raises error."""
with pytest.raises(FileNotFoundError):
_ = demo.load_checkpoint("/nonexistent/checkpoint.pt")
class TestScoNetDemoLabelMap:
"""Tests for LABEL_MAP constant."""
def test_label_map_has_three_classes(self) -> None:
"""Test LABEL_MAP has exactly 3 classes."""
from opengait_studio.sconet_demo import ScoNetDemo
assert len(ScoNetDemo.LABEL_MAP) == 3
assert set(ScoNetDemo.LABEL_MAP.keys()) == {0, 1, 2}
def test_label_map_values_are_valid_strings(self) -> None:
"""Test LABEL_MAP values are valid non-empty strings."""
from opengait_studio.sconet_demo import ScoNetDemo
for value in ScoNetDemo.LABEL_MAP.values():
assert isinstance(value, str)
assert len(value) > 0