feat: extract opengait_studio monorepo module

Move demo implementation into opengait_studio, retire Sports2D runtime integration, and align packaging with root-level monorepo dependency management.
This commit is contained in:
2026-03-03 17:16:17 +08:00
parent 5c6bef1ca1
commit 00fcda4fe3
39 changed files with 359 additions and 270 deletions
+1
View File
@@ -0,0 +1 @@
"""Tests for demo package."""
+23
View File
@@ -0,0 +1,23 @@
"""Test fixtures for demo package."""
import numpy as np
import pytest
import torch
@pytest.fixture
def mock_frame_tensor():
"""Return a mock video frame tensor (C, H, W)."""
return torch.randn(3, 224, 224)
@pytest.fixture
def mock_frame_array():
"""Return a mock video frame as numpy array (H, W, C)."""
return np.random.randn(224, 224, 3).astype(np.float32)
@pytest.fixture
def mock_video_sequence():
"""Return a mock video sequence tensor (T, C, H, W)."""
return torch.randn(16, 3, 224, 224)
+524
View File
@@ -0,0 +1,524 @@
"""Integration tests for NATS publisher functionality.
Tests cover:
- NATS message receipt with JSON schema validation
- Skip behavior when Docker/NATS unavailable
- Container lifecycle management (cleanup)
- JSON schema field/type validation
"""
from __future__ import annotations
import json
import socket
import subprocess
import time
from typing import TYPE_CHECKING, cast
import pytest
if TYPE_CHECKING:
from collections.abc import Generator
def _find_open_port() -> int:
"""Find an available TCP port on localhost."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("127.0.0.1", 0))
sock.listen(1)
addr = cast(tuple[str, int], sock.getsockname())
port: int = addr[1]
return port
# Constants for test configuration
NATS_SUBJECT = "scoliosis.result"
CONTAINER_NAME = "opengait-nats-test"
def _docker_available() -> bool:
"""Check if Docker is available and running."""
try:
result = subprocess.run(
["docker", "info"],
capture_output=True,
timeout=5,
check=False,
)
return result.returncode == 0
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
return False
def _nats_container_running() -> bool:
"""Check if NATS container is already running."""
try:
result = subprocess.run(
[
"docker",
"ps",
"--filter",
f"name={CONTAINER_NAME}",
"--format",
"{{.Names}}",
],
capture_output=True,
timeout=5,
check=False,
)
return CONTAINER_NAME in result.stdout.decode()
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
return False
def _start_nats_container(port: int) -> bool:
"""Start NATS container for testing."""
try:
# Remove existing container if present
_ = subprocess.run(
["docker", "rm", "-f", CONTAINER_NAME],
capture_output=True,
timeout=10,
check=False,
)
# Start new NATS container
result = subprocess.run(
[
"docker",
"run",
"-d",
"--name",
CONTAINER_NAME,
"-p",
f"{port}:4222",
"nats:latest",
],
capture_output=True,
timeout=30,
check=False,
)
if result.returncode != 0:
return False
# Wait for NATS to be ready (max 10 seconds)
for _ in range(20):
time.sleep(0.5)
try:
check_result = subprocess.run(
[
"docker",
"exec",
CONTAINER_NAME,
"nats",
"server",
"check",
"server",
],
capture_output=True,
timeout=5,
check=False,
)
if check_result.returncode == 0:
return True
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
continue
return False
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
return False
def _stop_nats_container() -> None:
"""Stop and remove NATS container."""
try:
_ = subprocess.run(
["docker", "rm", "-f", CONTAINER_NAME],
capture_output=True,
timeout=10,
check=False,
)
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
pass
def _validate_result_schema(data: dict[str, object]) -> tuple[bool, str]:
"""Validate JSON result schema.
Expected schema:
{
"frame": int,
"track_id": int,
"label": str (one of: "negative", "neutral", "positive"),
"confidence": float in [0, 1],
"window": int (non-negative),
"timestamp_ns": int
}
"""
required_fields: set[str] = {
"frame",
"track_id",
"label",
"confidence",
"window",
"timestamp_ns",
}
missing = required_fields - set(data.keys())
if missing:
return False, f"Missing required fields: {missing}"
# Validate frame (int)
frame = data["frame"]
if not isinstance(frame, int):
return False, f"frame must be int, got {type(frame)}"
# Validate track_id (int)
track_id = data["track_id"]
if not isinstance(track_id, int):
return False, f"track_id must be int, got {type(track_id)}"
# Validate label (str in specific set)
label = data["label"]
valid_labels = {"negative", "neutral", "positive"}
if not isinstance(label, str) or label not in valid_labels:
return False, f"label must be one of {valid_labels}, got {label}"
# Validate confidence (float in [0, 1])
confidence = data["confidence"]
if not isinstance(confidence, (int, float)):
return False, f"confidence must be numeric, got {type(confidence)}"
if not 0.0 <= float(confidence) <= 1.0:
return False, f"confidence must be in [0, 1], got {confidence}"
# Validate window (int, non-negative)
window = data["window"]
if not isinstance(window, int):
return False, f"window must be int, got {type(window)}"
if window < 0:
return False, f"window must be non-negative, got {window}"
# Validate timestamp_ns (int)
timestamp_ns = data["timestamp_ns"]
if not isinstance(timestamp_ns, int):
return False, f"timestamp_ns must be int, got {type(timestamp_ns)}"
return True, ""
@pytest.fixture(scope="module")
def nats_server() -> "Generator[tuple[bool, int], None, None]":
"""Fixture to manage NATS container lifecycle.
Yields (True, port) if NATS is available, (False, 0) otherwise.
Cleans up container after tests.
"""
if not _docker_available():
yield False, 0
return
port = _find_open_port()
if not _start_nats_container(port):
yield False, 0
return
try:
yield True, port
finally:
_stop_nats_container()
class TestNatsPublisherIntegration:
"""Integration tests for NATS publisher with live server."""
@pytest.mark.skipif(not _docker_available(), reason="Docker not available")
def test_nats_message_receipt_and_schema_validation(
self, nats_server: tuple[bool, int]
) -> None:
"""Test that messages are received and match expected schema."""
server_available, port = nats_server
if not server_available:
pytest.skip("NATS server not available")
nats_url = f"nats://127.0.0.1:{port}"
# Import here to avoid issues if nats-py not installed
try:
import asyncio
import nats # type: ignore[import-untyped]
except ImportError:
pytest.skip("nats-py not installed")
from opengait_studio.output import NatsPublisher, create_result
# Create publisher
publisher = NatsPublisher(nats_url, subject=NATS_SUBJECT)
# Create test results
test_results = [
create_result(
frame=100,
track_id=1,
label="positive",
confidence=0.85,
window=(70, 100),
timestamp_ns=1234567890000,
),
create_result(
frame=130,
track_id=1,
label="negative",
confidence=0.92,
window=(100, 130),
timestamp_ns=1234567890030,
),
create_result(
frame=160,
track_id=2,
label="neutral",
confidence=0.78,
window=(130, 160),
timestamp_ns=1234567890060,
),
]
# Collect received messages
received_messages: list[dict[str, object]] = []
async def subscribe_and_publish():
"""Subscribe to subject, publish messages, collect results."""
nc = await nats.connect(nats_url) # pyright: ignore[reportUnknownMemberType]
# Subscribe
sub = await nc.subscribe(NATS_SUBJECT) # pyright: ignore[reportUnknownMemberType]
# Publish all test results
for result in test_results:
publisher.publish(result)
# Wait for messages with timeout
for _ in range(len(test_results)):
try:
msg = await asyncio.wait_for(sub.next_msg(), timeout=5.0)
data = json.loads(msg.data.decode("utf-8")) # pyright: ignore[reportAny]
received_messages.append(cast(dict[str, object], data))
except asyncio.TimeoutError:
break
await sub.unsubscribe()
await nc.close()
# Run async subscriber
asyncio.run(subscribe_and_publish())
# Cleanup publisher
publisher.close()
# Verify all messages received
assert len(received_messages) == len(test_results), (
f"Expected {len(test_results)} messages, got {len(received_messages)}"
)
# Validate schema for each message
for i, msg in enumerate(received_messages):
is_valid, error = _validate_result_schema(msg)
assert is_valid, f"Message {i} schema validation failed: {error}"
# Verify specific values
assert received_messages[0]["frame"] == 100
assert received_messages[0]["label"] == "positive"
assert received_messages[0]["track_id"] == 1
_conf = received_messages[0]["confidence"]
assert isinstance(_conf, (int, float))
assert 0.0 <= float(_conf) <= 1.0
assert received_messages[1]["label"] == "negative"
assert received_messages[2]["track_id"] == 2
@pytest.mark.skipif(not _docker_available(), reason="Docker not available")
def test_nats_publisher_graceful_when_server_unavailable(self) -> None:
"""Test that publisher handles missing server gracefully."""
try:
from opengait_studio.output import NatsPublisher
except ImportError:
pytest.skip("output module not available")
# Use wrong port where no server is running
bad_url = "nats://127.0.0.1:14222"
publisher = NatsPublisher(bad_url, subject=NATS_SUBJECT)
# Should not raise when publishing without server
test_result: dict[str, object] = {
"frame": 1,
"track_id": 1,
"label": "positive",
"confidence": 0.85,
"window": 30,
"timestamp_ns": 1234567890,
}
# Should not raise
publisher.publish(test_result)
# Cleanup should also not raise
publisher.close()
@pytest.mark.skipif(not _docker_available(), reason="Docker not available")
def test_nats_publisher_context_manager(
self, nats_server: tuple[bool, int]
) -> None:
"""Test that publisher works as context manager."""
server_available, port = nats_server
if not server_available:
pytest.skip("NATS server not available")
nats_url = f"nats://127.0.0.1:{port}"
try:
import asyncio
import nats # type: ignore[import-untyped]
from opengait_studio.output import NatsPublisher, create_result
except ImportError as e:
pytest.skip(f"Required module not available: {e}")
received_messages: list[dict[str, object]] = []
async def subscribe_and_test():
nc = await nats.connect(nats_url) # pyright: ignore[reportUnknownMemberType]
sub = await nc.subscribe(NATS_SUBJECT) # pyright: ignore[reportUnknownMemberType]
# Use context manager
with NatsPublisher(nats_url, subject=NATS_SUBJECT) as publisher:
result = create_result(
frame=200,
track_id=5,
label="neutral",
confidence=0.65,
window=(170, 200),
timestamp_ns=9999999999,
)
publisher.publish(result)
# Wait for message
try:
msg = await asyncio.wait_for(sub.next_msg(), timeout=5.0)
data = json.loads(msg.data.decode("utf-8")) # pyright: ignore[reportAny]
received_messages.append(cast(dict[str, object], data))
except asyncio.TimeoutError:
pass
await sub.unsubscribe()
await nc.close()
asyncio.run(subscribe_and_test())
assert len(received_messages) == 1
assert received_messages[0]["frame"] == 200
assert received_messages[0]["track_id"] == 5
class TestNatsSchemaValidation:
"""Tests for JSON schema validation without requiring NATS server."""
def test_validate_result_schema_valid(self) -> None:
"""Test schema validation with valid data."""
valid_data: dict[str, object] = {
"frame": 1234,
"track_id": 42,
"label": "positive",
"confidence": 0.85,
"window": 1230,
"timestamp_ns": 1234567890000,
}
is_valid, error = _validate_result_schema(valid_data)
assert is_valid, f"Valid data rejected: {error}"
def test_validate_result_schema_invalid_label(self) -> None:
"""Test schema validation rejects invalid label."""
invalid_data: dict[str, object] = {
"frame": 1234,
"track_id": 42,
"label": "invalid_label",
"confidence": 0.85,
"window": 1230,
"timestamp_ns": 1234567890000,
}
is_valid, error = _validate_result_schema(invalid_data)
assert not is_valid
assert "label" in error.lower()
def test_validate_result_schema_confidence_out_of_range(self) -> None:
"""Test schema validation rejects confidence outside [0, 1]."""
invalid_data: dict[str, object] = {
"frame": 1234,
"track_id": 42,
"label": "positive",
"confidence": 1.5,
"window": 1230,
"timestamp_ns": 1234567890000,
}
is_valid, error = _validate_result_schema(invalid_data)
assert not is_valid
assert "confidence" in error.lower()
def test_validate_result_schema_missing_fields(self) -> None:
"""Test schema validation detects missing fields."""
incomplete_data: dict[str, object] = {
"frame": 1234,
"label": "positive",
}
is_valid, error = _validate_result_schema(incomplete_data)
assert not is_valid
assert "missing" in error.lower()
def test_validate_result_schema_wrong_types(self) -> None:
"""Test schema validation rejects wrong types."""
wrong_types: dict[str, object] = {
"frame": "not_an_int",
"track_id": 42,
"label": "positive",
"confidence": 0.85,
"window": 1230,
"timestamp_ns": 1234567890000,
}
is_valid, error = _validate_result_schema(wrong_types)
assert not is_valid
assert "frame" in error.lower()
def test_all_valid_labels_accepted(self) -> None:
"""Test that all valid labels are accepted."""
for label_str in ["negative", "neutral", "positive"]:
data: dict[str, object] = {
"frame": 100,
"track_id": 1,
"label": label_str,
"confidence": 0.5,
"window": 100,
"timestamp_ns": 1234567890,
}
is_valid, error = _validate_result_schema(data)
assert is_valid, f"Valid label '{label_str}' rejected: {error}"
class TestDockerAvailability:
"""Tests for Docker availability detection."""
def test_docker_available_check(self) -> None:
"""Test Docker availability check doesn't crash."""
# This should not raise
result = _docker_available()
assert isinstance(result, bool)
def test_nats_container_running_check(self) -> None:
"""Test container running check doesn't crash."""
# This should not raise even if Docker not available
result = _nats_container_running()
assert isinstance(result, bool)
+955
View File
@@ -0,0 +1,955 @@
from __future__ import annotations
import importlib.util
import json
import pickle
from pathlib import Path
import subprocess
import sys
import time
from typing import Final, Literal, cast
from unittest import mock
import numpy as np
from numpy.typing import NDArray
import pytest
import torch
from opengait_studio.sconet_demo import ScoNetDemo
REPO_ROOT: Final[Path] = Path(__file__).resolve().parents[2]
SAMPLE_VIDEO_PATH: Final[Path] = REPO_ROOT / "assets" / "sample.mp4"
CHECKPOINT_PATH: Final[Path] = REPO_ROOT / "ckpt" / "ScoNet-20000.pt"
CONFIG_PATH: Final[Path] = REPO_ROOT / "configs" / "sconet" / "sconet_scoliosis1k.yaml"
YOLO_MODEL_PATH: Final[Path] = REPO_ROOT / "ckpt" / "yolo11n-seg.pt"
def _device_for_runtime() -> str:
return "cuda:0" if torch.cuda.is_available() else "cpu"
def _run_pipeline_cli(
*args: str, timeout_seconds: int = 120
) -> subprocess.CompletedProcess[str]:
command = [sys.executable, "-m", "opengait_studio", *args]
return subprocess.run(
command,
cwd=REPO_ROOT,
capture_output=True,
text=True,
check=False,
timeout=timeout_seconds,
)
def _require_integration_assets() -> None:
if not SAMPLE_VIDEO_PATH.is_file():
pytest.skip(f"Missing sample video: {SAMPLE_VIDEO_PATH}")
if not CONFIG_PATH.is_file():
pytest.skip(f"Missing config: {CONFIG_PATH}")
if not YOLO_MODEL_PATH.is_file():
pytest.skip(f"Missing YOLO model file: {YOLO_MODEL_PATH}")
@pytest.fixture
def compatible_checkpoint_path(tmp_path: Path) -> Path:
if not CONFIG_PATH.is_file():
pytest.skip(f"Missing config: {CONFIG_PATH}")
checkpoint_file = tmp_path / "sconet-compatible.pt"
model = ScoNetDemo(cfg_path=str(CONFIG_PATH), checkpoint_path=None, device="cpu")
torch.save(model.state_dict(), checkpoint_file)
return checkpoint_file
def _extract_prediction_json_lines(stdout: str) -> list[dict[str, object]]:
required_keys = {
"frame",
"track_id",
"label",
"confidence",
"window",
"timestamp_ns",
}
predictions: list[dict[str, object]] = []
for line in stdout.splitlines():
stripped = line.strip()
if not stripped:
continue
try:
payload_obj = cast(object, json.loads(stripped))
except json.JSONDecodeError:
continue
if not isinstance(payload_obj, dict):
continue
payload = cast(dict[str, object], payload_obj)
if required_keys.issubset(payload.keys()):
predictions.append(payload)
return predictions
def _assert_prediction_schema(prediction: dict[str, object]) -> None:
assert isinstance(prediction["frame"], int)
assert isinstance(prediction["track_id"], int)
label = prediction["label"]
assert isinstance(label, str)
assert label in {"negative", "neutral", "positive"}
confidence = prediction["confidence"]
assert isinstance(confidence, (int, float))
confidence_value = float(confidence)
assert 0.0 <= confidence_value <= 1.0
window_obj = prediction["window"]
assert isinstance(window_obj, int)
assert window_obj >= 0
assert isinstance(prediction["timestamp_ns"], int)
def test_pipeline_cli_fps_benchmark_smoke(
compatible_checkpoint_path: Path,
) -> None:
_require_integration_assets()
max_frames = 90
started_at = time.perf_counter()
result = _run_pipeline_cli(
"--source",
str(SAMPLE_VIDEO_PATH),
"--checkpoint",
str(compatible_checkpoint_path),
"--config",
str(CONFIG_PATH),
"--device",
_device_for_runtime(),
"--yolo-model",
str(YOLO_MODEL_PATH),
"--window",
"5",
"--stride",
"1",
"--max-frames",
str(max_frames),
timeout_seconds=180,
)
elapsed_seconds = time.perf_counter() - started_at
assert result.returncode == 0, (
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
)
predictions = _extract_prediction_json_lines(result.stdout)
assert predictions, "Expected prediction output for FPS benchmark run"
for prediction in predictions:
_assert_prediction_schema(prediction)
observed_frames = {
frame_obj
for prediction in predictions
for frame_obj in [prediction["frame"]]
if isinstance(frame_obj, int)
}
observed_units = len(observed_frames)
if observed_units < 5:
pytest.skip(
"Insufficient observed frame samples for stable FPS benchmark in this environment"
)
if elapsed_seconds <= 0:
pytest.skip("Non-positive elapsed time; cannot compute FPS benchmark")
fps = observed_units / elapsed_seconds
min_expected_fps = 0.2
assert fps >= min_expected_fps, (
"Observed FPS below conservative CI threshold: "
f"{fps:.3f} < {min_expected_fps:.3f} "
f"(observed_units={observed_units}, elapsed_seconds={elapsed_seconds:.3f})"
)
def test_pipeline_cli_happy_path_outputs_json_predictions(
compatible_checkpoint_path: Path,
) -> None:
_require_integration_assets()
result = _run_pipeline_cli(
"--source",
str(SAMPLE_VIDEO_PATH),
"--checkpoint",
str(compatible_checkpoint_path),
"--config",
str(CONFIG_PATH),
"--device",
_device_for_runtime(),
"--yolo-model",
str(YOLO_MODEL_PATH),
"--window",
"10",
"--stride",
"10",
"--max-frames",
"120",
timeout_seconds=180,
)
assert result.returncode == 0, (
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
)
predictions = _extract_prediction_json_lines(result.stdout)
assert predictions, (
"Expected at least one prediction JSON line in stdout. "
f"stdout:\n{result.stdout}\nstderr:\n{result.stderr}"
)
for prediction in predictions:
_assert_prediction_schema(prediction)
assert "Connected to NATS" not in result.stderr
def test_pipeline_cli_max_frames_caps_output_frames(
compatible_checkpoint_path: Path,
) -> None:
_require_integration_assets()
max_frames = 20
result = _run_pipeline_cli(
"--source",
str(SAMPLE_VIDEO_PATH),
"--checkpoint",
str(compatible_checkpoint_path),
"--config",
str(CONFIG_PATH),
"--device",
_device_for_runtime(),
"--yolo-model",
str(YOLO_MODEL_PATH),
"--window",
"5",
"--stride",
"1",
"--max-frames",
str(max_frames),
timeout_seconds=180,
)
assert result.returncode == 0, (
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
)
predictions = _extract_prediction_json_lines(result.stdout)
assert predictions, "Expected prediction output with --max-frames run"
for prediction in predictions:
_assert_prediction_schema(prediction)
frame_idx_obj = prediction["frame"]
assert isinstance(frame_idx_obj, int)
assert frame_idx_obj < max_frames
def test_pipeline_cli_invalid_source_path_returns_user_error() -> None:
result = _run_pipeline_cli(
"--source",
"/definitely/not/a/real/video.mp4",
"--checkpoint",
"/tmp/unused-checkpoint.pt",
"--config",
str(CONFIG_PATH),
timeout_seconds=30,
)
assert result.returncode == 2
assert "Error: Video source not found" in result.stderr
def test_pipeline_cli_invalid_checkpoint_path_returns_user_error() -> None:
if not SAMPLE_VIDEO_PATH.is_file():
pytest.skip(f"Missing sample video: {SAMPLE_VIDEO_PATH}")
if not CONFIG_PATH.is_file():
pytest.skip(f"Missing config: {CONFIG_PATH}")
result = _run_pipeline_cli(
"--source",
str(SAMPLE_VIDEO_PATH),
"--checkpoint",
str(REPO_ROOT / "ckpt" / "missing-checkpoint.pt"),
"--config",
str(CONFIG_PATH),
timeout_seconds=30,
)
assert result.returncode == 2
assert "Error: Checkpoint not found" in result.stderr
def test_pipeline_cli_preprocess_only_requires_export_path(
compatible_checkpoint_path: Path,
) -> None:
"""Test that --preprocess-only requires --silhouette-export-path."""
_require_integration_assets()
result = _run_pipeline_cli(
"--source",
str(SAMPLE_VIDEO_PATH),
"--checkpoint",
str(compatible_checkpoint_path),
"--config",
str(CONFIG_PATH),
"--device",
_device_for_runtime(),
"--yolo-model",
str(YOLO_MODEL_PATH),
"--preprocess-only",
"--max-frames",
"10",
timeout_seconds=30,
)
assert result.returncode == 2
assert "--silhouette-export-path is required" in result.stderr
def test_pipeline_cli_preprocess_only_exports_pickle(
compatible_checkpoint_path: Path,
tmp_path: Path,
) -> None:
"""Test preprocess-only mode exports silhouettes to pickle."""
_require_integration_assets()
export_path = tmp_path / "silhouettes.pkl"
result = _run_pipeline_cli(
"--source",
str(SAMPLE_VIDEO_PATH),
"--checkpoint",
str(compatible_checkpoint_path),
"--config",
str(CONFIG_PATH),
"--device",
_device_for_runtime(),
"--yolo-model",
str(YOLO_MODEL_PATH),
"--preprocess-only",
"--silhouette-export-path",
str(export_path),
"--silhouette-export-format",
"pickle",
"--max-frames",
"30",
timeout_seconds=180,
)
assert result.returncode == 0, (
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
)
# Verify export file exists and contains silhouettes
assert export_path.is_file(), f"Export file not found: {export_path}"
with open(export_path, "rb") as f:
silhouettes = pickle.load(f)
assert isinstance(silhouettes, list)
assert len(silhouettes) > 0, "Expected at least one silhouette"
# Verify silhouette schema
for item in silhouettes:
assert isinstance(item, dict)
assert "frame" in item
assert "track_id" in item
assert "timestamp_ns" in item
assert "silhouette" in item
assert isinstance(item["frame"], int)
assert isinstance(item["track_id"], int)
assert isinstance(item["timestamp_ns"], int)
def test_pipeline_cli_result_export_json(
compatible_checkpoint_path: Path,
tmp_path: Path,
) -> None:
"""Test that results can be exported to JSON file."""
_require_integration_assets()
export_path = tmp_path / "results.jsonl"
result = _run_pipeline_cli(
"--source",
str(SAMPLE_VIDEO_PATH),
"--checkpoint",
str(compatible_checkpoint_path),
"--config",
str(CONFIG_PATH),
"--device",
_device_for_runtime(),
"--yolo-model",
str(YOLO_MODEL_PATH),
"--window",
"10",
"--stride",
"10",
"--result-export-path",
str(export_path),
"--result-export-format",
"json",
"--max-frames",
"60",
timeout_seconds=180,
)
assert result.returncode == 0, (
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
)
# Verify export file exists
assert export_path.is_file(), f"Export file not found: {export_path}"
# Read and verify JSON lines
predictions: list[dict[str, object]] = []
with open(export_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
predictions.append(cast(dict[str, object], json.loads(line)))
assert len(predictions) > 0, "Expected at least one prediction in export"
for prediction in predictions:
_assert_prediction_schema(prediction)
def test_pipeline_cli_result_export_pickle(
compatible_checkpoint_path: Path,
tmp_path: Path,
) -> None:
"""Test that results can be exported to pickle file."""
_require_integration_assets()
export_path = tmp_path / "results.pkl"
result = _run_pipeline_cli(
"--source",
str(SAMPLE_VIDEO_PATH),
"--checkpoint",
str(compatible_checkpoint_path),
"--config",
str(CONFIG_PATH),
"--device",
_device_for_runtime(),
"--yolo-model",
str(YOLO_MODEL_PATH),
"--window",
"10",
"--stride",
"10",
"--result-export-path",
str(export_path),
"--result-export-format",
"pickle",
"--max-frames",
"60",
timeout_seconds=180,
)
assert result.returncode == 0, (
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
)
# Verify export file exists
assert export_path.is_file(), f"Export file not found: {export_path}"
# Read and verify pickle
with open(export_path, "rb") as f:
predictions = pickle.load(f)
assert isinstance(predictions, list)
assert len(predictions) > 0, "Expected at least one prediction in export"
for prediction in predictions:
_assert_prediction_schema(prediction)
def test_pipeline_cli_silhouette_and_result_export(
compatible_checkpoint_path: Path,
tmp_path: Path,
) -> None:
"""Test exporting both silhouettes and results simultaneously."""
_require_integration_assets()
silhouette_export = tmp_path / "silhouettes.pkl"
result_export = tmp_path / "results.jsonl"
result = _run_pipeline_cli(
"--source",
str(SAMPLE_VIDEO_PATH),
"--checkpoint",
str(compatible_checkpoint_path),
"--config",
str(CONFIG_PATH),
"--device",
_device_for_runtime(),
"--yolo-model",
str(YOLO_MODEL_PATH),
"--window",
"10",
"--stride",
"10",
"--silhouette-export-path",
str(silhouette_export),
"--silhouette-export-format",
"pickle",
"--result-export-path",
str(result_export),
"--result-export-format",
"json",
"--max-frames",
"60",
timeout_seconds=180,
)
assert result.returncode == 0, (
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
)
# Verify both export files exist
assert silhouette_export.is_file(), (
f"Silhouette export not found: {silhouette_export}"
)
assert result_export.is_file(), f"Result export not found: {result_export}"
# Verify silhouette export
with open(silhouette_export, "rb") as f:
silhouettes = pickle.load(f)
assert isinstance(silhouettes, list)
assert len(silhouettes) > 0
# Verify result export
with open(result_export, "r", encoding="utf-8") as f:
predictions = [
cast(dict[str, object], json.loads(line)) for line in f if line.strip()
]
assert len(predictions) > 0
def test_pipeline_cli_parquet_export_requires_pyarrow(
compatible_checkpoint_path: Path,
tmp_path: Path,
) -> None:
"""Test that parquet export fails gracefully when pyarrow is not available."""
_require_integration_assets()
# Skip if pyarrow is actually installed
if importlib.util.find_spec("pyarrow") is not None:
pytest.skip("pyarrow is installed, skipping missing dependency test")
try:
import pyarrow # noqa: F401
pytest.skip("pyarrow is installed, skipping missing dependency test")
except ImportError:
pass
export_path = tmp_path / "results.parquet"
result = _run_pipeline_cli(
"--source",
str(SAMPLE_VIDEO_PATH),
"--checkpoint",
str(compatible_checkpoint_path),
"--config",
str(CONFIG_PATH),
"--device",
_device_for_runtime(),
"--yolo-model",
str(YOLO_MODEL_PATH),
"--window",
"10",
"--stride",
"10",
"--result-export-path",
str(export_path),
"--result-export-format",
"parquet",
"--max-frames",
"30",
timeout_seconds=180,
)
# Should fail with RuntimeError about pyarrow
assert result.returncode == 1
assert "parquet" in result.stderr.lower() or "pyarrow" in result.stderr.lower()
def test_pipeline_cli_silhouette_visualization(
compatible_checkpoint_path: Path,
tmp_path: Path,
) -> None:
"""Test that silhouette visualization creates PNG files."""
_require_integration_assets()
visualize_dir = tmp_path / "silhouette_viz"
result = _run_pipeline_cli(
"--source",
str(SAMPLE_VIDEO_PATH),
"--checkpoint",
str(compatible_checkpoint_path),
"--config",
str(CONFIG_PATH),
"--device",
_device_for_runtime(),
"--yolo-model",
str(YOLO_MODEL_PATH),
"--window",
"10",
"--stride",
"10",
"--silhouette-visualize-dir",
str(visualize_dir),
"--max-frames",
"30",
timeout_seconds=180,
)
assert result.returncode == 0, (
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
)
# Verify visualization directory exists and contains PNG files
assert visualize_dir.is_dir(), f"Visualization directory not found: {visualize_dir}"
png_files = list(visualize_dir.glob("*.png"))
assert len(png_files) > 0, "Expected at least one PNG visualization file"
# Verify filenames contain frame and track info
for png_file in png_files:
assert "silhouette_frame" in png_file.name
assert "_track" in png_file.name
def test_pipeline_cli_preprocess_only_with_visualization(
compatible_checkpoint_path: Path,
tmp_path: Path,
) -> None:
"""Test preprocess-only mode with both export and visualization."""
_require_integration_assets()
export_path = tmp_path / "silhouettes.pkl"
visualize_dir = tmp_path / "silhouette_viz"
result = _run_pipeline_cli(
"--source",
str(SAMPLE_VIDEO_PATH),
"--checkpoint",
str(compatible_checkpoint_path),
"--config",
str(CONFIG_PATH),
"--device",
_device_for_runtime(),
"--yolo-model",
str(YOLO_MODEL_PATH),
"--preprocess-only",
"--silhouette-export-path",
str(export_path),
"--silhouette-visualize-dir",
str(visualize_dir),
"--max-frames",
"30",
timeout_seconds=180,
)
assert result.returncode == 0, (
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
)
# Verify export file exists
assert export_path.is_file(), f"Export file not found: {export_path}"
# Verify visualization files exist
assert visualize_dir.is_dir(), f"Visualization directory not found: {visualize_dir}"
png_files = list(visualize_dir.glob("*.png"))
assert len(png_files) > 0, "Expected at least one PNG visualization file"
# Load and verify pickle export
with open(export_path, "rb") as f:
silhouettes = pickle.load(f)
assert isinstance(silhouettes, list)
assert len(silhouettes) > 0
# Number of exported silhouettes should match number of PNG files
assert len(silhouettes) == len(png_files), (
f"Mismatch: {len(silhouettes)} silhouettes exported but {len(png_files)} PNG files created"
)
class MockVisualizer:
"""Mock visualizer to track update calls."""
def __init__(self) -> None:
self.update_calls: list[dict[str, object]] = []
self.return_value: bool = True
def update(
self,
frame: NDArray[np.uint8],
bbox: tuple[int, int, int, int] | None,
bbox_mask: tuple[int, int, int, int] | None,
track_id: int,
mask_raw: NDArray[np.uint8] | None,
silhouette: NDArray[np.float32] | None,
segmentation_input: NDArray[np.float32] | None,
label: str | None,
confidence: float | None,
fps: float,
) -> bool:
self.update_calls.append(
{
"frame": frame,
"bbox": bbox,
"bbox_mask": bbox_mask,
"track_id": track_id,
"mask_raw": mask_raw,
"silhouette": silhouette,
"segmentation_input": segmentation_input,
"label": label,
"confidence": confidence,
"fps": fps,
}
)
return self.return_value
def close(self) -> None:
pass
def test_pipeline_visualizer_updates_on_no_detection() -> None:
"""Test that visualizer is still updated even when process_frame returns None.
This is a regression test for the window freeze issue when no person is detected.
The window should refresh every frame to prevent freezing.
"""
from opengait_studio.pipeline import ScoliosisPipeline
# Create a minimal pipeline with mocked dependencies
with (
mock.patch("opengait_studio.pipeline.YOLO") as mock_yolo,
mock.patch("opengait_studio.pipeline.create_source") as mock_source,
mock.patch("opengait_studio.pipeline.create_publisher") as mock_publisher,
mock.patch("opengait_studio.pipeline.ScoNetDemo") as mock_classifier,
):
# Setup mock detector that returns no detections (causing process_frame to return None)
mock_detector = mock.MagicMock()
mock_detector.track.return_value = [] # No detections
mock_yolo.return_value = mock_detector
# Setup mock source with 3 frames
mock_frame = np.zeros((480, 640, 3), dtype=np.uint8)
mock_source.return_value = [(mock_frame, {"frame_count": i}) for i in range(3)]
# Setup mock publisher and classifier
mock_publisher.return_value = mock.MagicMock()
mock_classifier.return_value = mock.MagicMock()
# Create pipeline with visualize enabled
pipeline = ScoliosisPipeline(
source="dummy.mp4",
checkpoint="dummy.pt",
config=str(CONFIG_PATH) if CONFIG_PATH.exists() else "dummy.yaml",
device="cpu",
yolo_model="dummy.pt",
window=30,
stride=30,
nats_url=None,
nats_subject="test",
max_frames=None,
visualize=True,
)
mock_viz = MockVisualizer()
setattr(pipeline, "_visualizer", mock_viz)
# Run pipeline
_ = pipeline.run()
# Verify visualizer was updated for all 3 frames even with no detections
assert len(mock_viz.update_calls) == 3, (
f"Expected visualizer.update() to be called 3 times (once per frame), "
f"but was called {len(mock_viz.update_calls)} times. "
f"Window would freeze if not updated on no-detection frames."
)
# Verify each call had the frame data
for call in mock_viz.update_calls:
assert call["track_id"] == 0 # Default track_id when no detection
assert call["bbox"] is None # No bbox when no detection
assert call["bbox_mask"] is None
assert call["mask_raw"] is None # No mask when no detection
assert call["silhouette"] is None # No silhouette when no detection
assert call["segmentation_input"] is None
assert call["label"] is None # No label when no detection
assert call["confidence"] is None # No confidence when no detection
def test_pipeline_visualizer_clears_bbox_on_no_detection() -> None:
from opengait_studio.pipeline import ScoliosisPipeline
# Create a minimal pipeline with mocked dependencies
with (
mock.patch("opengait_studio.pipeline.YOLO") as mock_yolo,
mock.patch("opengait_studio.pipeline.create_source") as mock_source,
mock.patch("opengait_studio.pipeline.create_publisher") as mock_publisher,
mock.patch("opengait_studio.pipeline.ScoNetDemo") as mock_classifier,
mock.patch("opengait_studio.pipeline.select_person") as mock_select_person,
mock.patch("opengait_studio.pipeline.mask_to_silhouette") as mock_mask_to_sil,
):
# Create mock detection result for frames 0-1 (valid detection)
mock_box = mock.MagicMock()
mock_box.xyxy = np.array([[100, 100, 200, 300]], dtype=np.float32)
mock_box.id = np.array([1], dtype=np.int64)
mock_mask = mock.MagicMock()
mock_mask.data = np.random.rand(1, 480, 640).astype(np.float32)
mock_result = mock.MagicMock()
mock_result.boxes = mock_box
mock_result.masks = mock_mask
# Setup mock detector: detection for frames 0-1, then no detection for frames 2-3
mock_detector = mock.MagicMock()
mock_detector.track.side_effect = [
[mock_result], # Frame 0: valid detection
[mock_result], # Frame 1: valid detection
[], # Frame 2: no detection
[], # Frame 3: no detection
]
mock_yolo.return_value = mock_detector
# Setup mock source with 4 frames
mock_frame = np.zeros((480, 640, 3), dtype=np.uint8)
mock_source.return_value = [(mock_frame, {"frame_count": i}) for i in range(4)]
# Setup mock publisher and classifier
mock_publisher.return_value = mock.MagicMock()
mock_classifier.return_value = mock.MagicMock()
# Setup mock select_person to return valid mask and bbox
dummy_mask = np.random.randint(0, 256, (480, 640), dtype=np.uint8)
dummy_bbox_mask = (100, 100, 200, 300)
dummy_bbox_frame = (100, 100, 200, 300)
mock_select_person.return_value = (
dummy_mask,
dummy_bbox_mask,
dummy_bbox_frame,
1,
)
# Setup mock mask_to_silhouette to return valid silhouette
dummy_silhouette = np.random.rand(64, 44).astype(np.float32)
mock_mask_to_sil.return_value = dummy_silhouette
# Create pipeline with visualize enabled
pipeline = ScoliosisPipeline(
source="dummy.mp4",
checkpoint="dummy.pt",
config=str(CONFIG_PATH) if CONFIG_PATH.exists() else "dummy.yaml",
device="cpu",
yolo_model="dummy.pt",
window=30,
stride=30,
nats_url=None,
nats_subject="test",
max_frames=None,
visualize=True,
)
mock_viz = MockVisualizer()
setattr(pipeline, "_visualizer", mock_viz)
# Run pipeline
_ = pipeline.run()
# Verify visualizer was updated for all 4 frames
assert len(mock_viz.update_calls) == 4, (
f"Expected visualizer.update() to be called 4 times, "
f"but was called {len(mock_viz.update_calls)} times."
)
# Extract the mask_raw values from each call
mask_raw_calls = [call["mask_raw"] for call in mock_viz.update_calls]
# Frames 0 and 1 should have valid masks (not None)
assert mask_raw_calls[0] is not None, "Frame 0 should have valid mask"
assert mask_raw_calls[1] is not None, "Frame 1 should have valid mask"
assert mask_raw_calls[2] is not None, (
"Frame 2 (no detection) should display cached mask from last valid detection, "
"not None/blank"
)
assert mask_raw_calls[3] is not None, (
"Frame 3 (no detection) should display cached mask from last valid detection, "
"not None/blank"
)
segmentation_inputs = [
call["segmentation_input"] for call in mock_viz.update_calls
]
bbox_mask_calls = [call["bbox_mask"] for call in mock_viz.update_calls]
assert segmentation_inputs[0] is not None
assert segmentation_inputs[1] is not None
assert segmentation_inputs[2] is not None
assert segmentation_inputs[3] is not None
bbox_calls = [call["bbox"] for call in mock_viz.update_calls]
assert bbox_calls[0] == dummy_bbox_frame
assert bbox_calls[1] == dummy_bbox_frame
assert bbox_calls[2] is None
assert bbox_calls[3] is None
assert bbox_mask_calls[0] == dummy_bbox_mask
assert bbox_mask_calls[1] == dummy_bbox_mask
assert bbox_mask_calls[2] is None
assert bbox_mask_calls[3] is None
label_calls = [call["label"] for call in mock_viz.update_calls]
confidence_calls = [call["confidence"] for call in mock_viz.update_calls]
assert label_calls[2] is None
assert label_calls[3] is None
assert confidence_calls[2] is None
assert confidence_calls[3] is None
if mask_raw_calls[1] is not None and mask_raw_calls[2] is not None:
assert mask_raw_calls[1] is not mask_raw_calls[2], (
"Cached mask should be a copy, not the same object reference"
)
def test_frame_pacer_emission_count_24_to_15() -> None:
from opengait_studio.pipeline import _FramePacer
pacer = _FramePacer(15.0)
interval_ns = int(1_000_000_000 / 24)
emitted = sum(pacer.should_emit(i * interval_ns) for i in range(100))
assert 60 <= emitted <= 65
def test_frame_pacer_requires_positive_target_fps() -> None:
from opengait_studio.pipeline import _FramePacer
with pytest.raises(ValueError, match="target_fps must be positive"):
_FramePacer(0.0)
@pytest.mark.parametrize(
("window", "stride", "mode", "expected"),
[
(30, 30, "manual", 30),
(30, 7, "manual", 7),
(30, 30, "sliding", 1),
(30, 1, "chunked", 30),
(15, 3, "chunked", 15),
],
)
def test_resolve_stride_modes(
window: int,
stride: int,
mode: Literal["manual", "sliding", "chunked"],
expected: int,
) -> None:
from opengait_studio.pipeline import resolve_stride
assert resolve_stride(window, stride, mode) == expected
+181
View File
@@ -0,0 +1,181 @@
"""Unit tests for silhouette preprocessing functions."""
from typing import cast
import numpy as np
from numpy.typing import NDArray
import pytest
from beartype.roar import BeartypeCallHintParamViolation
from jaxtyping import TypeCheckError
from opengait_studio.preprocess import mask_to_silhouette
class TestMaskToSilhouette:
"""Tests for mask_to_silhouette() function."""
def test_valid_mask_returns_correct_shape_dtype_and_range(self) -> None:
"""Valid mask should return (64, 44) float32 array in [0, 1] range."""
# Create a synthetic mask with sufficient area (person-shaped blob)
h, w = 200, 150
mask = np.zeros((h, w), dtype=np.uint8)
# Draw a filled ellipse to simulate a person
center_y, center_x = h // 2, w // 2
axes_y, axes_x = h // 3, w // 4
y, x = np.ogrid[:h, :w]
ellipse_mask = ((x - center_x) / axes_x) ** 2 + (
(y - center_y) / axes_y
) ** 2 <= 1
mask[ellipse_mask] = 255
bbox = (w // 4, h // 6, 3 * w // 4, 5 * h // 6)
result = mask_to_silhouette(mask, bbox)
assert result is not None
result_arr = cast(NDArray[np.float32], result)
assert result_arr.shape == (64, 44)
assert result_arr.dtype == np.float32
assert np.all(result_arr >= 0.0) and np.all(result_arr <= 1.0)
def test_tiny_mask_returns_none(self) -> None:
"""Mask with area below MIN_MASK_AREA should return None."""
# Create a tiny mask
mask = np.zeros((50, 50), dtype=np.uint8)
mask[20:22, 20:22] = 255 # Only 4 pixels, well below MIN_MASK_AREA (500)
bbox = (20, 20, 22, 22)
result = mask_to_silhouette(mask, bbox)
assert result is None
def test_empty_mask_returns_none(self) -> None:
"""Completely empty mask should return None."""
mask = np.zeros((100, 100), dtype=np.uint8)
bbox = (10, 10, 90, 90)
result = mask_to_silhouette(mask, bbox)
assert result is None
def test_full_frame_mask_returns_valid_output(self) -> None:
"""Full-frame mask (large bbox covering entire image) should work."""
h, w = 300, 200
mask = np.zeros((h, w), dtype=np.uint8)
# Create a large filled region
mask[50:250, 30:170] = 255
# Full frame bbox
bbox = (0, 0, w, h)
result = mask_to_silhouette(mask, bbox)
assert result is not None
result_arr = cast(NDArray[np.float32], result)
assert result_arr.shape == (64, 44)
assert result_arr.dtype == np.float32
def test_determinism_same_input_same_output(self) -> None:
"""Same input should always produce same output."""
h, w = 200, 150
mask = np.zeros((h, w), dtype=np.uint8)
# Create a person-shaped region
mask[50:150, 40:110] = 255
bbox = (40, 50, 110, 150)
result1 = mask_to_silhouette(mask, bbox)
result2 = mask_to_silhouette(mask, bbox)
result3 = mask_to_silhouette(mask, bbox)
assert result1 is not None
assert result2 is not None
assert result3 is not None
np.testing.assert_array_equal(result1, result2)
np.testing.assert_array_equal(result2, result3)
def test_hole_inside_mask_is_filled(self) -> None:
h, w = 200, 160
mask = np.zeros((h, w), dtype=np.uint8)
mask[30:170, 40:120] = 255
mask[80:120, 70:90] = 0
bbox = (40, 30, 120, 170)
result = mask_to_silhouette(mask, bbox)
assert result is not None
result_arr = cast(NDArray[np.float32], result)
hole_patch = result_arr[26:38, 18:26]
assert float(np.mean(hole_patch)) > 0.8
def test_hole_fill_works_when_mask_touches_corner(self) -> None:
h, w = 220, 180
mask = np.zeros((h, w), dtype=np.uint8)
mask[0:180, 0:130] = 255
mask[70:120, 55:95] = 0
bbox = (0, 0, 130, 180)
result = mask_to_silhouette(mask, bbox)
assert result is not None
result_arr = cast(NDArray[np.float32], result)
hole_patch = result_arr[24:40, 16:28]
assert float(np.mean(hole_patch)) > 0.75
def test_tall_narrow_mask_valid_output(self) -> None:
"""Tall narrow mask should produce valid silhouette."""
h, w = 400, 50
mask = np.zeros((h, w), dtype=np.uint8)
# Tall narrow person
mask[50:350, 10:40] = 255
bbox = (10, 50, 40, 350)
result = mask_to_silhouette(mask, bbox)
assert result is not None
result_arr = cast(NDArray[np.float32], result)
assert result_arr.shape == (64, 44)
def test_wide_short_mask_valid_output(self) -> None:
"""Wide short mask should produce valid silhouette."""
h, w = 100, 400
mask = np.zeros((h, w), dtype=np.uint8)
# Wide short person
mask[20:80, 50:350] = 255
bbox = (50, 20, 350, 80)
result = mask_to_silhouette(mask, bbox)
assert result is not None
result_arr = cast(NDArray[np.float32], result)
assert result_arr.shape == (64, 44)
def test_beartype_rejects_wrong_dtype(self) -> None:
"""Beartype should reject non-uint8 input."""
# Float array instead of uint8
mask = np.ones((100, 100), dtype=np.float32) * 255
bbox = (10, 10, 90, 90)
with pytest.raises((BeartypeCallHintParamViolation, TypeCheckError)):
_ = mask_to_silhouette(mask, bbox)
def test_beartype_rejects_wrong_ndim(self) -> None:
"""Beartype should reject non-2D array."""
# 3D array instead of 2D
mask = np.ones((100, 100, 3), dtype=np.uint8) * 255
bbox = (10, 10, 90, 90)
with pytest.raises((BeartypeCallHintParamViolation, TypeCheckError)):
_ = mask_to_silhouette(mask, bbox)
def test_beartype_rejects_wrong_bbox_type(self) -> None:
"""Beartype should reject non-tuple bbox."""
mask = np.ones((100, 100), dtype=np.uint8) * 255
# List instead of tuple
bbox = [10, 10, 90, 90]
with pytest.raises((BeartypeCallHintParamViolation, TypeCheckError)):
_ = mask_to_silhouette(mask, bbox)
+341
View File
@@ -0,0 +1,341 @@
"""Unit tests for ScoNetDemo forward pass.
Tests cover:
- Construction from config/checkpoint path handling
- Forward output shape (N, 3, 16) and dtype float
- Predict output (label_str, confidence_float) with valid label/range
- No-DDP leakage check (no torch.distributed calls in unit behavior)
"""
from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING, cast
from unittest.mock import patch
import pytest
import torch
from torch import Tensor
if TYPE_CHECKING:
from opengait_studio.sconet_demo import ScoNetDemo
# Constants for test configuration
CONFIG_PATH = Path("configs/sconet/sconet_scoliosis1k.yaml")
@pytest.fixture
def demo() -> "ScoNetDemo":
"""Create ScoNetDemo without loading checkpoint (CPU-only)."""
from opengait_studio.sconet_demo import ScoNetDemo
return ScoNetDemo(
cfg_path=str(CONFIG_PATH),
checkpoint_path=None,
device="cpu",
)
@pytest.fixture
def dummy_sils_batch() -> Tensor:
"""Return dummy silhouette tensor of shape (N, 1, S, 64, 44)."""
return torch.randn(2, 1, 30, 64, 44)
@pytest.fixture
def dummy_sils_single() -> Tensor:
"""Return dummy silhouette tensor of shape (1, 1, S, 64, 44) for predict."""
return torch.randn(1, 1, 30, 64, 44)
@pytest.fixture
def synthetic_state_dict() -> dict[str, Tensor]:
"""Return a synthetic state dict compatible with ScoNetDemo structure."""
return {
"backbone.conv1.conv.weight": torch.randn(64, 1, 3, 3),
"backbone.conv1.bn.weight": torch.ones(64),
"backbone.conv1.bn.bias": torch.zeros(64),
"backbone.conv1.bn.running_mean": torch.zeros(64),
"backbone.conv1.bn.running_var": torch.ones(64),
"fcs.fc_bin": torch.randn(16, 512, 256),
"bn_necks.fc_bin": torch.randn(16, 256, 3),
"bn_necks.bn1d.weight": torch.ones(4096),
"bn_necks.bn1d.bias": torch.zeros(4096),
"bn_necks.bn1d.running_mean": torch.zeros(4096),
"bn_necks.bn1d.running_var": torch.ones(4096),
}
class TestScoNetDemoConstruction:
"""Tests for ScoNetDemo construction and path handling."""
def test_construction_from_config_no_checkpoint(self) -> None:
"""Test construction with config only, no checkpoint."""
from opengait_studio.sconet_demo import ScoNetDemo
demo = ScoNetDemo(
cfg_path=str(CONFIG_PATH),
checkpoint_path=None,
device="cpu",
)
assert demo.cfg_path.endswith("sconet_scoliosis1k.yaml")
assert demo.device == torch.device("cpu")
assert demo.cfg is not None
assert "model_cfg" in demo.cfg
assert demo.training is False # eval mode
def test_construction_with_relative_path(self) -> None:
"""Test construction handles relative config path correctly."""
from opengait_studio.sconet_demo import ScoNetDemo
demo = ScoNetDemo(
cfg_path="configs/sconet/sconet_scoliosis1k.yaml",
checkpoint_path=None,
device="cpu",
)
assert demo.cfg is not None
assert demo.backbone is not None
def test_construction_invalid_config_raises(self) -> None:
"""Test construction raises with invalid config path."""
from opengait_studio.sconet_demo import ScoNetDemo
with pytest.raises((FileNotFoundError, TypeError)):
_ = ScoNetDemo(
cfg_path="/nonexistent/path/config.yaml",
checkpoint_path=None,
device="cpu",
)
class TestScoNetDemoForward:
"""Tests for ScoNetDemo forward pass."""
def test_forward_output_shape_and_dtype(
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
) -> None:
"""Test forward returns logits with shape (N, 3, 16) and correct dtypes."""
outputs_raw = demo.forward(dummy_sils_batch)
outputs = cast(dict[str, Tensor], outputs_raw)
assert "logits" in outputs
logits = outputs["logits"]
# Expected shape: (batch_size, num_classes, parts_num) = (N, 3, 16)
assert logits.shape == (2, 3, 16)
assert logits.dtype == torch.float32
assert outputs["label"].dtype == torch.int64
assert outputs["confidence"].dtype == torch.float32
def test_forward_returns_required_keys(
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
) -> None:
"""Test forward returns required output keys."""
outputs_raw = demo.forward(dummy_sils_batch)
outputs = cast(dict[str, Tensor], outputs_raw)
required_keys = {"logits", "label", "confidence"}
assert set(outputs.keys()) >= required_keys
def test_forward_batch_size_one(
self, demo: "ScoNetDemo", dummy_sils_single: Tensor
) -> None:
"""Test forward works with batch size 1."""
outputs_raw = demo.forward(dummy_sils_single)
outputs = cast(dict[str, Tensor], outputs_raw)
assert outputs["logits"].shape == (1, 3, 16)
assert outputs["label"].shape == (1,)
assert outputs["confidence"].shape == (1,)
def test_forward_label_range(
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
) -> None:
"""Test forward returns valid label indices (0, 1, or 2)."""
outputs_raw = demo.forward(dummy_sils_batch)
outputs = cast(dict[str, Tensor], outputs_raw)
labels = outputs["label"]
assert torch.all(labels >= 0)
assert torch.all(labels <= 2)
def test_forward_confidence_range(
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
) -> None:
"""Test forward returns confidence in [0, 1]."""
outputs_raw = demo.forward(dummy_sils_batch)
outputs = cast(dict[str, Tensor], outputs_raw)
confidence = outputs["confidence"]
assert torch.all(confidence >= 0.0)
assert torch.all(confidence <= 1.0)
class TestScoNetDemoPredict:
"""Tests for ScoNetDemo predict method."""
def test_predict_returns_tuple_with_valid_types(
self, demo: "ScoNetDemo", dummy_sils_single: Tensor
) -> None:
"""Test predict returns (str, float) tuple with valid label."""
from opengait_studio.sconet_demo import ScoNetDemo
result_raw = demo.predict(dummy_sils_single)
result = cast(tuple[str, float], result_raw)
assert isinstance(result, tuple)
assert len(result) == 2
label, confidence = result
assert isinstance(label, str)
assert isinstance(confidence, float)
valid_labels = set(ScoNetDemo.LABEL_MAP.values())
assert label in valid_labels
def test_predict_confidence_range(
self, demo: "ScoNetDemo", dummy_sils_single: Tensor
) -> None:
"""Test predict returns confidence in valid range [0, 1]."""
result_raw = demo.predict(dummy_sils_single)
result = cast(tuple[str, float], result_raw)
confidence = result[1]
assert 0.0 <= confidence <= 1.0
def test_predict_rejects_batch_size_greater_than_one(
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
) -> None:
"""Test predict raises ValueError for batch size > 1."""
with pytest.raises(ValueError, match="batch size 1"):
_ = demo.predict(dummy_sils_batch)
class TestScoNetDemoNoDDP:
"""Tests to verify no DDP leakage in unit behavior."""
def test_no_distributed_init_in_construction(self) -> None:
"""Test that construction does not call torch.distributed."""
from opengait_studio.sconet_demo import ScoNetDemo
with patch("torch.distributed.is_initialized") as mock_is_init:
with patch("torch.distributed.init_process_group") as mock_init_pg:
_ = ScoNetDemo(
cfg_path=str(CONFIG_PATH),
checkpoint_path=None,
device="cpu",
)
mock_init_pg.assert_not_called()
mock_is_init.assert_not_called()
def test_forward_no_distributed_calls(
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
) -> None:
"""Test forward pass does not call torch.distributed."""
with patch("torch.distributed.all_reduce") as mock_all_reduce:
with patch("torch.distributed.broadcast") as mock_broadcast:
_ = demo.forward(dummy_sils_batch)
mock_all_reduce.assert_not_called()
mock_broadcast.assert_not_called()
def test_predict_no_distributed_calls(
self, demo: "ScoNetDemo", dummy_sils_single: Tensor
) -> None:
"""Test predict does not call torch.distributed."""
with patch("torch.distributed.all_reduce") as mock_all_reduce:
with patch("torch.distributed.broadcast") as mock_broadcast:
_ = demo.predict(dummy_sils_single)
mock_all_reduce.assert_not_called()
mock_broadcast.assert_not_called()
def test_model_not_wrapped_in_ddp(self, demo: "ScoNetDemo") -> None:
"""Test model is not wrapped in DistributedDataParallel."""
from torch.nn.parallel import DistributedDataParallel as DDP
assert not isinstance(demo, DDP)
assert not isinstance(demo.backbone, DDP)
def test_device_is_cpu(self, demo: "ScoNetDemo") -> None:
"""Test model stays on CPU when device="cpu" specified."""
assert demo.device.type == "cpu"
for param in demo.parameters():
assert param.device.type == "cpu"
class TestScoNetDemoCheckpointLoading:
"""Tests for checkpoint loading behavior using synthetic state dict."""
def test_load_checkpoint_changes_weights(
self,
demo: "ScoNetDemo",
synthetic_state_dict: dict[str, Tensor],
) -> None:
"""Test loading checkpoint actually changes model weights."""
import tempfile
import os
# Get initial weight
initial_weight = next(iter(demo.parameters())).clone()
# Create temp checkpoint file with synthetic state dict
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f:
torch.save(synthetic_state_dict, f.name)
temp_path = f.name
try:
_ = demo.load_checkpoint(temp_path, strict=False)
new_weight = next(iter(demo.parameters()))
assert not torch.equal(initial_weight, new_weight)
finally:
os.unlink(temp_path)
def test_load_checkpoint_sets_eval_mode(
self,
demo: "ScoNetDemo",
synthetic_state_dict: dict[str, Tensor],
) -> None:
"""Test loading checkpoint sets model to eval mode."""
import tempfile
import os
_ = demo.train()
assert demo.training is True
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f:
torch.save(synthetic_state_dict, f.name)
temp_path = f.name
try:
_ = demo.load_checkpoint(temp_path, strict=False)
assert demo.training is False
finally:
os.unlink(temp_path)
def test_load_checkpoint_invalid_path_raises(self, demo: "ScoNetDemo") -> None:
"""Test loading from invalid checkpoint path raises error."""
with pytest.raises(FileNotFoundError):
_ = demo.load_checkpoint("/nonexistent/checkpoint.pt")
class TestScoNetDemoLabelMap:
"""Tests for LABEL_MAP constant."""
def test_label_map_has_three_classes(self) -> None:
"""Test LABEL_MAP has exactly 3 classes."""
from opengait_studio.sconet_demo import ScoNetDemo
assert len(ScoNetDemo.LABEL_MAP) == 3
assert set(ScoNetDemo.LABEL_MAP.keys()) == {0, 1, 2}
def test_label_map_values_are_valid_strings(self) -> None:
"""Test LABEL_MAP values are valid non-empty strings."""
from opengait_studio.sconet_demo import ScoNetDemo
for value in ScoNetDemo.LABEL_MAP.values():
assert isinstance(value, str)
assert len(value) > 0
+171
View File
@@ -0,0 +1,171 @@
from __future__ import annotations
from pathlib import Path
from typing import cast
from unittest import mock
import numpy as np
import pytest
from opengait_studio.input import create_source
from opengait_studio.visualizer import (
DISPLAY_HEIGHT,
DISPLAY_WIDTH,
ImageArray,
OpenCVVisualizer,
)
from opengait_studio.window import select_person
REPO_ROOT = Path(__file__).resolve().parents[2]
SAMPLE_VIDEO_PATH = REPO_ROOT / "assets" / "sample.mp4"
YOLO_MODEL_PATH = REPO_ROOT / "ckpt" / "yolo11n-seg.pt"
def test_prepare_raw_view_float_mask_has_visible_signal() -> None:
viz = OpenCVVisualizer()
mask_float = np.zeros((64, 64), dtype=np.float32)
mask_float[16:48, 16:48] = 1.0
rendered = viz._prepare_raw_view(cast(ImageArray, mask_float))
assert rendered.dtype == np.uint8
assert rendered.shape == (256, 176, 3)
mask_zero = np.zeros((64, 64), dtype=np.float32)
rendered_zero = viz._prepare_raw_view(cast(ImageArray, mask_zero))
roi = slice(0, DISPLAY_HEIGHT - 40)
diff = np.abs(rendered[roi].astype(np.int16) - rendered_zero[roi].astype(np.int16))
assert int(np.count_nonzero(diff)) > 0
def test_prepare_raw_view_handles_values_slightly_above_one() -> None:
viz = OpenCVVisualizer()
mask = np.zeros((64, 64), dtype=np.float32)
mask[20:40, 20:40] = 1.0001
rendered = viz._prepare_raw_view(cast(ImageArray, mask))
roi = rendered[: DISPLAY_HEIGHT - 40, :, 0]
assert int(np.count_nonzero(roi)) > 0
def test_segmentation_view_is_normalized_only_shape() -> None:
viz = OpenCVVisualizer()
mask = np.zeros((480, 640), dtype=np.uint8)
sil = np.random.rand(64, 44).astype(np.float32)
seg = viz._prepare_segmentation_view(cast(ImageArray, mask), sil, (0, 0, 100, 100))
assert seg.shape == (DISPLAY_HEIGHT, DISPLAY_WIDTH, 3)
def test_update_toggles_raw_window_with_r_key() -> None:
viz = OpenCVVisualizer()
frame = np.zeros((240, 320, 3), dtype=np.uint8)
mask = np.zeros((240, 320), dtype=np.uint8)
mask[20:100, 30:120] = 255
sil = np.random.rand(64, 44).astype(np.float32)
seg_input = np.random.rand(4, 64, 44).astype(np.float32)
with (
mock.patch("cv2.namedWindow") as named_window,
mock.patch("cv2.imshow"),
mock.patch("cv2.destroyWindow") as destroy_window,
mock.patch("cv2.waitKey", side_effect=[ord("r"), ord("r"), ord("q")]),
):
assert viz.update(
frame,
(10, 10, 120, 150),
(10, 10, 120, 150),
1,
cast(ImageArray, mask),
sil,
seg_input,
None,
None,
15.0,
)
assert viz.show_raw_window is True
assert viz._raw_window_created is True
assert viz.update(
frame,
(10, 10, 120, 150),
(10, 10, 120, 150),
1,
cast(ImageArray, mask),
sil,
seg_input,
None,
None,
15.0,
)
assert viz.show_raw_window is False
assert viz._raw_window_created is False
assert destroy_window.called
assert (
viz.update(
frame,
(10, 10, 120, 150),
(10, 10, 120, 150),
1,
cast(ImageArray, mask),
sil,
seg_input,
None,
None,
15.0,
)
is False
)
assert named_window.called
def test_sample_video_raw_mask_shape_range_and_render_signal() -> None:
if not SAMPLE_VIDEO_PATH.is_file():
pytest.skip(f"Missing sample video: {SAMPLE_VIDEO_PATH}")
if not YOLO_MODEL_PATH.is_file():
pytest.skip(f"Missing YOLO model file: {YOLO_MODEL_PATH}")
ultralytics = pytest.importorskip("ultralytics")
yolo_cls = getattr(ultralytics, "YOLO")
viz = OpenCVVisualizer()
detector = yolo_cls(str(YOLO_MODEL_PATH))
masks_seen = 0
rendered_nonzero: list[int] = []
for frame, _meta in create_source(str(SAMPLE_VIDEO_PATH), max_frames=30):
detections = detector.track(
frame,
persist=True,
verbose=False,
classes=[0],
device="cpu",
)
if not isinstance(detections, list) or not detections:
continue
selected = select_person(detections[0])
if selected is None:
continue
mask_raw, _, _, _ = selected
masks_seen += 1
arr = np.asarray(mask_raw)
assert arr.ndim == 2
assert arr.shape[0] > 0 and arr.shape[1] > 0
assert np.issubdtype(arr.dtype, np.number)
assert float(arr.min()) >= 0.0
assert float(arr.max()) <= 255.0
assert int(np.count_nonzero(arr)) > 0
rendered = viz._prepare_raw_view(arr)
roi = rendered[: DISPLAY_HEIGHT - 40, :, 0]
rendered_nonzero.append(int(np.count_nonzero(roi)))
assert masks_seen > 0
assert min(rendered_nonzero) > 0
+398
View File
@@ -0,0 +1,398 @@
"""Unit tests for SilhouetteWindow and select_person functions."""
from typing import Any
from unittest.mock import MagicMock
import numpy as np
import pytest
import torch
from numpy.typing import NDArray
from opengait_studio.window import SilhouetteWindow, select_person
class TestSilhouetteWindow:
"""Tests for SilhouetteWindow class."""
def test_window_fill_and_ready_behavior(self) -> None:
"""Window should be ready only when filled to window_size."""
window = SilhouetteWindow(window_size=5, stride=1, gap_threshold=10)
sil = np.ones((64, 44), dtype=np.float32)
# Not ready with fewer frames
for i in range(4):
window.push(sil, frame_idx=i, track_id=1)
assert not window.is_ready()
assert window.fill_level == (i + 1) / 5
# Ready at exactly window_size
window.push(sil, frame_idx=4, track_id=1)
assert window.is_ready()
assert window.fill_level == 1.0
def test_underfilled_not_ready(self) -> None:
"""Underfilled window should never report ready."""
window = SilhouetteWindow(window_size=10, stride=1, gap_threshold=5)
sil = np.ones((64, 44), dtype=np.float32)
# Push 9 frames (underfilled)
for i in range(9):
window.push(sil, frame_idx=i, track_id=1)
assert not window.is_ready()
assert window.fill_level == 0.9
# get_tensor should raise when not ready
with pytest.raises(ValueError, match="Window not ready"):
window.get_tensor()
def test_track_id_change_resets_buffer(self) -> None:
"""Changing track ID should reset the buffer."""
window = SilhouetteWindow(window_size=5, stride=1, gap_threshold=10)
sil = np.ones((64, 44), dtype=np.float32)
# Fill window with track_id=1
for i in range(5):
window.push(sil, frame_idx=i, track_id=1)
assert window.is_ready()
assert window.current_track_id == 1
assert window.fill_level == 1.0
# Push with different track_id should reset
window.push(sil, frame_idx=5, track_id=2)
assert not window.is_ready()
assert window.fill_level == 0.2
assert window.current_track_id == 2
def test_frame_gap_reset_behavior(self) -> None:
"""Frame gap exceeding threshold should reset buffer."""
window = SilhouetteWindow(window_size=5, stride=1, gap_threshold=3)
sil = np.ones((64, 44), dtype=np.float32)
# Fill window with consecutive frames
for i in range(5):
window.push(sil, frame_idx=i, track_id=1)
assert window.is_ready()
assert window.fill_level == 1.0
# Small gap (within threshold) - no reset
window.push(sil, frame_idx=6, track_id=1) # gap = 1
assert window.is_ready()
assert window.fill_level == 1.0 # deque maintains max
# Reset and fill again
window.reset()
for i in range(5):
window.push(sil, frame_idx=i, track_id=1)
# Large gap (exceeds threshold) - should reset
window.push(sil, frame_idx=10, track_id=1) # gap = 5 > 3
assert not window.is_ready()
assert window.fill_level == 0.2
def test_get_tensor_shape(self) -> None:
"""get_tensor should return tensor of shape [1, 1, window_size, 64, 44]."""
window_size = 7
window = SilhouetteWindow(window_size=window_size, stride=1, gap_threshold=10)
# Push unique frames to verify ordering
for i in range(window_size):
sil = np.full((64, 44), fill_value=float(i), dtype=np.float32)
window.push(sil, frame_idx=i, track_id=1)
assert window.is_ready()
tensor = window.get_tensor(device="cpu")
assert isinstance(tensor, torch.Tensor)
assert tensor.shape == (1, 1, window_size, 64, 44)
assert tensor.dtype == torch.float32
# Verify content ordering: first frame should have value 0.0
assert tensor[0, 0, 0, 0, 0].item() == 0.0
# Last frame should have value window_size-1
assert tensor[0, 0, window_size - 1, 0, 0].item() == window_size - 1
def test_should_classify_stride_behavior(self) -> None:
"""should_classify should respect stride setting."""
window = SilhouetteWindow(window_size=5, stride=3, gap_threshold=10)
sil = np.ones((64, 44), dtype=np.float32)
# Fill window
for i in range(5):
window.push(sil, frame_idx=i, track_id=1)
assert window.is_ready()
# First classification should always trigger
assert window.should_classify()
# Mark as classified at frame 4
window.mark_classified()
# Not ready to classify yet (stride=3, only 0 frames passed)
window.push(sil, frame_idx=5, track_id=1)
assert not window.should_classify()
# Still not ready (only 1 frame passed)
window.push(sil, frame_idx=6, track_id=1)
assert not window.should_classify()
# Now ready (3 frames passed since last classification)
window.push(sil, frame_idx=7, track_id=1)
assert window.should_classify()
def test_should_classify_not_ready(self) -> None:
"""should_classify should return False when window not ready."""
window = SilhouetteWindow(window_size=5, stride=1, gap_threshold=10)
sil = np.ones((64, 44), dtype=np.float32)
# Push only 3 frames (not ready)
for i in range(3):
window.push(sil, frame_idx=i, track_id=1)
assert not window.is_ready()
assert not window.should_classify()
def test_reset_clears_all_state(self) -> None:
"""reset should clear all internal state."""
window = SilhouetteWindow(window_size=5, stride=1, gap_threshold=10)
sil = np.ones((64, 44), dtype=np.float32)
# Fill and classify
for i in range(5):
window.push(sil, frame_idx=i, track_id=1)
window.mark_classified()
assert window.is_ready()
assert window.current_track_id == 1
assert window.frame_count == 5
# Reset
window.reset()
assert not window.is_ready()
assert window.fill_level == 0.0
assert window.current_track_id is None
assert window.frame_count == 0
assert not window.should_classify()
def test_push_invalid_shape_raises(self) -> None:
"""push should raise ValueError for invalid silhouette shape."""
window = SilhouetteWindow(window_size=5, stride=1, gap_threshold=10)
# Wrong shape
sil_wrong = np.ones((32, 32), dtype=np.float32)
with pytest.raises(ValueError, match="Expected silhouette shape"):
window.push(sil_wrong, frame_idx=0, track_id=1)
def test_push_wrong_dtype_converts(self) -> None:
"""push should convert dtype to float32."""
window = SilhouetteWindow(window_size=1, stride=1, gap_threshold=10)
# uint8 input
sil_uint8 = np.ones((64, 44), dtype=np.uint8) * 255
window.push(sil_uint8, frame_idx=0, track_id=1)
tensor = window.get_tensor()
assert tensor.dtype == torch.float32
class TestSelectPerson:
"""Tests for select_person function."""
def _create_mock_results(
self,
boxes_xyxy: NDArray[np.float32] | torch.Tensor,
masks_data: NDArray[np.float32] | torch.Tensor,
track_ids: NDArray[np.int64] | torch.Tensor | None,
) -> Any:
"""Create a mock detection results object."""
mock_boxes = MagicMock()
mock_boxes.xyxy = boxes_xyxy
mock_boxes.id = track_ids
mock_masks = MagicMock()
mock_masks.data = masks_data
mock_results = MagicMock()
mock_results.boxes = mock_boxes
mock_results.masks = mock_masks
return mock_results
def test_select_person_single_detection(self) -> None:
"""Single detection should return that person's data."""
boxes = np.array([[10.0, 10.0, 50.0, 90.0]], dtype=np.float32) # area = 3200
masks = np.random.rand(1, 100, 100).astype(np.float32)
track_ids = np.array([42], dtype=np.int64)
results = self._create_mock_results(boxes, masks, track_ids)
result = select_person(results)
assert result is not None
mask, bbox, _, tid = result
assert mask.shape == (100, 100)
assert bbox == (10, 10, 50, 90)
assert tid == 42
def test_select_person_multi_detection_selects_largest(self) -> None:
"""Multiple detections should select the one with largest bbox area."""
# Two boxes: second one is larger
boxes = np.array(
[
[0.0, 0.0, 10.0, 10.0], # area = 100
[0.0, 0.0, 30.0, 30.0], # area = 900 (largest)
[0.0, 0.0, 20.0, 20.0], # area = 400
],
dtype=np.float32,
)
masks = np.random.rand(3, 100, 100).astype(np.float32)
track_ids = np.array([1, 2, 3], dtype=np.int64)
results = self._create_mock_results(boxes, masks, track_ids)
result = select_person(results)
assert result is not None
mask, bbox, _, tid = result
assert bbox == (0, 0, 30, 30) # Largest box
assert tid == 2 # Corresponding track ID
def test_select_person_no_detections_returns_none(self) -> None:
"""No detections should return None."""
boxes = np.array([], dtype=np.float32).reshape(0, 4)
masks = np.array([], dtype=np.float32).reshape(0, 100, 100)
track_ids = np.array([], dtype=np.int64)
results = self._create_mock_results(boxes, masks, track_ids)
result = select_person(results)
assert result is None
def test_select_person_no_track_ids_returns_none(self) -> None:
"""Detections without track IDs should return None."""
boxes = np.array([[10.0, 10.0, 50.0, 90.0]], dtype=np.float32)
masks = np.random.rand(1, 100, 100).astype(np.float32)
results = self._create_mock_results(boxes, masks, track_ids=None)
result = select_person(results)
assert result is None
def test_select_person_empty_track_ids_returns_none(self) -> None:
"""Empty track IDs array should return None."""
boxes = np.array([[10.0, 10.0, 50.0, 90.0]], dtype=np.float32)
masks = np.random.rand(1, 100, 100).astype(np.float32)
track_ids = np.array([], dtype=np.int64)
results = self._create_mock_results(boxes, masks, track_ids)
result = select_person(results)
assert result is None
def test_select_person_missing_boxes_returns_none(self) -> None:
"""Missing boxes attribute should return None."""
mock_results = MagicMock()
mock_results.boxes = None
result = select_person(mock_results)
assert result is None
def test_select_person_missing_masks_returns_none(self) -> None:
"""Missing masks attribute should return None."""
boxes = np.array([[10.0, 10.0, 50.0, 90.0]], dtype=np.float32)
track_ids = np.array([1], dtype=np.int64)
mock_boxes = MagicMock()
mock_boxes.xyxy = boxes
mock_boxes.id = track_ids
mock_results = MagicMock()
mock_results.boxes = mock_boxes
mock_results.masks = None
result = select_person(mock_results)
assert result is None
def test_select_person_1d_bbox_handling(self) -> None:
"""1D bbox array should be reshaped to 2D."""
boxes = np.array([10.0, 10.0, 50.0, 90.0], dtype=np.float32) # 1D
masks = np.random.rand(1, 100, 100).astype(np.float32)
track_ids = np.array([1], dtype=np.int64)
results = self._create_mock_results(boxes, masks, track_ids)
result = select_person(results)
assert result is not None
_, bbox, _, tid = result
assert bbox == (10, 10, 50, 90)
assert tid == 1
def test_select_person_2d_mask_handling(self) -> None:
"""2D mask should be expanded to 3D."""
boxes = np.array([[10.0, 10.0, 50.0, 90.0]], dtype=np.float32)
masks = np.random.rand(100, 100).astype(np.float32) # 2D
track_ids = np.array([1], dtype=np.int64)
results = self._create_mock_results(boxes, masks, track_ids)
result = select_person(results)
assert result is not None
mask, _, _, _ = result
# Should be 2D (extracted from expanded 3D)
assert mask.shape == (100, 100)
def test_select_person_tensor_cpu_inputs(self) -> None:
"""Tensor-backed inputs (CPU) should work correctly."""
boxes = torch.tensor([[10.0, 10.0, 50.0, 90.0]], dtype=torch.float32)
masks = torch.rand(1, 100, 100, dtype=torch.float32)
track_ids = torch.tensor([42], dtype=torch.int64)
results = self._create_mock_results(boxes, masks, track_ids)
result = select_person(results)
assert result is not None
mask, bbox, _, tid = result
assert mask.shape == (100, 100)
assert bbox == (10, 10, 50, 90)
assert tid == 42
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_select_person_tensor_cuda_inputs(self) -> None:
"""Tensor-backed inputs (CUDA) should work correctly."""
boxes = torch.tensor([[10.0, 10.0, 50.0, 90.0]], dtype=torch.float32).cuda()
masks = torch.rand(1, 100, 100, dtype=torch.float32).cuda()
track_ids = torch.tensor([42], dtype=torch.int64).cuda()
results = self._create_mock_results(boxes, masks, track_ids)
result = select_person(results)
assert result is not None
mask, bbox, _, tid = result
assert mask.shape == (100, 100)
assert bbox == (10, 10, 50, 90)
assert tid == 42
def test_select_person_tensor_multi_detection(self) -> None:
"""Multiple tensor detections should select largest bbox."""
boxes = torch.tensor(
[
[0.0, 0.0, 10.0, 10.0], # area = 100
[0.0, 0.0, 30.0, 30.0], # area = 900 (largest)
[0.0, 0.0, 20.0, 20.0], # area = 400
],
dtype=torch.float32,
)
masks = torch.rand(3, 100, 100, dtype=torch.float32)
track_ids = torch.tensor([1, 2, 3], dtype=torch.int64)
results = self._create_mock_results(boxes, masks, track_ids)
result = select_person(results)
assert result is not None
_, bbox, _, tid = result
assert bbox == (0, 0, 30, 30) # Largest box
assert tid == 2 # Corresponding track ID