test(demo): add unit and integration coverage for pipeline
Introduce focused unit, integration, and NATS-path tests for demo modules, and align assertions with final schema and temporal contracts (window int, seq=30, fill-level ratio). This commit isolates validation logic from runtime changes and provides reproducible QA for pipeline behavior and failure modes.
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""Tests for demo package."""
|
||||
@@ -0,0 +1,23 @@
|
||||
"""Test fixtures for demo package."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_frame_tensor():
|
||||
"""Return a mock video frame tensor (C, H, W)."""
|
||||
return torch.randn(3, 224, 224)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_frame_array():
|
||||
"""Return a mock video frame as numpy array (H, W, C)."""
|
||||
return np.random.randn(224, 224, 3).astype(np.float32)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_video_sequence():
|
||||
"""Return a mock video sequence tensor (T, C, H, W)."""
|
||||
return torch.randn(16, 3, 224, 224)
|
||||
@@ -0,0 +1,525 @@
|
||||
"""Integration tests for NATS publisher functionality.
|
||||
|
||||
Tests cover:
|
||||
- NATS message receipt with JSON schema validation
|
||||
- Skip behavior when Docker/NATS unavailable
|
||||
- Container lifecycle management (cleanup)
|
||||
- JSON schema field/type validation
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import socket
|
||||
import subprocess
|
||||
import time
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
import pytest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
|
||||
|
||||
def _find_open_port() -> int:
|
||||
"""Find an available TCP port on localhost."""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
sock.bind(("127.0.0.1", 0))
|
||||
sock.listen(1)
|
||||
addr = cast(tuple[str, int], sock.getsockname())
|
||||
port: int = addr[1]
|
||||
return port
|
||||
|
||||
|
||||
# Constants for test configuration
|
||||
NATS_SUBJECT = "scoliosis.result"
|
||||
CONTAINER_NAME = "opengait-nats-test"
|
||||
|
||||
|
||||
def _docker_available() -> bool:
|
||||
"""Check if Docker is available and running."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["docker", "info"],
|
||||
capture_output=True,
|
||||
timeout=5,
|
||||
check=False,
|
||||
)
|
||||
return result.returncode == 0
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
|
||||
return False
|
||||
|
||||
|
||||
def _nats_container_running() -> bool:
|
||||
"""Check if NATS container is already running."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[
|
||||
"docker",
|
||||
"ps",
|
||||
"--filter",
|
||||
f"name={CONTAINER_NAME}",
|
||||
"--format",
|
||||
"{{.Names}}",
|
||||
],
|
||||
capture_output=True,
|
||||
timeout=5,
|
||||
check=False,
|
||||
)
|
||||
return CONTAINER_NAME in result.stdout.decode()
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
|
||||
return False
|
||||
|
||||
|
||||
def _start_nats_container(port: int) -> bool:
|
||||
"""Start NATS container for testing."""
|
||||
try:
|
||||
# Remove existing container if present
|
||||
_ = subprocess.run(
|
||||
["docker", "rm", "-f", CONTAINER_NAME],
|
||||
capture_output=True,
|
||||
timeout=10,
|
||||
check=False,
|
||||
)
|
||||
|
||||
# Start new NATS container
|
||||
result = subprocess.run(
|
||||
[
|
||||
"docker",
|
||||
"run",
|
||||
"-d",
|
||||
"--name",
|
||||
CONTAINER_NAME,
|
||||
"-p",
|
||||
f"{port}:{port}",
|
||||
"nats:latest",
|
||||
],
|
||||
capture_output=True,
|
||||
timeout=30,
|
||||
check=False,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
return False
|
||||
|
||||
# Wait for NATS to be ready (max 10 seconds)
|
||||
for _ in range(20):
|
||||
time.sleep(0.5)
|
||||
try:
|
||||
check_result = subprocess.run(
|
||||
[
|
||||
"docker",
|
||||
"exec",
|
||||
CONTAINER_NAME,
|
||||
"nats",
|
||||
"server",
|
||||
"check",
|
||||
"server",
|
||||
],
|
||||
capture_output=True,
|
||||
timeout=5,
|
||||
check=False,
|
||||
)
|
||||
if check_result.returncode == 0:
|
||||
return True
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
|
||||
continue
|
||||
|
||||
return False
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
|
||||
return False
|
||||
|
||||
|
||||
def _stop_nats_container() -> None:
|
||||
"""Stop and remove NATS container."""
|
||||
try:
|
||||
_ = subprocess.run(
|
||||
["docker", "rm", "-f", CONTAINER_NAME],
|
||||
capture_output=True,
|
||||
timeout=10,
|
||||
check=False,
|
||||
)
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
|
||||
pass
|
||||
|
||||
|
||||
def _validate_result_schema(data: dict[str, object]) -> tuple[bool, str]:
|
||||
"""Validate JSON result schema.
|
||||
|
||||
Expected schema:
|
||||
{
|
||||
"frame": int,
|
||||
"track_id": int,
|
||||
"label": str (one of: "negative", "neutral", "positive"),
|
||||
"confidence": float in [0, 1],
|
||||
"window": list[int] (start, end),
|
||||
"timestamp_ns": int
|
||||
}
|
||||
"""
|
||||
required_fields: set[str] = {
|
||||
"frame",
|
||||
"track_id",
|
||||
"label",
|
||||
"confidence",
|
||||
"window",
|
||||
"timestamp_ns",
|
||||
}
|
||||
missing = required_fields - set(data.keys())
|
||||
if missing:
|
||||
return False, f"Missing required fields: {missing}"
|
||||
|
||||
# Validate frame (int)
|
||||
frame = data["frame"]
|
||||
if not isinstance(frame, int):
|
||||
return False, f"frame must be int, got {type(frame)}"
|
||||
|
||||
# Validate track_id (int)
|
||||
track_id = data["track_id"]
|
||||
if not isinstance(track_id, int):
|
||||
return False, f"track_id must be int, got {type(track_id)}"
|
||||
|
||||
# Validate label (str in specific set)
|
||||
label = data["label"]
|
||||
valid_labels = {"negative", "neutral", "positive"}
|
||||
if not isinstance(label, str) or label not in valid_labels:
|
||||
return False, f"label must be one of {valid_labels}, got {label}"
|
||||
|
||||
# Validate confidence (float in [0, 1])
|
||||
confidence = data["confidence"]
|
||||
if not isinstance(confidence, (int, float)):
|
||||
return False, f"confidence must be numeric, got {type(confidence)}"
|
||||
if not 0.0 <= float(confidence) <= 1.0:
|
||||
return False, f"confidence must be in [0, 1], got {confidence}"
|
||||
|
||||
# Validate window (list of 2 ints)
|
||||
window = data["window"]
|
||||
if not isinstance(window, list) or len(cast(list[object], window)) != 2:
|
||||
return False, f"window must be list of 2 ints, got {window}"
|
||||
window_list = cast(list[object], window)
|
||||
if not all(isinstance(x, int) for x in window_list):
|
||||
return False, f"window elements must be ints, got {window}"
|
||||
|
||||
# Validate timestamp_ns (int)
|
||||
timestamp_ns = data["timestamp_ns"]
|
||||
if not isinstance(timestamp_ns, int):
|
||||
return False, f"timestamp_ns must be int, got {type(timestamp_ns)}"
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def nats_server() -> "Generator[tuple[bool, int], None, None]":
|
||||
"""Fixture to manage NATS container lifecycle.
|
||||
|
||||
Yields (True, port) if NATS is available, (False, 0) otherwise.
|
||||
Cleans up container after tests.
|
||||
"""
|
||||
if not _docker_available():
|
||||
yield False, 0
|
||||
return
|
||||
|
||||
port = _find_open_port()
|
||||
if not _start_nats_container(port):
|
||||
yield False, 0
|
||||
return
|
||||
|
||||
try:
|
||||
yield True, port
|
||||
finally:
|
||||
_stop_nats_container()
|
||||
|
||||
|
||||
class TestNatsPublisherIntegration:
|
||||
"""Integration tests for NATS publisher with live server."""
|
||||
|
||||
@pytest.mark.skipif(not _docker_available(), reason="Docker not available")
|
||||
def test_nats_message_receipt_and_schema_validation(
|
||||
self, nats_server: tuple[bool, int]
|
||||
) -> None:
|
||||
"""Test that messages are received and match expected schema."""
|
||||
server_available, port = nats_server
|
||||
if not server_available:
|
||||
pytest.skip("NATS server not available")
|
||||
|
||||
nats_url = f"nats://127.0.0.1:{port}"
|
||||
|
||||
# Import here to avoid issues if nats-py not installed
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
import nats # type: ignore[import-untyped]
|
||||
|
||||
except ImportError:
|
||||
pytest.skip("nats-py not installed")
|
||||
|
||||
from opengait.demo.output import NatsPublisher, create_result
|
||||
|
||||
# Create publisher
|
||||
publisher = NatsPublisher(nats_url, subject=NATS_SUBJECT)
|
||||
|
||||
# Create test results
|
||||
test_results = [
|
||||
create_result(
|
||||
frame=100,
|
||||
track_id=1,
|
||||
label="positive",
|
||||
confidence=0.85,
|
||||
window=(70, 100),
|
||||
timestamp_ns=1234567890000,
|
||||
),
|
||||
create_result(
|
||||
frame=130,
|
||||
track_id=1,
|
||||
label="negative",
|
||||
confidence=0.92,
|
||||
window=(100, 130),
|
||||
timestamp_ns=1234567890030,
|
||||
),
|
||||
create_result(
|
||||
frame=160,
|
||||
track_id=2,
|
||||
label="neutral",
|
||||
confidence=0.78,
|
||||
window=(130, 160),
|
||||
timestamp_ns=1234567890060,
|
||||
),
|
||||
]
|
||||
|
||||
# Collect received messages
|
||||
received_messages: list[dict[str, object]] = []
|
||||
|
||||
async def subscribe_and_publish():
|
||||
"""Subscribe to subject, publish messages, collect results."""
|
||||
nc = await nats.connect(nats_url) # pyright: ignore[reportUnknownMemberType]
|
||||
|
||||
# Subscribe
|
||||
sub = await nc.subscribe(NATS_SUBJECT) # pyright: ignore[reportUnknownMemberType]
|
||||
|
||||
# Publish all test results
|
||||
for result in test_results:
|
||||
publisher.publish(result)
|
||||
|
||||
# Wait for messages with timeout
|
||||
for _ in range(len(test_results)):
|
||||
try:
|
||||
msg = await asyncio.wait_for(sub.next_msg(), timeout=5.0)
|
||||
data = json.loads(msg.data.decode("utf-8")) # pyright: ignore[reportAny]
|
||||
received_messages.append(cast(dict[str, object], data))
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
|
||||
await sub.unsubscribe()
|
||||
await nc.close()
|
||||
|
||||
# Run async subscriber
|
||||
asyncio.run(subscribe_and_publish())
|
||||
|
||||
# Cleanup publisher
|
||||
publisher.close()
|
||||
|
||||
# Verify all messages received
|
||||
assert len(received_messages) == len(test_results), (
|
||||
f"Expected {len(test_results)} messages, got {len(received_messages)}"
|
||||
)
|
||||
|
||||
# Validate schema for each message
|
||||
for i, msg in enumerate(received_messages):
|
||||
is_valid, error = _validate_result_schema(msg)
|
||||
assert is_valid, f"Message {i} schema validation failed: {error}"
|
||||
|
||||
# Verify specific values
|
||||
assert received_messages[0]["frame"] == 100
|
||||
assert received_messages[0]["label"] == "positive"
|
||||
assert received_messages[0]["track_id"] == 1
|
||||
_conf = received_messages[0]["confidence"]
|
||||
assert isinstance(_conf, (int, float))
|
||||
assert 0.0 <= float(_conf) <= 1.0
|
||||
|
||||
assert received_messages[1]["label"] == "negative"
|
||||
assert received_messages[2]["track_id"] == 2
|
||||
|
||||
@pytest.mark.skipif(not _docker_available(), reason="Docker not available")
|
||||
def test_nats_publisher_graceful_when_server_unavailable(self) -> None:
|
||||
"""Test that publisher handles missing server gracefully."""
|
||||
try:
|
||||
from opengait.demo.output import NatsPublisher
|
||||
except ImportError:
|
||||
pytest.skip("output module not available")
|
||||
|
||||
# Use wrong port where no server is running
|
||||
bad_url = "nats://127.0.0.1:14222"
|
||||
publisher = NatsPublisher(bad_url, subject=NATS_SUBJECT)
|
||||
|
||||
# Should not raise when publishing without server
|
||||
test_result: dict[str, object] = {
|
||||
"frame": 1,
|
||||
"track_id": 1,
|
||||
"label": "positive",
|
||||
"confidence": 0.85,
|
||||
"window": [0, 30],
|
||||
"timestamp_ns": 1234567890,
|
||||
}
|
||||
|
||||
# Should not raise
|
||||
publisher.publish(test_result)
|
||||
|
||||
# Cleanup should also not raise
|
||||
publisher.close()
|
||||
|
||||
@pytest.mark.skipif(not _docker_available(), reason="Docker not available")
|
||||
def test_nats_publisher_context_manager(
|
||||
self, nats_server: tuple[bool, int]
|
||||
) -> None:
|
||||
"""Test that publisher works as context manager."""
|
||||
server_available, port = nats_server
|
||||
if not server_available:
|
||||
pytest.skip("NATS server not available")
|
||||
|
||||
nats_url = f"nats://127.0.0.1:{port}"
|
||||
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
import nats # type: ignore[import-untyped]
|
||||
from opengait.demo.output import NatsPublisher, create_result
|
||||
except ImportError as e:
|
||||
pytest.skip(f"Required module not available: {e}")
|
||||
|
||||
received_messages: list[dict[str, object]] = []
|
||||
|
||||
async def subscribe_and_test():
|
||||
nc = await nats.connect(nats_url) # pyright: ignore[reportUnknownMemberType]
|
||||
sub = await nc.subscribe(NATS_SUBJECT) # pyright: ignore[reportUnknownMemberType]
|
||||
|
||||
# Use context manager
|
||||
with NatsPublisher(nats_url, subject=NATS_SUBJECT) as publisher:
|
||||
result = create_result(
|
||||
frame=200,
|
||||
track_id=5,
|
||||
label="neutral",
|
||||
confidence=0.65,
|
||||
window=(170, 200),
|
||||
timestamp_ns=9999999999,
|
||||
)
|
||||
publisher.publish(result)
|
||||
|
||||
# Wait for message
|
||||
try:
|
||||
msg = await asyncio.wait_for(sub.next_msg(), timeout=5.0)
|
||||
data = json.loads(msg.data.decode("utf-8")) # pyright: ignore[reportAny]
|
||||
received_messages.append(cast(dict[str, object], data))
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
await sub.unsubscribe()
|
||||
await nc.close()
|
||||
|
||||
asyncio.run(subscribe_and_test())
|
||||
|
||||
assert len(received_messages) == 1
|
||||
assert received_messages[0]["frame"] == 200
|
||||
assert received_messages[0]["track_id"] == 5
|
||||
|
||||
|
||||
class TestNatsSchemaValidation:
|
||||
"""Tests for JSON schema validation without requiring NATS server."""
|
||||
|
||||
def test_validate_result_schema_valid(self) -> None:
|
||||
"""Test schema validation with valid data."""
|
||||
valid_data: dict[str, object] = {
|
||||
"frame": 1234,
|
||||
"track_id": 42,
|
||||
"label": "positive",
|
||||
"confidence": 0.85,
|
||||
"window": [1200, 1230],
|
||||
"timestamp_ns": 1234567890000,
|
||||
}
|
||||
|
||||
is_valid, error = _validate_result_schema(valid_data)
|
||||
assert is_valid, f"Valid data rejected: {error}"
|
||||
|
||||
def test_validate_result_schema_invalid_label(self) -> None:
|
||||
"""Test schema validation rejects invalid label."""
|
||||
invalid_data: dict[str, object] = {
|
||||
"frame": 1234,
|
||||
"track_id": 42,
|
||||
"label": "invalid_label",
|
||||
"confidence": 0.85,
|
||||
"window": [1200, 1230],
|
||||
"timestamp_ns": 1234567890000,
|
||||
}
|
||||
|
||||
is_valid, error = _validate_result_schema(invalid_data)
|
||||
assert not is_valid
|
||||
assert "label" in error.lower()
|
||||
|
||||
def test_validate_result_schema_confidence_out_of_range(self) -> None:
|
||||
"""Test schema validation rejects confidence outside [0, 1]."""
|
||||
invalid_data: dict[str, object] = {
|
||||
"frame": 1234,
|
||||
"track_id": 42,
|
||||
"label": "positive",
|
||||
"confidence": 1.5,
|
||||
"window": [1200, 1230],
|
||||
"timestamp_ns": 1234567890000,
|
||||
}
|
||||
|
||||
is_valid, error = _validate_result_schema(invalid_data)
|
||||
assert not is_valid
|
||||
assert "confidence" in error.lower()
|
||||
|
||||
def test_validate_result_schema_missing_fields(self) -> None:
|
||||
"""Test schema validation detects missing fields."""
|
||||
incomplete_data: dict[str, object] = {
|
||||
"frame": 1234,
|
||||
"label": "positive",
|
||||
}
|
||||
|
||||
is_valid, error = _validate_result_schema(incomplete_data)
|
||||
assert not is_valid
|
||||
assert "missing" in error.lower()
|
||||
|
||||
def test_validate_result_schema_wrong_types(self) -> None:
|
||||
"""Test schema validation rejects wrong types."""
|
||||
wrong_types: dict[str, object] = {
|
||||
"frame": "not_an_int",
|
||||
"track_id": 42,
|
||||
"label": "positive",
|
||||
"confidence": 0.85,
|
||||
"window": [1200, 1230],
|
||||
"timestamp_ns": 1234567890000,
|
||||
}
|
||||
|
||||
is_valid, error = _validate_result_schema(wrong_types)
|
||||
assert not is_valid
|
||||
assert "frame" in error.lower()
|
||||
|
||||
def test_all_valid_labels_accepted(self) -> None:
|
||||
"""Test that all valid labels are accepted."""
|
||||
for label_str in ["negative", "neutral", "positive"]:
|
||||
data: dict[str, object] = {
|
||||
"frame": 100,
|
||||
"track_id": 1,
|
||||
"label": label_str,
|
||||
"confidence": 0.5,
|
||||
"window": [70, 100],
|
||||
"timestamp_ns": 1234567890,
|
||||
}
|
||||
is_valid, error = _validate_result_schema(data)
|
||||
assert is_valid, f"Valid label '{label_str}' rejected: {error}"
|
||||
|
||||
|
||||
class TestDockerAvailability:
|
||||
"""Tests for Docker availability detection."""
|
||||
|
||||
def test_docker_available_check(self) -> None:
|
||||
"""Test Docker availability check doesn't crash."""
|
||||
# This should not raise
|
||||
result = _docker_available()
|
||||
assert isinstance(result, bool)
|
||||
|
||||
def test_nats_container_running_check(self) -> None:
|
||||
"""Test container running check doesn't crash."""
|
||||
# This should not raise even if Docker not available
|
||||
result = _nats_container_running()
|
||||
assert isinstance(result, bool)
|
||||
@@ -0,0 +1,279 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from typing import Final, cast
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from opengait.demo.sconet_demo import ScoNetDemo
|
||||
|
||||
REPO_ROOT: Final[Path] = Path(__file__).resolve().parents[2]
|
||||
SAMPLE_VIDEO_PATH: Final[Path] = REPO_ROOT / "assets" / "sample.mp4"
|
||||
CHECKPOINT_PATH: Final[Path] = REPO_ROOT / "ckpt" / "ScoNet-20000.pt"
|
||||
CONFIG_PATH: Final[Path] = REPO_ROOT / "configs" / "sconet" / "sconet_scoliosis1k.yaml"
|
||||
YOLO_MODEL_PATH: Final[Path] = REPO_ROOT / "yolo11n-seg.pt"
|
||||
|
||||
|
||||
def _device_for_runtime() -> str:
|
||||
return "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def _run_pipeline_cli(
|
||||
*args: str, timeout_seconds: int = 120
|
||||
) -> subprocess.CompletedProcess[str]:
|
||||
command = [sys.executable, "-m", "opengait.demo", *args]
|
||||
return subprocess.run(
|
||||
command,
|
||||
cwd=REPO_ROOT,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
timeout=timeout_seconds,
|
||||
)
|
||||
|
||||
|
||||
def _require_integration_assets() -> None:
|
||||
if not SAMPLE_VIDEO_PATH.is_file():
|
||||
pytest.skip(f"Missing sample video: {SAMPLE_VIDEO_PATH}")
|
||||
if not CONFIG_PATH.is_file():
|
||||
pytest.skip(f"Missing config: {CONFIG_PATH}")
|
||||
if not YOLO_MODEL_PATH.is_file():
|
||||
pytest.skip(f"Missing YOLO model file: {YOLO_MODEL_PATH}")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def compatible_checkpoint_path(tmp_path: Path) -> Path:
|
||||
if not CONFIG_PATH.is_file():
|
||||
pytest.skip(f"Missing config: {CONFIG_PATH}")
|
||||
|
||||
checkpoint_file = tmp_path / "sconet-compatible.pt"
|
||||
model = ScoNetDemo(cfg_path=str(CONFIG_PATH), checkpoint_path=None, device="cpu")
|
||||
torch.save(model.state_dict(), checkpoint_file)
|
||||
return checkpoint_file
|
||||
|
||||
|
||||
def _extract_prediction_json_lines(stdout: str) -> list[dict[str, object]]:
|
||||
required_keys = {
|
||||
"frame",
|
||||
"track_id",
|
||||
"label",
|
||||
"confidence",
|
||||
"window",
|
||||
"timestamp_ns",
|
||||
}
|
||||
predictions: list[dict[str, object]] = []
|
||||
|
||||
for line in stdout.splitlines():
|
||||
stripped = line.strip()
|
||||
if not stripped:
|
||||
continue
|
||||
try:
|
||||
payload_obj = cast(object, json.loads(stripped))
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if not isinstance(payload_obj, dict):
|
||||
continue
|
||||
payload = cast(dict[str, object], payload_obj)
|
||||
if required_keys.issubset(payload.keys()):
|
||||
predictions.append(payload)
|
||||
|
||||
return predictions
|
||||
|
||||
|
||||
def _assert_prediction_schema(prediction: dict[str, object]) -> None:
|
||||
assert isinstance(prediction["frame"], int)
|
||||
assert isinstance(prediction["track_id"], int)
|
||||
|
||||
label = prediction["label"]
|
||||
assert isinstance(label, str)
|
||||
assert label in {"negative", "neutral", "positive"}
|
||||
|
||||
confidence = prediction["confidence"]
|
||||
assert isinstance(confidence, (int, float))
|
||||
confidence_value = float(confidence)
|
||||
assert 0.0 <= confidence_value <= 1.0
|
||||
|
||||
window_obj = prediction["window"]
|
||||
assert isinstance(window_obj, int)
|
||||
assert window_obj >= 0
|
||||
|
||||
assert isinstance(prediction["timestamp_ns"], int)
|
||||
|
||||
|
||||
def test_pipeline_cli_fps_benchmark_smoke(
|
||||
compatible_checkpoint_path: Path,
|
||||
) -> None:
|
||||
_require_integration_assets()
|
||||
|
||||
max_frames = 90
|
||||
started_at = time.perf_counter()
|
||||
result = _run_pipeline_cli(
|
||||
"--source",
|
||||
str(SAMPLE_VIDEO_PATH),
|
||||
"--checkpoint",
|
||||
str(compatible_checkpoint_path),
|
||||
"--config",
|
||||
str(CONFIG_PATH),
|
||||
"--device",
|
||||
_device_for_runtime(),
|
||||
"--yolo-model",
|
||||
str(YOLO_MODEL_PATH),
|
||||
"--window",
|
||||
"5",
|
||||
"--stride",
|
||||
"1",
|
||||
"--max-frames",
|
||||
str(max_frames),
|
||||
timeout_seconds=180,
|
||||
)
|
||||
elapsed_seconds = time.perf_counter() - started_at
|
||||
|
||||
assert result.returncode == 0, (
|
||||
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
|
||||
)
|
||||
predictions = _extract_prediction_json_lines(result.stdout)
|
||||
assert predictions, "Expected prediction output for FPS benchmark run"
|
||||
|
||||
for prediction in predictions:
|
||||
_assert_prediction_schema(prediction)
|
||||
|
||||
observed_frames = {
|
||||
frame_obj
|
||||
for prediction in predictions
|
||||
for frame_obj in [prediction["frame"]]
|
||||
if isinstance(frame_obj, int)
|
||||
}
|
||||
observed_units = len(observed_frames)
|
||||
if observed_units < 5:
|
||||
pytest.skip(
|
||||
"Insufficient observed frame samples for stable FPS benchmark in this environment"
|
||||
)
|
||||
if elapsed_seconds <= 0:
|
||||
pytest.skip("Non-positive elapsed time; cannot compute FPS benchmark")
|
||||
|
||||
fps = observed_units / elapsed_seconds
|
||||
min_expected_fps = 0.2
|
||||
assert fps >= min_expected_fps, (
|
||||
"Observed FPS below conservative CI threshold: "
|
||||
f"{fps:.3f} < {min_expected_fps:.3f} "
|
||||
f"(observed_units={observed_units}, elapsed_seconds={elapsed_seconds:.3f})"
|
||||
)
|
||||
|
||||
|
||||
def test_pipeline_cli_happy_path_outputs_json_predictions(
|
||||
compatible_checkpoint_path: Path,
|
||||
) -> None:
|
||||
_require_integration_assets()
|
||||
|
||||
result = _run_pipeline_cli(
|
||||
"--source",
|
||||
str(SAMPLE_VIDEO_PATH),
|
||||
"--checkpoint",
|
||||
str(compatible_checkpoint_path),
|
||||
"--config",
|
||||
str(CONFIG_PATH),
|
||||
"--device",
|
||||
_device_for_runtime(),
|
||||
"--yolo-model",
|
||||
str(YOLO_MODEL_PATH),
|
||||
"--window",
|
||||
"10",
|
||||
"--stride",
|
||||
"10",
|
||||
"--max-frames",
|
||||
"120",
|
||||
timeout_seconds=180,
|
||||
)
|
||||
|
||||
assert result.returncode == 0, (
|
||||
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
|
||||
)
|
||||
predictions = _extract_prediction_json_lines(result.stdout)
|
||||
assert predictions, (
|
||||
"Expected at least one prediction JSON line in stdout. "
|
||||
f"stdout:\n{result.stdout}\nstderr:\n{result.stderr}"
|
||||
)
|
||||
for prediction in predictions:
|
||||
_assert_prediction_schema(prediction)
|
||||
|
||||
assert "Connected to NATS" not in result.stderr
|
||||
|
||||
|
||||
def test_pipeline_cli_max_frames_caps_output_frames(
|
||||
compatible_checkpoint_path: Path,
|
||||
) -> None:
|
||||
_require_integration_assets()
|
||||
|
||||
max_frames = 20
|
||||
result = _run_pipeline_cli(
|
||||
"--source",
|
||||
str(SAMPLE_VIDEO_PATH),
|
||||
"--checkpoint",
|
||||
str(compatible_checkpoint_path),
|
||||
"--config",
|
||||
str(CONFIG_PATH),
|
||||
"--device",
|
||||
_device_for_runtime(),
|
||||
"--yolo-model",
|
||||
str(YOLO_MODEL_PATH),
|
||||
"--window",
|
||||
"5",
|
||||
"--stride",
|
||||
"1",
|
||||
"--max-frames",
|
||||
str(max_frames),
|
||||
timeout_seconds=180,
|
||||
)
|
||||
|
||||
assert result.returncode == 0, (
|
||||
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
|
||||
)
|
||||
predictions = _extract_prediction_json_lines(result.stdout)
|
||||
assert predictions, "Expected prediction output with --max-frames run"
|
||||
|
||||
for prediction in predictions:
|
||||
_assert_prediction_schema(prediction)
|
||||
frame_idx_obj = prediction["frame"]
|
||||
assert isinstance(frame_idx_obj, int)
|
||||
assert frame_idx_obj < max_frames
|
||||
|
||||
|
||||
def test_pipeline_cli_invalid_source_path_returns_user_error() -> None:
|
||||
result = _run_pipeline_cli(
|
||||
"--source",
|
||||
"/definitely/not/a/real/video.mp4",
|
||||
"--checkpoint",
|
||||
"/tmp/unused-checkpoint.pt",
|
||||
"--config",
|
||||
str(CONFIG_PATH),
|
||||
timeout_seconds=30,
|
||||
)
|
||||
|
||||
assert result.returncode == 2
|
||||
assert "Error: Video source not found" in result.stderr
|
||||
|
||||
|
||||
def test_pipeline_cli_invalid_checkpoint_path_returns_user_error() -> None:
|
||||
if not SAMPLE_VIDEO_PATH.is_file():
|
||||
pytest.skip(f"Missing sample video: {SAMPLE_VIDEO_PATH}")
|
||||
if not CONFIG_PATH.is_file():
|
||||
pytest.skip(f"Missing config: {CONFIG_PATH}")
|
||||
|
||||
result = _run_pipeline_cli(
|
||||
"--source",
|
||||
str(SAMPLE_VIDEO_PATH),
|
||||
"--checkpoint",
|
||||
str(REPO_ROOT / "ckpt" / "missing-checkpoint.pt"),
|
||||
"--config",
|
||||
str(CONFIG_PATH),
|
||||
timeout_seconds=30,
|
||||
)
|
||||
|
||||
assert result.returncode == 2
|
||||
assert "Error: Checkpoint not found" in result.stderr
|
||||
@@ -0,0 +1,153 @@
|
||||
"""Unit tests for silhouette preprocessing functions."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
import pytest
|
||||
from beartype.roar import BeartypeCallHintParamViolation
|
||||
from jaxtyping import TypeCheckError
|
||||
|
||||
from opengait.demo.preprocess import mask_to_silhouette
|
||||
|
||||
|
||||
class TestMaskToSilhouette:
|
||||
"""Tests for mask_to_silhouette() function."""
|
||||
|
||||
def test_valid_mask_returns_correct_shape_dtype_and_range(self) -> None:
|
||||
"""Valid mask should return (64, 44) float32 array in [0, 1] range."""
|
||||
# Create a synthetic mask with sufficient area (person-shaped blob)
|
||||
h, w = 200, 150
|
||||
mask = np.zeros((h, w), dtype=np.uint8)
|
||||
# Draw a filled ellipse to simulate a person
|
||||
center_y, center_x = h // 2, w // 2
|
||||
axes_y, axes_x = h // 3, w // 4
|
||||
y, x = np.ogrid[:h, :w]
|
||||
ellipse_mask = ((x - center_x) / axes_x) ** 2 + (
|
||||
(y - center_y) / axes_y
|
||||
) ** 2 <= 1
|
||||
mask[ellipse_mask] = 255
|
||||
|
||||
bbox = (w // 4, h // 6, 3 * w // 4, 5 * h // 6)
|
||||
|
||||
result = mask_to_silhouette(mask, bbox)
|
||||
|
||||
assert result is not None
|
||||
result_arr = cast(NDArray[np.float32], result)
|
||||
assert result_arr.shape == (64, 44)
|
||||
assert result_arr.dtype == np.float32
|
||||
assert np.all(result_arr >= 0.0) and np.all(result_arr <= 1.0)
|
||||
|
||||
def test_tiny_mask_returns_none(self) -> None:
|
||||
"""Mask with area below MIN_MASK_AREA should return None."""
|
||||
# Create a tiny mask
|
||||
mask = np.zeros((50, 50), dtype=np.uint8)
|
||||
mask[20:22, 20:22] = 255 # Only 4 pixels, well below MIN_MASK_AREA (500)
|
||||
|
||||
bbox = (20, 20, 22, 22)
|
||||
|
||||
result = mask_to_silhouette(mask, bbox)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_empty_mask_returns_none(self) -> None:
|
||||
"""Completely empty mask should return None."""
|
||||
mask = np.zeros((100, 100), dtype=np.uint8)
|
||||
bbox = (10, 10, 90, 90)
|
||||
|
||||
result = mask_to_silhouette(mask, bbox)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_full_frame_mask_returns_valid_output(self) -> None:
|
||||
"""Full-frame mask (large bbox covering entire image) should work."""
|
||||
h, w = 300, 200
|
||||
mask = np.zeros((h, w), dtype=np.uint8)
|
||||
# Create a large filled region
|
||||
mask[50:250, 30:170] = 255
|
||||
|
||||
# Full frame bbox
|
||||
bbox = (0, 0, w, h)
|
||||
|
||||
result = mask_to_silhouette(mask, bbox)
|
||||
|
||||
assert result is not None
|
||||
result_arr = cast(NDArray[np.float32], result)
|
||||
assert result_arr.shape == (64, 44)
|
||||
assert result_arr.dtype == np.float32
|
||||
|
||||
def test_determinism_same_input_same_output(self) -> None:
|
||||
"""Same input should always produce same output."""
|
||||
h, w = 200, 150
|
||||
mask = np.zeros((h, w), dtype=np.uint8)
|
||||
# Create a person-shaped region
|
||||
mask[50:150, 40:110] = 255
|
||||
|
||||
bbox = (40, 50, 110, 150)
|
||||
|
||||
result1 = mask_to_silhouette(mask, bbox)
|
||||
result2 = mask_to_silhouette(mask, bbox)
|
||||
result3 = mask_to_silhouette(mask, bbox)
|
||||
|
||||
assert result1 is not None
|
||||
assert result2 is not None
|
||||
assert result3 is not None
|
||||
np.testing.assert_array_equal(result1, result2)
|
||||
np.testing.assert_array_equal(result2, result3)
|
||||
|
||||
def test_tall_narrow_mask_valid_output(self) -> None:
|
||||
"""Tall narrow mask should produce valid silhouette."""
|
||||
h, w = 400, 50
|
||||
mask = np.zeros((h, w), dtype=np.uint8)
|
||||
# Tall narrow person
|
||||
mask[50:350, 10:40] = 255
|
||||
|
||||
bbox = (10, 50, 40, 350)
|
||||
|
||||
result = mask_to_silhouette(mask, bbox)
|
||||
|
||||
assert result is not None
|
||||
result_arr = cast(NDArray[np.float32], result)
|
||||
assert result_arr.shape == (64, 44)
|
||||
|
||||
def test_wide_short_mask_valid_output(self) -> None:
|
||||
"""Wide short mask should produce valid silhouette."""
|
||||
h, w = 100, 400
|
||||
mask = np.zeros((h, w), dtype=np.uint8)
|
||||
# Wide short person
|
||||
mask[20:80, 50:350] = 255
|
||||
|
||||
bbox = (50, 20, 350, 80)
|
||||
|
||||
result = mask_to_silhouette(mask, bbox)
|
||||
|
||||
assert result is not None
|
||||
result_arr = cast(NDArray[np.float32], result)
|
||||
assert result_arr.shape == (64, 44)
|
||||
|
||||
def test_beartype_rejects_wrong_dtype(self) -> None:
|
||||
"""Beartype should reject non-uint8 input."""
|
||||
# Float array instead of uint8
|
||||
mask = np.ones((100, 100), dtype=np.float32) * 255
|
||||
bbox = (10, 10, 90, 90)
|
||||
|
||||
with pytest.raises((BeartypeCallHintParamViolation, TypeCheckError)):
|
||||
_ = mask_to_silhouette(mask, bbox)
|
||||
|
||||
def test_beartype_rejects_wrong_ndim(self) -> None:
|
||||
"""Beartype should reject non-2D array."""
|
||||
# 3D array instead of 2D
|
||||
mask = np.ones((100, 100, 3), dtype=np.uint8) * 255
|
||||
bbox = (10, 10, 90, 90)
|
||||
|
||||
with pytest.raises((BeartypeCallHintParamViolation, TypeCheckError)):
|
||||
_ = mask_to_silhouette(mask, bbox)
|
||||
|
||||
def test_beartype_rejects_wrong_bbox_type(self) -> None:
|
||||
"""Beartype should reject non-tuple bbox."""
|
||||
mask = np.ones((100, 100), dtype=np.uint8) * 255
|
||||
# List instead of tuple
|
||||
bbox = [10, 10, 90, 90]
|
||||
|
||||
with pytest.raises((BeartypeCallHintParamViolation, TypeCheckError)):
|
||||
_ = mask_to_silhouette(mask, bbox)
|
||||
@@ -0,0 +1,341 @@
|
||||
"""Unit tests for ScoNetDemo forward pass.
|
||||
|
||||
Tests cover:
|
||||
- Construction from config/checkpoint path handling
|
||||
- Forward output shape (N, 3, 16) and dtype float
|
||||
- Predict output (label_str, confidence_float) with valid label/range
|
||||
- No-DDP leakage check (no torch.distributed calls in unit behavior)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, cast
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opengait.demo.sconet_demo import ScoNetDemo
|
||||
|
||||
# Constants for test configuration
|
||||
CONFIG_PATH = Path("configs/sconet/sconet_scoliosis1k.yaml")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def demo() -> "ScoNetDemo":
|
||||
"""Create ScoNetDemo without loading checkpoint (CPU-only)."""
|
||||
from opengait.demo.sconet_demo import ScoNetDemo
|
||||
|
||||
return ScoNetDemo(
|
||||
cfg_path=str(CONFIG_PATH),
|
||||
checkpoint_path=None,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_sils_batch() -> Tensor:
|
||||
"""Return dummy silhouette tensor of shape (N, 1, S, 64, 44)."""
|
||||
return torch.randn(2, 1, 30, 64, 44)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_sils_single() -> Tensor:
|
||||
"""Return dummy silhouette tensor of shape (1, 1, S, 64, 44) for predict."""
|
||||
return torch.randn(1, 1, 30, 64, 44)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def synthetic_state_dict() -> dict[str, Tensor]:
|
||||
"""Return a synthetic state dict compatible with ScoNetDemo structure."""
|
||||
return {
|
||||
"backbone.conv1.conv.weight": torch.randn(64, 1, 3, 3),
|
||||
"backbone.conv1.bn.weight": torch.ones(64),
|
||||
"backbone.conv1.bn.bias": torch.zeros(64),
|
||||
"backbone.conv1.bn.running_mean": torch.zeros(64),
|
||||
"backbone.conv1.bn.running_var": torch.ones(64),
|
||||
"fcs.fc_bin": torch.randn(16, 512, 256),
|
||||
"bn_necks.fc_bin": torch.randn(16, 256, 3),
|
||||
"bn_necks.bn1d.weight": torch.ones(4096),
|
||||
"bn_necks.bn1d.bias": torch.zeros(4096),
|
||||
"bn_necks.bn1d.running_mean": torch.zeros(4096),
|
||||
"bn_necks.bn1d.running_var": torch.ones(4096),
|
||||
}
|
||||
|
||||
|
||||
class TestScoNetDemoConstruction:
|
||||
"""Tests for ScoNetDemo construction and path handling."""
|
||||
|
||||
def test_construction_from_config_no_checkpoint(self) -> None:
|
||||
"""Test construction with config only, no checkpoint."""
|
||||
from opengait.demo.sconet_demo import ScoNetDemo
|
||||
|
||||
demo = ScoNetDemo(
|
||||
cfg_path=str(CONFIG_PATH),
|
||||
checkpoint_path=None,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
assert demo.cfg_path.endswith("sconet_scoliosis1k.yaml")
|
||||
assert demo.device == torch.device("cpu")
|
||||
assert demo.cfg is not None
|
||||
assert "model_cfg" in demo.cfg
|
||||
assert demo.training is False # eval mode
|
||||
|
||||
def test_construction_with_relative_path(self) -> None:
|
||||
"""Test construction handles relative config path correctly."""
|
||||
from opengait.demo.sconet_demo import ScoNetDemo
|
||||
|
||||
demo = ScoNetDemo(
|
||||
cfg_path="configs/sconet/sconet_scoliosis1k.yaml",
|
||||
checkpoint_path=None,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
assert demo.cfg is not None
|
||||
assert demo.backbone is not None
|
||||
|
||||
def test_construction_invalid_config_raises(self) -> None:
|
||||
"""Test construction raises with invalid config path."""
|
||||
from opengait.demo.sconet_demo import ScoNetDemo
|
||||
|
||||
with pytest.raises((FileNotFoundError, TypeError)):
|
||||
_ = ScoNetDemo(
|
||||
cfg_path="/nonexistent/path/config.yaml",
|
||||
checkpoint_path=None,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
|
||||
class TestScoNetDemoForward:
|
||||
"""Tests for ScoNetDemo forward pass."""
|
||||
|
||||
def test_forward_output_shape_and_dtype(
|
||||
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
|
||||
) -> None:
|
||||
"""Test forward returns logits with shape (N, 3, 16) and correct dtypes."""
|
||||
outputs_raw = demo.forward(dummy_sils_batch)
|
||||
outputs = cast(dict[str, Tensor], outputs_raw)
|
||||
|
||||
assert "logits" in outputs
|
||||
logits = outputs["logits"]
|
||||
|
||||
# Expected shape: (batch_size, num_classes, parts_num) = (N, 3, 16)
|
||||
assert logits.shape == (2, 3, 16)
|
||||
assert logits.dtype == torch.float32
|
||||
assert outputs["label"].dtype == torch.int64
|
||||
assert outputs["confidence"].dtype == torch.float32
|
||||
|
||||
def test_forward_returns_required_keys(
|
||||
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
|
||||
) -> None:
|
||||
"""Test forward returns required output keys."""
|
||||
outputs_raw = demo.forward(dummy_sils_batch)
|
||||
outputs = cast(dict[str, Tensor], outputs_raw)
|
||||
|
||||
required_keys = {"logits", "label", "confidence"}
|
||||
assert set(outputs.keys()) >= required_keys
|
||||
|
||||
def test_forward_batch_size_one(
|
||||
self, demo: "ScoNetDemo", dummy_sils_single: Tensor
|
||||
) -> None:
|
||||
"""Test forward works with batch size 1."""
|
||||
outputs_raw = demo.forward(dummy_sils_single)
|
||||
outputs = cast(dict[str, Tensor], outputs_raw)
|
||||
|
||||
assert outputs["logits"].shape == (1, 3, 16)
|
||||
assert outputs["label"].shape == (1,)
|
||||
assert outputs["confidence"].shape == (1,)
|
||||
|
||||
def test_forward_label_range(
|
||||
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
|
||||
) -> None:
|
||||
"""Test forward returns valid label indices (0, 1, or 2)."""
|
||||
outputs_raw = demo.forward(dummy_sils_batch)
|
||||
outputs = cast(dict[str, Tensor], outputs_raw)
|
||||
|
||||
labels = outputs["label"]
|
||||
assert torch.all(labels >= 0)
|
||||
assert torch.all(labels <= 2)
|
||||
|
||||
def test_forward_confidence_range(
|
||||
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
|
||||
) -> None:
|
||||
"""Test forward returns confidence in [0, 1]."""
|
||||
outputs_raw = demo.forward(dummy_sils_batch)
|
||||
outputs = cast(dict[str, Tensor], outputs_raw)
|
||||
|
||||
confidence = outputs["confidence"]
|
||||
assert torch.all(confidence >= 0.0)
|
||||
assert torch.all(confidence <= 1.0)
|
||||
|
||||
|
||||
class TestScoNetDemoPredict:
|
||||
"""Tests for ScoNetDemo predict method."""
|
||||
|
||||
def test_predict_returns_tuple_with_valid_types(
|
||||
self, demo: "ScoNetDemo", dummy_sils_single: Tensor
|
||||
) -> None:
|
||||
"""Test predict returns (str, float) tuple with valid label."""
|
||||
from opengait.demo.sconet_demo import ScoNetDemo
|
||||
|
||||
result_raw = demo.predict(dummy_sils_single)
|
||||
result = cast(tuple[str, float], result_raw)
|
||||
|
||||
assert isinstance(result, tuple)
|
||||
assert len(result) == 2
|
||||
label, confidence = result
|
||||
assert isinstance(label, str)
|
||||
assert isinstance(confidence, float)
|
||||
|
||||
valid_labels = set(ScoNetDemo.LABEL_MAP.values())
|
||||
assert label in valid_labels
|
||||
|
||||
def test_predict_confidence_range(
|
||||
self, demo: "ScoNetDemo", dummy_sils_single: Tensor
|
||||
) -> None:
|
||||
"""Test predict returns confidence in valid range [0, 1]."""
|
||||
result_raw = demo.predict(dummy_sils_single)
|
||||
result = cast(tuple[str, float], result_raw)
|
||||
confidence = result[1]
|
||||
|
||||
assert 0.0 <= confidence <= 1.0
|
||||
|
||||
def test_predict_rejects_batch_size_greater_than_one(
|
||||
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
|
||||
) -> None:
|
||||
"""Test predict raises ValueError for batch size > 1."""
|
||||
with pytest.raises(ValueError, match="batch size 1"):
|
||||
_ = demo.predict(dummy_sils_batch)
|
||||
|
||||
|
||||
class TestScoNetDemoNoDDP:
|
||||
"""Tests to verify no DDP leakage in unit behavior."""
|
||||
|
||||
def test_no_distributed_init_in_construction(self) -> None:
|
||||
"""Test that construction does not call torch.distributed."""
|
||||
from opengait.demo.sconet_demo import ScoNetDemo
|
||||
|
||||
with patch("torch.distributed.is_initialized") as mock_is_init:
|
||||
with patch("torch.distributed.init_process_group") as mock_init_pg:
|
||||
_ = ScoNetDemo(
|
||||
cfg_path=str(CONFIG_PATH),
|
||||
checkpoint_path=None,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
mock_init_pg.assert_not_called()
|
||||
mock_is_init.assert_not_called()
|
||||
|
||||
def test_forward_no_distributed_calls(
|
||||
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
|
||||
) -> None:
|
||||
"""Test forward pass does not call torch.distributed."""
|
||||
with patch("torch.distributed.all_reduce") as mock_all_reduce:
|
||||
with patch("torch.distributed.broadcast") as mock_broadcast:
|
||||
_ = demo.forward(dummy_sils_batch)
|
||||
|
||||
mock_all_reduce.assert_not_called()
|
||||
mock_broadcast.assert_not_called()
|
||||
|
||||
def test_predict_no_distributed_calls(
|
||||
self, demo: "ScoNetDemo", dummy_sils_single: Tensor
|
||||
) -> None:
|
||||
"""Test predict does not call torch.distributed."""
|
||||
with patch("torch.distributed.all_reduce") as mock_all_reduce:
|
||||
with patch("torch.distributed.broadcast") as mock_broadcast:
|
||||
_ = demo.predict(dummy_sils_single)
|
||||
|
||||
mock_all_reduce.assert_not_called()
|
||||
mock_broadcast.assert_not_called()
|
||||
|
||||
def test_model_not_wrapped_in_ddp(self, demo: "ScoNetDemo") -> None:
|
||||
"""Test model is not wrapped in DistributedDataParallel."""
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
assert not isinstance(demo, DDP)
|
||||
assert not isinstance(demo.backbone, DDP)
|
||||
|
||||
def test_device_is_cpu(self, demo: "ScoNetDemo") -> None:
|
||||
"""Test model stays on CPU when device="cpu" specified."""
|
||||
assert demo.device.type == "cpu"
|
||||
|
||||
for param in demo.parameters():
|
||||
assert param.device.type == "cpu"
|
||||
|
||||
|
||||
class TestScoNetDemoCheckpointLoading:
|
||||
"""Tests for checkpoint loading behavior using synthetic state dict."""
|
||||
|
||||
def test_load_checkpoint_changes_weights(
|
||||
self,
|
||||
demo: "ScoNetDemo",
|
||||
synthetic_state_dict: dict[str, Tensor],
|
||||
) -> None:
|
||||
"""Test loading checkpoint actually changes model weights."""
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
# Get initial weight
|
||||
initial_weight = next(iter(demo.parameters())).clone()
|
||||
|
||||
# Create temp checkpoint file with synthetic state dict
|
||||
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f:
|
||||
torch.save(synthetic_state_dict, f.name)
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
_ = demo.load_checkpoint(temp_path, strict=False)
|
||||
new_weight = next(iter(demo.parameters()))
|
||||
assert not torch.equal(initial_weight, new_weight)
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
|
||||
def test_load_checkpoint_sets_eval_mode(
|
||||
self,
|
||||
demo: "ScoNetDemo",
|
||||
synthetic_state_dict: dict[str, Tensor],
|
||||
) -> None:
|
||||
"""Test loading checkpoint sets model to eval mode."""
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
_ = demo.train()
|
||||
assert demo.training is True
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f:
|
||||
torch.save(synthetic_state_dict, f.name)
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
_ = demo.load_checkpoint(temp_path, strict=False)
|
||||
assert demo.training is False
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
|
||||
def test_load_checkpoint_invalid_path_raises(self, demo: "ScoNetDemo") -> None:
|
||||
"""Test loading from invalid checkpoint path raises error."""
|
||||
with pytest.raises(FileNotFoundError):
|
||||
_ = demo.load_checkpoint("/nonexistent/checkpoint.pt")
|
||||
|
||||
|
||||
class TestScoNetDemoLabelMap:
|
||||
"""Tests for LABEL_MAP constant."""
|
||||
|
||||
def test_label_map_has_three_classes(self) -> None:
|
||||
"""Test LABEL_MAP has exactly 3 classes."""
|
||||
from opengait.demo.sconet_demo import ScoNetDemo
|
||||
|
||||
assert len(ScoNetDemo.LABEL_MAP) == 3
|
||||
assert set(ScoNetDemo.LABEL_MAP.keys()) == {0, 1, 2}
|
||||
|
||||
def test_label_map_values_are_valid_strings(self) -> None:
|
||||
"""Test LABEL_MAP values are valid non-empty strings."""
|
||||
from opengait.demo.sconet_demo import ScoNetDemo
|
||||
|
||||
for value in ScoNetDemo.LABEL_MAP.values():
|
||||
assert isinstance(value, str)
|
||||
assert len(value) > 0
|
||||
@@ -0,0 +1,346 @@
|
||||
"""Unit tests for SilhouetteWindow and select_person functions."""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from opengait.demo.window import SilhouetteWindow, select_person
|
||||
|
||||
|
||||
class TestSilhouetteWindow:
|
||||
"""Tests for SilhouetteWindow class."""
|
||||
|
||||
def test_window_fill_and_ready_behavior(self) -> None:
|
||||
"""Window should be ready only when filled to window_size."""
|
||||
window = SilhouetteWindow(window_size=5, stride=1, gap_threshold=10)
|
||||
sil = np.ones((64, 44), dtype=np.float32)
|
||||
|
||||
# Not ready with fewer frames
|
||||
for i in range(4):
|
||||
window.push(sil, frame_idx=i, track_id=1)
|
||||
assert not window.is_ready()
|
||||
assert window.fill_level == (i + 1) / 5
|
||||
|
||||
# Ready at exactly window_size
|
||||
window.push(sil, frame_idx=4, track_id=1)
|
||||
assert window.is_ready()
|
||||
assert window.fill_level == 1.0
|
||||
|
||||
def test_underfilled_not_ready(self) -> None:
|
||||
"""Underfilled window should never report ready."""
|
||||
window = SilhouetteWindow(window_size=10, stride=1, gap_threshold=5)
|
||||
sil = np.ones((64, 44), dtype=np.float32)
|
||||
|
||||
# Push 9 frames (underfilled)
|
||||
for i in range(9):
|
||||
window.push(sil, frame_idx=i, track_id=1)
|
||||
|
||||
assert not window.is_ready()
|
||||
assert window.fill_level == 0.9
|
||||
|
||||
# get_tensor should raise when not ready
|
||||
with pytest.raises(ValueError, match="Window not ready"):
|
||||
window.get_tensor()
|
||||
|
||||
def test_track_id_change_resets_buffer(self) -> None:
|
||||
"""Changing track ID should reset the buffer."""
|
||||
window = SilhouetteWindow(window_size=5, stride=1, gap_threshold=10)
|
||||
sil = np.ones((64, 44), dtype=np.float32)
|
||||
|
||||
# Fill window with track_id=1
|
||||
for i in range(5):
|
||||
window.push(sil, frame_idx=i, track_id=1)
|
||||
|
||||
assert window.is_ready()
|
||||
assert window.current_track_id == 1
|
||||
assert window.fill_level == 1.0
|
||||
|
||||
# Push with different track_id should reset
|
||||
window.push(sil, frame_idx=5, track_id=2)
|
||||
assert not window.is_ready()
|
||||
assert window.fill_level == 0.2
|
||||
assert window.current_track_id == 2
|
||||
|
||||
def test_frame_gap_reset_behavior(self) -> None:
|
||||
"""Frame gap exceeding threshold should reset buffer."""
|
||||
window = SilhouetteWindow(window_size=5, stride=1, gap_threshold=3)
|
||||
sil = np.ones((64, 44), dtype=np.float32)
|
||||
|
||||
# Fill window with consecutive frames
|
||||
for i in range(5):
|
||||
window.push(sil, frame_idx=i, track_id=1)
|
||||
|
||||
assert window.is_ready()
|
||||
assert window.fill_level == 1.0
|
||||
|
||||
# Small gap (within threshold) - no reset
|
||||
window.push(sil, frame_idx=6, track_id=1) # gap = 1
|
||||
assert window.is_ready()
|
||||
assert window.fill_level == 1.0 # deque maintains max
|
||||
|
||||
# Reset and fill again
|
||||
window.reset()
|
||||
for i in range(5):
|
||||
window.push(sil, frame_idx=i, track_id=1)
|
||||
|
||||
# Large gap (exceeds threshold) - should reset
|
||||
window.push(sil, frame_idx=10, track_id=1) # gap = 5 > 3
|
||||
assert not window.is_ready()
|
||||
assert window.fill_level == 0.2
|
||||
|
||||
def test_get_tensor_shape(self) -> None:
|
||||
"""get_tensor should return tensor of shape [1, 1, window_size, 64, 44]."""
|
||||
window_size = 7
|
||||
window = SilhouetteWindow(window_size=window_size, stride=1, gap_threshold=10)
|
||||
|
||||
# Push unique frames to verify ordering
|
||||
for i in range(window_size):
|
||||
sil = np.full((64, 44), fill_value=float(i), dtype=np.float32)
|
||||
window.push(sil, frame_idx=i, track_id=1)
|
||||
|
||||
assert window.is_ready()
|
||||
|
||||
tensor = window.get_tensor(device="cpu")
|
||||
|
||||
assert isinstance(tensor, torch.Tensor)
|
||||
assert tensor.shape == (1, 1, window_size, 64, 44)
|
||||
assert tensor.dtype == torch.float32
|
||||
|
||||
# Verify content ordering: first frame should have value 0.0
|
||||
assert tensor[0, 0, 0, 0, 0].item() == 0.0
|
||||
# Last frame should have value window_size-1
|
||||
assert tensor[0, 0, window_size - 1, 0, 0].item() == window_size - 1
|
||||
|
||||
def test_should_classify_stride_behavior(self) -> None:
|
||||
"""should_classify should respect stride setting."""
|
||||
window = SilhouetteWindow(window_size=5, stride=3, gap_threshold=10)
|
||||
sil = np.ones((64, 44), dtype=np.float32)
|
||||
|
||||
# Fill window
|
||||
for i in range(5):
|
||||
window.push(sil, frame_idx=i, track_id=1)
|
||||
|
||||
assert window.is_ready()
|
||||
|
||||
# First classification should always trigger
|
||||
assert window.should_classify()
|
||||
|
||||
# Mark as classified at frame 4
|
||||
window.mark_classified()
|
||||
|
||||
# Not ready to classify yet (stride=3, only 0 frames passed)
|
||||
window.push(sil, frame_idx=5, track_id=1)
|
||||
assert not window.should_classify()
|
||||
|
||||
# Still not ready (only 1 frame passed)
|
||||
window.push(sil, frame_idx=6, track_id=1)
|
||||
assert not window.should_classify()
|
||||
|
||||
# Now ready (3 frames passed since last classification)
|
||||
window.push(sil, frame_idx=7, track_id=1)
|
||||
assert window.should_classify()
|
||||
|
||||
def test_should_classify_not_ready(self) -> None:
|
||||
"""should_classify should return False when window not ready."""
|
||||
window = SilhouetteWindow(window_size=5, stride=1, gap_threshold=10)
|
||||
sil = np.ones((64, 44), dtype=np.float32)
|
||||
|
||||
# Push only 3 frames (not ready)
|
||||
for i in range(3):
|
||||
window.push(sil, frame_idx=i, track_id=1)
|
||||
|
||||
assert not window.is_ready()
|
||||
assert not window.should_classify()
|
||||
|
||||
def test_reset_clears_all_state(self) -> None:
|
||||
"""reset should clear all internal state."""
|
||||
window = SilhouetteWindow(window_size=5, stride=1, gap_threshold=10)
|
||||
sil = np.ones((64, 44), dtype=np.float32)
|
||||
|
||||
# Fill and classify
|
||||
for i in range(5):
|
||||
window.push(sil, frame_idx=i, track_id=1)
|
||||
window.mark_classified()
|
||||
|
||||
assert window.is_ready()
|
||||
assert window.current_track_id == 1
|
||||
assert window.frame_count == 5
|
||||
|
||||
# Reset
|
||||
window.reset()
|
||||
|
||||
assert not window.is_ready()
|
||||
assert window.fill_level == 0.0
|
||||
assert window.current_track_id is None
|
||||
assert window.frame_count == 0
|
||||
assert not window.should_classify()
|
||||
|
||||
def test_push_invalid_shape_raises(self) -> None:
|
||||
"""push should raise ValueError for invalid silhouette shape."""
|
||||
window = SilhouetteWindow(window_size=5, stride=1, gap_threshold=10)
|
||||
|
||||
# Wrong shape
|
||||
sil_wrong = np.ones((32, 32), dtype=np.float32)
|
||||
|
||||
with pytest.raises(ValueError, match="Expected silhouette shape"):
|
||||
window.push(sil_wrong, frame_idx=0, track_id=1)
|
||||
|
||||
def test_push_wrong_dtype_converts(self) -> None:
|
||||
"""push should convert dtype to float32."""
|
||||
window = SilhouetteWindow(window_size=1, stride=1, gap_threshold=10)
|
||||
|
||||
# uint8 input
|
||||
sil_uint8 = np.ones((64, 44), dtype=np.uint8) * 255
|
||||
window.push(sil_uint8, frame_idx=0, track_id=1)
|
||||
|
||||
tensor = window.get_tensor()
|
||||
assert tensor.dtype == torch.float32
|
||||
|
||||
|
||||
class TestSelectPerson:
|
||||
"""Tests for select_person function."""
|
||||
|
||||
def _create_mock_results(
|
||||
self,
|
||||
boxes_xyxy: NDArray[np.float32],
|
||||
masks_data: NDArray[np.float32],
|
||||
track_ids: NDArray[np.int64] | None,
|
||||
) -> Any:
|
||||
"""Create a mock detection results object."""
|
||||
mock_boxes = MagicMock()
|
||||
mock_boxes.xyxy = boxes_xyxy
|
||||
mock_boxes.id = track_ids
|
||||
|
||||
mock_masks = MagicMock()
|
||||
mock_masks.data = masks_data
|
||||
|
||||
mock_results = MagicMock()
|
||||
mock_results.boxes = mock_boxes
|
||||
mock_results.masks = mock_masks
|
||||
|
||||
return mock_results
|
||||
|
||||
def test_select_person_single_detection(self) -> None:
|
||||
"""Single detection should return that person's data."""
|
||||
boxes = np.array([[10.0, 10.0, 50.0, 90.0]], dtype=np.float32) # area = 3200
|
||||
masks = np.random.rand(1, 100, 100).astype(np.float32)
|
||||
track_ids = np.array([42], dtype=np.int64)
|
||||
|
||||
results = self._create_mock_results(boxes, masks, track_ids)
|
||||
result = select_person(results)
|
||||
|
||||
assert result is not None
|
||||
mask, bbox, tid = result
|
||||
assert mask.shape == (100, 100)
|
||||
assert bbox == (10, 10, 50, 90)
|
||||
assert tid == 42
|
||||
|
||||
def test_select_person_multi_detection_selects_largest(self) -> None:
|
||||
"""Multiple detections should select the one with largest bbox area."""
|
||||
# Two boxes: second one is larger
|
||||
boxes = np.array(
|
||||
[
|
||||
[0.0, 0.0, 10.0, 10.0], # area = 100
|
||||
[0.0, 0.0, 30.0, 30.0], # area = 900 (largest)
|
||||
[0.0, 0.0, 20.0, 20.0], # area = 400
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
masks = np.random.rand(3, 100, 100).astype(np.float32)
|
||||
track_ids = np.array([1, 2, 3], dtype=np.int64)
|
||||
|
||||
results = self._create_mock_results(boxes, masks, track_ids)
|
||||
result = select_person(results)
|
||||
|
||||
assert result is not None
|
||||
mask, bbox, tid = result
|
||||
assert bbox == (0, 0, 30, 30) # Largest box
|
||||
assert tid == 2 # Corresponding track ID
|
||||
|
||||
def test_select_person_no_detections_returns_none(self) -> None:
|
||||
"""No detections should return None."""
|
||||
boxes = np.array([], dtype=np.float32).reshape(0, 4)
|
||||
masks = np.array([], dtype=np.float32).reshape(0, 100, 100)
|
||||
track_ids = np.array([], dtype=np.int64)
|
||||
|
||||
results = self._create_mock_results(boxes, masks, track_ids)
|
||||
result = select_person(results)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_select_person_no_track_ids_returns_none(self) -> None:
|
||||
"""Detections without track IDs should return None."""
|
||||
boxes = np.array([[10.0, 10.0, 50.0, 90.0]], dtype=np.float32)
|
||||
masks = np.random.rand(1, 100, 100).astype(np.float32)
|
||||
|
||||
results = self._create_mock_results(boxes, masks, track_ids=None)
|
||||
result = select_person(results)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_select_person_empty_track_ids_returns_none(self) -> None:
|
||||
"""Empty track IDs array should return None."""
|
||||
boxes = np.array([[10.0, 10.0, 50.0, 90.0]], dtype=np.float32)
|
||||
masks = np.random.rand(1, 100, 100).astype(np.float32)
|
||||
track_ids = np.array([], dtype=np.int64)
|
||||
|
||||
results = self._create_mock_results(boxes, masks, track_ids)
|
||||
result = select_person(results)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_select_person_missing_boxes_returns_none(self) -> None:
|
||||
"""Missing boxes attribute should return None."""
|
||||
mock_results = MagicMock()
|
||||
mock_results.boxes = None
|
||||
|
||||
result = select_person(mock_results)
|
||||
assert result is None
|
||||
|
||||
def test_select_person_missing_masks_returns_none(self) -> None:
|
||||
"""Missing masks attribute should return None."""
|
||||
boxes = np.array([[10.0, 10.0, 50.0, 90.0]], dtype=np.float32)
|
||||
track_ids = np.array([1], dtype=np.int64)
|
||||
|
||||
mock_boxes = MagicMock()
|
||||
mock_boxes.xyxy = boxes
|
||||
mock_boxes.id = track_ids
|
||||
|
||||
mock_results = MagicMock()
|
||||
mock_results.boxes = mock_boxes
|
||||
mock_results.masks = None
|
||||
|
||||
result = select_person(mock_results)
|
||||
assert result is None
|
||||
|
||||
def test_select_person_1d_bbox_handling(self) -> None:
|
||||
"""1D bbox array should be reshaped to 2D."""
|
||||
boxes = np.array([10.0, 10.0, 50.0, 90.0], dtype=np.float32) # 1D
|
||||
masks = np.random.rand(1, 100, 100).astype(np.float32)
|
||||
track_ids = np.array([1], dtype=np.int64)
|
||||
|
||||
results = self._create_mock_results(boxes, masks, track_ids)
|
||||
result = select_person(results)
|
||||
|
||||
assert result is not None
|
||||
_, bbox, tid = result
|
||||
assert bbox == (10, 10, 50, 90)
|
||||
assert tid == 1
|
||||
|
||||
def test_select_person_2d_mask_handling(self) -> None:
|
||||
"""2D mask should be expanded to 3D."""
|
||||
boxes = np.array([[10.0, 10.0, 50.0, 90.0]], dtype=np.float32)
|
||||
masks = np.random.rand(100, 100).astype(np.float32) # 2D
|
||||
track_ids = np.array([1], dtype=np.int64)
|
||||
|
||||
results = self._create_mock_results(boxes, masks, track_ids)
|
||||
result = select_person(results)
|
||||
|
||||
assert result is not None
|
||||
mask, _, _ = result
|
||||
# Should be 2D (extracted from expanded 3D)
|
||||
assert mask.shape == (100, 100)
|
||||
Reference in New Issue
Block a user