test(demo): add unit and integration coverage for pipeline
Introduce focused unit, integration, and NATS-path tests for demo modules, and align assertions with final schema and temporal contracts (window int, seq=30, fill-level ratio). This commit isolates validation logic from runtime changes and provides reproducible QA for pipeline behavior and failure modes.
This commit is contained in:
@@ -0,0 +1,525 @@
|
||||
"""Integration tests for NATS publisher functionality.
|
||||
|
||||
Tests cover:
|
||||
- NATS message receipt with JSON schema validation
|
||||
- Skip behavior when Docker/NATS unavailable
|
||||
- Container lifecycle management (cleanup)
|
||||
- JSON schema field/type validation
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import socket
|
||||
import subprocess
|
||||
import time
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
import pytest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
|
||||
|
||||
def _find_open_port() -> int:
|
||||
"""Find an available TCP port on localhost."""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
sock.bind(("127.0.0.1", 0))
|
||||
sock.listen(1)
|
||||
addr = cast(tuple[str, int], sock.getsockname())
|
||||
port: int = addr[1]
|
||||
return port
|
||||
|
||||
|
||||
# Constants for test configuration
|
||||
NATS_SUBJECT = "scoliosis.result"
|
||||
CONTAINER_NAME = "opengait-nats-test"
|
||||
|
||||
|
||||
def _docker_available() -> bool:
|
||||
"""Check if Docker is available and running."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["docker", "info"],
|
||||
capture_output=True,
|
||||
timeout=5,
|
||||
check=False,
|
||||
)
|
||||
return result.returncode == 0
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
|
||||
return False
|
||||
|
||||
|
||||
def _nats_container_running() -> bool:
|
||||
"""Check if NATS container is already running."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[
|
||||
"docker",
|
||||
"ps",
|
||||
"--filter",
|
||||
f"name={CONTAINER_NAME}",
|
||||
"--format",
|
||||
"{{.Names}}",
|
||||
],
|
||||
capture_output=True,
|
||||
timeout=5,
|
||||
check=False,
|
||||
)
|
||||
return CONTAINER_NAME in result.stdout.decode()
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
|
||||
return False
|
||||
|
||||
|
||||
def _start_nats_container(port: int) -> bool:
|
||||
"""Start NATS container for testing."""
|
||||
try:
|
||||
# Remove existing container if present
|
||||
_ = subprocess.run(
|
||||
["docker", "rm", "-f", CONTAINER_NAME],
|
||||
capture_output=True,
|
||||
timeout=10,
|
||||
check=False,
|
||||
)
|
||||
|
||||
# Start new NATS container
|
||||
result = subprocess.run(
|
||||
[
|
||||
"docker",
|
||||
"run",
|
||||
"-d",
|
||||
"--name",
|
||||
CONTAINER_NAME,
|
||||
"-p",
|
||||
f"{port}:{port}",
|
||||
"nats:latest",
|
||||
],
|
||||
capture_output=True,
|
||||
timeout=30,
|
||||
check=False,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
return False
|
||||
|
||||
# Wait for NATS to be ready (max 10 seconds)
|
||||
for _ in range(20):
|
||||
time.sleep(0.5)
|
||||
try:
|
||||
check_result = subprocess.run(
|
||||
[
|
||||
"docker",
|
||||
"exec",
|
||||
CONTAINER_NAME,
|
||||
"nats",
|
||||
"server",
|
||||
"check",
|
||||
"server",
|
||||
],
|
||||
capture_output=True,
|
||||
timeout=5,
|
||||
check=False,
|
||||
)
|
||||
if check_result.returncode == 0:
|
||||
return True
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
|
||||
continue
|
||||
|
||||
return False
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
|
||||
return False
|
||||
|
||||
|
||||
def _stop_nats_container() -> None:
|
||||
"""Stop and remove NATS container."""
|
||||
try:
|
||||
_ = subprocess.run(
|
||||
["docker", "rm", "-f", CONTAINER_NAME],
|
||||
capture_output=True,
|
||||
timeout=10,
|
||||
check=False,
|
||||
)
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
|
||||
pass
|
||||
|
||||
|
||||
def _validate_result_schema(data: dict[str, object]) -> tuple[bool, str]:
|
||||
"""Validate JSON result schema.
|
||||
|
||||
Expected schema:
|
||||
{
|
||||
"frame": int,
|
||||
"track_id": int,
|
||||
"label": str (one of: "negative", "neutral", "positive"),
|
||||
"confidence": float in [0, 1],
|
||||
"window": list[int] (start, end),
|
||||
"timestamp_ns": int
|
||||
}
|
||||
"""
|
||||
required_fields: set[str] = {
|
||||
"frame",
|
||||
"track_id",
|
||||
"label",
|
||||
"confidence",
|
||||
"window",
|
||||
"timestamp_ns",
|
||||
}
|
||||
missing = required_fields - set(data.keys())
|
||||
if missing:
|
||||
return False, f"Missing required fields: {missing}"
|
||||
|
||||
# Validate frame (int)
|
||||
frame = data["frame"]
|
||||
if not isinstance(frame, int):
|
||||
return False, f"frame must be int, got {type(frame)}"
|
||||
|
||||
# Validate track_id (int)
|
||||
track_id = data["track_id"]
|
||||
if not isinstance(track_id, int):
|
||||
return False, f"track_id must be int, got {type(track_id)}"
|
||||
|
||||
# Validate label (str in specific set)
|
||||
label = data["label"]
|
||||
valid_labels = {"negative", "neutral", "positive"}
|
||||
if not isinstance(label, str) or label not in valid_labels:
|
||||
return False, f"label must be one of {valid_labels}, got {label}"
|
||||
|
||||
# Validate confidence (float in [0, 1])
|
||||
confidence = data["confidence"]
|
||||
if not isinstance(confidence, (int, float)):
|
||||
return False, f"confidence must be numeric, got {type(confidence)}"
|
||||
if not 0.0 <= float(confidence) <= 1.0:
|
||||
return False, f"confidence must be in [0, 1], got {confidence}"
|
||||
|
||||
# Validate window (list of 2 ints)
|
||||
window = data["window"]
|
||||
if not isinstance(window, list) or len(cast(list[object], window)) != 2:
|
||||
return False, f"window must be list of 2 ints, got {window}"
|
||||
window_list = cast(list[object], window)
|
||||
if not all(isinstance(x, int) for x in window_list):
|
||||
return False, f"window elements must be ints, got {window}"
|
||||
|
||||
# Validate timestamp_ns (int)
|
||||
timestamp_ns = data["timestamp_ns"]
|
||||
if not isinstance(timestamp_ns, int):
|
||||
return False, f"timestamp_ns must be int, got {type(timestamp_ns)}"
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def nats_server() -> "Generator[tuple[bool, int], None, None]":
|
||||
"""Fixture to manage NATS container lifecycle.
|
||||
|
||||
Yields (True, port) if NATS is available, (False, 0) otherwise.
|
||||
Cleans up container after tests.
|
||||
"""
|
||||
if not _docker_available():
|
||||
yield False, 0
|
||||
return
|
||||
|
||||
port = _find_open_port()
|
||||
if not _start_nats_container(port):
|
||||
yield False, 0
|
||||
return
|
||||
|
||||
try:
|
||||
yield True, port
|
||||
finally:
|
||||
_stop_nats_container()
|
||||
|
||||
|
||||
class TestNatsPublisherIntegration:
|
||||
"""Integration tests for NATS publisher with live server."""
|
||||
|
||||
@pytest.mark.skipif(not _docker_available(), reason="Docker not available")
|
||||
def test_nats_message_receipt_and_schema_validation(
|
||||
self, nats_server: tuple[bool, int]
|
||||
) -> None:
|
||||
"""Test that messages are received and match expected schema."""
|
||||
server_available, port = nats_server
|
||||
if not server_available:
|
||||
pytest.skip("NATS server not available")
|
||||
|
||||
nats_url = f"nats://127.0.0.1:{port}"
|
||||
|
||||
# Import here to avoid issues if nats-py not installed
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
import nats # type: ignore[import-untyped]
|
||||
|
||||
except ImportError:
|
||||
pytest.skip("nats-py not installed")
|
||||
|
||||
from opengait.demo.output import NatsPublisher, create_result
|
||||
|
||||
# Create publisher
|
||||
publisher = NatsPublisher(nats_url, subject=NATS_SUBJECT)
|
||||
|
||||
# Create test results
|
||||
test_results = [
|
||||
create_result(
|
||||
frame=100,
|
||||
track_id=1,
|
||||
label="positive",
|
||||
confidence=0.85,
|
||||
window=(70, 100),
|
||||
timestamp_ns=1234567890000,
|
||||
),
|
||||
create_result(
|
||||
frame=130,
|
||||
track_id=1,
|
||||
label="negative",
|
||||
confidence=0.92,
|
||||
window=(100, 130),
|
||||
timestamp_ns=1234567890030,
|
||||
),
|
||||
create_result(
|
||||
frame=160,
|
||||
track_id=2,
|
||||
label="neutral",
|
||||
confidence=0.78,
|
||||
window=(130, 160),
|
||||
timestamp_ns=1234567890060,
|
||||
),
|
||||
]
|
||||
|
||||
# Collect received messages
|
||||
received_messages: list[dict[str, object]] = []
|
||||
|
||||
async def subscribe_and_publish():
|
||||
"""Subscribe to subject, publish messages, collect results."""
|
||||
nc = await nats.connect(nats_url) # pyright: ignore[reportUnknownMemberType]
|
||||
|
||||
# Subscribe
|
||||
sub = await nc.subscribe(NATS_SUBJECT) # pyright: ignore[reportUnknownMemberType]
|
||||
|
||||
# Publish all test results
|
||||
for result in test_results:
|
||||
publisher.publish(result)
|
||||
|
||||
# Wait for messages with timeout
|
||||
for _ in range(len(test_results)):
|
||||
try:
|
||||
msg = await asyncio.wait_for(sub.next_msg(), timeout=5.0)
|
||||
data = json.loads(msg.data.decode("utf-8")) # pyright: ignore[reportAny]
|
||||
received_messages.append(cast(dict[str, object], data))
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
|
||||
await sub.unsubscribe()
|
||||
await nc.close()
|
||||
|
||||
# Run async subscriber
|
||||
asyncio.run(subscribe_and_publish())
|
||||
|
||||
# Cleanup publisher
|
||||
publisher.close()
|
||||
|
||||
# Verify all messages received
|
||||
assert len(received_messages) == len(test_results), (
|
||||
f"Expected {len(test_results)} messages, got {len(received_messages)}"
|
||||
)
|
||||
|
||||
# Validate schema for each message
|
||||
for i, msg in enumerate(received_messages):
|
||||
is_valid, error = _validate_result_schema(msg)
|
||||
assert is_valid, f"Message {i} schema validation failed: {error}"
|
||||
|
||||
# Verify specific values
|
||||
assert received_messages[0]["frame"] == 100
|
||||
assert received_messages[0]["label"] == "positive"
|
||||
assert received_messages[0]["track_id"] == 1
|
||||
_conf = received_messages[0]["confidence"]
|
||||
assert isinstance(_conf, (int, float))
|
||||
assert 0.0 <= float(_conf) <= 1.0
|
||||
|
||||
assert received_messages[1]["label"] == "negative"
|
||||
assert received_messages[2]["track_id"] == 2
|
||||
|
||||
@pytest.mark.skipif(not _docker_available(), reason="Docker not available")
|
||||
def test_nats_publisher_graceful_when_server_unavailable(self) -> None:
|
||||
"""Test that publisher handles missing server gracefully."""
|
||||
try:
|
||||
from opengait.demo.output import NatsPublisher
|
||||
except ImportError:
|
||||
pytest.skip("output module not available")
|
||||
|
||||
# Use wrong port where no server is running
|
||||
bad_url = "nats://127.0.0.1:14222"
|
||||
publisher = NatsPublisher(bad_url, subject=NATS_SUBJECT)
|
||||
|
||||
# Should not raise when publishing without server
|
||||
test_result: dict[str, object] = {
|
||||
"frame": 1,
|
||||
"track_id": 1,
|
||||
"label": "positive",
|
||||
"confidence": 0.85,
|
||||
"window": [0, 30],
|
||||
"timestamp_ns": 1234567890,
|
||||
}
|
||||
|
||||
# Should not raise
|
||||
publisher.publish(test_result)
|
||||
|
||||
# Cleanup should also not raise
|
||||
publisher.close()
|
||||
|
||||
@pytest.mark.skipif(not _docker_available(), reason="Docker not available")
|
||||
def test_nats_publisher_context_manager(
|
||||
self, nats_server: tuple[bool, int]
|
||||
) -> None:
|
||||
"""Test that publisher works as context manager."""
|
||||
server_available, port = nats_server
|
||||
if not server_available:
|
||||
pytest.skip("NATS server not available")
|
||||
|
||||
nats_url = f"nats://127.0.0.1:{port}"
|
||||
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
import nats # type: ignore[import-untyped]
|
||||
from opengait.demo.output import NatsPublisher, create_result
|
||||
except ImportError as e:
|
||||
pytest.skip(f"Required module not available: {e}")
|
||||
|
||||
received_messages: list[dict[str, object]] = []
|
||||
|
||||
async def subscribe_and_test():
|
||||
nc = await nats.connect(nats_url) # pyright: ignore[reportUnknownMemberType]
|
||||
sub = await nc.subscribe(NATS_SUBJECT) # pyright: ignore[reportUnknownMemberType]
|
||||
|
||||
# Use context manager
|
||||
with NatsPublisher(nats_url, subject=NATS_SUBJECT) as publisher:
|
||||
result = create_result(
|
||||
frame=200,
|
||||
track_id=5,
|
||||
label="neutral",
|
||||
confidence=0.65,
|
||||
window=(170, 200),
|
||||
timestamp_ns=9999999999,
|
||||
)
|
||||
publisher.publish(result)
|
||||
|
||||
# Wait for message
|
||||
try:
|
||||
msg = await asyncio.wait_for(sub.next_msg(), timeout=5.0)
|
||||
data = json.loads(msg.data.decode("utf-8")) # pyright: ignore[reportAny]
|
||||
received_messages.append(cast(dict[str, object], data))
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
await sub.unsubscribe()
|
||||
await nc.close()
|
||||
|
||||
asyncio.run(subscribe_and_test())
|
||||
|
||||
assert len(received_messages) == 1
|
||||
assert received_messages[0]["frame"] == 200
|
||||
assert received_messages[0]["track_id"] == 5
|
||||
|
||||
|
||||
class TestNatsSchemaValidation:
|
||||
"""Tests for JSON schema validation without requiring NATS server."""
|
||||
|
||||
def test_validate_result_schema_valid(self) -> None:
|
||||
"""Test schema validation with valid data."""
|
||||
valid_data: dict[str, object] = {
|
||||
"frame": 1234,
|
||||
"track_id": 42,
|
||||
"label": "positive",
|
||||
"confidence": 0.85,
|
||||
"window": [1200, 1230],
|
||||
"timestamp_ns": 1234567890000,
|
||||
}
|
||||
|
||||
is_valid, error = _validate_result_schema(valid_data)
|
||||
assert is_valid, f"Valid data rejected: {error}"
|
||||
|
||||
def test_validate_result_schema_invalid_label(self) -> None:
|
||||
"""Test schema validation rejects invalid label."""
|
||||
invalid_data: dict[str, object] = {
|
||||
"frame": 1234,
|
||||
"track_id": 42,
|
||||
"label": "invalid_label",
|
||||
"confidence": 0.85,
|
||||
"window": [1200, 1230],
|
||||
"timestamp_ns": 1234567890000,
|
||||
}
|
||||
|
||||
is_valid, error = _validate_result_schema(invalid_data)
|
||||
assert not is_valid
|
||||
assert "label" in error.lower()
|
||||
|
||||
def test_validate_result_schema_confidence_out_of_range(self) -> None:
|
||||
"""Test schema validation rejects confidence outside [0, 1]."""
|
||||
invalid_data: dict[str, object] = {
|
||||
"frame": 1234,
|
||||
"track_id": 42,
|
||||
"label": "positive",
|
||||
"confidence": 1.5,
|
||||
"window": [1200, 1230],
|
||||
"timestamp_ns": 1234567890000,
|
||||
}
|
||||
|
||||
is_valid, error = _validate_result_schema(invalid_data)
|
||||
assert not is_valid
|
||||
assert "confidence" in error.lower()
|
||||
|
||||
def test_validate_result_schema_missing_fields(self) -> None:
|
||||
"""Test schema validation detects missing fields."""
|
||||
incomplete_data: dict[str, object] = {
|
||||
"frame": 1234,
|
||||
"label": "positive",
|
||||
}
|
||||
|
||||
is_valid, error = _validate_result_schema(incomplete_data)
|
||||
assert not is_valid
|
||||
assert "missing" in error.lower()
|
||||
|
||||
def test_validate_result_schema_wrong_types(self) -> None:
|
||||
"""Test schema validation rejects wrong types."""
|
||||
wrong_types: dict[str, object] = {
|
||||
"frame": "not_an_int",
|
||||
"track_id": 42,
|
||||
"label": "positive",
|
||||
"confidence": 0.85,
|
||||
"window": [1200, 1230],
|
||||
"timestamp_ns": 1234567890000,
|
||||
}
|
||||
|
||||
is_valid, error = _validate_result_schema(wrong_types)
|
||||
assert not is_valid
|
||||
assert "frame" in error.lower()
|
||||
|
||||
def test_all_valid_labels_accepted(self) -> None:
|
||||
"""Test that all valid labels are accepted."""
|
||||
for label_str in ["negative", "neutral", "positive"]:
|
||||
data: dict[str, object] = {
|
||||
"frame": 100,
|
||||
"track_id": 1,
|
||||
"label": label_str,
|
||||
"confidence": 0.5,
|
||||
"window": [70, 100],
|
||||
"timestamp_ns": 1234567890,
|
||||
}
|
||||
is_valid, error = _validate_result_schema(data)
|
||||
assert is_valid, f"Valid label '{label_str}' rejected: {error}"
|
||||
|
||||
|
||||
class TestDockerAvailability:
|
||||
"""Tests for Docker availability detection."""
|
||||
|
||||
def test_docker_available_check(self) -> None:
|
||||
"""Test Docker availability check doesn't crash."""
|
||||
# This should not raise
|
||||
result = _docker_available()
|
||||
assert isinstance(result, bool)
|
||||
|
||||
def test_nats_container_running_check(self) -> None:
|
||||
"""Test container running check doesn't crash."""
|
||||
# This should not raise even if Docker not available
|
||||
result = _nats_container_running()
|
||||
assert isinstance(result, bool)
|
||||
Reference in New Issue
Block a user