d6fd6c03e6
Introduce focused unit, integration, and NATS-path tests for demo modules, and align assertions with final schema and temporal contracts (window int, seq=30, fill-level ratio). This commit isolates validation logic from runtime changes and provides reproducible QA for pipeline behavior and failure modes.
342 lines
12 KiB
Python
342 lines
12 KiB
Python
"""Unit tests for ScoNetDemo forward pass.
|
|
|
|
Tests cover:
|
|
- Construction from config/checkpoint path handling
|
|
- Forward output shape (N, 3, 16) and dtype float
|
|
- Predict output (label_str, confidence_float) with valid label/range
|
|
- No-DDP leakage check (no torch.distributed calls in unit behavior)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, cast
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
import torch
|
|
from torch import Tensor
|
|
|
|
if TYPE_CHECKING:
|
|
from opengait.demo.sconet_demo import ScoNetDemo
|
|
|
|
# Constants for test configuration
|
|
CONFIG_PATH = Path("configs/sconet/sconet_scoliosis1k.yaml")
|
|
|
|
|
|
@pytest.fixture
|
|
def demo() -> "ScoNetDemo":
|
|
"""Create ScoNetDemo without loading checkpoint (CPU-only)."""
|
|
from opengait.demo.sconet_demo import ScoNetDemo
|
|
|
|
return ScoNetDemo(
|
|
cfg_path=str(CONFIG_PATH),
|
|
checkpoint_path=None,
|
|
device="cpu",
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def dummy_sils_batch() -> Tensor:
|
|
"""Return dummy silhouette tensor of shape (N, 1, S, 64, 44)."""
|
|
return torch.randn(2, 1, 30, 64, 44)
|
|
|
|
|
|
@pytest.fixture
|
|
def dummy_sils_single() -> Tensor:
|
|
"""Return dummy silhouette tensor of shape (1, 1, S, 64, 44) for predict."""
|
|
return torch.randn(1, 1, 30, 64, 44)
|
|
|
|
|
|
@pytest.fixture
|
|
def synthetic_state_dict() -> dict[str, Tensor]:
|
|
"""Return a synthetic state dict compatible with ScoNetDemo structure."""
|
|
return {
|
|
"backbone.conv1.conv.weight": torch.randn(64, 1, 3, 3),
|
|
"backbone.conv1.bn.weight": torch.ones(64),
|
|
"backbone.conv1.bn.bias": torch.zeros(64),
|
|
"backbone.conv1.bn.running_mean": torch.zeros(64),
|
|
"backbone.conv1.bn.running_var": torch.ones(64),
|
|
"fcs.fc_bin": torch.randn(16, 512, 256),
|
|
"bn_necks.fc_bin": torch.randn(16, 256, 3),
|
|
"bn_necks.bn1d.weight": torch.ones(4096),
|
|
"bn_necks.bn1d.bias": torch.zeros(4096),
|
|
"bn_necks.bn1d.running_mean": torch.zeros(4096),
|
|
"bn_necks.bn1d.running_var": torch.ones(4096),
|
|
}
|
|
|
|
|
|
class TestScoNetDemoConstruction:
|
|
"""Tests for ScoNetDemo construction and path handling."""
|
|
|
|
def test_construction_from_config_no_checkpoint(self) -> None:
|
|
"""Test construction with config only, no checkpoint."""
|
|
from opengait.demo.sconet_demo import ScoNetDemo
|
|
|
|
demo = ScoNetDemo(
|
|
cfg_path=str(CONFIG_PATH),
|
|
checkpoint_path=None,
|
|
device="cpu",
|
|
)
|
|
|
|
assert demo.cfg_path.endswith("sconet_scoliosis1k.yaml")
|
|
assert demo.device == torch.device("cpu")
|
|
assert demo.cfg is not None
|
|
assert "model_cfg" in demo.cfg
|
|
assert demo.training is False # eval mode
|
|
|
|
def test_construction_with_relative_path(self) -> None:
|
|
"""Test construction handles relative config path correctly."""
|
|
from opengait.demo.sconet_demo import ScoNetDemo
|
|
|
|
demo = ScoNetDemo(
|
|
cfg_path="configs/sconet/sconet_scoliosis1k.yaml",
|
|
checkpoint_path=None,
|
|
device="cpu",
|
|
)
|
|
|
|
assert demo.cfg is not None
|
|
assert demo.backbone is not None
|
|
|
|
def test_construction_invalid_config_raises(self) -> None:
|
|
"""Test construction raises with invalid config path."""
|
|
from opengait.demo.sconet_demo import ScoNetDemo
|
|
|
|
with pytest.raises((FileNotFoundError, TypeError)):
|
|
_ = ScoNetDemo(
|
|
cfg_path="/nonexistent/path/config.yaml",
|
|
checkpoint_path=None,
|
|
device="cpu",
|
|
)
|
|
|
|
|
|
class TestScoNetDemoForward:
|
|
"""Tests for ScoNetDemo forward pass."""
|
|
|
|
def test_forward_output_shape_and_dtype(
|
|
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
|
|
) -> None:
|
|
"""Test forward returns logits with shape (N, 3, 16) and correct dtypes."""
|
|
outputs_raw = demo.forward(dummy_sils_batch)
|
|
outputs = cast(dict[str, Tensor], outputs_raw)
|
|
|
|
assert "logits" in outputs
|
|
logits = outputs["logits"]
|
|
|
|
# Expected shape: (batch_size, num_classes, parts_num) = (N, 3, 16)
|
|
assert logits.shape == (2, 3, 16)
|
|
assert logits.dtype == torch.float32
|
|
assert outputs["label"].dtype == torch.int64
|
|
assert outputs["confidence"].dtype == torch.float32
|
|
|
|
def test_forward_returns_required_keys(
|
|
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
|
|
) -> None:
|
|
"""Test forward returns required output keys."""
|
|
outputs_raw = demo.forward(dummy_sils_batch)
|
|
outputs = cast(dict[str, Tensor], outputs_raw)
|
|
|
|
required_keys = {"logits", "label", "confidence"}
|
|
assert set(outputs.keys()) >= required_keys
|
|
|
|
def test_forward_batch_size_one(
|
|
self, demo: "ScoNetDemo", dummy_sils_single: Tensor
|
|
) -> None:
|
|
"""Test forward works with batch size 1."""
|
|
outputs_raw = demo.forward(dummy_sils_single)
|
|
outputs = cast(dict[str, Tensor], outputs_raw)
|
|
|
|
assert outputs["logits"].shape == (1, 3, 16)
|
|
assert outputs["label"].shape == (1,)
|
|
assert outputs["confidence"].shape == (1,)
|
|
|
|
def test_forward_label_range(
|
|
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
|
|
) -> None:
|
|
"""Test forward returns valid label indices (0, 1, or 2)."""
|
|
outputs_raw = demo.forward(dummy_sils_batch)
|
|
outputs = cast(dict[str, Tensor], outputs_raw)
|
|
|
|
labels = outputs["label"]
|
|
assert torch.all(labels >= 0)
|
|
assert torch.all(labels <= 2)
|
|
|
|
def test_forward_confidence_range(
|
|
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
|
|
) -> None:
|
|
"""Test forward returns confidence in [0, 1]."""
|
|
outputs_raw = demo.forward(dummy_sils_batch)
|
|
outputs = cast(dict[str, Tensor], outputs_raw)
|
|
|
|
confidence = outputs["confidence"]
|
|
assert torch.all(confidence >= 0.0)
|
|
assert torch.all(confidence <= 1.0)
|
|
|
|
|
|
class TestScoNetDemoPredict:
|
|
"""Tests for ScoNetDemo predict method."""
|
|
|
|
def test_predict_returns_tuple_with_valid_types(
|
|
self, demo: "ScoNetDemo", dummy_sils_single: Tensor
|
|
) -> None:
|
|
"""Test predict returns (str, float) tuple with valid label."""
|
|
from opengait.demo.sconet_demo import ScoNetDemo
|
|
|
|
result_raw = demo.predict(dummy_sils_single)
|
|
result = cast(tuple[str, float], result_raw)
|
|
|
|
assert isinstance(result, tuple)
|
|
assert len(result) == 2
|
|
label, confidence = result
|
|
assert isinstance(label, str)
|
|
assert isinstance(confidence, float)
|
|
|
|
valid_labels = set(ScoNetDemo.LABEL_MAP.values())
|
|
assert label in valid_labels
|
|
|
|
def test_predict_confidence_range(
|
|
self, demo: "ScoNetDemo", dummy_sils_single: Tensor
|
|
) -> None:
|
|
"""Test predict returns confidence in valid range [0, 1]."""
|
|
result_raw = demo.predict(dummy_sils_single)
|
|
result = cast(tuple[str, float], result_raw)
|
|
confidence = result[1]
|
|
|
|
assert 0.0 <= confidence <= 1.0
|
|
|
|
def test_predict_rejects_batch_size_greater_than_one(
|
|
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
|
|
) -> None:
|
|
"""Test predict raises ValueError for batch size > 1."""
|
|
with pytest.raises(ValueError, match="batch size 1"):
|
|
_ = demo.predict(dummy_sils_batch)
|
|
|
|
|
|
class TestScoNetDemoNoDDP:
|
|
"""Tests to verify no DDP leakage in unit behavior."""
|
|
|
|
def test_no_distributed_init_in_construction(self) -> None:
|
|
"""Test that construction does not call torch.distributed."""
|
|
from opengait.demo.sconet_demo import ScoNetDemo
|
|
|
|
with patch("torch.distributed.is_initialized") as mock_is_init:
|
|
with patch("torch.distributed.init_process_group") as mock_init_pg:
|
|
_ = ScoNetDemo(
|
|
cfg_path=str(CONFIG_PATH),
|
|
checkpoint_path=None,
|
|
device="cpu",
|
|
)
|
|
|
|
mock_init_pg.assert_not_called()
|
|
mock_is_init.assert_not_called()
|
|
|
|
def test_forward_no_distributed_calls(
|
|
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
|
|
) -> None:
|
|
"""Test forward pass does not call torch.distributed."""
|
|
with patch("torch.distributed.all_reduce") as mock_all_reduce:
|
|
with patch("torch.distributed.broadcast") as mock_broadcast:
|
|
_ = demo.forward(dummy_sils_batch)
|
|
|
|
mock_all_reduce.assert_not_called()
|
|
mock_broadcast.assert_not_called()
|
|
|
|
def test_predict_no_distributed_calls(
|
|
self, demo: "ScoNetDemo", dummy_sils_single: Tensor
|
|
) -> None:
|
|
"""Test predict does not call torch.distributed."""
|
|
with patch("torch.distributed.all_reduce") as mock_all_reduce:
|
|
with patch("torch.distributed.broadcast") as mock_broadcast:
|
|
_ = demo.predict(dummy_sils_single)
|
|
|
|
mock_all_reduce.assert_not_called()
|
|
mock_broadcast.assert_not_called()
|
|
|
|
def test_model_not_wrapped_in_ddp(self, demo: "ScoNetDemo") -> None:
|
|
"""Test model is not wrapped in DistributedDataParallel."""
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
|
assert not isinstance(demo, DDP)
|
|
assert not isinstance(demo.backbone, DDP)
|
|
|
|
def test_device_is_cpu(self, demo: "ScoNetDemo") -> None:
|
|
"""Test model stays on CPU when device="cpu" specified."""
|
|
assert demo.device.type == "cpu"
|
|
|
|
for param in demo.parameters():
|
|
assert param.device.type == "cpu"
|
|
|
|
|
|
class TestScoNetDemoCheckpointLoading:
|
|
"""Tests for checkpoint loading behavior using synthetic state dict."""
|
|
|
|
def test_load_checkpoint_changes_weights(
|
|
self,
|
|
demo: "ScoNetDemo",
|
|
synthetic_state_dict: dict[str, Tensor],
|
|
) -> None:
|
|
"""Test loading checkpoint actually changes model weights."""
|
|
import tempfile
|
|
import os
|
|
|
|
# Get initial weight
|
|
initial_weight = next(iter(demo.parameters())).clone()
|
|
|
|
# Create temp checkpoint file with synthetic state dict
|
|
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f:
|
|
torch.save(synthetic_state_dict, f.name)
|
|
temp_path = f.name
|
|
|
|
try:
|
|
_ = demo.load_checkpoint(temp_path, strict=False)
|
|
new_weight = next(iter(demo.parameters()))
|
|
assert not torch.equal(initial_weight, new_weight)
|
|
finally:
|
|
os.unlink(temp_path)
|
|
|
|
def test_load_checkpoint_sets_eval_mode(
|
|
self,
|
|
demo: "ScoNetDemo",
|
|
synthetic_state_dict: dict[str, Tensor],
|
|
) -> None:
|
|
"""Test loading checkpoint sets model to eval mode."""
|
|
import tempfile
|
|
import os
|
|
|
|
_ = demo.train()
|
|
assert demo.training is True
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f:
|
|
torch.save(synthetic_state_dict, f.name)
|
|
temp_path = f.name
|
|
|
|
try:
|
|
_ = demo.load_checkpoint(temp_path, strict=False)
|
|
assert demo.training is False
|
|
finally:
|
|
os.unlink(temp_path)
|
|
|
|
def test_load_checkpoint_invalid_path_raises(self, demo: "ScoNetDemo") -> None:
|
|
"""Test loading from invalid checkpoint path raises error."""
|
|
with pytest.raises(FileNotFoundError):
|
|
_ = demo.load_checkpoint("/nonexistent/checkpoint.pt")
|
|
|
|
|
|
class TestScoNetDemoLabelMap:
|
|
"""Tests for LABEL_MAP constant."""
|
|
|
|
def test_label_map_has_three_classes(self) -> None:
|
|
"""Test LABEL_MAP has exactly 3 classes."""
|
|
from opengait.demo.sconet_demo import ScoNetDemo
|
|
|
|
assert len(ScoNetDemo.LABEL_MAP) == 3
|
|
assert set(ScoNetDemo.LABEL_MAP.keys()) == {0, 1, 2}
|
|
|
|
def test_label_map_values_are_valid_strings(self) -> None:
|
|
"""Test LABEL_MAP values are valid non-empty strings."""
|
|
from opengait.demo.sconet_demo import ScoNetDemo
|
|
|
|
for value in ScoNetDemo.LABEL_MAP.values():
|
|
assert isinstance(value, str)
|
|
assert len(value) > 0
|