00fcda4fe3
Move demo implementation into opengait_studio, retire Sports2D runtime integration, and align packaging with root-level monorepo dependency management.
342 lines
12 KiB
Python
342 lines
12 KiB
Python
"""Unit tests for ScoNetDemo forward pass.
|
|
|
|
Tests cover:
|
|
- Construction from config/checkpoint path handling
|
|
- Forward output shape (N, 3, 16) and dtype float
|
|
- Predict output (label_str, confidence_float) with valid label/range
|
|
- No-DDP leakage check (no torch.distributed calls in unit behavior)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, cast
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
import torch
|
|
from torch import Tensor
|
|
|
|
if TYPE_CHECKING:
|
|
from opengait_studio.sconet_demo import ScoNetDemo
|
|
|
|
# Constants for test configuration
|
|
CONFIG_PATH = Path("configs/sconet/sconet_scoliosis1k.yaml")
|
|
|
|
|
|
@pytest.fixture
|
|
def demo() -> "ScoNetDemo":
|
|
"""Create ScoNetDemo without loading checkpoint (CPU-only)."""
|
|
from opengait_studio.sconet_demo import ScoNetDemo
|
|
|
|
return ScoNetDemo(
|
|
cfg_path=str(CONFIG_PATH),
|
|
checkpoint_path=None,
|
|
device="cpu",
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def dummy_sils_batch() -> Tensor:
|
|
"""Return dummy silhouette tensor of shape (N, 1, S, 64, 44)."""
|
|
return torch.randn(2, 1, 30, 64, 44)
|
|
|
|
|
|
@pytest.fixture
|
|
def dummy_sils_single() -> Tensor:
|
|
"""Return dummy silhouette tensor of shape (1, 1, S, 64, 44) for predict."""
|
|
return torch.randn(1, 1, 30, 64, 44)
|
|
|
|
|
|
@pytest.fixture
|
|
def synthetic_state_dict() -> dict[str, Tensor]:
|
|
"""Return a synthetic state dict compatible with ScoNetDemo structure."""
|
|
return {
|
|
"backbone.conv1.conv.weight": torch.randn(64, 1, 3, 3),
|
|
"backbone.conv1.bn.weight": torch.ones(64),
|
|
"backbone.conv1.bn.bias": torch.zeros(64),
|
|
"backbone.conv1.bn.running_mean": torch.zeros(64),
|
|
"backbone.conv1.bn.running_var": torch.ones(64),
|
|
"fcs.fc_bin": torch.randn(16, 512, 256),
|
|
"bn_necks.fc_bin": torch.randn(16, 256, 3),
|
|
"bn_necks.bn1d.weight": torch.ones(4096),
|
|
"bn_necks.bn1d.bias": torch.zeros(4096),
|
|
"bn_necks.bn1d.running_mean": torch.zeros(4096),
|
|
"bn_necks.bn1d.running_var": torch.ones(4096),
|
|
}
|
|
|
|
|
|
class TestScoNetDemoConstruction:
|
|
"""Tests for ScoNetDemo construction and path handling."""
|
|
|
|
def test_construction_from_config_no_checkpoint(self) -> None:
|
|
"""Test construction with config only, no checkpoint."""
|
|
from opengait_studio.sconet_demo import ScoNetDemo
|
|
|
|
demo = ScoNetDemo(
|
|
cfg_path=str(CONFIG_PATH),
|
|
checkpoint_path=None,
|
|
device="cpu",
|
|
)
|
|
|
|
assert demo.cfg_path.endswith("sconet_scoliosis1k.yaml")
|
|
assert demo.device == torch.device("cpu")
|
|
assert demo.cfg is not None
|
|
assert "model_cfg" in demo.cfg
|
|
assert demo.training is False # eval mode
|
|
|
|
def test_construction_with_relative_path(self) -> None:
|
|
"""Test construction handles relative config path correctly."""
|
|
from opengait_studio.sconet_demo import ScoNetDemo
|
|
|
|
demo = ScoNetDemo(
|
|
cfg_path="configs/sconet/sconet_scoliosis1k.yaml",
|
|
checkpoint_path=None,
|
|
device="cpu",
|
|
)
|
|
|
|
assert demo.cfg is not None
|
|
assert demo.backbone is not None
|
|
|
|
def test_construction_invalid_config_raises(self) -> None:
|
|
"""Test construction raises with invalid config path."""
|
|
from opengait_studio.sconet_demo import ScoNetDemo
|
|
|
|
with pytest.raises((FileNotFoundError, TypeError)):
|
|
_ = ScoNetDemo(
|
|
cfg_path="/nonexistent/path/config.yaml",
|
|
checkpoint_path=None,
|
|
device="cpu",
|
|
)
|
|
|
|
|
|
class TestScoNetDemoForward:
|
|
"""Tests for ScoNetDemo forward pass."""
|
|
|
|
def test_forward_output_shape_and_dtype(
|
|
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
|
|
) -> None:
|
|
"""Test forward returns logits with shape (N, 3, 16) and correct dtypes."""
|
|
outputs_raw = demo.forward(dummy_sils_batch)
|
|
outputs = cast(dict[str, Tensor], outputs_raw)
|
|
|
|
assert "logits" in outputs
|
|
logits = outputs["logits"]
|
|
|
|
# Expected shape: (batch_size, num_classes, parts_num) = (N, 3, 16)
|
|
assert logits.shape == (2, 3, 16)
|
|
assert logits.dtype == torch.float32
|
|
assert outputs["label"].dtype == torch.int64
|
|
assert outputs["confidence"].dtype == torch.float32
|
|
|
|
def test_forward_returns_required_keys(
|
|
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
|
|
) -> None:
|
|
"""Test forward returns required output keys."""
|
|
outputs_raw = demo.forward(dummy_sils_batch)
|
|
outputs = cast(dict[str, Tensor], outputs_raw)
|
|
|
|
required_keys = {"logits", "label", "confidence"}
|
|
assert set(outputs.keys()) >= required_keys
|
|
|
|
def test_forward_batch_size_one(
|
|
self, demo: "ScoNetDemo", dummy_sils_single: Tensor
|
|
) -> None:
|
|
"""Test forward works with batch size 1."""
|
|
outputs_raw = demo.forward(dummy_sils_single)
|
|
outputs = cast(dict[str, Tensor], outputs_raw)
|
|
|
|
assert outputs["logits"].shape == (1, 3, 16)
|
|
assert outputs["label"].shape == (1,)
|
|
assert outputs["confidence"].shape == (1,)
|
|
|
|
def test_forward_label_range(
|
|
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
|
|
) -> None:
|
|
"""Test forward returns valid label indices (0, 1, or 2)."""
|
|
outputs_raw = demo.forward(dummy_sils_batch)
|
|
outputs = cast(dict[str, Tensor], outputs_raw)
|
|
|
|
labels = outputs["label"]
|
|
assert torch.all(labels >= 0)
|
|
assert torch.all(labels <= 2)
|
|
|
|
def test_forward_confidence_range(
|
|
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
|
|
) -> None:
|
|
"""Test forward returns confidence in [0, 1]."""
|
|
outputs_raw = demo.forward(dummy_sils_batch)
|
|
outputs = cast(dict[str, Tensor], outputs_raw)
|
|
|
|
confidence = outputs["confidence"]
|
|
assert torch.all(confidence >= 0.0)
|
|
assert torch.all(confidence <= 1.0)
|
|
|
|
|
|
class TestScoNetDemoPredict:
|
|
"""Tests for ScoNetDemo predict method."""
|
|
|
|
def test_predict_returns_tuple_with_valid_types(
|
|
self, demo: "ScoNetDemo", dummy_sils_single: Tensor
|
|
) -> None:
|
|
"""Test predict returns (str, float) tuple with valid label."""
|
|
from opengait_studio.sconet_demo import ScoNetDemo
|
|
|
|
result_raw = demo.predict(dummy_sils_single)
|
|
result = cast(tuple[str, float], result_raw)
|
|
|
|
assert isinstance(result, tuple)
|
|
assert len(result) == 2
|
|
label, confidence = result
|
|
assert isinstance(label, str)
|
|
assert isinstance(confidence, float)
|
|
|
|
valid_labels = set(ScoNetDemo.LABEL_MAP.values())
|
|
assert label in valid_labels
|
|
|
|
def test_predict_confidence_range(
|
|
self, demo: "ScoNetDemo", dummy_sils_single: Tensor
|
|
) -> None:
|
|
"""Test predict returns confidence in valid range [0, 1]."""
|
|
result_raw = demo.predict(dummy_sils_single)
|
|
result = cast(tuple[str, float], result_raw)
|
|
confidence = result[1]
|
|
|
|
assert 0.0 <= confidence <= 1.0
|
|
|
|
def test_predict_rejects_batch_size_greater_than_one(
|
|
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
|
|
) -> None:
|
|
"""Test predict raises ValueError for batch size > 1."""
|
|
with pytest.raises(ValueError, match="batch size 1"):
|
|
_ = demo.predict(dummy_sils_batch)
|
|
|
|
|
|
class TestScoNetDemoNoDDP:
|
|
"""Tests to verify no DDP leakage in unit behavior."""
|
|
|
|
def test_no_distributed_init_in_construction(self) -> None:
|
|
"""Test that construction does not call torch.distributed."""
|
|
from opengait_studio.sconet_demo import ScoNetDemo
|
|
|
|
with patch("torch.distributed.is_initialized") as mock_is_init:
|
|
with patch("torch.distributed.init_process_group") as mock_init_pg:
|
|
_ = ScoNetDemo(
|
|
cfg_path=str(CONFIG_PATH),
|
|
checkpoint_path=None,
|
|
device="cpu",
|
|
)
|
|
|
|
mock_init_pg.assert_not_called()
|
|
mock_is_init.assert_not_called()
|
|
|
|
def test_forward_no_distributed_calls(
|
|
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
|
|
) -> None:
|
|
"""Test forward pass does not call torch.distributed."""
|
|
with patch("torch.distributed.all_reduce") as mock_all_reduce:
|
|
with patch("torch.distributed.broadcast") as mock_broadcast:
|
|
_ = demo.forward(dummy_sils_batch)
|
|
|
|
mock_all_reduce.assert_not_called()
|
|
mock_broadcast.assert_not_called()
|
|
|
|
def test_predict_no_distributed_calls(
|
|
self, demo: "ScoNetDemo", dummy_sils_single: Tensor
|
|
) -> None:
|
|
"""Test predict does not call torch.distributed."""
|
|
with patch("torch.distributed.all_reduce") as mock_all_reduce:
|
|
with patch("torch.distributed.broadcast") as mock_broadcast:
|
|
_ = demo.predict(dummy_sils_single)
|
|
|
|
mock_all_reduce.assert_not_called()
|
|
mock_broadcast.assert_not_called()
|
|
|
|
def test_model_not_wrapped_in_ddp(self, demo: "ScoNetDemo") -> None:
|
|
"""Test model is not wrapped in DistributedDataParallel."""
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
|
assert not isinstance(demo, DDP)
|
|
assert not isinstance(demo.backbone, DDP)
|
|
|
|
def test_device_is_cpu(self, demo: "ScoNetDemo") -> None:
|
|
"""Test model stays on CPU when device="cpu" specified."""
|
|
assert demo.device.type == "cpu"
|
|
|
|
for param in demo.parameters():
|
|
assert param.device.type == "cpu"
|
|
|
|
|
|
class TestScoNetDemoCheckpointLoading:
|
|
"""Tests for checkpoint loading behavior using synthetic state dict."""
|
|
|
|
def test_load_checkpoint_changes_weights(
|
|
self,
|
|
demo: "ScoNetDemo",
|
|
synthetic_state_dict: dict[str, Tensor],
|
|
) -> None:
|
|
"""Test loading checkpoint actually changes model weights."""
|
|
import tempfile
|
|
import os
|
|
|
|
# Get initial weight
|
|
initial_weight = next(iter(demo.parameters())).clone()
|
|
|
|
# Create temp checkpoint file with synthetic state dict
|
|
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f:
|
|
torch.save(synthetic_state_dict, f.name)
|
|
temp_path = f.name
|
|
|
|
try:
|
|
_ = demo.load_checkpoint(temp_path, strict=False)
|
|
new_weight = next(iter(demo.parameters()))
|
|
assert not torch.equal(initial_weight, new_weight)
|
|
finally:
|
|
os.unlink(temp_path)
|
|
|
|
def test_load_checkpoint_sets_eval_mode(
|
|
self,
|
|
demo: "ScoNetDemo",
|
|
synthetic_state_dict: dict[str, Tensor],
|
|
) -> None:
|
|
"""Test loading checkpoint sets model to eval mode."""
|
|
import tempfile
|
|
import os
|
|
|
|
_ = demo.train()
|
|
assert demo.training is True
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f:
|
|
torch.save(synthetic_state_dict, f.name)
|
|
temp_path = f.name
|
|
|
|
try:
|
|
_ = demo.load_checkpoint(temp_path, strict=False)
|
|
assert demo.training is False
|
|
finally:
|
|
os.unlink(temp_path)
|
|
|
|
def test_load_checkpoint_invalid_path_raises(self, demo: "ScoNetDemo") -> None:
|
|
"""Test loading from invalid checkpoint path raises error."""
|
|
with pytest.raises(FileNotFoundError):
|
|
_ = demo.load_checkpoint("/nonexistent/checkpoint.pt")
|
|
|
|
|
|
class TestScoNetDemoLabelMap:
|
|
"""Tests for LABEL_MAP constant."""
|
|
|
|
def test_label_map_has_three_classes(self) -> None:
|
|
"""Test LABEL_MAP has exactly 3 classes."""
|
|
from opengait_studio.sconet_demo import ScoNetDemo
|
|
|
|
assert len(ScoNetDemo.LABEL_MAP) == 3
|
|
assert set(ScoNetDemo.LABEL_MAP.keys()) == {0, 1, 2}
|
|
|
|
def test_label_map_values_are_valid_strings(self) -> None:
|
|
"""Test LABEL_MAP values are valid non-empty strings."""
|
|
from opengait_studio.sconet_demo import ScoNetDemo
|
|
|
|
for value in ScoNetDemo.LABEL_MAP.values():
|
|
assert isinstance(value, str)
|
|
assert len(value) > 0
|