"""Unit tests for ScoNetDemo forward pass. Tests cover: - Construction from config/checkpoint path handling - Forward output shape (N, 3, 16) and dtype float - Predict output (label_str, confidence_float) with valid label/range - No-DDP leakage check (no torch.distributed calls in unit behavior) """ from __future__ import annotations from pathlib import Path from typing import TYPE_CHECKING, cast from unittest.mock import patch import pytest import torch from torch import Tensor if TYPE_CHECKING: from opengait_studio.sconet_demo import ScoNetDemo # Constants for test configuration CONFIG_PATH = Path("configs/sconet/sconet_scoliosis1k.yaml") @pytest.fixture def demo() -> "ScoNetDemo": """Create ScoNetDemo without loading checkpoint (CPU-only).""" from opengait_studio.sconet_demo import ScoNetDemo return ScoNetDemo( cfg_path=str(CONFIG_PATH), checkpoint_path=None, device="cpu", ) @pytest.fixture def dummy_sils_batch() -> Tensor: """Return dummy silhouette tensor of shape (N, 1, S, 64, 44).""" return torch.randn(2, 1, 30, 64, 44) @pytest.fixture def dummy_sils_single() -> Tensor: """Return dummy silhouette tensor of shape (1, 1, S, 64, 44) for predict.""" return torch.randn(1, 1, 30, 64, 44) @pytest.fixture def synthetic_state_dict() -> dict[str, Tensor]: """Return a synthetic state dict compatible with ScoNetDemo structure.""" return { "backbone.conv1.conv.weight": torch.randn(64, 1, 3, 3), "backbone.conv1.bn.weight": torch.ones(64), "backbone.conv1.bn.bias": torch.zeros(64), "backbone.conv1.bn.running_mean": torch.zeros(64), "backbone.conv1.bn.running_var": torch.ones(64), "fcs.fc_bin": torch.randn(16, 512, 256), "bn_necks.fc_bin": torch.randn(16, 256, 3), "bn_necks.bn1d.weight": torch.ones(4096), "bn_necks.bn1d.bias": torch.zeros(4096), "bn_necks.bn1d.running_mean": torch.zeros(4096), "bn_necks.bn1d.running_var": torch.ones(4096), } class TestScoNetDemoConstruction: """Tests for ScoNetDemo construction and path handling.""" def test_construction_from_config_no_checkpoint(self) -> None: """Test construction with config only, no checkpoint.""" from opengait_studio.sconet_demo import ScoNetDemo demo = ScoNetDemo( cfg_path=str(CONFIG_PATH), checkpoint_path=None, device="cpu", ) assert demo.cfg_path.endswith("sconet_scoliosis1k.yaml") assert demo.device == torch.device("cpu") assert demo.cfg is not None assert "model_cfg" in demo.cfg assert demo.training is False # eval mode def test_construction_with_relative_path(self) -> None: """Test construction handles relative config path correctly.""" from opengait_studio.sconet_demo import ScoNetDemo demo = ScoNetDemo( cfg_path="configs/sconet/sconet_scoliosis1k.yaml", checkpoint_path=None, device="cpu", ) assert demo.cfg is not None assert demo.backbone is not None def test_construction_invalid_config_raises(self) -> None: """Test construction raises with invalid config path.""" from opengait_studio.sconet_demo import ScoNetDemo with pytest.raises((FileNotFoundError, TypeError)): _ = ScoNetDemo( cfg_path="/nonexistent/path/config.yaml", checkpoint_path=None, device="cpu", ) class TestScoNetDemoForward: """Tests for ScoNetDemo forward pass.""" def test_forward_output_shape_and_dtype( self, demo: "ScoNetDemo", dummy_sils_batch: Tensor ) -> None: """Test forward returns logits with shape (N, 3, 16) and correct dtypes.""" outputs_raw = demo.forward(dummy_sils_batch) outputs = cast(dict[str, Tensor], outputs_raw) assert "logits" in outputs logits = outputs["logits"] # Expected shape: (batch_size, num_classes, parts_num) = (N, 3, 16) assert logits.shape == (2, 3, 16) assert logits.dtype == torch.float32 assert outputs["label"].dtype == torch.int64 assert outputs["confidence"].dtype == torch.float32 def test_forward_returns_required_keys( self, demo: "ScoNetDemo", dummy_sils_batch: Tensor ) -> None: """Test forward returns required output keys.""" outputs_raw = demo.forward(dummy_sils_batch) outputs = cast(dict[str, Tensor], outputs_raw) required_keys = {"logits", "label", "confidence"} assert set(outputs.keys()) >= required_keys def test_forward_batch_size_one( self, demo: "ScoNetDemo", dummy_sils_single: Tensor ) -> None: """Test forward works with batch size 1.""" outputs_raw = demo.forward(dummy_sils_single) outputs = cast(dict[str, Tensor], outputs_raw) assert outputs["logits"].shape == (1, 3, 16) assert outputs["label"].shape == (1,) assert outputs["confidence"].shape == (1,) def test_forward_label_range( self, demo: "ScoNetDemo", dummy_sils_batch: Tensor ) -> None: """Test forward returns valid label indices (0, 1, or 2).""" outputs_raw = demo.forward(dummy_sils_batch) outputs = cast(dict[str, Tensor], outputs_raw) labels = outputs["label"] assert torch.all(labels >= 0) assert torch.all(labels <= 2) def test_forward_confidence_range( self, demo: "ScoNetDemo", dummy_sils_batch: Tensor ) -> None: """Test forward returns confidence in [0, 1].""" outputs_raw = demo.forward(dummy_sils_batch) outputs = cast(dict[str, Tensor], outputs_raw) confidence = outputs["confidence"] assert torch.all(confidence >= 0.0) assert torch.all(confidence <= 1.0) class TestScoNetDemoPredict: """Tests for ScoNetDemo predict method.""" def test_predict_returns_tuple_with_valid_types( self, demo: "ScoNetDemo", dummy_sils_single: Tensor ) -> None: """Test predict returns (str, float) tuple with valid label.""" from opengait_studio.sconet_demo import ScoNetDemo result_raw = demo.predict(dummy_sils_single) result = cast(tuple[str, float], result_raw) assert isinstance(result, tuple) assert len(result) == 2 label, confidence = result assert isinstance(label, str) assert isinstance(confidence, float) valid_labels = set(ScoNetDemo.LABEL_MAP.values()) assert label in valid_labels def test_predict_confidence_range( self, demo: "ScoNetDemo", dummy_sils_single: Tensor ) -> None: """Test predict returns confidence in valid range [0, 1].""" result_raw = demo.predict(dummy_sils_single) result = cast(tuple[str, float], result_raw) confidence = result[1] assert 0.0 <= confidence <= 1.0 def test_predict_rejects_batch_size_greater_than_one( self, demo: "ScoNetDemo", dummy_sils_batch: Tensor ) -> None: """Test predict raises ValueError for batch size > 1.""" with pytest.raises(ValueError, match="batch size 1"): _ = demo.predict(dummy_sils_batch) class TestScoNetDemoNoDDP: """Tests to verify no DDP leakage in unit behavior.""" def test_no_distributed_init_in_construction(self) -> None: """Test that construction does not call torch.distributed.""" from opengait_studio.sconet_demo import ScoNetDemo with patch("torch.distributed.is_initialized") as mock_is_init: with patch("torch.distributed.init_process_group") as mock_init_pg: _ = ScoNetDemo( cfg_path=str(CONFIG_PATH), checkpoint_path=None, device="cpu", ) mock_init_pg.assert_not_called() mock_is_init.assert_not_called() def test_forward_no_distributed_calls( self, demo: "ScoNetDemo", dummy_sils_batch: Tensor ) -> None: """Test forward pass does not call torch.distributed.""" with patch("torch.distributed.all_reduce") as mock_all_reduce: with patch("torch.distributed.broadcast") as mock_broadcast: _ = demo.forward(dummy_sils_batch) mock_all_reduce.assert_not_called() mock_broadcast.assert_not_called() def test_predict_no_distributed_calls( self, demo: "ScoNetDemo", dummy_sils_single: Tensor ) -> None: """Test predict does not call torch.distributed.""" with patch("torch.distributed.all_reduce") as mock_all_reduce: with patch("torch.distributed.broadcast") as mock_broadcast: _ = demo.predict(dummy_sils_single) mock_all_reduce.assert_not_called() mock_broadcast.assert_not_called() def test_model_not_wrapped_in_ddp(self, demo: "ScoNetDemo") -> None: """Test model is not wrapped in DistributedDataParallel.""" from torch.nn.parallel import DistributedDataParallel as DDP assert not isinstance(demo, DDP) assert not isinstance(demo.backbone, DDP) def test_device_is_cpu(self, demo: "ScoNetDemo") -> None: """Test model stays on CPU when device="cpu" specified.""" assert demo.device.type == "cpu" for param in demo.parameters(): assert param.device.type == "cpu" class TestScoNetDemoCheckpointLoading: """Tests for checkpoint loading behavior using synthetic state dict.""" def test_load_checkpoint_changes_weights( self, demo: "ScoNetDemo", synthetic_state_dict: dict[str, Tensor], ) -> None: """Test loading checkpoint actually changes model weights.""" import tempfile import os # Get initial weight initial_weight = next(iter(demo.parameters())).clone() # Create temp checkpoint file with synthetic state dict with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f: torch.save(synthetic_state_dict, f.name) temp_path = f.name try: _ = demo.load_checkpoint(temp_path, strict=False) new_weight = next(iter(demo.parameters())) assert not torch.equal(initial_weight, new_weight) finally: os.unlink(temp_path) def test_load_checkpoint_sets_eval_mode( self, demo: "ScoNetDemo", synthetic_state_dict: dict[str, Tensor], ) -> None: """Test loading checkpoint sets model to eval mode.""" import tempfile import os _ = demo.train() assert demo.training is True with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f: torch.save(synthetic_state_dict, f.name) temp_path = f.name try: _ = demo.load_checkpoint(temp_path, strict=False) assert demo.training is False finally: os.unlink(temp_path) def test_load_checkpoint_invalid_path_raises(self, demo: "ScoNetDemo") -> None: """Test loading from invalid checkpoint path raises error.""" with pytest.raises(FileNotFoundError): _ = demo.load_checkpoint("/nonexistent/checkpoint.pt") class TestScoNetDemoLabelMap: """Tests for LABEL_MAP constant.""" def test_label_map_has_three_classes(self) -> None: """Test LABEL_MAP has exactly 3 classes.""" from opengait_studio.sconet_demo import ScoNetDemo assert len(ScoNetDemo.LABEL_MAP) == 3 assert set(ScoNetDemo.LABEL_MAP.keys()) == {0, 1, 2} def test_label_map_values_are_valid_strings(self) -> None: """Test LABEL_MAP values are valid non-empty strings.""" from opengait_studio.sconet_demo import ScoNetDemo for value in ScoNetDemo.LABEL_MAP.values(): assert isinstance(value, str) assert len(value) > 0