test(demo): add unit and integration coverage for pipeline
Introduce focused unit, integration, and NATS-path tests for demo modules, and align assertions with final schema and temporal contracts (window int, seq=30, fill-level ratio). This commit isolates validation logic from runtime changes and provides reproducible QA for pipeline behavior and failure modes.
This commit is contained in:
@@ -0,0 +1 @@
|
|||||||
|
"""OpenGait tests package."""
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
"""Tests for demo package."""
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
"""Test fixtures for demo package."""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_frame_tensor():
|
||||||
|
"""Return a mock video frame tensor (C, H, W)."""
|
||||||
|
return torch.randn(3, 224, 224)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_frame_array():
|
||||||
|
"""Return a mock video frame as numpy array (H, W, C)."""
|
||||||
|
return np.random.randn(224, 224, 3).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_video_sequence():
|
||||||
|
"""Return a mock video sequence tensor (T, C, H, W)."""
|
||||||
|
return torch.randn(16, 3, 224, 224)
|
||||||
@@ -0,0 +1,525 @@
|
|||||||
|
"""Integration tests for NATS publisher functionality.
|
||||||
|
|
||||||
|
Tests cover:
|
||||||
|
- NATS message receipt with JSON schema validation
|
||||||
|
- Skip behavior when Docker/NATS unavailable
|
||||||
|
- Container lifecycle management (cleanup)
|
||||||
|
- JSON schema field/type validation
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import socket
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
from typing import TYPE_CHECKING, cast
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Generator
|
||||||
|
|
||||||
|
|
||||||
|
def _find_open_port() -> int:
|
||||||
|
"""Find an available TCP port on localhost."""
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||||
|
sock.bind(("127.0.0.1", 0))
|
||||||
|
sock.listen(1)
|
||||||
|
addr = cast(tuple[str, int], sock.getsockname())
|
||||||
|
port: int = addr[1]
|
||||||
|
return port
|
||||||
|
|
||||||
|
|
||||||
|
# Constants for test configuration
|
||||||
|
NATS_SUBJECT = "scoliosis.result"
|
||||||
|
CONTAINER_NAME = "opengait-nats-test"
|
||||||
|
|
||||||
|
|
||||||
|
def _docker_available() -> bool:
|
||||||
|
"""Check if Docker is available and running."""
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
["docker", "info"],
|
||||||
|
capture_output=True,
|
||||||
|
timeout=5,
|
||||||
|
check=False,
|
||||||
|
)
|
||||||
|
return result.returncode == 0
|
||||||
|
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _nats_container_running() -> bool:
|
||||||
|
"""Check if NATS container is already running."""
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
[
|
||||||
|
"docker",
|
||||||
|
"ps",
|
||||||
|
"--filter",
|
||||||
|
f"name={CONTAINER_NAME}",
|
||||||
|
"--format",
|
||||||
|
"{{.Names}}",
|
||||||
|
],
|
||||||
|
capture_output=True,
|
||||||
|
timeout=5,
|
||||||
|
check=False,
|
||||||
|
)
|
||||||
|
return CONTAINER_NAME in result.stdout.decode()
|
||||||
|
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _start_nats_container(port: int) -> bool:
|
||||||
|
"""Start NATS container for testing."""
|
||||||
|
try:
|
||||||
|
# Remove existing container if present
|
||||||
|
_ = subprocess.run(
|
||||||
|
["docker", "rm", "-f", CONTAINER_NAME],
|
||||||
|
capture_output=True,
|
||||||
|
timeout=10,
|
||||||
|
check=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start new NATS container
|
||||||
|
result = subprocess.run(
|
||||||
|
[
|
||||||
|
"docker",
|
||||||
|
"run",
|
||||||
|
"-d",
|
||||||
|
"--name",
|
||||||
|
CONTAINER_NAME,
|
||||||
|
"-p",
|
||||||
|
f"{port}:{port}",
|
||||||
|
"nats:latest",
|
||||||
|
],
|
||||||
|
capture_output=True,
|
||||||
|
timeout=30,
|
||||||
|
check=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Wait for NATS to be ready (max 10 seconds)
|
||||||
|
for _ in range(20):
|
||||||
|
time.sleep(0.5)
|
||||||
|
try:
|
||||||
|
check_result = subprocess.run(
|
||||||
|
[
|
||||||
|
"docker",
|
||||||
|
"exec",
|
||||||
|
CONTAINER_NAME,
|
||||||
|
"nats",
|
||||||
|
"server",
|
||||||
|
"check",
|
||||||
|
"server",
|
||||||
|
],
|
||||||
|
capture_output=True,
|
||||||
|
timeout=5,
|
||||||
|
check=False,
|
||||||
|
)
|
||||||
|
if check_result.returncode == 0:
|
||||||
|
return True
|
||||||
|
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
return False
|
||||||
|
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _stop_nats_container() -> None:
|
||||||
|
"""Stop and remove NATS container."""
|
||||||
|
try:
|
||||||
|
_ = subprocess.run(
|
||||||
|
["docker", "rm", "-f", CONTAINER_NAME],
|
||||||
|
capture_output=True,
|
||||||
|
timeout=10,
|
||||||
|
check=False,
|
||||||
|
)
|
||||||
|
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_result_schema(data: dict[str, object]) -> tuple[bool, str]:
|
||||||
|
"""Validate JSON result schema.
|
||||||
|
|
||||||
|
Expected schema:
|
||||||
|
{
|
||||||
|
"frame": int,
|
||||||
|
"track_id": int,
|
||||||
|
"label": str (one of: "negative", "neutral", "positive"),
|
||||||
|
"confidence": float in [0, 1],
|
||||||
|
"window": list[int] (start, end),
|
||||||
|
"timestamp_ns": int
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
required_fields: set[str] = {
|
||||||
|
"frame",
|
||||||
|
"track_id",
|
||||||
|
"label",
|
||||||
|
"confidence",
|
||||||
|
"window",
|
||||||
|
"timestamp_ns",
|
||||||
|
}
|
||||||
|
missing = required_fields - set(data.keys())
|
||||||
|
if missing:
|
||||||
|
return False, f"Missing required fields: {missing}"
|
||||||
|
|
||||||
|
# Validate frame (int)
|
||||||
|
frame = data["frame"]
|
||||||
|
if not isinstance(frame, int):
|
||||||
|
return False, f"frame must be int, got {type(frame)}"
|
||||||
|
|
||||||
|
# Validate track_id (int)
|
||||||
|
track_id = data["track_id"]
|
||||||
|
if not isinstance(track_id, int):
|
||||||
|
return False, f"track_id must be int, got {type(track_id)}"
|
||||||
|
|
||||||
|
# Validate label (str in specific set)
|
||||||
|
label = data["label"]
|
||||||
|
valid_labels = {"negative", "neutral", "positive"}
|
||||||
|
if not isinstance(label, str) or label not in valid_labels:
|
||||||
|
return False, f"label must be one of {valid_labels}, got {label}"
|
||||||
|
|
||||||
|
# Validate confidence (float in [0, 1])
|
||||||
|
confidence = data["confidence"]
|
||||||
|
if not isinstance(confidence, (int, float)):
|
||||||
|
return False, f"confidence must be numeric, got {type(confidence)}"
|
||||||
|
if not 0.0 <= float(confidence) <= 1.0:
|
||||||
|
return False, f"confidence must be in [0, 1], got {confidence}"
|
||||||
|
|
||||||
|
# Validate window (list of 2 ints)
|
||||||
|
window = data["window"]
|
||||||
|
if not isinstance(window, list) or len(cast(list[object], window)) != 2:
|
||||||
|
return False, f"window must be list of 2 ints, got {window}"
|
||||||
|
window_list = cast(list[object], window)
|
||||||
|
if not all(isinstance(x, int) for x in window_list):
|
||||||
|
return False, f"window elements must be ints, got {window}"
|
||||||
|
|
||||||
|
# Validate timestamp_ns (int)
|
||||||
|
timestamp_ns = data["timestamp_ns"]
|
||||||
|
if not isinstance(timestamp_ns, int):
|
||||||
|
return False, f"timestamp_ns must be int, got {type(timestamp_ns)}"
|
||||||
|
|
||||||
|
return True, ""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def nats_server() -> "Generator[tuple[bool, int], None, None]":
|
||||||
|
"""Fixture to manage NATS container lifecycle.
|
||||||
|
|
||||||
|
Yields (True, port) if NATS is available, (False, 0) otherwise.
|
||||||
|
Cleans up container after tests.
|
||||||
|
"""
|
||||||
|
if not _docker_available():
|
||||||
|
yield False, 0
|
||||||
|
return
|
||||||
|
|
||||||
|
port = _find_open_port()
|
||||||
|
if not _start_nats_container(port):
|
||||||
|
yield False, 0
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield True, port
|
||||||
|
finally:
|
||||||
|
_stop_nats_container()
|
||||||
|
|
||||||
|
|
||||||
|
class TestNatsPublisherIntegration:
|
||||||
|
"""Integration tests for NATS publisher with live server."""
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not _docker_available(), reason="Docker not available")
|
||||||
|
def test_nats_message_receipt_and_schema_validation(
|
||||||
|
self, nats_server: tuple[bool, int]
|
||||||
|
) -> None:
|
||||||
|
"""Test that messages are received and match expected schema."""
|
||||||
|
server_available, port = nats_server
|
||||||
|
if not server_available:
|
||||||
|
pytest.skip("NATS server not available")
|
||||||
|
|
||||||
|
nats_url = f"nats://127.0.0.1:{port}"
|
||||||
|
|
||||||
|
# Import here to avoid issues if nats-py not installed
|
||||||
|
try:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import nats # type: ignore[import-untyped]
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("nats-py not installed")
|
||||||
|
|
||||||
|
from opengait.demo.output import NatsPublisher, create_result
|
||||||
|
|
||||||
|
# Create publisher
|
||||||
|
publisher = NatsPublisher(nats_url, subject=NATS_SUBJECT)
|
||||||
|
|
||||||
|
# Create test results
|
||||||
|
test_results = [
|
||||||
|
create_result(
|
||||||
|
frame=100,
|
||||||
|
track_id=1,
|
||||||
|
label="positive",
|
||||||
|
confidence=0.85,
|
||||||
|
window=(70, 100),
|
||||||
|
timestamp_ns=1234567890000,
|
||||||
|
),
|
||||||
|
create_result(
|
||||||
|
frame=130,
|
||||||
|
track_id=1,
|
||||||
|
label="negative",
|
||||||
|
confidence=0.92,
|
||||||
|
window=(100, 130),
|
||||||
|
timestamp_ns=1234567890030,
|
||||||
|
),
|
||||||
|
create_result(
|
||||||
|
frame=160,
|
||||||
|
track_id=2,
|
||||||
|
label="neutral",
|
||||||
|
confidence=0.78,
|
||||||
|
window=(130, 160),
|
||||||
|
timestamp_ns=1234567890060,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Collect received messages
|
||||||
|
received_messages: list[dict[str, object]] = []
|
||||||
|
|
||||||
|
async def subscribe_and_publish():
|
||||||
|
"""Subscribe to subject, publish messages, collect results."""
|
||||||
|
nc = await nats.connect(nats_url) # pyright: ignore[reportUnknownMemberType]
|
||||||
|
|
||||||
|
# Subscribe
|
||||||
|
sub = await nc.subscribe(NATS_SUBJECT) # pyright: ignore[reportUnknownMemberType]
|
||||||
|
|
||||||
|
# Publish all test results
|
||||||
|
for result in test_results:
|
||||||
|
publisher.publish(result)
|
||||||
|
|
||||||
|
# Wait for messages with timeout
|
||||||
|
for _ in range(len(test_results)):
|
||||||
|
try:
|
||||||
|
msg = await asyncio.wait_for(sub.next_msg(), timeout=5.0)
|
||||||
|
data = json.loads(msg.data.decode("utf-8")) # pyright: ignore[reportAny]
|
||||||
|
received_messages.append(cast(dict[str, object], data))
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
break
|
||||||
|
|
||||||
|
await sub.unsubscribe()
|
||||||
|
await nc.close()
|
||||||
|
|
||||||
|
# Run async subscriber
|
||||||
|
asyncio.run(subscribe_and_publish())
|
||||||
|
|
||||||
|
# Cleanup publisher
|
||||||
|
publisher.close()
|
||||||
|
|
||||||
|
# Verify all messages received
|
||||||
|
assert len(received_messages) == len(test_results), (
|
||||||
|
f"Expected {len(test_results)} messages, got {len(received_messages)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate schema for each message
|
||||||
|
for i, msg in enumerate(received_messages):
|
||||||
|
is_valid, error = _validate_result_schema(msg)
|
||||||
|
assert is_valid, f"Message {i} schema validation failed: {error}"
|
||||||
|
|
||||||
|
# Verify specific values
|
||||||
|
assert received_messages[0]["frame"] == 100
|
||||||
|
assert received_messages[0]["label"] == "positive"
|
||||||
|
assert received_messages[0]["track_id"] == 1
|
||||||
|
_conf = received_messages[0]["confidence"]
|
||||||
|
assert isinstance(_conf, (int, float))
|
||||||
|
assert 0.0 <= float(_conf) <= 1.0
|
||||||
|
|
||||||
|
assert received_messages[1]["label"] == "negative"
|
||||||
|
assert received_messages[2]["track_id"] == 2
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not _docker_available(), reason="Docker not available")
|
||||||
|
def test_nats_publisher_graceful_when_server_unavailable(self) -> None:
|
||||||
|
"""Test that publisher handles missing server gracefully."""
|
||||||
|
try:
|
||||||
|
from opengait.demo.output import NatsPublisher
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("output module not available")
|
||||||
|
|
||||||
|
# Use wrong port where no server is running
|
||||||
|
bad_url = "nats://127.0.0.1:14222"
|
||||||
|
publisher = NatsPublisher(bad_url, subject=NATS_SUBJECT)
|
||||||
|
|
||||||
|
# Should not raise when publishing without server
|
||||||
|
test_result: dict[str, object] = {
|
||||||
|
"frame": 1,
|
||||||
|
"track_id": 1,
|
||||||
|
"label": "positive",
|
||||||
|
"confidence": 0.85,
|
||||||
|
"window": [0, 30],
|
||||||
|
"timestamp_ns": 1234567890,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Should not raise
|
||||||
|
publisher.publish(test_result)
|
||||||
|
|
||||||
|
# Cleanup should also not raise
|
||||||
|
publisher.close()
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not _docker_available(), reason="Docker not available")
|
||||||
|
def test_nats_publisher_context_manager(
|
||||||
|
self, nats_server: tuple[bool, int]
|
||||||
|
) -> None:
|
||||||
|
"""Test that publisher works as context manager."""
|
||||||
|
server_available, port = nats_server
|
||||||
|
if not server_available:
|
||||||
|
pytest.skip("NATS server not available")
|
||||||
|
|
||||||
|
nats_url = f"nats://127.0.0.1:{port}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import nats # type: ignore[import-untyped]
|
||||||
|
from opengait.demo.output import NatsPublisher, create_result
|
||||||
|
except ImportError as e:
|
||||||
|
pytest.skip(f"Required module not available: {e}")
|
||||||
|
|
||||||
|
received_messages: list[dict[str, object]] = []
|
||||||
|
|
||||||
|
async def subscribe_and_test():
|
||||||
|
nc = await nats.connect(nats_url) # pyright: ignore[reportUnknownMemberType]
|
||||||
|
sub = await nc.subscribe(NATS_SUBJECT) # pyright: ignore[reportUnknownMemberType]
|
||||||
|
|
||||||
|
# Use context manager
|
||||||
|
with NatsPublisher(nats_url, subject=NATS_SUBJECT) as publisher:
|
||||||
|
result = create_result(
|
||||||
|
frame=200,
|
||||||
|
track_id=5,
|
||||||
|
label="neutral",
|
||||||
|
confidence=0.65,
|
||||||
|
window=(170, 200),
|
||||||
|
timestamp_ns=9999999999,
|
||||||
|
)
|
||||||
|
publisher.publish(result)
|
||||||
|
|
||||||
|
# Wait for message
|
||||||
|
try:
|
||||||
|
msg = await asyncio.wait_for(sub.next_msg(), timeout=5.0)
|
||||||
|
data = json.loads(msg.data.decode("utf-8")) # pyright: ignore[reportAny]
|
||||||
|
received_messages.append(cast(dict[str, object], data))
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
await sub.unsubscribe()
|
||||||
|
await nc.close()
|
||||||
|
|
||||||
|
asyncio.run(subscribe_and_test())
|
||||||
|
|
||||||
|
assert len(received_messages) == 1
|
||||||
|
assert received_messages[0]["frame"] == 200
|
||||||
|
assert received_messages[0]["track_id"] == 5
|
||||||
|
|
||||||
|
|
||||||
|
class TestNatsSchemaValidation:
|
||||||
|
"""Tests for JSON schema validation without requiring NATS server."""
|
||||||
|
|
||||||
|
def test_validate_result_schema_valid(self) -> None:
|
||||||
|
"""Test schema validation with valid data."""
|
||||||
|
valid_data: dict[str, object] = {
|
||||||
|
"frame": 1234,
|
||||||
|
"track_id": 42,
|
||||||
|
"label": "positive",
|
||||||
|
"confidence": 0.85,
|
||||||
|
"window": [1200, 1230],
|
||||||
|
"timestamp_ns": 1234567890000,
|
||||||
|
}
|
||||||
|
|
||||||
|
is_valid, error = _validate_result_schema(valid_data)
|
||||||
|
assert is_valid, f"Valid data rejected: {error}"
|
||||||
|
|
||||||
|
def test_validate_result_schema_invalid_label(self) -> None:
|
||||||
|
"""Test schema validation rejects invalid label."""
|
||||||
|
invalid_data: dict[str, object] = {
|
||||||
|
"frame": 1234,
|
||||||
|
"track_id": 42,
|
||||||
|
"label": "invalid_label",
|
||||||
|
"confidence": 0.85,
|
||||||
|
"window": [1200, 1230],
|
||||||
|
"timestamp_ns": 1234567890000,
|
||||||
|
}
|
||||||
|
|
||||||
|
is_valid, error = _validate_result_schema(invalid_data)
|
||||||
|
assert not is_valid
|
||||||
|
assert "label" in error.lower()
|
||||||
|
|
||||||
|
def test_validate_result_schema_confidence_out_of_range(self) -> None:
|
||||||
|
"""Test schema validation rejects confidence outside [0, 1]."""
|
||||||
|
invalid_data: dict[str, object] = {
|
||||||
|
"frame": 1234,
|
||||||
|
"track_id": 42,
|
||||||
|
"label": "positive",
|
||||||
|
"confidence": 1.5,
|
||||||
|
"window": [1200, 1230],
|
||||||
|
"timestamp_ns": 1234567890000,
|
||||||
|
}
|
||||||
|
|
||||||
|
is_valid, error = _validate_result_schema(invalid_data)
|
||||||
|
assert not is_valid
|
||||||
|
assert "confidence" in error.lower()
|
||||||
|
|
||||||
|
def test_validate_result_schema_missing_fields(self) -> None:
|
||||||
|
"""Test schema validation detects missing fields."""
|
||||||
|
incomplete_data: dict[str, object] = {
|
||||||
|
"frame": 1234,
|
||||||
|
"label": "positive",
|
||||||
|
}
|
||||||
|
|
||||||
|
is_valid, error = _validate_result_schema(incomplete_data)
|
||||||
|
assert not is_valid
|
||||||
|
assert "missing" in error.lower()
|
||||||
|
|
||||||
|
def test_validate_result_schema_wrong_types(self) -> None:
|
||||||
|
"""Test schema validation rejects wrong types."""
|
||||||
|
wrong_types: dict[str, object] = {
|
||||||
|
"frame": "not_an_int",
|
||||||
|
"track_id": 42,
|
||||||
|
"label": "positive",
|
||||||
|
"confidence": 0.85,
|
||||||
|
"window": [1200, 1230],
|
||||||
|
"timestamp_ns": 1234567890000,
|
||||||
|
}
|
||||||
|
|
||||||
|
is_valid, error = _validate_result_schema(wrong_types)
|
||||||
|
assert not is_valid
|
||||||
|
assert "frame" in error.lower()
|
||||||
|
|
||||||
|
def test_all_valid_labels_accepted(self) -> None:
|
||||||
|
"""Test that all valid labels are accepted."""
|
||||||
|
for label_str in ["negative", "neutral", "positive"]:
|
||||||
|
data: dict[str, object] = {
|
||||||
|
"frame": 100,
|
||||||
|
"track_id": 1,
|
||||||
|
"label": label_str,
|
||||||
|
"confidence": 0.5,
|
||||||
|
"window": [70, 100],
|
||||||
|
"timestamp_ns": 1234567890,
|
||||||
|
}
|
||||||
|
is_valid, error = _validate_result_schema(data)
|
||||||
|
assert is_valid, f"Valid label '{label_str}' rejected: {error}"
|
||||||
|
|
||||||
|
|
||||||
|
class TestDockerAvailability:
|
||||||
|
"""Tests for Docker availability detection."""
|
||||||
|
|
||||||
|
def test_docker_available_check(self) -> None:
|
||||||
|
"""Test Docker availability check doesn't crash."""
|
||||||
|
# This should not raise
|
||||||
|
result = _docker_available()
|
||||||
|
assert isinstance(result, bool)
|
||||||
|
|
||||||
|
def test_nats_container_running_check(self) -> None:
|
||||||
|
"""Test container running check doesn't crash."""
|
||||||
|
# This should not raise even if Docker not available
|
||||||
|
result = _nats_container_running()
|
||||||
|
assert isinstance(result, bool)
|
||||||
@@ -0,0 +1,279 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from typing import Final, cast
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from opengait.demo.sconet_demo import ScoNetDemo
|
||||||
|
|
||||||
|
REPO_ROOT: Final[Path] = Path(__file__).resolve().parents[2]
|
||||||
|
SAMPLE_VIDEO_PATH: Final[Path] = REPO_ROOT / "assets" / "sample.mp4"
|
||||||
|
CHECKPOINT_PATH: Final[Path] = REPO_ROOT / "ckpt" / "ScoNet-20000.pt"
|
||||||
|
CONFIG_PATH: Final[Path] = REPO_ROOT / "configs" / "sconet" / "sconet_scoliosis1k.yaml"
|
||||||
|
YOLO_MODEL_PATH: Final[Path] = REPO_ROOT / "yolo11n-seg.pt"
|
||||||
|
|
||||||
|
|
||||||
|
def _device_for_runtime() -> str:
|
||||||
|
return "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def _run_pipeline_cli(
|
||||||
|
*args: str, timeout_seconds: int = 120
|
||||||
|
) -> subprocess.CompletedProcess[str]:
|
||||||
|
command = [sys.executable, "-m", "opengait.demo", *args]
|
||||||
|
return subprocess.run(
|
||||||
|
command,
|
||||||
|
cwd=REPO_ROOT,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
check=False,
|
||||||
|
timeout=timeout_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _require_integration_assets() -> None:
|
||||||
|
if not SAMPLE_VIDEO_PATH.is_file():
|
||||||
|
pytest.skip(f"Missing sample video: {SAMPLE_VIDEO_PATH}")
|
||||||
|
if not CONFIG_PATH.is_file():
|
||||||
|
pytest.skip(f"Missing config: {CONFIG_PATH}")
|
||||||
|
if not YOLO_MODEL_PATH.is_file():
|
||||||
|
pytest.skip(f"Missing YOLO model file: {YOLO_MODEL_PATH}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def compatible_checkpoint_path(tmp_path: Path) -> Path:
|
||||||
|
if not CONFIG_PATH.is_file():
|
||||||
|
pytest.skip(f"Missing config: {CONFIG_PATH}")
|
||||||
|
|
||||||
|
checkpoint_file = tmp_path / "sconet-compatible.pt"
|
||||||
|
model = ScoNetDemo(cfg_path=str(CONFIG_PATH), checkpoint_path=None, device="cpu")
|
||||||
|
torch.save(model.state_dict(), checkpoint_file)
|
||||||
|
return checkpoint_file
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_prediction_json_lines(stdout: str) -> list[dict[str, object]]:
|
||||||
|
required_keys = {
|
||||||
|
"frame",
|
||||||
|
"track_id",
|
||||||
|
"label",
|
||||||
|
"confidence",
|
||||||
|
"window",
|
||||||
|
"timestamp_ns",
|
||||||
|
}
|
||||||
|
predictions: list[dict[str, object]] = []
|
||||||
|
|
||||||
|
for line in stdout.splitlines():
|
||||||
|
stripped = line.strip()
|
||||||
|
if not stripped:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
payload_obj = cast(object, json.loads(stripped))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not isinstance(payload_obj, dict):
|
||||||
|
continue
|
||||||
|
payload = cast(dict[str, object], payload_obj)
|
||||||
|
if required_keys.issubset(payload.keys()):
|
||||||
|
predictions.append(payload)
|
||||||
|
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_prediction_schema(prediction: dict[str, object]) -> None:
|
||||||
|
assert isinstance(prediction["frame"], int)
|
||||||
|
assert isinstance(prediction["track_id"], int)
|
||||||
|
|
||||||
|
label = prediction["label"]
|
||||||
|
assert isinstance(label, str)
|
||||||
|
assert label in {"negative", "neutral", "positive"}
|
||||||
|
|
||||||
|
confidence = prediction["confidence"]
|
||||||
|
assert isinstance(confidence, (int, float))
|
||||||
|
confidence_value = float(confidence)
|
||||||
|
assert 0.0 <= confidence_value <= 1.0
|
||||||
|
|
||||||
|
window_obj = prediction["window"]
|
||||||
|
assert isinstance(window_obj, int)
|
||||||
|
assert window_obj >= 0
|
||||||
|
|
||||||
|
assert isinstance(prediction["timestamp_ns"], int)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipeline_cli_fps_benchmark_smoke(
|
||||||
|
compatible_checkpoint_path: Path,
|
||||||
|
) -> None:
|
||||||
|
_require_integration_assets()
|
||||||
|
|
||||||
|
max_frames = 90
|
||||||
|
started_at = time.perf_counter()
|
||||||
|
result = _run_pipeline_cli(
|
||||||
|
"--source",
|
||||||
|
str(SAMPLE_VIDEO_PATH),
|
||||||
|
"--checkpoint",
|
||||||
|
str(compatible_checkpoint_path),
|
||||||
|
"--config",
|
||||||
|
str(CONFIG_PATH),
|
||||||
|
"--device",
|
||||||
|
_device_for_runtime(),
|
||||||
|
"--yolo-model",
|
||||||
|
str(YOLO_MODEL_PATH),
|
||||||
|
"--window",
|
||||||
|
"5",
|
||||||
|
"--stride",
|
||||||
|
"1",
|
||||||
|
"--max-frames",
|
||||||
|
str(max_frames),
|
||||||
|
timeout_seconds=180,
|
||||||
|
)
|
||||||
|
elapsed_seconds = time.perf_counter() - started_at
|
||||||
|
|
||||||
|
assert result.returncode == 0, (
|
||||||
|
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
|
||||||
|
)
|
||||||
|
predictions = _extract_prediction_json_lines(result.stdout)
|
||||||
|
assert predictions, "Expected prediction output for FPS benchmark run"
|
||||||
|
|
||||||
|
for prediction in predictions:
|
||||||
|
_assert_prediction_schema(prediction)
|
||||||
|
|
||||||
|
observed_frames = {
|
||||||
|
frame_obj
|
||||||
|
for prediction in predictions
|
||||||
|
for frame_obj in [prediction["frame"]]
|
||||||
|
if isinstance(frame_obj, int)
|
||||||
|
}
|
||||||
|
observed_units = len(observed_frames)
|
||||||
|
if observed_units < 5:
|
||||||
|
pytest.skip(
|
||||||
|
"Insufficient observed frame samples for stable FPS benchmark in this environment"
|
||||||
|
)
|
||||||
|
if elapsed_seconds <= 0:
|
||||||
|
pytest.skip("Non-positive elapsed time; cannot compute FPS benchmark")
|
||||||
|
|
||||||
|
fps = observed_units / elapsed_seconds
|
||||||
|
min_expected_fps = 0.2
|
||||||
|
assert fps >= min_expected_fps, (
|
||||||
|
"Observed FPS below conservative CI threshold: "
|
||||||
|
f"{fps:.3f} < {min_expected_fps:.3f} "
|
||||||
|
f"(observed_units={observed_units}, elapsed_seconds={elapsed_seconds:.3f})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipeline_cli_happy_path_outputs_json_predictions(
|
||||||
|
compatible_checkpoint_path: Path,
|
||||||
|
) -> None:
|
||||||
|
_require_integration_assets()
|
||||||
|
|
||||||
|
result = _run_pipeline_cli(
|
||||||
|
"--source",
|
||||||
|
str(SAMPLE_VIDEO_PATH),
|
||||||
|
"--checkpoint",
|
||||||
|
str(compatible_checkpoint_path),
|
||||||
|
"--config",
|
||||||
|
str(CONFIG_PATH),
|
||||||
|
"--device",
|
||||||
|
_device_for_runtime(),
|
||||||
|
"--yolo-model",
|
||||||
|
str(YOLO_MODEL_PATH),
|
||||||
|
"--window",
|
||||||
|
"10",
|
||||||
|
"--stride",
|
||||||
|
"10",
|
||||||
|
"--max-frames",
|
||||||
|
"120",
|
||||||
|
timeout_seconds=180,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.returncode == 0, (
|
||||||
|
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
|
||||||
|
)
|
||||||
|
predictions = _extract_prediction_json_lines(result.stdout)
|
||||||
|
assert predictions, (
|
||||||
|
"Expected at least one prediction JSON line in stdout. "
|
||||||
|
f"stdout:\n{result.stdout}\nstderr:\n{result.stderr}"
|
||||||
|
)
|
||||||
|
for prediction in predictions:
|
||||||
|
_assert_prediction_schema(prediction)
|
||||||
|
|
||||||
|
assert "Connected to NATS" not in result.stderr
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipeline_cli_max_frames_caps_output_frames(
|
||||||
|
compatible_checkpoint_path: Path,
|
||||||
|
) -> None:
|
||||||
|
_require_integration_assets()
|
||||||
|
|
||||||
|
max_frames = 20
|
||||||
|
result = _run_pipeline_cli(
|
||||||
|
"--source",
|
||||||
|
str(SAMPLE_VIDEO_PATH),
|
||||||
|
"--checkpoint",
|
||||||
|
str(compatible_checkpoint_path),
|
||||||
|
"--config",
|
||||||
|
str(CONFIG_PATH),
|
||||||
|
"--device",
|
||||||
|
_device_for_runtime(),
|
||||||
|
"--yolo-model",
|
||||||
|
str(YOLO_MODEL_PATH),
|
||||||
|
"--window",
|
||||||
|
"5",
|
||||||
|
"--stride",
|
||||||
|
"1",
|
||||||
|
"--max-frames",
|
||||||
|
str(max_frames),
|
||||||
|
timeout_seconds=180,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.returncode == 0, (
|
||||||
|
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
|
||||||
|
)
|
||||||
|
predictions = _extract_prediction_json_lines(result.stdout)
|
||||||
|
assert predictions, "Expected prediction output with --max-frames run"
|
||||||
|
|
||||||
|
for prediction in predictions:
|
||||||
|
_assert_prediction_schema(prediction)
|
||||||
|
frame_idx_obj = prediction["frame"]
|
||||||
|
assert isinstance(frame_idx_obj, int)
|
||||||
|
assert frame_idx_obj < max_frames
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipeline_cli_invalid_source_path_returns_user_error() -> None:
|
||||||
|
result = _run_pipeline_cli(
|
||||||
|
"--source",
|
||||||
|
"/definitely/not/a/real/video.mp4",
|
||||||
|
"--checkpoint",
|
||||||
|
"/tmp/unused-checkpoint.pt",
|
||||||
|
"--config",
|
||||||
|
str(CONFIG_PATH),
|
||||||
|
timeout_seconds=30,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.returncode == 2
|
||||||
|
assert "Error: Video source not found" in result.stderr
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipeline_cli_invalid_checkpoint_path_returns_user_error() -> None:
|
||||||
|
if not SAMPLE_VIDEO_PATH.is_file():
|
||||||
|
pytest.skip(f"Missing sample video: {SAMPLE_VIDEO_PATH}")
|
||||||
|
if not CONFIG_PATH.is_file():
|
||||||
|
pytest.skip(f"Missing config: {CONFIG_PATH}")
|
||||||
|
|
||||||
|
result = _run_pipeline_cli(
|
||||||
|
"--source",
|
||||||
|
str(SAMPLE_VIDEO_PATH),
|
||||||
|
"--checkpoint",
|
||||||
|
str(REPO_ROOT / "ckpt" / "missing-checkpoint.pt"),
|
||||||
|
"--config",
|
||||||
|
str(CONFIG_PATH),
|
||||||
|
timeout_seconds=30,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.returncode == 2
|
||||||
|
assert "Error: Checkpoint not found" in result.stderr
|
||||||
@@ -0,0 +1,153 @@
|
|||||||
|
"""Unit tests for silhouette preprocessing functions."""
|
||||||
|
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
import pytest
|
||||||
|
from beartype.roar import BeartypeCallHintParamViolation
|
||||||
|
from jaxtyping import TypeCheckError
|
||||||
|
|
||||||
|
from opengait.demo.preprocess import mask_to_silhouette
|
||||||
|
|
||||||
|
|
||||||
|
class TestMaskToSilhouette:
|
||||||
|
"""Tests for mask_to_silhouette() function."""
|
||||||
|
|
||||||
|
def test_valid_mask_returns_correct_shape_dtype_and_range(self) -> None:
|
||||||
|
"""Valid mask should return (64, 44) float32 array in [0, 1] range."""
|
||||||
|
# Create a synthetic mask with sufficient area (person-shaped blob)
|
||||||
|
h, w = 200, 150
|
||||||
|
mask = np.zeros((h, w), dtype=np.uint8)
|
||||||
|
# Draw a filled ellipse to simulate a person
|
||||||
|
center_y, center_x = h // 2, w // 2
|
||||||
|
axes_y, axes_x = h // 3, w // 4
|
||||||
|
y, x = np.ogrid[:h, :w]
|
||||||
|
ellipse_mask = ((x - center_x) / axes_x) ** 2 + (
|
||||||
|
(y - center_y) / axes_y
|
||||||
|
) ** 2 <= 1
|
||||||
|
mask[ellipse_mask] = 255
|
||||||
|
|
||||||
|
bbox = (w // 4, h // 6, 3 * w // 4, 5 * h // 6)
|
||||||
|
|
||||||
|
result = mask_to_silhouette(mask, bbox)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
result_arr = cast(NDArray[np.float32], result)
|
||||||
|
assert result_arr.shape == (64, 44)
|
||||||
|
assert result_arr.dtype == np.float32
|
||||||
|
assert np.all(result_arr >= 0.0) and np.all(result_arr <= 1.0)
|
||||||
|
|
||||||
|
def test_tiny_mask_returns_none(self) -> None:
|
||||||
|
"""Mask with area below MIN_MASK_AREA should return None."""
|
||||||
|
# Create a tiny mask
|
||||||
|
mask = np.zeros((50, 50), dtype=np.uint8)
|
||||||
|
mask[20:22, 20:22] = 255 # Only 4 pixels, well below MIN_MASK_AREA (500)
|
||||||
|
|
||||||
|
bbox = (20, 20, 22, 22)
|
||||||
|
|
||||||
|
result = mask_to_silhouette(mask, bbox)
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_empty_mask_returns_none(self) -> None:
|
||||||
|
"""Completely empty mask should return None."""
|
||||||
|
mask = np.zeros((100, 100), dtype=np.uint8)
|
||||||
|
bbox = (10, 10, 90, 90)
|
||||||
|
|
||||||
|
result = mask_to_silhouette(mask, bbox)
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_full_frame_mask_returns_valid_output(self) -> None:
|
||||||
|
"""Full-frame mask (large bbox covering entire image) should work."""
|
||||||
|
h, w = 300, 200
|
||||||
|
mask = np.zeros((h, w), dtype=np.uint8)
|
||||||
|
# Create a large filled region
|
||||||
|
mask[50:250, 30:170] = 255
|
||||||
|
|
||||||
|
# Full frame bbox
|
||||||
|
bbox = (0, 0, w, h)
|
||||||
|
|
||||||
|
result = mask_to_silhouette(mask, bbox)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
result_arr = cast(NDArray[np.float32], result)
|
||||||
|
assert result_arr.shape == (64, 44)
|
||||||
|
assert result_arr.dtype == np.float32
|
||||||
|
|
||||||
|
def test_determinism_same_input_same_output(self) -> None:
|
||||||
|
"""Same input should always produce same output."""
|
||||||
|
h, w = 200, 150
|
||||||
|
mask = np.zeros((h, w), dtype=np.uint8)
|
||||||
|
# Create a person-shaped region
|
||||||
|
mask[50:150, 40:110] = 255
|
||||||
|
|
||||||
|
bbox = (40, 50, 110, 150)
|
||||||
|
|
||||||
|
result1 = mask_to_silhouette(mask, bbox)
|
||||||
|
result2 = mask_to_silhouette(mask, bbox)
|
||||||
|
result3 = mask_to_silhouette(mask, bbox)
|
||||||
|
|
||||||
|
assert result1 is not None
|
||||||
|
assert result2 is not None
|
||||||
|
assert result3 is not None
|
||||||
|
np.testing.assert_array_equal(result1, result2)
|
||||||
|
np.testing.assert_array_equal(result2, result3)
|
||||||
|
|
||||||
|
def test_tall_narrow_mask_valid_output(self) -> None:
|
||||||
|
"""Tall narrow mask should produce valid silhouette."""
|
||||||
|
h, w = 400, 50
|
||||||
|
mask = np.zeros((h, w), dtype=np.uint8)
|
||||||
|
# Tall narrow person
|
||||||
|
mask[50:350, 10:40] = 255
|
||||||
|
|
||||||
|
bbox = (10, 50, 40, 350)
|
||||||
|
|
||||||
|
result = mask_to_silhouette(mask, bbox)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
result_arr = cast(NDArray[np.float32], result)
|
||||||
|
assert result_arr.shape == (64, 44)
|
||||||
|
|
||||||
|
def test_wide_short_mask_valid_output(self) -> None:
|
||||||
|
"""Wide short mask should produce valid silhouette."""
|
||||||
|
h, w = 100, 400
|
||||||
|
mask = np.zeros((h, w), dtype=np.uint8)
|
||||||
|
# Wide short person
|
||||||
|
mask[20:80, 50:350] = 255
|
||||||
|
|
||||||
|
bbox = (50, 20, 350, 80)
|
||||||
|
|
||||||
|
result = mask_to_silhouette(mask, bbox)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
result_arr = cast(NDArray[np.float32], result)
|
||||||
|
assert result_arr.shape == (64, 44)
|
||||||
|
|
||||||
|
def test_beartype_rejects_wrong_dtype(self) -> None:
|
||||||
|
"""Beartype should reject non-uint8 input."""
|
||||||
|
# Float array instead of uint8
|
||||||
|
mask = np.ones((100, 100), dtype=np.float32) * 255
|
||||||
|
bbox = (10, 10, 90, 90)
|
||||||
|
|
||||||
|
with pytest.raises((BeartypeCallHintParamViolation, TypeCheckError)):
|
||||||
|
_ = mask_to_silhouette(mask, bbox)
|
||||||
|
|
||||||
|
def test_beartype_rejects_wrong_ndim(self) -> None:
|
||||||
|
"""Beartype should reject non-2D array."""
|
||||||
|
# 3D array instead of 2D
|
||||||
|
mask = np.ones((100, 100, 3), dtype=np.uint8) * 255
|
||||||
|
bbox = (10, 10, 90, 90)
|
||||||
|
|
||||||
|
with pytest.raises((BeartypeCallHintParamViolation, TypeCheckError)):
|
||||||
|
_ = mask_to_silhouette(mask, bbox)
|
||||||
|
|
||||||
|
def test_beartype_rejects_wrong_bbox_type(self) -> None:
|
||||||
|
"""Beartype should reject non-tuple bbox."""
|
||||||
|
mask = np.ones((100, 100), dtype=np.uint8) * 255
|
||||||
|
# List instead of tuple
|
||||||
|
bbox = [10, 10, 90, 90]
|
||||||
|
|
||||||
|
with pytest.raises((BeartypeCallHintParamViolation, TypeCheckError)):
|
||||||
|
_ = mask_to_silhouette(mask, bbox)
|
||||||
@@ -0,0 +1,341 @@
|
|||||||
|
"""Unit tests for ScoNetDemo forward pass.
|
||||||
|
|
||||||
|
Tests cover:
|
||||||
|
- Construction from config/checkpoint path handling
|
||||||
|
- Forward output shape (N, 3, 16) and dtype float
|
||||||
|
- Predict output (label_str, confidence_float) with valid label/range
|
||||||
|
- No-DDP leakage check (no torch.distributed calls in unit behavior)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, cast
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from opengait.demo.sconet_demo import ScoNetDemo
|
||||||
|
|
||||||
|
# Constants for test configuration
|
||||||
|
CONFIG_PATH = Path("configs/sconet/sconet_scoliosis1k.yaml")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def demo() -> "ScoNetDemo":
|
||||||
|
"""Create ScoNetDemo without loading checkpoint (CPU-only)."""
|
||||||
|
from opengait.demo.sconet_demo import ScoNetDemo
|
||||||
|
|
||||||
|
return ScoNetDemo(
|
||||||
|
cfg_path=str(CONFIG_PATH),
|
||||||
|
checkpoint_path=None,
|
||||||
|
device="cpu",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dummy_sils_batch() -> Tensor:
|
||||||
|
"""Return dummy silhouette tensor of shape (N, 1, S, 64, 44)."""
|
||||||
|
return torch.randn(2, 1, 30, 64, 44)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dummy_sils_single() -> Tensor:
|
||||||
|
"""Return dummy silhouette tensor of shape (1, 1, S, 64, 44) for predict."""
|
||||||
|
return torch.randn(1, 1, 30, 64, 44)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def synthetic_state_dict() -> dict[str, Tensor]:
|
||||||
|
"""Return a synthetic state dict compatible with ScoNetDemo structure."""
|
||||||
|
return {
|
||||||
|
"backbone.conv1.conv.weight": torch.randn(64, 1, 3, 3),
|
||||||
|
"backbone.conv1.bn.weight": torch.ones(64),
|
||||||
|
"backbone.conv1.bn.bias": torch.zeros(64),
|
||||||
|
"backbone.conv1.bn.running_mean": torch.zeros(64),
|
||||||
|
"backbone.conv1.bn.running_var": torch.ones(64),
|
||||||
|
"fcs.fc_bin": torch.randn(16, 512, 256),
|
||||||
|
"bn_necks.fc_bin": torch.randn(16, 256, 3),
|
||||||
|
"bn_necks.bn1d.weight": torch.ones(4096),
|
||||||
|
"bn_necks.bn1d.bias": torch.zeros(4096),
|
||||||
|
"bn_necks.bn1d.running_mean": torch.zeros(4096),
|
||||||
|
"bn_necks.bn1d.running_var": torch.ones(4096),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestScoNetDemoConstruction:
|
||||||
|
"""Tests for ScoNetDemo construction and path handling."""
|
||||||
|
|
||||||
|
def test_construction_from_config_no_checkpoint(self) -> None:
|
||||||
|
"""Test construction with config only, no checkpoint."""
|
||||||
|
from opengait.demo.sconet_demo import ScoNetDemo
|
||||||
|
|
||||||
|
demo = ScoNetDemo(
|
||||||
|
cfg_path=str(CONFIG_PATH),
|
||||||
|
checkpoint_path=None,
|
||||||
|
device="cpu",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert demo.cfg_path.endswith("sconet_scoliosis1k.yaml")
|
||||||
|
assert demo.device == torch.device("cpu")
|
||||||
|
assert demo.cfg is not None
|
||||||
|
assert "model_cfg" in demo.cfg
|
||||||
|
assert demo.training is False # eval mode
|
||||||
|
|
||||||
|
def test_construction_with_relative_path(self) -> None:
|
||||||
|
"""Test construction handles relative config path correctly."""
|
||||||
|
from opengait.demo.sconet_demo import ScoNetDemo
|
||||||
|
|
||||||
|
demo = ScoNetDemo(
|
||||||
|
cfg_path="configs/sconet/sconet_scoliosis1k.yaml",
|
||||||
|
checkpoint_path=None,
|
||||||
|
device="cpu",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert demo.cfg is not None
|
||||||
|
assert demo.backbone is not None
|
||||||
|
|
||||||
|
def test_construction_invalid_config_raises(self) -> None:
|
||||||
|
"""Test construction raises with invalid config path."""
|
||||||
|
from opengait.demo.sconet_demo import ScoNetDemo
|
||||||
|
|
||||||
|
with pytest.raises((FileNotFoundError, TypeError)):
|
||||||
|
_ = ScoNetDemo(
|
||||||
|
cfg_path="/nonexistent/path/config.yaml",
|
||||||
|
checkpoint_path=None,
|
||||||
|
device="cpu",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestScoNetDemoForward:
|
||||||
|
"""Tests for ScoNetDemo forward pass."""
|
||||||
|
|
||||||
|
def test_forward_output_shape_and_dtype(
|
||||||
|
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
|
||||||
|
) -> None:
|
||||||
|
"""Test forward returns logits with shape (N, 3, 16) and correct dtypes."""
|
||||||
|
outputs_raw = demo.forward(dummy_sils_batch)
|
||||||
|
outputs = cast(dict[str, Tensor], outputs_raw)
|
||||||
|
|
||||||
|
assert "logits" in outputs
|
||||||
|
logits = outputs["logits"]
|
||||||
|
|
||||||
|
# Expected shape: (batch_size, num_classes, parts_num) = (N, 3, 16)
|
||||||
|
assert logits.shape == (2, 3, 16)
|
||||||
|
assert logits.dtype == torch.float32
|
||||||
|
assert outputs["label"].dtype == torch.int64
|
||||||
|
assert outputs["confidence"].dtype == torch.float32
|
||||||
|
|
||||||
|
def test_forward_returns_required_keys(
|
||||||
|
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
|
||||||
|
) -> None:
|
||||||
|
"""Test forward returns required output keys."""
|
||||||
|
outputs_raw = demo.forward(dummy_sils_batch)
|
||||||
|
outputs = cast(dict[str, Tensor], outputs_raw)
|
||||||
|
|
||||||
|
required_keys = {"logits", "label", "confidence"}
|
||||||
|
assert set(outputs.keys()) >= required_keys
|
||||||
|
|
||||||
|
def test_forward_batch_size_one(
|
||||||
|
self, demo: "ScoNetDemo", dummy_sils_single: Tensor
|
||||||
|
) -> None:
|
||||||
|
"""Test forward works with batch size 1."""
|
||||||
|
outputs_raw = demo.forward(dummy_sils_single)
|
||||||
|
outputs = cast(dict[str, Tensor], outputs_raw)
|
||||||
|
|
||||||
|
assert outputs["logits"].shape == (1, 3, 16)
|
||||||
|
assert outputs["label"].shape == (1,)
|
||||||
|
assert outputs["confidence"].shape == (1,)
|
||||||
|
|
||||||
|
def test_forward_label_range(
|
||||||
|
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
|
||||||
|
) -> None:
|
||||||
|
"""Test forward returns valid label indices (0, 1, or 2)."""
|
||||||
|
outputs_raw = demo.forward(dummy_sils_batch)
|
||||||
|
outputs = cast(dict[str, Tensor], outputs_raw)
|
||||||
|
|
||||||
|
labels = outputs["label"]
|
||||||
|
assert torch.all(labels >= 0)
|
||||||
|
assert torch.all(labels <= 2)
|
||||||
|
|
||||||
|
def test_forward_confidence_range(
|
||||||
|
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
|
||||||
|
) -> None:
|
||||||
|
"""Test forward returns confidence in [0, 1]."""
|
||||||
|
outputs_raw = demo.forward(dummy_sils_batch)
|
||||||
|
outputs = cast(dict[str, Tensor], outputs_raw)
|
||||||
|
|
||||||
|
confidence = outputs["confidence"]
|
||||||
|
assert torch.all(confidence >= 0.0)
|
||||||
|
assert torch.all(confidence <= 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
class TestScoNetDemoPredict:
|
||||||
|
"""Tests for ScoNetDemo predict method."""
|
||||||
|
|
||||||
|
def test_predict_returns_tuple_with_valid_types(
|
||||||
|
self, demo: "ScoNetDemo", dummy_sils_single: Tensor
|
||||||
|
) -> None:
|
||||||
|
"""Test predict returns (str, float) tuple with valid label."""
|
||||||
|
from opengait.demo.sconet_demo import ScoNetDemo
|
||||||
|
|
||||||
|
result_raw = demo.predict(dummy_sils_single)
|
||||||
|
result = cast(tuple[str, float], result_raw)
|
||||||
|
|
||||||
|
assert isinstance(result, tuple)
|
||||||
|
assert len(result) == 2
|
||||||
|
label, confidence = result
|
||||||
|
assert isinstance(label, str)
|
||||||
|
assert isinstance(confidence, float)
|
||||||
|
|
||||||
|
valid_labels = set(ScoNetDemo.LABEL_MAP.values())
|
||||||
|
assert label in valid_labels
|
||||||
|
|
||||||
|
def test_predict_confidence_range(
|
||||||
|
self, demo: "ScoNetDemo", dummy_sils_single: Tensor
|
||||||
|
) -> None:
|
||||||
|
"""Test predict returns confidence in valid range [0, 1]."""
|
||||||
|
result_raw = demo.predict(dummy_sils_single)
|
||||||
|
result = cast(tuple[str, float], result_raw)
|
||||||
|
confidence = result[1]
|
||||||
|
|
||||||
|
assert 0.0 <= confidence <= 1.0
|
||||||
|
|
||||||
|
def test_predict_rejects_batch_size_greater_than_one(
|
||||||
|
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
|
||||||
|
) -> None:
|
||||||
|
"""Test predict raises ValueError for batch size > 1."""
|
||||||
|
with pytest.raises(ValueError, match="batch size 1"):
|
||||||
|
_ = demo.predict(dummy_sils_batch)
|
||||||
|
|
||||||
|
|
||||||
|
class TestScoNetDemoNoDDP:
|
||||||
|
"""Tests to verify no DDP leakage in unit behavior."""
|
||||||
|
|
||||||
|
def test_no_distributed_init_in_construction(self) -> None:
|
||||||
|
"""Test that construction does not call torch.distributed."""
|
||||||
|
from opengait.demo.sconet_demo import ScoNetDemo
|
||||||
|
|
||||||
|
with patch("torch.distributed.is_initialized") as mock_is_init:
|
||||||
|
with patch("torch.distributed.init_process_group") as mock_init_pg:
|
||||||
|
_ = ScoNetDemo(
|
||||||
|
cfg_path=str(CONFIG_PATH),
|
||||||
|
checkpoint_path=None,
|
||||||
|
device="cpu",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_init_pg.assert_not_called()
|
||||||
|
mock_is_init.assert_not_called()
|
||||||
|
|
||||||
|
def test_forward_no_distributed_calls(
|
||||||
|
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
|
||||||
|
) -> None:
|
||||||
|
"""Test forward pass does not call torch.distributed."""
|
||||||
|
with patch("torch.distributed.all_reduce") as mock_all_reduce:
|
||||||
|
with patch("torch.distributed.broadcast") as mock_broadcast:
|
||||||
|
_ = demo.forward(dummy_sils_batch)
|
||||||
|
|
||||||
|
mock_all_reduce.assert_not_called()
|
||||||
|
mock_broadcast.assert_not_called()
|
||||||
|
|
||||||
|
def test_predict_no_distributed_calls(
|
||||||
|
self, demo: "ScoNetDemo", dummy_sils_single: Tensor
|
||||||
|
) -> None:
|
||||||
|
"""Test predict does not call torch.distributed."""
|
||||||
|
with patch("torch.distributed.all_reduce") as mock_all_reduce:
|
||||||
|
with patch("torch.distributed.broadcast") as mock_broadcast:
|
||||||
|
_ = demo.predict(dummy_sils_single)
|
||||||
|
|
||||||
|
mock_all_reduce.assert_not_called()
|
||||||
|
mock_broadcast.assert_not_called()
|
||||||
|
|
||||||
|
def test_model_not_wrapped_in_ddp(self, demo: "ScoNetDemo") -> None:
|
||||||
|
"""Test model is not wrapped in DistributedDataParallel."""
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
|
assert not isinstance(demo, DDP)
|
||||||
|
assert not isinstance(demo.backbone, DDP)
|
||||||
|
|
||||||
|
def test_device_is_cpu(self, demo: "ScoNetDemo") -> None:
|
||||||
|
"""Test model stays on CPU when device="cpu" specified."""
|
||||||
|
assert demo.device.type == "cpu"
|
||||||
|
|
||||||
|
for param in demo.parameters():
|
||||||
|
assert param.device.type == "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
class TestScoNetDemoCheckpointLoading:
|
||||||
|
"""Tests for checkpoint loading behavior using synthetic state dict."""
|
||||||
|
|
||||||
|
def test_load_checkpoint_changes_weights(
|
||||||
|
self,
|
||||||
|
demo: "ScoNetDemo",
|
||||||
|
synthetic_state_dict: dict[str, Tensor],
|
||||||
|
) -> None:
|
||||||
|
"""Test loading checkpoint actually changes model weights."""
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Get initial weight
|
||||||
|
initial_weight = next(iter(demo.parameters())).clone()
|
||||||
|
|
||||||
|
# Create temp checkpoint file with synthetic state dict
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f:
|
||||||
|
torch.save(synthetic_state_dict, f.name)
|
||||||
|
temp_path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
_ = demo.load_checkpoint(temp_path, strict=False)
|
||||||
|
new_weight = next(iter(demo.parameters()))
|
||||||
|
assert not torch.equal(initial_weight, new_weight)
|
||||||
|
finally:
|
||||||
|
os.unlink(temp_path)
|
||||||
|
|
||||||
|
def test_load_checkpoint_sets_eval_mode(
|
||||||
|
self,
|
||||||
|
demo: "ScoNetDemo",
|
||||||
|
synthetic_state_dict: dict[str, Tensor],
|
||||||
|
) -> None:
|
||||||
|
"""Test loading checkpoint sets model to eval mode."""
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
|
||||||
|
_ = demo.train()
|
||||||
|
assert demo.training is True
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f:
|
||||||
|
torch.save(synthetic_state_dict, f.name)
|
||||||
|
temp_path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
_ = demo.load_checkpoint(temp_path, strict=False)
|
||||||
|
assert demo.training is False
|
||||||
|
finally:
|
||||||
|
os.unlink(temp_path)
|
||||||
|
|
||||||
|
def test_load_checkpoint_invalid_path_raises(self, demo: "ScoNetDemo") -> None:
|
||||||
|
"""Test loading from invalid checkpoint path raises error."""
|
||||||
|
with pytest.raises(FileNotFoundError):
|
||||||
|
_ = demo.load_checkpoint("/nonexistent/checkpoint.pt")
|
||||||
|
|
||||||
|
|
||||||
|
class TestScoNetDemoLabelMap:
|
||||||
|
"""Tests for LABEL_MAP constant."""
|
||||||
|
|
||||||
|
def test_label_map_has_three_classes(self) -> None:
|
||||||
|
"""Test LABEL_MAP has exactly 3 classes."""
|
||||||
|
from opengait.demo.sconet_demo import ScoNetDemo
|
||||||
|
|
||||||
|
assert len(ScoNetDemo.LABEL_MAP) == 3
|
||||||
|
assert set(ScoNetDemo.LABEL_MAP.keys()) == {0, 1, 2}
|
||||||
|
|
||||||
|
def test_label_map_values_are_valid_strings(self) -> None:
|
||||||
|
"""Test LABEL_MAP values are valid non-empty strings."""
|
||||||
|
from opengait.demo.sconet_demo import ScoNetDemo
|
||||||
|
|
||||||
|
for value in ScoNetDemo.LABEL_MAP.values():
|
||||||
|
assert isinstance(value, str)
|
||||||
|
assert len(value) > 0
|
||||||
@@ -0,0 +1,346 @@
|
|||||||
|
"""Unit tests for SilhouetteWindow and select_person functions."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
|
from opengait.demo.window import SilhouetteWindow, select_person
|
||||||
|
|
||||||
|
|
||||||
|
class TestSilhouetteWindow:
|
||||||
|
"""Tests for SilhouetteWindow class."""
|
||||||
|
|
||||||
|
def test_window_fill_and_ready_behavior(self) -> None:
|
||||||
|
"""Window should be ready only when filled to window_size."""
|
||||||
|
window = SilhouetteWindow(window_size=5, stride=1, gap_threshold=10)
|
||||||
|
sil = np.ones((64, 44), dtype=np.float32)
|
||||||
|
|
||||||
|
# Not ready with fewer frames
|
||||||
|
for i in range(4):
|
||||||
|
window.push(sil, frame_idx=i, track_id=1)
|
||||||
|
assert not window.is_ready()
|
||||||
|
assert window.fill_level == (i + 1) / 5
|
||||||
|
|
||||||
|
# Ready at exactly window_size
|
||||||
|
window.push(sil, frame_idx=4, track_id=1)
|
||||||
|
assert window.is_ready()
|
||||||
|
assert window.fill_level == 1.0
|
||||||
|
|
||||||
|
def test_underfilled_not_ready(self) -> None:
|
||||||
|
"""Underfilled window should never report ready."""
|
||||||
|
window = SilhouetteWindow(window_size=10, stride=1, gap_threshold=5)
|
||||||
|
sil = np.ones((64, 44), dtype=np.float32)
|
||||||
|
|
||||||
|
# Push 9 frames (underfilled)
|
||||||
|
for i in range(9):
|
||||||
|
window.push(sil, frame_idx=i, track_id=1)
|
||||||
|
|
||||||
|
assert not window.is_ready()
|
||||||
|
assert window.fill_level == 0.9
|
||||||
|
|
||||||
|
# get_tensor should raise when not ready
|
||||||
|
with pytest.raises(ValueError, match="Window not ready"):
|
||||||
|
window.get_tensor()
|
||||||
|
|
||||||
|
def test_track_id_change_resets_buffer(self) -> None:
|
||||||
|
"""Changing track ID should reset the buffer."""
|
||||||
|
window = SilhouetteWindow(window_size=5, stride=1, gap_threshold=10)
|
||||||
|
sil = np.ones((64, 44), dtype=np.float32)
|
||||||
|
|
||||||
|
# Fill window with track_id=1
|
||||||
|
for i in range(5):
|
||||||
|
window.push(sil, frame_idx=i, track_id=1)
|
||||||
|
|
||||||
|
assert window.is_ready()
|
||||||
|
assert window.current_track_id == 1
|
||||||
|
assert window.fill_level == 1.0
|
||||||
|
|
||||||
|
# Push with different track_id should reset
|
||||||
|
window.push(sil, frame_idx=5, track_id=2)
|
||||||
|
assert not window.is_ready()
|
||||||
|
assert window.fill_level == 0.2
|
||||||
|
assert window.current_track_id == 2
|
||||||
|
|
||||||
|
def test_frame_gap_reset_behavior(self) -> None:
|
||||||
|
"""Frame gap exceeding threshold should reset buffer."""
|
||||||
|
window = SilhouetteWindow(window_size=5, stride=1, gap_threshold=3)
|
||||||
|
sil = np.ones((64, 44), dtype=np.float32)
|
||||||
|
|
||||||
|
# Fill window with consecutive frames
|
||||||
|
for i in range(5):
|
||||||
|
window.push(sil, frame_idx=i, track_id=1)
|
||||||
|
|
||||||
|
assert window.is_ready()
|
||||||
|
assert window.fill_level == 1.0
|
||||||
|
|
||||||
|
# Small gap (within threshold) - no reset
|
||||||
|
window.push(sil, frame_idx=6, track_id=1) # gap = 1
|
||||||
|
assert window.is_ready()
|
||||||
|
assert window.fill_level == 1.0 # deque maintains max
|
||||||
|
|
||||||
|
# Reset and fill again
|
||||||
|
window.reset()
|
||||||
|
for i in range(5):
|
||||||
|
window.push(sil, frame_idx=i, track_id=1)
|
||||||
|
|
||||||
|
# Large gap (exceeds threshold) - should reset
|
||||||
|
window.push(sil, frame_idx=10, track_id=1) # gap = 5 > 3
|
||||||
|
assert not window.is_ready()
|
||||||
|
assert window.fill_level == 0.2
|
||||||
|
|
||||||
|
def test_get_tensor_shape(self) -> None:
|
||||||
|
"""get_tensor should return tensor of shape [1, 1, window_size, 64, 44]."""
|
||||||
|
window_size = 7
|
||||||
|
window = SilhouetteWindow(window_size=window_size, stride=1, gap_threshold=10)
|
||||||
|
|
||||||
|
# Push unique frames to verify ordering
|
||||||
|
for i in range(window_size):
|
||||||
|
sil = np.full((64, 44), fill_value=float(i), dtype=np.float32)
|
||||||
|
window.push(sil, frame_idx=i, track_id=1)
|
||||||
|
|
||||||
|
assert window.is_ready()
|
||||||
|
|
||||||
|
tensor = window.get_tensor(device="cpu")
|
||||||
|
|
||||||
|
assert isinstance(tensor, torch.Tensor)
|
||||||
|
assert tensor.shape == (1, 1, window_size, 64, 44)
|
||||||
|
assert tensor.dtype == torch.float32
|
||||||
|
|
||||||
|
# Verify content ordering: first frame should have value 0.0
|
||||||
|
assert tensor[0, 0, 0, 0, 0].item() == 0.0
|
||||||
|
# Last frame should have value window_size-1
|
||||||
|
assert tensor[0, 0, window_size - 1, 0, 0].item() == window_size - 1
|
||||||
|
|
||||||
|
def test_should_classify_stride_behavior(self) -> None:
|
||||||
|
"""should_classify should respect stride setting."""
|
||||||
|
window = SilhouetteWindow(window_size=5, stride=3, gap_threshold=10)
|
||||||
|
sil = np.ones((64, 44), dtype=np.float32)
|
||||||
|
|
||||||
|
# Fill window
|
||||||
|
for i in range(5):
|
||||||
|
window.push(sil, frame_idx=i, track_id=1)
|
||||||
|
|
||||||
|
assert window.is_ready()
|
||||||
|
|
||||||
|
# First classification should always trigger
|
||||||
|
assert window.should_classify()
|
||||||
|
|
||||||
|
# Mark as classified at frame 4
|
||||||
|
window.mark_classified()
|
||||||
|
|
||||||
|
# Not ready to classify yet (stride=3, only 0 frames passed)
|
||||||
|
window.push(sil, frame_idx=5, track_id=1)
|
||||||
|
assert not window.should_classify()
|
||||||
|
|
||||||
|
# Still not ready (only 1 frame passed)
|
||||||
|
window.push(sil, frame_idx=6, track_id=1)
|
||||||
|
assert not window.should_classify()
|
||||||
|
|
||||||
|
# Now ready (3 frames passed since last classification)
|
||||||
|
window.push(sil, frame_idx=7, track_id=1)
|
||||||
|
assert window.should_classify()
|
||||||
|
|
||||||
|
def test_should_classify_not_ready(self) -> None:
|
||||||
|
"""should_classify should return False when window not ready."""
|
||||||
|
window = SilhouetteWindow(window_size=5, stride=1, gap_threshold=10)
|
||||||
|
sil = np.ones((64, 44), dtype=np.float32)
|
||||||
|
|
||||||
|
# Push only 3 frames (not ready)
|
||||||
|
for i in range(3):
|
||||||
|
window.push(sil, frame_idx=i, track_id=1)
|
||||||
|
|
||||||
|
assert not window.is_ready()
|
||||||
|
assert not window.should_classify()
|
||||||
|
|
||||||
|
def test_reset_clears_all_state(self) -> None:
|
||||||
|
"""reset should clear all internal state."""
|
||||||
|
window = SilhouetteWindow(window_size=5, stride=1, gap_threshold=10)
|
||||||
|
sil = np.ones((64, 44), dtype=np.float32)
|
||||||
|
|
||||||
|
# Fill and classify
|
||||||
|
for i in range(5):
|
||||||
|
window.push(sil, frame_idx=i, track_id=1)
|
||||||
|
window.mark_classified()
|
||||||
|
|
||||||
|
assert window.is_ready()
|
||||||
|
assert window.current_track_id == 1
|
||||||
|
assert window.frame_count == 5
|
||||||
|
|
||||||
|
# Reset
|
||||||
|
window.reset()
|
||||||
|
|
||||||
|
assert not window.is_ready()
|
||||||
|
assert window.fill_level == 0.0
|
||||||
|
assert window.current_track_id is None
|
||||||
|
assert window.frame_count == 0
|
||||||
|
assert not window.should_classify()
|
||||||
|
|
||||||
|
def test_push_invalid_shape_raises(self) -> None:
|
||||||
|
"""push should raise ValueError for invalid silhouette shape."""
|
||||||
|
window = SilhouetteWindow(window_size=5, stride=1, gap_threshold=10)
|
||||||
|
|
||||||
|
# Wrong shape
|
||||||
|
sil_wrong = np.ones((32, 32), dtype=np.float32)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Expected silhouette shape"):
|
||||||
|
window.push(sil_wrong, frame_idx=0, track_id=1)
|
||||||
|
|
||||||
|
def test_push_wrong_dtype_converts(self) -> None:
|
||||||
|
"""push should convert dtype to float32."""
|
||||||
|
window = SilhouetteWindow(window_size=1, stride=1, gap_threshold=10)
|
||||||
|
|
||||||
|
# uint8 input
|
||||||
|
sil_uint8 = np.ones((64, 44), dtype=np.uint8) * 255
|
||||||
|
window.push(sil_uint8, frame_idx=0, track_id=1)
|
||||||
|
|
||||||
|
tensor = window.get_tensor()
|
||||||
|
assert tensor.dtype == torch.float32
|
||||||
|
|
||||||
|
|
||||||
|
class TestSelectPerson:
|
||||||
|
"""Tests for select_person function."""
|
||||||
|
|
||||||
|
def _create_mock_results(
|
||||||
|
self,
|
||||||
|
boxes_xyxy: NDArray[np.float32],
|
||||||
|
masks_data: NDArray[np.float32],
|
||||||
|
track_ids: NDArray[np.int64] | None,
|
||||||
|
) -> Any:
|
||||||
|
"""Create a mock detection results object."""
|
||||||
|
mock_boxes = MagicMock()
|
||||||
|
mock_boxes.xyxy = boxes_xyxy
|
||||||
|
mock_boxes.id = track_ids
|
||||||
|
|
||||||
|
mock_masks = MagicMock()
|
||||||
|
mock_masks.data = masks_data
|
||||||
|
|
||||||
|
mock_results = MagicMock()
|
||||||
|
mock_results.boxes = mock_boxes
|
||||||
|
mock_results.masks = mock_masks
|
||||||
|
|
||||||
|
return mock_results
|
||||||
|
|
||||||
|
def test_select_person_single_detection(self) -> None:
|
||||||
|
"""Single detection should return that person's data."""
|
||||||
|
boxes = np.array([[10.0, 10.0, 50.0, 90.0]], dtype=np.float32) # area = 3200
|
||||||
|
masks = np.random.rand(1, 100, 100).astype(np.float32)
|
||||||
|
track_ids = np.array([42], dtype=np.int64)
|
||||||
|
|
||||||
|
results = self._create_mock_results(boxes, masks, track_ids)
|
||||||
|
result = select_person(results)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
mask, bbox, tid = result
|
||||||
|
assert mask.shape == (100, 100)
|
||||||
|
assert bbox == (10, 10, 50, 90)
|
||||||
|
assert tid == 42
|
||||||
|
|
||||||
|
def test_select_person_multi_detection_selects_largest(self) -> None:
|
||||||
|
"""Multiple detections should select the one with largest bbox area."""
|
||||||
|
# Two boxes: second one is larger
|
||||||
|
boxes = np.array(
|
||||||
|
[
|
||||||
|
[0.0, 0.0, 10.0, 10.0], # area = 100
|
||||||
|
[0.0, 0.0, 30.0, 30.0], # area = 900 (largest)
|
||||||
|
[0.0, 0.0, 20.0, 20.0], # area = 400
|
||||||
|
],
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
masks = np.random.rand(3, 100, 100).astype(np.float32)
|
||||||
|
track_ids = np.array([1, 2, 3], dtype=np.int64)
|
||||||
|
|
||||||
|
results = self._create_mock_results(boxes, masks, track_ids)
|
||||||
|
result = select_person(results)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
mask, bbox, tid = result
|
||||||
|
assert bbox == (0, 0, 30, 30) # Largest box
|
||||||
|
assert tid == 2 # Corresponding track ID
|
||||||
|
|
||||||
|
def test_select_person_no_detections_returns_none(self) -> None:
|
||||||
|
"""No detections should return None."""
|
||||||
|
boxes = np.array([], dtype=np.float32).reshape(0, 4)
|
||||||
|
masks = np.array([], dtype=np.float32).reshape(0, 100, 100)
|
||||||
|
track_ids = np.array([], dtype=np.int64)
|
||||||
|
|
||||||
|
results = self._create_mock_results(boxes, masks, track_ids)
|
||||||
|
result = select_person(results)
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_select_person_no_track_ids_returns_none(self) -> None:
|
||||||
|
"""Detections without track IDs should return None."""
|
||||||
|
boxes = np.array([[10.0, 10.0, 50.0, 90.0]], dtype=np.float32)
|
||||||
|
masks = np.random.rand(1, 100, 100).astype(np.float32)
|
||||||
|
|
||||||
|
results = self._create_mock_results(boxes, masks, track_ids=None)
|
||||||
|
result = select_person(results)
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_select_person_empty_track_ids_returns_none(self) -> None:
|
||||||
|
"""Empty track IDs array should return None."""
|
||||||
|
boxes = np.array([[10.0, 10.0, 50.0, 90.0]], dtype=np.float32)
|
||||||
|
masks = np.random.rand(1, 100, 100).astype(np.float32)
|
||||||
|
track_ids = np.array([], dtype=np.int64)
|
||||||
|
|
||||||
|
results = self._create_mock_results(boxes, masks, track_ids)
|
||||||
|
result = select_person(results)
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_select_person_missing_boxes_returns_none(self) -> None:
|
||||||
|
"""Missing boxes attribute should return None."""
|
||||||
|
mock_results = MagicMock()
|
||||||
|
mock_results.boxes = None
|
||||||
|
|
||||||
|
result = select_person(mock_results)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_select_person_missing_masks_returns_none(self) -> None:
|
||||||
|
"""Missing masks attribute should return None."""
|
||||||
|
boxes = np.array([[10.0, 10.0, 50.0, 90.0]], dtype=np.float32)
|
||||||
|
track_ids = np.array([1], dtype=np.int64)
|
||||||
|
|
||||||
|
mock_boxes = MagicMock()
|
||||||
|
mock_boxes.xyxy = boxes
|
||||||
|
mock_boxes.id = track_ids
|
||||||
|
|
||||||
|
mock_results = MagicMock()
|
||||||
|
mock_results.boxes = mock_boxes
|
||||||
|
mock_results.masks = None
|
||||||
|
|
||||||
|
result = select_person(mock_results)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_select_person_1d_bbox_handling(self) -> None:
|
||||||
|
"""1D bbox array should be reshaped to 2D."""
|
||||||
|
boxes = np.array([10.0, 10.0, 50.0, 90.0], dtype=np.float32) # 1D
|
||||||
|
masks = np.random.rand(1, 100, 100).astype(np.float32)
|
||||||
|
track_ids = np.array([1], dtype=np.int64)
|
||||||
|
|
||||||
|
results = self._create_mock_results(boxes, masks, track_ids)
|
||||||
|
result = select_person(results)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
_, bbox, tid = result
|
||||||
|
assert bbox == (10, 10, 50, 90)
|
||||||
|
assert tid == 1
|
||||||
|
|
||||||
|
def test_select_person_2d_mask_handling(self) -> None:
|
||||||
|
"""2D mask should be expanded to 3D."""
|
||||||
|
boxes = np.array([[10.0, 10.0, 50.0, 90.0]], dtype=np.float32)
|
||||||
|
masks = np.random.rand(100, 100).astype(np.float32) # 2D
|
||||||
|
track_ids = np.array([1], dtype=np.int64)
|
||||||
|
|
||||||
|
results = self._create_mock_results(boxes, masks, track_ids)
|
||||||
|
result = select_person(results)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
mask, _, _ = result
|
||||||
|
# Should be 2D (extracted from expanded 3D)
|
||||||
|
assert mask.shape == (100, 100)
|
||||||
Reference in New Issue
Block a user