f501119d43
Add preprocess-only silhouette export and configurable result exporters so demo runs can be persisted for offline analysis and reproducible evaluation. Include optional parquet support and CLI visualization dumps while updating tests and tracking notes for the verified pipeline/debug workflow.
525 lines
16 KiB
Python
525 lines
16 KiB
Python
"""Integration tests for NATS publisher functionality.
|
|
|
|
Tests cover:
|
|
- NATS message receipt with JSON schema validation
|
|
- Skip behavior when Docker/NATS unavailable
|
|
- Container lifecycle management (cleanup)
|
|
- JSON schema field/type validation
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import socket
|
|
import subprocess
|
|
import time
|
|
from typing import TYPE_CHECKING, cast
|
|
|
|
import pytest
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Generator
|
|
|
|
|
|
def _find_open_port() -> int:
|
|
"""Find an available TCP port on localhost."""
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
|
sock.bind(("127.0.0.1", 0))
|
|
sock.listen(1)
|
|
addr = cast(tuple[str, int], sock.getsockname())
|
|
port: int = addr[1]
|
|
return port
|
|
|
|
|
|
# Constants for test configuration
|
|
NATS_SUBJECT = "scoliosis.result"
|
|
CONTAINER_NAME = "opengait-nats-test"
|
|
|
|
|
|
def _docker_available() -> bool:
|
|
"""Check if Docker is available and running."""
|
|
try:
|
|
result = subprocess.run(
|
|
["docker", "info"],
|
|
capture_output=True,
|
|
timeout=5,
|
|
check=False,
|
|
)
|
|
return result.returncode == 0
|
|
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
|
|
return False
|
|
|
|
|
|
def _nats_container_running() -> bool:
|
|
"""Check if NATS container is already running."""
|
|
try:
|
|
result = subprocess.run(
|
|
[
|
|
"docker",
|
|
"ps",
|
|
"--filter",
|
|
f"name={CONTAINER_NAME}",
|
|
"--format",
|
|
"{{.Names}}",
|
|
],
|
|
capture_output=True,
|
|
timeout=5,
|
|
check=False,
|
|
)
|
|
return CONTAINER_NAME in result.stdout.decode()
|
|
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
|
|
return False
|
|
|
|
|
|
def _start_nats_container(port: int) -> bool:
|
|
"""Start NATS container for testing."""
|
|
try:
|
|
# Remove existing container if present
|
|
_ = subprocess.run(
|
|
["docker", "rm", "-f", CONTAINER_NAME],
|
|
capture_output=True,
|
|
timeout=10,
|
|
check=False,
|
|
)
|
|
|
|
# Start new NATS container
|
|
result = subprocess.run(
|
|
[
|
|
"docker",
|
|
"run",
|
|
"-d",
|
|
"--name",
|
|
CONTAINER_NAME,
|
|
"-p",
|
|
f"{port}:4222",
|
|
"nats:latest",
|
|
],
|
|
capture_output=True,
|
|
timeout=30,
|
|
check=False,
|
|
)
|
|
|
|
if result.returncode != 0:
|
|
return False
|
|
|
|
# Wait for NATS to be ready (max 10 seconds)
|
|
for _ in range(20):
|
|
time.sleep(0.5)
|
|
try:
|
|
check_result = subprocess.run(
|
|
[
|
|
"docker",
|
|
"exec",
|
|
CONTAINER_NAME,
|
|
"nats",
|
|
"server",
|
|
"check",
|
|
"server",
|
|
],
|
|
capture_output=True,
|
|
timeout=5,
|
|
check=False,
|
|
)
|
|
if check_result.returncode == 0:
|
|
return True
|
|
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
|
|
continue
|
|
|
|
return False
|
|
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
|
|
return False
|
|
|
|
|
|
def _stop_nats_container() -> None:
|
|
"""Stop and remove NATS container."""
|
|
try:
|
|
_ = subprocess.run(
|
|
["docker", "rm", "-f", CONTAINER_NAME],
|
|
capture_output=True,
|
|
timeout=10,
|
|
check=False,
|
|
)
|
|
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
|
|
pass
|
|
|
|
|
|
def _validate_result_schema(data: dict[str, object]) -> tuple[bool, str]:
|
|
"""Validate JSON result schema.
|
|
|
|
Expected schema:
|
|
{
|
|
"frame": int,
|
|
"track_id": int,
|
|
"label": str (one of: "negative", "neutral", "positive"),
|
|
"confidence": float in [0, 1],
|
|
"window": int (non-negative),
|
|
"timestamp_ns": int
|
|
}
|
|
"""
|
|
required_fields: set[str] = {
|
|
"frame",
|
|
"track_id",
|
|
"label",
|
|
"confidence",
|
|
"window",
|
|
"timestamp_ns",
|
|
}
|
|
missing = required_fields - set(data.keys())
|
|
if missing:
|
|
return False, f"Missing required fields: {missing}"
|
|
|
|
# Validate frame (int)
|
|
frame = data["frame"]
|
|
if not isinstance(frame, int):
|
|
return False, f"frame must be int, got {type(frame)}"
|
|
|
|
# Validate track_id (int)
|
|
track_id = data["track_id"]
|
|
if not isinstance(track_id, int):
|
|
return False, f"track_id must be int, got {type(track_id)}"
|
|
|
|
# Validate label (str in specific set)
|
|
label = data["label"]
|
|
valid_labels = {"negative", "neutral", "positive"}
|
|
if not isinstance(label, str) or label not in valid_labels:
|
|
return False, f"label must be one of {valid_labels}, got {label}"
|
|
|
|
# Validate confidence (float in [0, 1])
|
|
confidence = data["confidence"]
|
|
if not isinstance(confidence, (int, float)):
|
|
return False, f"confidence must be numeric, got {type(confidence)}"
|
|
if not 0.0 <= float(confidence) <= 1.0:
|
|
return False, f"confidence must be in [0, 1], got {confidence}"
|
|
# Validate window (int, non-negative)
|
|
window = data["window"]
|
|
if not isinstance(window, int):
|
|
return False, f"window must be int, got {type(window)}"
|
|
if window < 0:
|
|
return False, f"window must be non-negative, got {window}"
|
|
|
|
|
|
# Validate timestamp_ns (int)
|
|
timestamp_ns = data["timestamp_ns"]
|
|
if not isinstance(timestamp_ns, int):
|
|
return False, f"timestamp_ns must be int, got {type(timestamp_ns)}"
|
|
|
|
return True, ""
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def nats_server() -> "Generator[tuple[bool, int], None, None]":
|
|
"""Fixture to manage NATS container lifecycle.
|
|
|
|
Yields (True, port) if NATS is available, (False, 0) otherwise.
|
|
Cleans up container after tests.
|
|
"""
|
|
if not _docker_available():
|
|
yield False, 0
|
|
return
|
|
|
|
port = _find_open_port()
|
|
if not _start_nats_container(port):
|
|
yield False, 0
|
|
return
|
|
|
|
try:
|
|
yield True, port
|
|
finally:
|
|
_stop_nats_container()
|
|
|
|
|
|
class TestNatsPublisherIntegration:
|
|
"""Integration tests for NATS publisher with live server."""
|
|
|
|
@pytest.mark.skipif(not _docker_available(), reason="Docker not available")
|
|
def test_nats_message_receipt_and_schema_validation(
|
|
self, nats_server: tuple[bool, int]
|
|
) -> None:
|
|
"""Test that messages are received and match expected schema."""
|
|
server_available, port = nats_server
|
|
if not server_available:
|
|
pytest.skip("NATS server not available")
|
|
|
|
nats_url = f"nats://127.0.0.1:{port}"
|
|
|
|
# Import here to avoid issues if nats-py not installed
|
|
try:
|
|
import asyncio
|
|
|
|
import nats # type: ignore[import-untyped]
|
|
|
|
except ImportError:
|
|
pytest.skip("nats-py not installed")
|
|
|
|
from opengait.demo.output import NatsPublisher, create_result
|
|
|
|
# Create publisher
|
|
publisher = NatsPublisher(nats_url, subject=NATS_SUBJECT)
|
|
|
|
# Create test results
|
|
test_results = [
|
|
create_result(
|
|
frame=100,
|
|
track_id=1,
|
|
label="positive",
|
|
confidence=0.85,
|
|
window=(70, 100),
|
|
timestamp_ns=1234567890000,
|
|
),
|
|
create_result(
|
|
frame=130,
|
|
track_id=1,
|
|
label="negative",
|
|
confidence=0.92,
|
|
window=(100, 130),
|
|
timestamp_ns=1234567890030,
|
|
),
|
|
create_result(
|
|
frame=160,
|
|
track_id=2,
|
|
label="neutral",
|
|
confidence=0.78,
|
|
window=(130, 160),
|
|
timestamp_ns=1234567890060,
|
|
),
|
|
]
|
|
|
|
# Collect received messages
|
|
received_messages: list[dict[str, object]] = []
|
|
|
|
async def subscribe_and_publish():
|
|
"""Subscribe to subject, publish messages, collect results."""
|
|
nc = await nats.connect(nats_url) # pyright: ignore[reportUnknownMemberType]
|
|
|
|
# Subscribe
|
|
sub = await nc.subscribe(NATS_SUBJECT) # pyright: ignore[reportUnknownMemberType]
|
|
|
|
# Publish all test results
|
|
for result in test_results:
|
|
publisher.publish(result)
|
|
|
|
# Wait for messages with timeout
|
|
for _ in range(len(test_results)):
|
|
try:
|
|
msg = await asyncio.wait_for(sub.next_msg(), timeout=5.0)
|
|
data = json.loads(msg.data.decode("utf-8")) # pyright: ignore[reportAny]
|
|
received_messages.append(cast(dict[str, object], data))
|
|
except asyncio.TimeoutError:
|
|
break
|
|
|
|
await sub.unsubscribe()
|
|
await nc.close()
|
|
|
|
# Run async subscriber
|
|
asyncio.run(subscribe_and_publish())
|
|
|
|
# Cleanup publisher
|
|
publisher.close()
|
|
|
|
# Verify all messages received
|
|
assert len(received_messages) == len(test_results), (
|
|
f"Expected {len(test_results)} messages, got {len(received_messages)}"
|
|
)
|
|
|
|
# Validate schema for each message
|
|
for i, msg in enumerate(received_messages):
|
|
is_valid, error = _validate_result_schema(msg)
|
|
assert is_valid, f"Message {i} schema validation failed: {error}"
|
|
|
|
# Verify specific values
|
|
assert received_messages[0]["frame"] == 100
|
|
assert received_messages[0]["label"] == "positive"
|
|
assert received_messages[0]["track_id"] == 1
|
|
_conf = received_messages[0]["confidence"]
|
|
assert isinstance(_conf, (int, float))
|
|
assert 0.0 <= float(_conf) <= 1.0
|
|
|
|
assert received_messages[1]["label"] == "negative"
|
|
assert received_messages[2]["track_id"] == 2
|
|
|
|
@pytest.mark.skipif(not _docker_available(), reason="Docker not available")
|
|
def test_nats_publisher_graceful_when_server_unavailable(self) -> None:
|
|
"""Test that publisher handles missing server gracefully."""
|
|
try:
|
|
from opengait.demo.output import NatsPublisher
|
|
except ImportError:
|
|
pytest.skip("output module not available")
|
|
|
|
# Use wrong port where no server is running
|
|
bad_url = "nats://127.0.0.1:14222"
|
|
publisher = NatsPublisher(bad_url, subject=NATS_SUBJECT)
|
|
|
|
# Should not raise when publishing without server
|
|
test_result: dict[str, object] = {
|
|
"frame": 1,
|
|
"track_id": 1,
|
|
"label": "positive",
|
|
"confidence": 0.85,
|
|
"window": 30,
|
|
"timestamp_ns": 1234567890,
|
|
}
|
|
|
|
# Should not raise
|
|
publisher.publish(test_result)
|
|
|
|
# Cleanup should also not raise
|
|
publisher.close()
|
|
|
|
@pytest.mark.skipif(not _docker_available(), reason="Docker not available")
|
|
def test_nats_publisher_context_manager(
|
|
self, nats_server: tuple[bool, int]
|
|
) -> None:
|
|
"""Test that publisher works as context manager."""
|
|
server_available, port = nats_server
|
|
if not server_available:
|
|
pytest.skip("NATS server not available")
|
|
|
|
nats_url = f"nats://127.0.0.1:{port}"
|
|
|
|
try:
|
|
import asyncio
|
|
|
|
import nats # type: ignore[import-untyped]
|
|
from opengait.demo.output import NatsPublisher, create_result
|
|
except ImportError as e:
|
|
pytest.skip(f"Required module not available: {e}")
|
|
|
|
received_messages: list[dict[str, object]] = []
|
|
|
|
async def subscribe_and_test():
|
|
nc = await nats.connect(nats_url) # pyright: ignore[reportUnknownMemberType]
|
|
sub = await nc.subscribe(NATS_SUBJECT) # pyright: ignore[reportUnknownMemberType]
|
|
|
|
# Use context manager
|
|
with NatsPublisher(nats_url, subject=NATS_SUBJECT) as publisher:
|
|
result = create_result(
|
|
frame=200,
|
|
track_id=5,
|
|
label="neutral",
|
|
confidence=0.65,
|
|
window=(170, 200),
|
|
timestamp_ns=9999999999,
|
|
)
|
|
publisher.publish(result)
|
|
|
|
# Wait for message
|
|
try:
|
|
msg = await asyncio.wait_for(sub.next_msg(), timeout=5.0)
|
|
data = json.loads(msg.data.decode("utf-8")) # pyright: ignore[reportAny]
|
|
received_messages.append(cast(dict[str, object], data))
|
|
except asyncio.TimeoutError:
|
|
pass
|
|
|
|
await sub.unsubscribe()
|
|
await nc.close()
|
|
|
|
asyncio.run(subscribe_and_test())
|
|
|
|
assert len(received_messages) == 1
|
|
assert received_messages[0]["frame"] == 200
|
|
assert received_messages[0]["track_id"] == 5
|
|
|
|
|
|
class TestNatsSchemaValidation:
|
|
"""Tests for JSON schema validation without requiring NATS server."""
|
|
|
|
def test_validate_result_schema_valid(self) -> None:
|
|
"""Test schema validation with valid data."""
|
|
valid_data: dict[str, object] = {
|
|
"frame": 1234,
|
|
"track_id": 42,
|
|
"label": "positive",
|
|
"confidence": 0.85,
|
|
"window": 1230,
|
|
"timestamp_ns": 1234567890000,
|
|
}
|
|
|
|
is_valid, error = _validate_result_schema(valid_data)
|
|
assert is_valid, f"Valid data rejected: {error}"
|
|
|
|
def test_validate_result_schema_invalid_label(self) -> None:
|
|
"""Test schema validation rejects invalid label."""
|
|
invalid_data: dict[str, object] = {
|
|
"frame": 1234,
|
|
"track_id": 42,
|
|
"label": "invalid_label",
|
|
"confidence": 0.85,
|
|
"window": 1230,
|
|
"timestamp_ns": 1234567890000,
|
|
}
|
|
|
|
is_valid, error = _validate_result_schema(invalid_data)
|
|
assert not is_valid
|
|
assert "label" in error.lower()
|
|
|
|
def test_validate_result_schema_confidence_out_of_range(self) -> None:
|
|
"""Test schema validation rejects confidence outside [0, 1]."""
|
|
invalid_data: dict[str, object] = {
|
|
"frame": 1234,
|
|
"track_id": 42,
|
|
"label": "positive",
|
|
"confidence": 1.5,
|
|
"window": 1230,
|
|
"timestamp_ns": 1234567890000,
|
|
}
|
|
|
|
is_valid, error = _validate_result_schema(invalid_data)
|
|
assert not is_valid
|
|
assert "confidence" in error.lower()
|
|
|
|
def test_validate_result_schema_missing_fields(self) -> None:
|
|
"""Test schema validation detects missing fields."""
|
|
incomplete_data: dict[str, object] = {
|
|
"frame": 1234,
|
|
"label": "positive",
|
|
}
|
|
|
|
is_valid, error = _validate_result_schema(incomplete_data)
|
|
assert not is_valid
|
|
assert "missing" in error.lower()
|
|
|
|
def test_validate_result_schema_wrong_types(self) -> None:
|
|
"""Test schema validation rejects wrong types."""
|
|
wrong_types: dict[str, object] = {
|
|
"frame": "not_an_int",
|
|
"track_id": 42,
|
|
"label": "positive",
|
|
"confidence": 0.85,
|
|
"window": 1230,
|
|
"timestamp_ns": 1234567890000,
|
|
}
|
|
|
|
is_valid, error = _validate_result_schema(wrong_types)
|
|
assert not is_valid
|
|
assert "frame" in error.lower()
|
|
|
|
def test_all_valid_labels_accepted(self) -> None:
|
|
"""Test that all valid labels are accepted."""
|
|
for label_str in ["negative", "neutral", "positive"]:
|
|
data: dict[str, object] = {
|
|
"frame": 100,
|
|
"track_id": 1,
|
|
"label": label_str,
|
|
"confidence": 0.5,
|
|
"window": 100,
|
|
"timestamp_ns": 1234567890,
|
|
}
|
|
is_valid, error = _validate_result_schema(data)
|
|
assert is_valid, f"Valid label '{label_str}' rejected: {error}"
|
|
|
|
|
|
class TestDockerAvailability:
|
|
"""Tests for Docker availability detection."""
|
|
|
|
def test_docker_available_check(self) -> None:
|
|
"""Test Docker availability check doesn't crash."""
|
|
# This should not raise
|
|
result = _docker_available()
|
|
assert isinstance(result, bool)
|
|
|
|
def test_nats_container_running_check(self) -> None:
|
|
"""Test container running check doesn't crash."""
|
|
# This should not raise even if Docker not available
|
|
result = _nats_container_running()
|
|
assert isinstance(result, bool)
|