00fcda4fe3
Move demo implementation into opengait_studio, retire Sports2D runtime integration, and align packaging with root-level monorepo dependency management.
525 lines
16 KiB
Python
525 lines
16 KiB
Python
"""Integration tests for NATS publisher functionality.
|
|
|
|
Tests cover:
|
|
- NATS message receipt with JSON schema validation
|
|
- Skip behavior when Docker/NATS unavailable
|
|
- Container lifecycle management (cleanup)
|
|
- JSON schema field/type validation
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import socket
|
|
import subprocess
|
|
import time
|
|
from typing import TYPE_CHECKING, cast
|
|
|
|
import pytest
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Generator
|
|
|
|
|
|
def _find_open_port() -> int:
|
|
"""Find an available TCP port on localhost."""
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
|
sock.bind(("127.0.0.1", 0))
|
|
sock.listen(1)
|
|
addr = cast(tuple[str, int], sock.getsockname())
|
|
port: int = addr[1]
|
|
return port
|
|
|
|
|
|
# Constants for test configuration
|
|
NATS_SUBJECT = "scoliosis.result"
|
|
CONTAINER_NAME = "opengait-nats-test"
|
|
|
|
|
|
def _docker_available() -> bool:
|
|
"""Check if Docker is available and running."""
|
|
try:
|
|
result = subprocess.run(
|
|
["docker", "info"],
|
|
capture_output=True,
|
|
timeout=5,
|
|
check=False,
|
|
)
|
|
return result.returncode == 0
|
|
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
|
|
return False
|
|
|
|
|
|
def _nats_container_running() -> bool:
|
|
"""Check if NATS container is already running."""
|
|
try:
|
|
result = subprocess.run(
|
|
[
|
|
"docker",
|
|
"ps",
|
|
"--filter",
|
|
f"name={CONTAINER_NAME}",
|
|
"--format",
|
|
"{{.Names}}",
|
|
],
|
|
capture_output=True,
|
|
timeout=5,
|
|
check=False,
|
|
)
|
|
return CONTAINER_NAME in result.stdout.decode()
|
|
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
|
|
return False
|
|
|
|
|
|
def _start_nats_container(port: int) -> bool:
|
|
"""Start NATS container for testing."""
|
|
try:
|
|
# Remove existing container if present
|
|
_ = subprocess.run(
|
|
["docker", "rm", "-f", CONTAINER_NAME],
|
|
capture_output=True,
|
|
timeout=10,
|
|
check=False,
|
|
)
|
|
|
|
# Start new NATS container
|
|
result = subprocess.run(
|
|
[
|
|
"docker",
|
|
"run",
|
|
"-d",
|
|
"--name",
|
|
CONTAINER_NAME,
|
|
"-p",
|
|
f"{port}:4222",
|
|
"nats:latest",
|
|
],
|
|
capture_output=True,
|
|
timeout=30,
|
|
check=False,
|
|
)
|
|
|
|
if result.returncode != 0:
|
|
return False
|
|
|
|
# Wait for NATS to be ready (max 10 seconds)
|
|
for _ in range(20):
|
|
time.sleep(0.5)
|
|
try:
|
|
check_result = subprocess.run(
|
|
[
|
|
"docker",
|
|
"exec",
|
|
CONTAINER_NAME,
|
|
"nats",
|
|
"server",
|
|
"check",
|
|
"server",
|
|
],
|
|
capture_output=True,
|
|
timeout=5,
|
|
check=False,
|
|
)
|
|
if check_result.returncode == 0:
|
|
return True
|
|
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
|
|
continue
|
|
|
|
return False
|
|
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
|
|
return False
|
|
|
|
|
|
def _stop_nats_container() -> None:
|
|
"""Stop and remove NATS container."""
|
|
try:
|
|
_ = subprocess.run(
|
|
["docker", "rm", "-f", CONTAINER_NAME],
|
|
capture_output=True,
|
|
timeout=10,
|
|
check=False,
|
|
)
|
|
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
|
|
pass
|
|
|
|
|
|
def _validate_result_schema(data: dict[str, object]) -> tuple[bool, str]:
|
|
"""Validate JSON result schema.
|
|
|
|
Expected schema:
|
|
{
|
|
"frame": int,
|
|
"track_id": int,
|
|
"label": str (one of: "negative", "neutral", "positive"),
|
|
"confidence": float in [0, 1],
|
|
"window": int (non-negative),
|
|
"timestamp_ns": int
|
|
}
|
|
"""
|
|
required_fields: set[str] = {
|
|
"frame",
|
|
"track_id",
|
|
"label",
|
|
"confidence",
|
|
"window",
|
|
"timestamp_ns",
|
|
}
|
|
missing = required_fields - set(data.keys())
|
|
if missing:
|
|
return False, f"Missing required fields: {missing}"
|
|
|
|
# Validate frame (int)
|
|
frame = data["frame"]
|
|
if not isinstance(frame, int):
|
|
return False, f"frame must be int, got {type(frame)}"
|
|
|
|
# Validate track_id (int)
|
|
track_id = data["track_id"]
|
|
if not isinstance(track_id, int):
|
|
return False, f"track_id must be int, got {type(track_id)}"
|
|
|
|
# Validate label (str in specific set)
|
|
label = data["label"]
|
|
valid_labels = {"negative", "neutral", "positive"}
|
|
if not isinstance(label, str) or label not in valid_labels:
|
|
return False, f"label must be one of {valid_labels}, got {label}"
|
|
|
|
# Validate confidence (float in [0, 1])
|
|
confidence = data["confidence"]
|
|
if not isinstance(confidence, (int, float)):
|
|
return False, f"confidence must be numeric, got {type(confidence)}"
|
|
if not 0.0 <= float(confidence) <= 1.0:
|
|
return False, f"confidence must be in [0, 1], got {confidence}"
|
|
# Validate window (int, non-negative)
|
|
window = data["window"]
|
|
if not isinstance(window, int):
|
|
return False, f"window must be int, got {type(window)}"
|
|
if window < 0:
|
|
return False, f"window must be non-negative, got {window}"
|
|
|
|
|
|
# Validate timestamp_ns (int)
|
|
timestamp_ns = data["timestamp_ns"]
|
|
if not isinstance(timestamp_ns, int):
|
|
return False, f"timestamp_ns must be int, got {type(timestamp_ns)}"
|
|
|
|
return True, ""
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def nats_server() -> "Generator[tuple[bool, int], None, None]":
|
|
"""Fixture to manage NATS container lifecycle.
|
|
|
|
Yields (True, port) if NATS is available, (False, 0) otherwise.
|
|
Cleans up container after tests.
|
|
"""
|
|
if not _docker_available():
|
|
yield False, 0
|
|
return
|
|
|
|
port = _find_open_port()
|
|
if not _start_nats_container(port):
|
|
yield False, 0
|
|
return
|
|
|
|
try:
|
|
yield True, port
|
|
finally:
|
|
_stop_nats_container()
|
|
|
|
|
|
class TestNatsPublisherIntegration:
|
|
"""Integration tests for NATS publisher with live server."""
|
|
|
|
@pytest.mark.skipif(not _docker_available(), reason="Docker not available")
|
|
def test_nats_message_receipt_and_schema_validation(
|
|
self, nats_server: tuple[bool, int]
|
|
) -> None:
|
|
"""Test that messages are received and match expected schema."""
|
|
server_available, port = nats_server
|
|
if not server_available:
|
|
pytest.skip("NATS server not available")
|
|
|
|
nats_url = f"nats://127.0.0.1:{port}"
|
|
|
|
# Import here to avoid issues if nats-py not installed
|
|
try:
|
|
import asyncio
|
|
|
|
import nats # type: ignore[import-untyped]
|
|
|
|
except ImportError:
|
|
pytest.skip("nats-py not installed")
|
|
|
|
from opengait_studio.output import NatsPublisher, create_result
|
|
|
|
# Create publisher
|
|
publisher = NatsPublisher(nats_url, subject=NATS_SUBJECT)
|
|
|
|
# Create test results
|
|
test_results = [
|
|
create_result(
|
|
frame=100,
|
|
track_id=1,
|
|
label="positive",
|
|
confidence=0.85,
|
|
window=(70, 100),
|
|
timestamp_ns=1234567890000,
|
|
),
|
|
create_result(
|
|
frame=130,
|
|
track_id=1,
|
|
label="negative",
|
|
confidence=0.92,
|
|
window=(100, 130),
|
|
timestamp_ns=1234567890030,
|
|
),
|
|
create_result(
|
|
frame=160,
|
|
track_id=2,
|
|
label="neutral",
|
|
confidence=0.78,
|
|
window=(130, 160),
|
|
timestamp_ns=1234567890060,
|
|
),
|
|
]
|
|
|
|
# Collect received messages
|
|
received_messages: list[dict[str, object]] = []
|
|
|
|
async def subscribe_and_publish():
|
|
"""Subscribe to subject, publish messages, collect results."""
|
|
nc = await nats.connect(nats_url) # pyright: ignore[reportUnknownMemberType]
|
|
|
|
# Subscribe
|
|
sub = await nc.subscribe(NATS_SUBJECT) # pyright: ignore[reportUnknownMemberType]
|
|
|
|
# Publish all test results
|
|
for result in test_results:
|
|
publisher.publish(result)
|
|
|
|
# Wait for messages with timeout
|
|
for _ in range(len(test_results)):
|
|
try:
|
|
msg = await asyncio.wait_for(sub.next_msg(), timeout=5.0)
|
|
data = json.loads(msg.data.decode("utf-8")) # pyright: ignore[reportAny]
|
|
received_messages.append(cast(dict[str, object], data))
|
|
except asyncio.TimeoutError:
|
|
break
|
|
|
|
await sub.unsubscribe()
|
|
await nc.close()
|
|
|
|
# Run async subscriber
|
|
asyncio.run(subscribe_and_publish())
|
|
|
|
# Cleanup publisher
|
|
publisher.close()
|
|
|
|
# Verify all messages received
|
|
assert len(received_messages) == len(test_results), (
|
|
f"Expected {len(test_results)} messages, got {len(received_messages)}"
|
|
)
|
|
|
|
# Validate schema for each message
|
|
for i, msg in enumerate(received_messages):
|
|
is_valid, error = _validate_result_schema(msg)
|
|
assert is_valid, f"Message {i} schema validation failed: {error}"
|
|
|
|
# Verify specific values
|
|
assert received_messages[0]["frame"] == 100
|
|
assert received_messages[0]["label"] == "positive"
|
|
assert received_messages[0]["track_id"] == 1
|
|
_conf = received_messages[0]["confidence"]
|
|
assert isinstance(_conf, (int, float))
|
|
assert 0.0 <= float(_conf) <= 1.0
|
|
|
|
assert received_messages[1]["label"] == "negative"
|
|
assert received_messages[2]["track_id"] == 2
|
|
|
|
@pytest.mark.skipif(not _docker_available(), reason="Docker not available")
|
|
def test_nats_publisher_graceful_when_server_unavailable(self) -> None:
|
|
"""Test that publisher handles missing server gracefully."""
|
|
try:
|
|
from opengait_studio.output import NatsPublisher
|
|
except ImportError:
|
|
pytest.skip("output module not available")
|
|
|
|
# Use wrong port where no server is running
|
|
bad_url = "nats://127.0.0.1:14222"
|
|
publisher = NatsPublisher(bad_url, subject=NATS_SUBJECT)
|
|
|
|
# Should not raise when publishing without server
|
|
test_result: dict[str, object] = {
|
|
"frame": 1,
|
|
"track_id": 1,
|
|
"label": "positive",
|
|
"confidence": 0.85,
|
|
"window": 30,
|
|
"timestamp_ns": 1234567890,
|
|
}
|
|
|
|
# Should not raise
|
|
publisher.publish(test_result)
|
|
|
|
# Cleanup should also not raise
|
|
publisher.close()
|
|
|
|
@pytest.mark.skipif(not _docker_available(), reason="Docker not available")
|
|
def test_nats_publisher_context_manager(
|
|
self, nats_server: tuple[bool, int]
|
|
) -> None:
|
|
"""Test that publisher works as context manager."""
|
|
server_available, port = nats_server
|
|
if not server_available:
|
|
pytest.skip("NATS server not available")
|
|
|
|
nats_url = f"nats://127.0.0.1:{port}"
|
|
|
|
try:
|
|
import asyncio
|
|
|
|
import nats # type: ignore[import-untyped]
|
|
from opengait_studio.output import NatsPublisher, create_result
|
|
except ImportError as e:
|
|
pytest.skip(f"Required module not available: {e}")
|
|
|
|
received_messages: list[dict[str, object]] = []
|
|
|
|
async def subscribe_and_test():
|
|
nc = await nats.connect(nats_url) # pyright: ignore[reportUnknownMemberType]
|
|
sub = await nc.subscribe(NATS_SUBJECT) # pyright: ignore[reportUnknownMemberType]
|
|
|
|
# Use context manager
|
|
with NatsPublisher(nats_url, subject=NATS_SUBJECT) as publisher:
|
|
result = create_result(
|
|
frame=200,
|
|
track_id=5,
|
|
label="neutral",
|
|
confidence=0.65,
|
|
window=(170, 200),
|
|
timestamp_ns=9999999999,
|
|
)
|
|
publisher.publish(result)
|
|
|
|
# Wait for message
|
|
try:
|
|
msg = await asyncio.wait_for(sub.next_msg(), timeout=5.0)
|
|
data = json.loads(msg.data.decode("utf-8")) # pyright: ignore[reportAny]
|
|
received_messages.append(cast(dict[str, object], data))
|
|
except asyncio.TimeoutError:
|
|
pass
|
|
|
|
await sub.unsubscribe()
|
|
await nc.close()
|
|
|
|
asyncio.run(subscribe_and_test())
|
|
|
|
assert len(received_messages) == 1
|
|
assert received_messages[0]["frame"] == 200
|
|
assert received_messages[0]["track_id"] == 5
|
|
|
|
|
|
class TestNatsSchemaValidation:
|
|
"""Tests for JSON schema validation without requiring NATS server."""
|
|
|
|
def test_validate_result_schema_valid(self) -> None:
|
|
"""Test schema validation with valid data."""
|
|
valid_data: dict[str, object] = {
|
|
"frame": 1234,
|
|
"track_id": 42,
|
|
"label": "positive",
|
|
"confidence": 0.85,
|
|
"window": 1230,
|
|
"timestamp_ns": 1234567890000,
|
|
}
|
|
|
|
is_valid, error = _validate_result_schema(valid_data)
|
|
assert is_valid, f"Valid data rejected: {error}"
|
|
|
|
def test_validate_result_schema_invalid_label(self) -> None:
|
|
"""Test schema validation rejects invalid label."""
|
|
invalid_data: dict[str, object] = {
|
|
"frame": 1234,
|
|
"track_id": 42,
|
|
"label": "invalid_label",
|
|
"confidence": 0.85,
|
|
"window": 1230,
|
|
"timestamp_ns": 1234567890000,
|
|
}
|
|
|
|
is_valid, error = _validate_result_schema(invalid_data)
|
|
assert not is_valid
|
|
assert "label" in error.lower()
|
|
|
|
def test_validate_result_schema_confidence_out_of_range(self) -> None:
|
|
"""Test schema validation rejects confidence outside [0, 1]."""
|
|
invalid_data: dict[str, object] = {
|
|
"frame": 1234,
|
|
"track_id": 42,
|
|
"label": "positive",
|
|
"confidence": 1.5,
|
|
"window": 1230,
|
|
"timestamp_ns": 1234567890000,
|
|
}
|
|
|
|
is_valid, error = _validate_result_schema(invalid_data)
|
|
assert not is_valid
|
|
assert "confidence" in error.lower()
|
|
|
|
def test_validate_result_schema_missing_fields(self) -> None:
|
|
"""Test schema validation detects missing fields."""
|
|
incomplete_data: dict[str, object] = {
|
|
"frame": 1234,
|
|
"label": "positive",
|
|
}
|
|
|
|
is_valid, error = _validate_result_schema(incomplete_data)
|
|
assert not is_valid
|
|
assert "missing" in error.lower()
|
|
|
|
def test_validate_result_schema_wrong_types(self) -> None:
|
|
"""Test schema validation rejects wrong types."""
|
|
wrong_types: dict[str, object] = {
|
|
"frame": "not_an_int",
|
|
"track_id": 42,
|
|
"label": "positive",
|
|
"confidence": 0.85,
|
|
"window": 1230,
|
|
"timestamp_ns": 1234567890000,
|
|
}
|
|
|
|
is_valid, error = _validate_result_schema(wrong_types)
|
|
assert not is_valid
|
|
assert "frame" in error.lower()
|
|
|
|
def test_all_valid_labels_accepted(self) -> None:
|
|
"""Test that all valid labels are accepted."""
|
|
for label_str in ["negative", "neutral", "positive"]:
|
|
data: dict[str, object] = {
|
|
"frame": 100,
|
|
"track_id": 1,
|
|
"label": label_str,
|
|
"confidence": 0.5,
|
|
"window": 100,
|
|
"timestamp_ns": 1234567890,
|
|
}
|
|
is_valid, error = _validate_result_schema(data)
|
|
assert is_valid, f"Valid label '{label_str}' rejected: {error}"
|
|
|
|
|
|
class TestDockerAvailability:
|
|
"""Tests for Docker availability detection."""
|
|
|
|
def test_docker_available_check(self) -> None:
|
|
"""Test Docker availability check doesn't crash."""
|
|
# This should not raise
|
|
result = _docker_available()
|
|
assert isinstance(result, bool)
|
|
|
|
def test_nats_container_running_check(self) -> None:
|
|
"""Test container running check doesn't crash."""
|
|
# This should not raise even if Docker not available
|
|
result = _nats_container_running()
|
|
assert isinstance(result, bool)
|