test(demo): add unit and integration coverage for pipeline

Introduce focused unit, integration, and NATS-path tests for demo modules, and align assertions with final schema and temporal contracts (window int, seq=30, fill-level ratio). This commit isolates validation logic from runtime changes and provides reproducible QA for pipeline behavior and failure modes.
This commit is contained in:
2026-02-27 09:59:14 +08:00
parent b24644f16e
commit d6fd6c03e6
8 changed files with 1669 additions and 0 deletions
+1
View File
@@ -0,0 +1 @@
"""OpenGait tests package."""
+1
View File
@@ -0,0 +1 @@
"""Tests for demo package."""
+23
View File
@@ -0,0 +1,23 @@
"""Test fixtures for demo package."""
import numpy as np
import pytest
import torch
@pytest.fixture
def mock_frame_tensor():
"""Return a mock video frame tensor (C, H, W)."""
return torch.randn(3, 224, 224)
@pytest.fixture
def mock_frame_array():
"""Return a mock video frame as numpy array (H, W, C)."""
return np.random.randn(224, 224, 3).astype(np.float32)
@pytest.fixture
def mock_video_sequence():
"""Return a mock video sequence tensor (T, C, H, W)."""
return torch.randn(16, 3, 224, 224)
+525
View File
@@ -0,0 +1,525 @@
"""Integration tests for NATS publisher functionality.
Tests cover:
- NATS message receipt with JSON schema validation
- Skip behavior when Docker/NATS unavailable
- Container lifecycle management (cleanup)
- JSON schema field/type validation
"""
from __future__ import annotations
import json
import socket
import subprocess
import time
from typing import TYPE_CHECKING, cast
import pytest
if TYPE_CHECKING:
from collections.abc import Generator
def _find_open_port() -> int:
"""Find an available TCP port on localhost."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("127.0.0.1", 0))
sock.listen(1)
addr = cast(tuple[str, int], sock.getsockname())
port: int = addr[1]
return port
# Constants for test configuration
NATS_SUBJECT = "scoliosis.result"
CONTAINER_NAME = "opengait-nats-test"
def _docker_available() -> bool:
"""Check if Docker is available and running."""
try:
result = subprocess.run(
["docker", "info"],
capture_output=True,
timeout=5,
check=False,
)
return result.returncode == 0
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
return False
def _nats_container_running() -> bool:
"""Check if NATS container is already running."""
try:
result = subprocess.run(
[
"docker",
"ps",
"--filter",
f"name={CONTAINER_NAME}",
"--format",
"{{.Names}}",
],
capture_output=True,
timeout=5,
check=False,
)
return CONTAINER_NAME in result.stdout.decode()
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
return False
def _start_nats_container(port: int) -> bool:
"""Start NATS container for testing."""
try:
# Remove existing container if present
_ = subprocess.run(
["docker", "rm", "-f", CONTAINER_NAME],
capture_output=True,
timeout=10,
check=False,
)
# Start new NATS container
result = subprocess.run(
[
"docker",
"run",
"-d",
"--name",
CONTAINER_NAME,
"-p",
f"{port}:{port}",
"nats:latest",
],
capture_output=True,
timeout=30,
check=False,
)
if result.returncode != 0:
return False
# Wait for NATS to be ready (max 10 seconds)
for _ in range(20):
time.sleep(0.5)
try:
check_result = subprocess.run(
[
"docker",
"exec",
CONTAINER_NAME,
"nats",
"server",
"check",
"server",
],
capture_output=True,
timeout=5,
check=False,
)
if check_result.returncode == 0:
return True
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
continue
return False
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
return False
def _stop_nats_container() -> None:
"""Stop and remove NATS container."""
try:
_ = subprocess.run(
["docker", "rm", "-f", CONTAINER_NAME],
capture_output=True,
timeout=10,
check=False,
)
except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
pass
def _validate_result_schema(data: dict[str, object]) -> tuple[bool, str]:
"""Validate JSON result schema.
Expected schema:
{
"frame": int,
"track_id": int,
"label": str (one of: "negative", "neutral", "positive"),
"confidence": float in [0, 1],
"window": list[int] (start, end),
"timestamp_ns": int
}
"""
required_fields: set[str] = {
"frame",
"track_id",
"label",
"confidence",
"window",
"timestamp_ns",
}
missing = required_fields - set(data.keys())
if missing:
return False, f"Missing required fields: {missing}"
# Validate frame (int)
frame = data["frame"]
if not isinstance(frame, int):
return False, f"frame must be int, got {type(frame)}"
# Validate track_id (int)
track_id = data["track_id"]
if not isinstance(track_id, int):
return False, f"track_id must be int, got {type(track_id)}"
# Validate label (str in specific set)
label = data["label"]
valid_labels = {"negative", "neutral", "positive"}
if not isinstance(label, str) or label not in valid_labels:
return False, f"label must be one of {valid_labels}, got {label}"
# Validate confidence (float in [0, 1])
confidence = data["confidence"]
if not isinstance(confidence, (int, float)):
return False, f"confidence must be numeric, got {type(confidence)}"
if not 0.0 <= float(confidence) <= 1.0:
return False, f"confidence must be in [0, 1], got {confidence}"
# Validate window (list of 2 ints)
window = data["window"]
if not isinstance(window, list) or len(cast(list[object], window)) != 2:
return False, f"window must be list of 2 ints, got {window}"
window_list = cast(list[object], window)
if not all(isinstance(x, int) for x in window_list):
return False, f"window elements must be ints, got {window}"
# Validate timestamp_ns (int)
timestamp_ns = data["timestamp_ns"]
if not isinstance(timestamp_ns, int):
return False, f"timestamp_ns must be int, got {type(timestamp_ns)}"
return True, ""
@pytest.fixture(scope="module")
def nats_server() -> "Generator[tuple[bool, int], None, None]":
"""Fixture to manage NATS container lifecycle.
Yields (True, port) if NATS is available, (False, 0) otherwise.
Cleans up container after tests.
"""
if not _docker_available():
yield False, 0
return
port = _find_open_port()
if not _start_nats_container(port):
yield False, 0
return
try:
yield True, port
finally:
_stop_nats_container()
class TestNatsPublisherIntegration:
"""Integration tests for NATS publisher with live server."""
@pytest.mark.skipif(not _docker_available(), reason="Docker not available")
def test_nats_message_receipt_and_schema_validation(
self, nats_server: tuple[bool, int]
) -> None:
"""Test that messages are received and match expected schema."""
server_available, port = nats_server
if not server_available:
pytest.skip("NATS server not available")
nats_url = f"nats://127.0.0.1:{port}"
# Import here to avoid issues if nats-py not installed
try:
import asyncio
import nats # type: ignore[import-untyped]
except ImportError:
pytest.skip("nats-py not installed")
from opengait.demo.output import NatsPublisher, create_result
# Create publisher
publisher = NatsPublisher(nats_url, subject=NATS_SUBJECT)
# Create test results
test_results = [
create_result(
frame=100,
track_id=1,
label="positive",
confidence=0.85,
window=(70, 100),
timestamp_ns=1234567890000,
),
create_result(
frame=130,
track_id=1,
label="negative",
confidence=0.92,
window=(100, 130),
timestamp_ns=1234567890030,
),
create_result(
frame=160,
track_id=2,
label="neutral",
confidence=0.78,
window=(130, 160),
timestamp_ns=1234567890060,
),
]
# Collect received messages
received_messages: list[dict[str, object]] = []
async def subscribe_and_publish():
"""Subscribe to subject, publish messages, collect results."""
nc = await nats.connect(nats_url) # pyright: ignore[reportUnknownMemberType]
# Subscribe
sub = await nc.subscribe(NATS_SUBJECT) # pyright: ignore[reportUnknownMemberType]
# Publish all test results
for result in test_results:
publisher.publish(result)
# Wait for messages with timeout
for _ in range(len(test_results)):
try:
msg = await asyncio.wait_for(sub.next_msg(), timeout=5.0)
data = json.loads(msg.data.decode("utf-8")) # pyright: ignore[reportAny]
received_messages.append(cast(dict[str, object], data))
except asyncio.TimeoutError:
break
await sub.unsubscribe()
await nc.close()
# Run async subscriber
asyncio.run(subscribe_and_publish())
# Cleanup publisher
publisher.close()
# Verify all messages received
assert len(received_messages) == len(test_results), (
f"Expected {len(test_results)} messages, got {len(received_messages)}"
)
# Validate schema for each message
for i, msg in enumerate(received_messages):
is_valid, error = _validate_result_schema(msg)
assert is_valid, f"Message {i} schema validation failed: {error}"
# Verify specific values
assert received_messages[0]["frame"] == 100
assert received_messages[0]["label"] == "positive"
assert received_messages[0]["track_id"] == 1
_conf = received_messages[0]["confidence"]
assert isinstance(_conf, (int, float))
assert 0.0 <= float(_conf) <= 1.0
assert received_messages[1]["label"] == "negative"
assert received_messages[2]["track_id"] == 2
@pytest.mark.skipif(not _docker_available(), reason="Docker not available")
def test_nats_publisher_graceful_when_server_unavailable(self) -> None:
"""Test that publisher handles missing server gracefully."""
try:
from opengait.demo.output import NatsPublisher
except ImportError:
pytest.skip("output module not available")
# Use wrong port where no server is running
bad_url = "nats://127.0.0.1:14222"
publisher = NatsPublisher(bad_url, subject=NATS_SUBJECT)
# Should not raise when publishing without server
test_result: dict[str, object] = {
"frame": 1,
"track_id": 1,
"label": "positive",
"confidence": 0.85,
"window": [0, 30],
"timestamp_ns": 1234567890,
}
# Should not raise
publisher.publish(test_result)
# Cleanup should also not raise
publisher.close()
@pytest.mark.skipif(not _docker_available(), reason="Docker not available")
def test_nats_publisher_context_manager(
self, nats_server: tuple[bool, int]
) -> None:
"""Test that publisher works as context manager."""
server_available, port = nats_server
if not server_available:
pytest.skip("NATS server not available")
nats_url = f"nats://127.0.0.1:{port}"
try:
import asyncio
import nats # type: ignore[import-untyped]
from opengait.demo.output import NatsPublisher, create_result
except ImportError as e:
pytest.skip(f"Required module not available: {e}")
received_messages: list[dict[str, object]] = []
async def subscribe_and_test():
nc = await nats.connect(nats_url) # pyright: ignore[reportUnknownMemberType]
sub = await nc.subscribe(NATS_SUBJECT) # pyright: ignore[reportUnknownMemberType]
# Use context manager
with NatsPublisher(nats_url, subject=NATS_SUBJECT) as publisher:
result = create_result(
frame=200,
track_id=5,
label="neutral",
confidence=0.65,
window=(170, 200),
timestamp_ns=9999999999,
)
publisher.publish(result)
# Wait for message
try:
msg = await asyncio.wait_for(sub.next_msg(), timeout=5.0)
data = json.loads(msg.data.decode("utf-8")) # pyright: ignore[reportAny]
received_messages.append(cast(dict[str, object], data))
except asyncio.TimeoutError:
pass
await sub.unsubscribe()
await nc.close()
asyncio.run(subscribe_and_test())
assert len(received_messages) == 1
assert received_messages[0]["frame"] == 200
assert received_messages[0]["track_id"] == 5
class TestNatsSchemaValidation:
"""Tests for JSON schema validation without requiring NATS server."""
def test_validate_result_schema_valid(self) -> None:
"""Test schema validation with valid data."""
valid_data: dict[str, object] = {
"frame": 1234,
"track_id": 42,
"label": "positive",
"confidence": 0.85,
"window": [1200, 1230],
"timestamp_ns": 1234567890000,
}
is_valid, error = _validate_result_schema(valid_data)
assert is_valid, f"Valid data rejected: {error}"
def test_validate_result_schema_invalid_label(self) -> None:
"""Test schema validation rejects invalid label."""
invalid_data: dict[str, object] = {
"frame": 1234,
"track_id": 42,
"label": "invalid_label",
"confidence": 0.85,
"window": [1200, 1230],
"timestamp_ns": 1234567890000,
}
is_valid, error = _validate_result_schema(invalid_data)
assert not is_valid
assert "label" in error.lower()
def test_validate_result_schema_confidence_out_of_range(self) -> None:
"""Test schema validation rejects confidence outside [0, 1]."""
invalid_data: dict[str, object] = {
"frame": 1234,
"track_id": 42,
"label": "positive",
"confidence": 1.5,
"window": [1200, 1230],
"timestamp_ns": 1234567890000,
}
is_valid, error = _validate_result_schema(invalid_data)
assert not is_valid
assert "confidence" in error.lower()
def test_validate_result_schema_missing_fields(self) -> None:
"""Test schema validation detects missing fields."""
incomplete_data: dict[str, object] = {
"frame": 1234,
"label": "positive",
}
is_valid, error = _validate_result_schema(incomplete_data)
assert not is_valid
assert "missing" in error.lower()
def test_validate_result_schema_wrong_types(self) -> None:
"""Test schema validation rejects wrong types."""
wrong_types: dict[str, object] = {
"frame": "not_an_int",
"track_id": 42,
"label": "positive",
"confidence": 0.85,
"window": [1200, 1230],
"timestamp_ns": 1234567890000,
}
is_valid, error = _validate_result_schema(wrong_types)
assert not is_valid
assert "frame" in error.lower()
def test_all_valid_labels_accepted(self) -> None:
"""Test that all valid labels are accepted."""
for label_str in ["negative", "neutral", "positive"]:
data: dict[str, object] = {
"frame": 100,
"track_id": 1,
"label": label_str,
"confidence": 0.5,
"window": [70, 100],
"timestamp_ns": 1234567890,
}
is_valid, error = _validate_result_schema(data)
assert is_valid, f"Valid label '{label_str}' rejected: {error}"
class TestDockerAvailability:
"""Tests for Docker availability detection."""
def test_docker_available_check(self) -> None:
"""Test Docker availability check doesn't crash."""
# This should not raise
result = _docker_available()
assert isinstance(result, bool)
def test_nats_container_running_check(self) -> None:
"""Test container running check doesn't crash."""
# This should not raise even if Docker not available
result = _nats_container_running()
assert isinstance(result, bool)
+279
View File
@@ -0,0 +1,279 @@
from __future__ import annotations
import json
from pathlib import Path
import subprocess
import sys
import time
from typing import Final, cast
import pytest
import torch
from opengait.demo.sconet_demo import ScoNetDemo
REPO_ROOT: Final[Path] = Path(__file__).resolve().parents[2]
SAMPLE_VIDEO_PATH: Final[Path] = REPO_ROOT / "assets" / "sample.mp4"
CHECKPOINT_PATH: Final[Path] = REPO_ROOT / "ckpt" / "ScoNet-20000.pt"
CONFIG_PATH: Final[Path] = REPO_ROOT / "configs" / "sconet" / "sconet_scoliosis1k.yaml"
YOLO_MODEL_PATH: Final[Path] = REPO_ROOT / "yolo11n-seg.pt"
def _device_for_runtime() -> str:
return "cuda:0" if torch.cuda.is_available() else "cpu"
def _run_pipeline_cli(
*args: str, timeout_seconds: int = 120
) -> subprocess.CompletedProcess[str]:
command = [sys.executable, "-m", "opengait.demo", *args]
return subprocess.run(
command,
cwd=REPO_ROOT,
capture_output=True,
text=True,
check=False,
timeout=timeout_seconds,
)
def _require_integration_assets() -> None:
if not SAMPLE_VIDEO_PATH.is_file():
pytest.skip(f"Missing sample video: {SAMPLE_VIDEO_PATH}")
if not CONFIG_PATH.is_file():
pytest.skip(f"Missing config: {CONFIG_PATH}")
if not YOLO_MODEL_PATH.is_file():
pytest.skip(f"Missing YOLO model file: {YOLO_MODEL_PATH}")
@pytest.fixture
def compatible_checkpoint_path(tmp_path: Path) -> Path:
if not CONFIG_PATH.is_file():
pytest.skip(f"Missing config: {CONFIG_PATH}")
checkpoint_file = tmp_path / "sconet-compatible.pt"
model = ScoNetDemo(cfg_path=str(CONFIG_PATH), checkpoint_path=None, device="cpu")
torch.save(model.state_dict(), checkpoint_file)
return checkpoint_file
def _extract_prediction_json_lines(stdout: str) -> list[dict[str, object]]:
required_keys = {
"frame",
"track_id",
"label",
"confidence",
"window",
"timestamp_ns",
}
predictions: list[dict[str, object]] = []
for line in stdout.splitlines():
stripped = line.strip()
if not stripped:
continue
try:
payload_obj = cast(object, json.loads(stripped))
except json.JSONDecodeError:
continue
if not isinstance(payload_obj, dict):
continue
payload = cast(dict[str, object], payload_obj)
if required_keys.issubset(payload.keys()):
predictions.append(payload)
return predictions
def _assert_prediction_schema(prediction: dict[str, object]) -> None:
assert isinstance(prediction["frame"], int)
assert isinstance(prediction["track_id"], int)
label = prediction["label"]
assert isinstance(label, str)
assert label in {"negative", "neutral", "positive"}
confidence = prediction["confidence"]
assert isinstance(confidence, (int, float))
confidence_value = float(confidence)
assert 0.0 <= confidence_value <= 1.0
window_obj = prediction["window"]
assert isinstance(window_obj, int)
assert window_obj >= 0
assert isinstance(prediction["timestamp_ns"], int)
def test_pipeline_cli_fps_benchmark_smoke(
compatible_checkpoint_path: Path,
) -> None:
_require_integration_assets()
max_frames = 90
started_at = time.perf_counter()
result = _run_pipeline_cli(
"--source",
str(SAMPLE_VIDEO_PATH),
"--checkpoint",
str(compatible_checkpoint_path),
"--config",
str(CONFIG_PATH),
"--device",
_device_for_runtime(),
"--yolo-model",
str(YOLO_MODEL_PATH),
"--window",
"5",
"--stride",
"1",
"--max-frames",
str(max_frames),
timeout_seconds=180,
)
elapsed_seconds = time.perf_counter() - started_at
assert result.returncode == 0, (
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
)
predictions = _extract_prediction_json_lines(result.stdout)
assert predictions, "Expected prediction output for FPS benchmark run"
for prediction in predictions:
_assert_prediction_schema(prediction)
observed_frames = {
frame_obj
for prediction in predictions
for frame_obj in [prediction["frame"]]
if isinstance(frame_obj, int)
}
observed_units = len(observed_frames)
if observed_units < 5:
pytest.skip(
"Insufficient observed frame samples for stable FPS benchmark in this environment"
)
if elapsed_seconds <= 0:
pytest.skip("Non-positive elapsed time; cannot compute FPS benchmark")
fps = observed_units / elapsed_seconds
min_expected_fps = 0.2
assert fps >= min_expected_fps, (
"Observed FPS below conservative CI threshold: "
f"{fps:.3f} < {min_expected_fps:.3f} "
f"(observed_units={observed_units}, elapsed_seconds={elapsed_seconds:.3f})"
)
def test_pipeline_cli_happy_path_outputs_json_predictions(
compatible_checkpoint_path: Path,
) -> None:
_require_integration_assets()
result = _run_pipeline_cli(
"--source",
str(SAMPLE_VIDEO_PATH),
"--checkpoint",
str(compatible_checkpoint_path),
"--config",
str(CONFIG_PATH),
"--device",
_device_for_runtime(),
"--yolo-model",
str(YOLO_MODEL_PATH),
"--window",
"10",
"--stride",
"10",
"--max-frames",
"120",
timeout_seconds=180,
)
assert result.returncode == 0, (
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
)
predictions = _extract_prediction_json_lines(result.stdout)
assert predictions, (
"Expected at least one prediction JSON line in stdout. "
f"stdout:\n{result.stdout}\nstderr:\n{result.stderr}"
)
for prediction in predictions:
_assert_prediction_schema(prediction)
assert "Connected to NATS" not in result.stderr
def test_pipeline_cli_max_frames_caps_output_frames(
compatible_checkpoint_path: Path,
) -> None:
_require_integration_assets()
max_frames = 20
result = _run_pipeline_cli(
"--source",
str(SAMPLE_VIDEO_PATH),
"--checkpoint",
str(compatible_checkpoint_path),
"--config",
str(CONFIG_PATH),
"--device",
_device_for_runtime(),
"--yolo-model",
str(YOLO_MODEL_PATH),
"--window",
"5",
"--stride",
"1",
"--max-frames",
str(max_frames),
timeout_seconds=180,
)
assert result.returncode == 0, (
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
)
predictions = _extract_prediction_json_lines(result.stdout)
assert predictions, "Expected prediction output with --max-frames run"
for prediction in predictions:
_assert_prediction_schema(prediction)
frame_idx_obj = prediction["frame"]
assert isinstance(frame_idx_obj, int)
assert frame_idx_obj < max_frames
def test_pipeline_cli_invalid_source_path_returns_user_error() -> None:
result = _run_pipeline_cli(
"--source",
"/definitely/not/a/real/video.mp4",
"--checkpoint",
"/tmp/unused-checkpoint.pt",
"--config",
str(CONFIG_PATH),
timeout_seconds=30,
)
assert result.returncode == 2
assert "Error: Video source not found" in result.stderr
def test_pipeline_cli_invalid_checkpoint_path_returns_user_error() -> None:
if not SAMPLE_VIDEO_PATH.is_file():
pytest.skip(f"Missing sample video: {SAMPLE_VIDEO_PATH}")
if not CONFIG_PATH.is_file():
pytest.skip(f"Missing config: {CONFIG_PATH}")
result = _run_pipeline_cli(
"--source",
str(SAMPLE_VIDEO_PATH),
"--checkpoint",
str(REPO_ROOT / "ckpt" / "missing-checkpoint.pt"),
"--config",
str(CONFIG_PATH),
timeout_seconds=30,
)
assert result.returncode == 2
assert "Error: Checkpoint not found" in result.stderr
+153
View File
@@ -0,0 +1,153 @@
"""Unit tests for silhouette preprocessing functions."""
from typing import cast
import numpy as np
from numpy.typing import NDArray
import pytest
from beartype.roar import BeartypeCallHintParamViolation
from jaxtyping import TypeCheckError
from opengait.demo.preprocess import mask_to_silhouette
class TestMaskToSilhouette:
"""Tests for mask_to_silhouette() function."""
def test_valid_mask_returns_correct_shape_dtype_and_range(self) -> None:
"""Valid mask should return (64, 44) float32 array in [0, 1] range."""
# Create a synthetic mask with sufficient area (person-shaped blob)
h, w = 200, 150
mask = np.zeros((h, w), dtype=np.uint8)
# Draw a filled ellipse to simulate a person
center_y, center_x = h // 2, w // 2
axes_y, axes_x = h // 3, w // 4
y, x = np.ogrid[:h, :w]
ellipse_mask = ((x - center_x) / axes_x) ** 2 + (
(y - center_y) / axes_y
) ** 2 <= 1
mask[ellipse_mask] = 255
bbox = (w // 4, h // 6, 3 * w // 4, 5 * h // 6)
result = mask_to_silhouette(mask, bbox)
assert result is not None
result_arr = cast(NDArray[np.float32], result)
assert result_arr.shape == (64, 44)
assert result_arr.dtype == np.float32
assert np.all(result_arr >= 0.0) and np.all(result_arr <= 1.0)
def test_tiny_mask_returns_none(self) -> None:
"""Mask with area below MIN_MASK_AREA should return None."""
# Create a tiny mask
mask = np.zeros((50, 50), dtype=np.uint8)
mask[20:22, 20:22] = 255 # Only 4 pixels, well below MIN_MASK_AREA (500)
bbox = (20, 20, 22, 22)
result = mask_to_silhouette(mask, bbox)
assert result is None
def test_empty_mask_returns_none(self) -> None:
"""Completely empty mask should return None."""
mask = np.zeros((100, 100), dtype=np.uint8)
bbox = (10, 10, 90, 90)
result = mask_to_silhouette(mask, bbox)
assert result is None
def test_full_frame_mask_returns_valid_output(self) -> None:
"""Full-frame mask (large bbox covering entire image) should work."""
h, w = 300, 200
mask = np.zeros((h, w), dtype=np.uint8)
# Create a large filled region
mask[50:250, 30:170] = 255
# Full frame bbox
bbox = (0, 0, w, h)
result = mask_to_silhouette(mask, bbox)
assert result is not None
result_arr = cast(NDArray[np.float32], result)
assert result_arr.shape == (64, 44)
assert result_arr.dtype == np.float32
def test_determinism_same_input_same_output(self) -> None:
"""Same input should always produce same output."""
h, w = 200, 150
mask = np.zeros((h, w), dtype=np.uint8)
# Create a person-shaped region
mask[50:150, 40:110] = 255
bbox = (40, 50, 110, 150)
result1 = mask_to_silhouette(mask, bbox)
result2 = mask_to_silhouette(mask, bbox)
result3 = mask_to_silhouette(mask, bbox)
assert result1 is not None
assert result2 is not None
assert result3 is not None
np.testing.assert_array_equal(result1, result2)
np.testing.assert_array_equal(result2, result3)
def test_tall_narrow_mask_valid_output(self) -> None:
"""Tall narrow mask should produce valid silhouette."""
h, w = 400, 50
mask = np.zeros((h, w), dtype=np.uint8)
# Tall narrow person
mask[50:350, 10:40] = 255
bbox = (10, 50, 40, 350)
result = mask_to_silhouette(mask, bbox)
assert result is not None
result_arr = cast(NDArray[np.float32], result)
assert result_arr.shape == (64, 44)
def test_wide_short_mask_valid_output(self) -> None:
"""Wide short mask should produce valid silhouette."""
h, w = 100, 400
mask = np.zeros((h, w), dtype=np.uint8)
# Wide short person
mask[20:80, 50:350] = 255
bbox = (50, 20, 350, 80)
result = mask_to_silhouette(mask, bbox)
assert result is not None
result_arr = cast(NDArray[np.float32], result)
assert result_arr.shape == (64, 44)
def test_beartype_rejects_wrong_dtype(self) -> None:
"""Beartype should reject non-uint8 input."""
# Float array instead of uint8
mask = np.ones((100, 100), dtype=np.float32) * 255
bbox = (10, 10, 90, 90)
with pytest.raises((BeartypeCallHintParamViolation, TypeCheckError)):
_ = mask_to_silhouette(mask, bbox)
def test_beartype_rejects_wrong_ndim(self) -> None:
"""Beartype should reject non-2D array."""
# 3D array instead of 2D
mask = np.ones((100, 100, 3), dtype=np.uint8) * 255
bbox = (10, 10, 90, 90)
with pytest.raises((BeartypeCallHintParamViolation, TypeCheckError)):
_ = mask_to_silhouette(mask, bbox)
def test_beartype_rejects_wrong_bbox_type(self) -> None:
"""Beartype should reject non-tuple bbox."""
mask = np.ones((100, 100), dtype=np.uint8) * 255
# List instead of tuple
bbox = [10, 10, 90, 90]
with pytest.raises((BeartypeCallHintParamViolation, TypeCheckError)):
_ = mask_to_silhouette(mask, bbox)
+341
View File
@@ -0,0 +1,341 @@
"""Unit tests for ScoNetDemo forward pass.
Tests cover:
- Construction from config/checkpoint path handling
- Forward output shape (N, 3, 16) and dtype float
- Predict output (label_str, confidence_float) with valid label/range
- No-DDP leakage check (no torch.distributed calls in unit behavior)
"""
from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING, cast
from unittest.mock import patch
import pytest
import torch
from torch import Tensor
if TYPE_CHECKING:
from opengait.demo.sconet_demo import ScoNetDemo
# Constants for test configuration
CONFIG_PATH = Path("configs/sconet/sconet_scoliosis1k.yaml")
@pytest.fixture
def demo() -> "ScoNetDemo":
"""Create ScoNetDemo without loading checkpoint (CPU-only)."""
from opengait.demo.sconet_demo import ScoNetDemo
return ScoNetDemo(
cfg_path=str(CONFIG_PATH),
checkpoint_path=None,
device="cpu",
)
@pytest.fixture
def dummy_sils_batch() -> Tensor:
"""Return dummy silhouette tensor of shape (N, 1, S, 64, 44)."""
return torch.randn(2, 1, 30, 64, 44)
@pytest.fixture
def dummy_sils_single() -> Tensor:
"""Return dummy silhouette tensor of shape (1, 1, S, 64, 44) for predict."""
return torch.randn(1, 1, 30, 64, 44)
@pytest.fixture
def synthetic_state_dict() -> dict[str, Tensor]:
"""Return a synthetic state dict compatible with ScoNetDemo structure."""
return {
"backbone.conv1.conv.weight": torch.randn(64, 1, 3, 3),
"backbone.conv1.bn.weight": torch.ones(64),
"backbone.conv1.bn.bias": torch.zeros(64),
"backbone.conv1.bn.running_mean": torch.zeros(64),
"backbone.conv1.bn.running_var": torch.ones(64),
"fcs.fc_bin": torch.randn(16, 512, 256),
"bn_necks.fc_bin": torch.randn(16, 256, 3),
"bn_necks.bn1d.weight": torch.ones(4096),
"bn_necks.bn1d.bias": torch.zeros(4096),
"bn_necks.bn1d.running_mean": torch.zeros(4096),
"bn_necks.bn1d.running_var": torch.ones(4096),
}
class TestScoNetDemoConstruction:
"""Tests for ScoNetDemo construction and path handling."""
def test_construction_from_config_no_checkpoint(self) -> None:
"""Test construction with config only, no checkpoint."""
from opengait.demo.sconet_demo import ScoNetDemo
demo = ScoNetDemo(
cfg_path=str(CONFIG_PATH),
checkpoint_path=None,
device="cpu",
)
assert demo.cfg_path.endswith("sconet_scoliosis1k.yaml")
assert demo.device == torch.device("cpu")
assert demo.cfg is not None
assert "model_cfg" in demo.cfg
assert demo.training is False # eval mode
def test_construction_with_relative_path(self) -> None:
"""Test construction handles relative config path correctly."""
from opengait.demo.sconet_demo import ScoNetDemo
demo = ScoNetDemo(
cfg_path="configs/sconet/sconet_scoliosis1k.yaml",
checkpoint_path=None,
device="cpu",
)
assert demo.cfg is not None
assert demo.backbone is not None
def test_construction_invalid_config_raises(self) -> None:
"""Test construction raises with invalid config path."""
from opengait.demo.sconet_demo import ScoNetDemo
with pytest.raises((FileNotFoundError, TypeError)):
_ = ScoNetDemo(
cfg_path="/nonexistent/path/config.yaml",
checkpoint_path=None,
device="cpu",
)
class TestScoNetDemoForward:
"""Tests for ScoNetDemo forward pass."""
def test_forward_output_shape_and_dtype(
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
) -> None:
"""Test forward returns logits with shape (N, 3, 16) and correct dtypes."""
outputs_raw = demo.forward(dummy_sils_batch)
outputs = cast(dict[str, Tensor], outputs_raw)
assert "logits" in outputs
logits = outputs["logits"]
# Expected shape: (batch_size, num_classes, parts_num) = (N, 3, 16)
assert logits.shape == (2, 3, 16)
assert logits.dtype == torch.float32
assert outputs["label"].dtype == torch.int64
assert outputs["confidence"].dtype == torch.float32
def test_forward_returns_required_keys(
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
) -> None:
"""Test forward returns required output keys."""
outputs_raw = demo.forward(dummy_sils_batch)
outputs = cast(dict[str, Tensor], outputs_raw)
required_keys = {"logits", "label", "confidence"}
assert set(outputs.keys()) >= required_keys
def test_forward_batch_size_one(
self, demo: "ScoNetDemo", dummy_sils_single: Tensor
) -> None:
"""Test forward works with batch size 1."""
outputs_raw = demo.forward(dummy_sils_single)
outputs = cast(dict[str, Tensor], outputs_raw)
assert outputs["logits"].shape == (1, 3, 16)
assert outputs["label"].shape == (1,)
assert outputs["confidence"].shape == (1,)
def test_forward_label_range(
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
) -> None:
"""Test forward returns valid label indices (0, 1, or 2)."""
outputs_raw = demo.forward(dummy_sils_batch)
outputs = cast(dict[str, Tensor], outputs_raw)
labels = outputs["label"]
assert torch.all(labels >= 0)
assert torch.all(labels <= 2)
def test_forward_confidence_range(
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
) -> None:
"""Test forward returns confidence in [0, 1]."""
outputs_raw = demo.forward(dummy_sils_batch)
outputs = cast(dict[str, Tensor], outputs_raw)
confidence = outputs["confidence"]
assert torch.all(confidence >= 0.0)
assert torch.all(confidence <= 1.0)
class TestScoNetDemoPredict:
"""Tests for ScoNetDemo predict method."""
def test_predict_returns_tuple_with_valid_types(
self, demo: "ScoNetDemo", dummy_sils_single: Tensor
) -> None:
"""Test predict returns (str, float) tuple with valid label."""
from opengait.demo.sconet_demo import ScoNetDemo
result_raw = demo.predict(dummy_sils_single)
result = cast(tuple[str, float], result_raw)
assert isinstance(result, tuple)
assert len(result) == 2
label, confidence = result
assert isinstance(label, str)
assert isinstance(confidence, float)
valid_labels = set(ScoNetDemo.LABEL_MAP.values())
assert label in valid_labels
def test_predict_confidence_range(
self, demo: "ScoNetDemo", dummy_sils_single: Tensor
) -> None:
"""Test predict returns confidence in valid range [0, 1]."""
result_raw = demo.predict(dummy_sils_single)
result = cast(tuple[str, float], result_raw)
confidence = result[1]
assert 0.0 <= confidence <= 1.0
def test_predict_rejects_batch_size_greater_than_one(
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
) -> None:
"""Test predict raises ValueError for batch size > 1."""
with pytest.raises(ValueError, match="batch size 1"):
_ = demo.predict(dummy_sils_batch)
class TestScoNetDemoNoDDP:
"""Tests to verify no DDP leakage in unit behavior."""
def test_no_distributed_init_in_construction(self) -> None:
"""Test that construction does not call torch.distributed."""
from opengait.demo.sconet_demo import ScoNetDemo
with patch("torch.distributed.is_initialized") as mock_is_init:
with patch("torch.distributed.init_process_group") as mock_init_pg:
_ = ScoNetDemo(
cfg_path=str(CONFIG_PATH),
checkpoint_path=None,
device="cpu",
)
mock_init_pg.assert_not_called()
mock_is_init.assert_not_called()
def test_forward_no_distributed_calls(
self, demo: "ScoNetDemo", dummy_sils_batch: Tensor
) -> None:
"""Test forward pass does not call torch.distributed."""
with patch("torch.distributed.all_reduce") as mock_all_reduce:
with patch("torch.distributed.broadcast") as mock_broadcast:
_ = demo.forward(dummy_sils_batch)
mock_all_reduce.assert_not_called()
mock_broadcast.assert_not_called()
def test_predict_no_distributed_calls(
self, demo: "ScoNetDemo", dummy_sils_single: Tensor
) -> None:
"""Test predict does not call torch.distributed."""
with patch("torch.distributed.all_reduce") as mock_all_reduce:
with patch("torch.distributed.broadcast") as mock_broadcast:
_ = demo.predict(dummy_sils_single)
mock_all_reduce.assert_not_called()
mock_broadcast.assert_not_called()
def test_model_not_wrapped_in_ddp(self, demo: "ScoNetDemo") -> None:
"""Test model is not wrapped in DistributedDataParallel."""
from torch.nn.parallel import DistributedDataParallel as DDP
assert not isinstance(demo, DDP)
assert not isinstance(demo.backbone, DDP)
def test_device_is_cpu(self, demo: "ScoNetDemo") -> None:
"""Test model stays on CPU when device="cpu" specified."""
assert demo.device.type == "cpu"
for param in demo.parameters():
assert param.device.type == "cpu"
class TestScoNetDemoCheckpointLoading:
"""Tests for checkpoint loading behavior using synthetic state dict."""
def test_load_checkpoint_changes_weights(
self,
demo: "ScoNetDemo",
synthetic_state_dict: dict[str, Tensor],
) -> None:
"""Test loading checkpoint actually changes model weights."""
import tempfile
import os
# Get initial weight
initial_weight = next(iter(demo.parameters())).clone()
# Create temp checkpoint file with synthetic state dict
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f:
torch.save(synthetic_state_dict, f.name)
temp_path = f.name
try:
_ = demo.load_checkpoint(temp_path, strict=False)
new_weight = next(iter(demo.parameters()))
assert not torch.equal(initial_weight, new_weight)
finally:
os.unlink(temp_path)
def test_load_checkpoint_sets_eval_mode(
self,
demo: "ScoNetDemo",
synthetic_state_dict: dict[str, Tensor],
) -> None:
"""Test loading checkpoint sets model to eval mode."""
import tempfile
import os
_ = demo.train()
assert demo.training is True
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f:
torch.save(synthetic_state_dict, f.name)
temp_path = f.name
try:
_ = demo.load_checkpoint(temp_path, strict=False)
assert demo.training is False
finally:
os.unlink(temp_path)
def test_load_checkpoint_invalid_path_raises(self, demo: "ScoNetDemo") -> None:
"""Test loading from invalid checkpoint path raises error."""
with pytest.raises(FileNotFoundError):
_ = demo.load_checkpoint("/nonexistent/checkpoint.pt")
class TestScoNetDemoLabelMap:
"""Tests for LABEL_MAP constant."""
def test_label_map_has_three_classes(self) -> None:
"""Test LABEL_MAP has exactly 3 classes."""
from opengait.demo.sconet_demo import ScoNetDemo
assert len(ScoNetDemo.LABEL_MAP) == 3
assert set(ScoNetDemo.LABEL_MAP.keys()) == {0, 1, 2}
def test_label_map_values_are_valid_strings(self) -> None:
"""Test LABEL_MAP values are valid non-empty strings."""
from opengait.demo.sconet_demo import ScoNetDemo
for value in ScoNetDemo.LABEL_MAP.values():
assert isinstance(value, str)
assert len(value) > 0
+346
View File
@@ -0,0 +1,346 @@
"""Unit tests for SilhouetteWindow and select_person functions."""
from typing import Any
from unittest.mock import MagicMock
import numpy as np
import pytest
import torch
from numpy.typing import NDArray
from opengait.demo.window import SilhouetteWindow, select_person
class TestSilhouetteWindow:
"""Tests for SilhouetteWindow class."""
def test_window_fill_and_ready_behavior(self) -> None:
"""Window should be ready only when filled to window_size."""
window = SilhouetteWindow(window_size=5, stride=1, gap_threshold=10)
sil = np.ones((64, 44), dtype=np.float32)
# Not ready with fewer frames
for i in range(4):
window.push(sil, frame_idx=i, track_id=1)
assert not window.is_ready()
assert window.fill_level == (i + 1) / 5
# Ready at exactly window_size
window.push(sil, frame_idx=4, track_id=1)
assert window.is_ready()
assert window.fill_level == 1.0
def test_underfilled_not_ready(self) -> None:
"""Underfilled window should never report ready."""
window = SilhouetteWindow(window_size=10, stride=1, gap_threshold=5)
sil = np.ones((64, 44), dtype=np.float32)
# Push 9 frames (underfilled)
for i in range(9):
window.push(sil, frame_idx=i, track_id=1)
assert not window.is_ready()
assert window.fill_level == 0.9
# get_tensor should raise when not ready
with pytest.raises(ValueError, match="Window not ready"):
window.get_tensor()
def test_track_id_change_resets_buffer(self) -> None:
"""Changing track ID should reset the buffer."""
window = SilhouetteWindow(window_size=5, stride=1, gap_threshold=10)
sil = np.ones((64, 44), dtype=np.float32)
# Fill window with track_id=1
for i in range(5):
window.push(sil, frame_idx=i, track_id=1)
assert window.is_ready()
assert window.current_track_id == 1
assert window.fill_level == 1.0
# Push with different track_id should reset
window.push(sil, frame_idx=5, track_id=2)
assert not window.is_ready()
assert window.fill_level == 0.2
assert window.current_track_id == 2
def test_frame_gap_reset_behavior(self) -> None:
"""Frame gap exceeding threshold should reset buffer."""
window = SilhouetteWindow(window_size=5, stride=1, gap_threshold=3)
sil = np.ones((64, 44), dtype=np.float32)
# Fill window with consecutive frames
for i in range(5):
window.push(sil, frame_idx=i, track_id=1)
assert window.is_ready()
assert window.fill_level == 1.0
# Small gap (within threshold) - no reset
window.push(sil, frame_idx=6, track_id=1) # gap = 1
assert window.is_ready()
assert window.fill_level == 1.0 # deque maintains max
# Reset and fill again
window.reset()
for i in range(5):
window.push(sil, frame_idx=i, track_id=1)
# Large gap (exceeds threshold) - should reset
window.push(sil, frame_idx=10, track_id=1) # gap = 5 > 3
assert not window.is_ready()
assert window.fill_level == 0.2
def test_get_tensor_shape(self) -> None:
"""get_tensor should return tensor of shape [1, 1, window_size, 64, 44]."""
window_size = 7
window = SilhouetteWindow(window_size=window_size, stride=1, gap_threshold=10)
# Push unique frames to verify ordering
for i in range(window_size):
sil = np.full((64, 44), fill_value=float(i), dtype=np.float32)
window.push(sil, frame_idx=i, track_id=1)
assert window.is_ready()
tensor = window.get_tensor(device="cpu")
assert isinstance(tensor, torch.Tensor)
assert tensor.shape == (1, 1, window_size, 64, 44)
assert tensor.dtype == torch.float32
# Verify content ordering: first frame should have value 0.0
assert tensor[0, 0, 0, 0, 0].item() == 0.0
# Last frame should have value window_size-1
assert tensor[0, 0, window_size - 1, 0, 0].item() == window_size - 1
def test_should_classify_stride_behavior(self) -> None:
"""should_classify should respect stride setting."""
window = SilhouetteWindow(window_size=5, stride=3, gap_threshold=10)
sil = np.ones((64, 44), dtype=np.float32)
# Fill window
for i in range(5):
window.push(sil, frame_idx=i, track_id=1)
assert window.is_ready()
# First classification should always trigger
assert window.should_classify()
# Mark as classified at frame 4
window.mark_classified()
# Not ready to classify yet (stride=3, only 0 frames passed)
window.push(sil, frame_idx=5, track_id=1)
assert not window.should_classify()
# Still not ready (only 1 frame passed)
window.push(sil, frame_idx=6, track_id=1)
assert not window.should_classify()
# Now ready (3 frames passed since last classification)
window.push(sil, frame_idx=7, track_id=1)
assert window.should_classify()
def test_should_classify_not_ready(self) -> None:
"""should_classify should return False when window not ready."""
window = SilhouetteWindow(window_size=5, stride=1, gap_threshold=10)
sil = np.ones((64, 44), dtype=np.float32)
# Push only 3 frames (not ready)
for i in range(3):
window.push(sil, frame_idx=i, track_id=1)
assert not window.is_ready()
assert not window.should_classify()
def test_reset_clears_all_state(self) -> None:
"""reset should clear all internal state."""
window = SilhouetteWindow(window_size=5, stride=1, gap_threshold=10)
sil = np.ones((64, 44), dtype=np.float32)
# Fill and classify
for i in range(5):
window.push(sil, frame_idx=i, track_id=1)
window.mark_classified()
assert window.is_ready()
assert window.current_track_id == 1
assert window.frame_count == 5
# Reset
window.reset()
assert not window.is_ready()
assert window.fill_level == 0.0
assert window.current_track_id is None
assert window.frame_count == 0
assert not window.should_classify()
def test_push_invalid_shape_raises(self) -> None:
"""push should raise ValueError for invalid silhouette shape."""
window = SilhouetteWindow(window_size=5, stride=1, gap_threshold=10)
# Wrong shape
sil_wrong = np.ones((32, 32), dtype=np.float32)
with pytest.raises(ValueError, match="Expected silhouette shape"):
window.push(sil_wrong, frame_idx=0, track_id=1)
def test_push_wrong_dtype_converts(self) -> None:
"""push should convert dtype to float32."""
window = SilhouetteWindow(window_size=1, stride=1, gap_threshold=10)
# uint8 input
sil_uint8 = np.ones((64, 44), dtype=np.uint8) * 255
window.push(sil_uint8, frame_idx=0, track_id=1)
tensor = window.get_tensor()
assert tensor.dtype == torch.float32
class TestSelectPerson:
"""Tests for select_person function."""
def _create_mock_results(
self,
boxes_xyxy: NDArray[np.float32],
masks_data: NDArray[np.float32],
track_ids: NDArray[np.int64] | None,
) -> Any:
"""Create a mock detection results object."""
mock_boxes = MagicMock()
mock_boxes.xyxy = boxes_xyxy
mock_boxes.id = track_ids
mock_masks = MagicMock()
mock_masks.data = masks_data
mock_results = MagicMock()
mock_results.boxes = mock_boxes
mock_results.masks = mock_masks
return mock_results
def test_select_person_single_detection(self) -> None:
"""Single detection should return that person's data."""
boxes = np.array([[10.0, 10.0, 50.0, 90.0]], dtype=np.float32) # area = 3200
masks = np.random.rand(1, 100, 100).astype(np.float32)
track_ids = np.array([42], dtype=np.int64)
results = self._create_mock_results(boxes, masks, track_ids)
result = select_person(results)
assert result is not None
mask, bbox, tid = result
assert mask.shape == (100, 100)
assert bbox == (10, 10, 50, 90)
assert tid == 42
def test_select_person_multi_detection_selects_largest(self) -> None:
"""Multiple detections should select the one with largest bbox area."""
# Two boxes: second one is larger
boxes = np.array(
[
[0.0, 0.0, 10.0, 10.0], # area = 100
[0.0, 0.0, 30.0, 30.0], # area = 900 (largest)
[0.0, 0.0, 20.0, 20.0], # area = 400
],
dtype=np.float32,
)
masks = np.random.rand(3, 100, 100).astype(np.float32)
track_ids = np.array([1, 2, 3], dtype=np.int64)
results = self._create_mock_results(boxes, masks, track_ids)
result = select_person(results)
assert result is not None
mask, bbox, tid = result
assert bbox == (0, 0, 30, 30) # Largest box
assert tid == 2 # Corresponding track ID
def test_select_person_no_detections_returns_none(self) -> None:
"""No detections should return None."""
boxes = np.array([], dtype=np.float32).reshape(0, 4)
masks = np.array([], dtype=np.float32).reshape(0, 100, 100)
track_ids = np.array([], dtype=np.int64)
results = self._create_mock_results(boxes, masks, track_ids)
result = select_person(results)
assert result is None
def test_select_person_no_track_ids_returns_none(self) -> None:
"""Detections without track IDs should return None."""
boxes = np.array([[10.0, 10.0, 50.0, 90.0]], dtype=np.float32)
masks = np.random.rand(1, 100, 100).astype(np.float32)
results = self._create_mock_results(boxes, masks, track_ids=None)
result = select_person(results)
assert result is None
def test_select_person_empty_track_ids_returns_none(self) -> None:
"""Empty track IDs array should return None."""
boxes = np.array([[10.0, 10.0, 50.0, 90.0]], dtype=np.float32)
masks = np.random.rand(1, 100, 100).astype(np.float32)
track_ids = np.array([], dtype=np.int64)
results = self._create_mock_results(boxes, masks, track_ids)
result = select_person(results)
assert result is None
def test_select_person_missing_boxes_returns_none(self) -> None:
"""Missing boxes attribute should return None."""
mock_results = MagicMock()
mock_results.boxes = None
result = select_person(mock_results)
assert result is None
def test_select_person_missing_masks_returns_none(self) -> None:
"""Missing masks attribute should return None."""
boxes = np.array([[10.0, 10.0, 50.0, 90.0]], dtype=np.float32)
track_ids = np.array([1], dtype=np.int64)
mock_boxes = MagicMock()
mock_boxes.xyxy = boxes
mock_boxes.id = track_ids
mock_results = MagicMock()
mock_results.boxes = mock_boxes
mock_results.masks = None
result = select_person(mock_results)
assert result is None
def test_select_person_1d_bbox_handling(self) -> None:
"""1D bbox array should be reshaped to 2D."""
boxes = np.array([10.0, 10.0, 50.0, 90.0], dtype=np.float32) # 1D
masks = np.random.rand(1, 100, 100).astype(np.float32)
track_ids = np.array([1], dtype=np.int64)
results = self._create_mock_results(boxes, masks, track_ids)
result = select_person(results)
assert result is not None
_, bbox, tid = result
assert bbox == (10, 10, 50, 90)
assert tid == 1
def test_select_person_2d_mask_handling(self) -> None:
"""2D mask should be expanded to 3D."""
boxes = np.array([[10.0, 10.0, 50.0, 90.0]], dtype=np.float32)
masks = np.random.rand(100, 100).astype(np.float32) # 2D
track_ids = np.array([1], dtype=np.int64)
results = self._create_mock_results(boxes, masks, track_ids)
result = select_person(results)
assert result is not None
mask, _, _ = result
# Should be 2D (extracted from expanded 3D)
assert mask.shape == (100, 100)