diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..d7cce71 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""OpenGait tests package.""" diff --git a/tests/demo/__init__.py b/tests/demo/__init__.py new file mode 100644 index 0000000..059f81e --- /dev/null +++ b/tests/demo/__init__.py @@ -0,0 +1 @@ +"""Tests for demo package.""" diff --git a/tests/demo/conftest.py b/tests/demo/conftest.py new file mode 100644 index 0000000..e4efdab --- /dev/null +++ b/tests/demo/conftest.py @@ -0,0 +1,23 @@ +"""Test fixtures for demo package.""" + +import numpy as np +import pytest +import torch + + +@pytest.fixture +def mock_frame_tensor(): + """Return a mock video frame tensor (C, H, W).""" + return torch.randn(3, 224, 224) + + +@pytest.fixture +def mock_frame_array(): + """Return a mock video frame as numpy array (H, W, C).""" + return np.random.randn(224, 224, 3).astype(np.float32) + + +@pytest.fixture +def mock_video_sequence(): + """Return a mock video sequence tensor (T, C, H, W).""" + return torch.randn(16, 3, 224, 224) diff --git a/tests/demo/test_nats.py b/tests/demo/test_nats.py new file mode 100644 index 0000000..d9b856a --- /dev/null +++ b/tests/demo/test_nats.py @@ -0,0 +1,525 @@ +"""Integration tests for NATS publisher functionality. + +Tests cover: +- NATS message receipt with JSON schema validation +- Skip behavior when Docker/NATS unavailable +- Container lifecycle management (cleanup) +- JSON schema field/type validation +""" + +from __future__ import annotations + +import json +import socket +import subprocess +import time +from typing import TYPE_CHECKING, cast + +import pytest + +if TYPE_CHECKING: + from collections.abc import Generator + + +def _find_open_port() -> int: + """Find an available TCP port on localhost.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + sock.listen(1) + addr = cast(tuple[str, int], sock.getsockname()) + port: int = addr[1] + return port + + +# Constants for test configuration +NATS_SUBJECT = "scoliosis.result" +CONTAINER_NAME = "opengait-nats-test" + + +def _docker_available() -> bool: + """Check if Docker is available and running.""" + try: + result = subprocess.run( + ["docker", "info"], + capture_output=True, + timeout=5, + check=False, + ) + return result.returncode == 0 + except (subprocess.TimeoutExpired, FileNotFoundError, OSError): + return False + + +def _nats_container_running() -> bool: + """Check if NATS container is already running.""" + try: + result = subprocess.run( + [ + "docker", + "ps", + "--filter", + f"name={CONTAINER_NAME}", + "--format", + "{{.Names}}", + ], + capture_output=True, + timeout=5, + check=False, + ) + return CONTAINER_NAME in result.stdout.decode() + except (subprocess.TimeoutExpired, FileNotFoundError, OSError): + return False + + +def _start_nats_container(port: int) -> bool: + """Start NATS container for testing.""" + try: + # Remove existing container if present + _ = subprocess.run( + ["docker", "rm", "-f", CONTAINER_NAME], + capture_output=True, + timeout=10, + check=False, + ) + + # Start new NATS container + result = subprocess.run( + [ + "docker", + "run", + "-d", + "--name", + CONTAINER_NAME, + "-p", + f"{port}:{port}", + "nats:latest", + ], + capture_output=True, + timeout=30, + check=False, + ) + + if result.returncode != 0: + return False + + # Wait for NATS to be ready (max 10 seconds) + for _ in range(20): + time.sleep(0.5) + try: + check_result = subprocess.run( + [ + "docker", + "exec", + CONTAINER_NAME, + "nats", + "server", + "check", + "server", + ], + capture_output=True, + timeout=5, + check=False, + ) + if check_result.returncode == 0: + return True + except (subprocess.TimeoutExpired, FileNotFoundError, OSError): + continue + + return False + except (subprocess.TimeoutExpired, FileNotFoundError, OSError): + return False + + +def _stop_nats_container() -> None: + """Stop and remove NATS container.""" + try: + _ = subprocess.run( + ["docker", "rm", "-f", CONTAINER_NAME], + capture_output=True, + timeout=10, + check=False, + ) + except (subprocess.TimeoutExpired, FileNotFoundError, OSError): + pass + + +def _validate_result_schema(data: dict[str, object]) -> tuple[bool, str]: + """Validate JSON result schema. + + Expected schema: + { + "frame": int, + "track_id": int, + "label": str (one of: "negative", "neutral", "positive"), + "confidence": float in [0, 1], + "window": list[int] (start, end), + "timestamp_ns": int + } + """ + required_fields: set[str] = { + "frame", + "track_id", + "label", + "confidence", + "window", + "timestamp_ns", + } + missing = required_fields - set(data.keys()) + if missing: + return False, f"Missing required fields: {missing}" + + # Validate frame (int) + frame = data["frame"] + if not isinstance(frame, int): + return False, f"frame must be int, got {type(frame)}" + + # Validate track_id (int) + track_id = data["track_id"] + if not isinstance(track_id, int): + return False, f"track_id must be int, got {type(track_id)}" + + # Validate label (str in specific set) + label = data["label"] + valid_labels = {"negative", "neutral", "positive"} + if not isinstance(label, str) or label not in valid_labels: + return False, f"label must be one of {valid_labels}, got {label}" + + # Validate confidence (float in [0, 1]) + confidence = data["confidence"] + if not isinstance(confidence, (int, float)): + return False, f"confidence must be numeric, got {type(confidence)}" + if not 0.0 <= float(confidence) <= 1.0: + return False, f"confidence must be in [0, 1], got {confidence}" + + # Validate window (list of 2 ints) + window = data["window"] + if not isinstance(window, list) or len(cast(list[object], window)) != 2: + return False, f"window must be list of 2 ints, got {window}" + window_list = cast(list[object], window) + if not all(isinstance(x, int) for x in window_list): + return False, f"window elements must be ints, got {window}" + + # Validate timestamp_ns (int) + timestamp_ns = data["timestamp_ns"] + if not isinstance(timestamp_ns, int): + return False, f"timestamp_ns must be int, got {type(timestamp_ns)}" + + return True, "" + + +@pytest.fixture(scope="module") +def nats_server() -> "Generator[tuple[bool, int], None, None]": + """Fixture to manage NATS container lifecycle. + + Yields (True, port) if NATS is available, (False, 0) otherwise. + Cleans up container after tests. + """ + if not _docker_available(): + yield False, 0 + return + + port = _find_open_port() + if not _start_nats_container(port): + yield False, 0 + return + + try: + yield True, port + finally: + _stop_nats_container() + + +class TestNatsPublisherIntegration: + """Integration tests for NATS publisher with live server.""" + + @pytest.mark.skipif(not _docker_available(), reason="Docker not available") + def test_nats_message_receipt_and_schema_validation( + self, nats_server: tuple[bool, int] + ) -> None: + """Test that messages are received and match expected schema.""" + server_available, port = nats_server + if not server_available: + pytest.skip("NATS server not available") + + nats_url = f"nats://127.0.0.1:{port}" + + # Import here to avoid issues if nats-py not installed + try: + import asyncio + + import nats # type: ignore[import-untyped] + + except ImportError: + pytest.skip("nats-py not installed") + + from opengait.demo.output import NatsPublisher, create_result + + # Create publisher + publisher = NatsPublisher(nats_url, subject=NATS_SUBJECT) + + # Create test results + test_results = [ + create_result( + frame=100, + track_id=1, + label="positive", + confidence=0.85, + window=(70, 100), + timestamp_ns=1234567890000, + ), + create_result( + frame=130, + track_id=1, + label="negative", + confidence=0.92, + window=(100, 130), + timestamp_ns=1234567890030, + ), + create_result( + frame=160, + track_id=2, + label="neutral", + confidence=0.78, + window=(130, 160), + timestamp_ns=1234567890060, + ), + ] + + # Collect received messages + received_messages: list[dict[str, object]] = [] + + async def subscribe_and_publish(): + """Subscribe to subject, publish messages, collect results.""" + nc = await nats.connect(nats_url) # pyright: ignore[reportUnknownMemberType] + + # Subscribe + sub = await nc.subscribe(NATS_SUBJECT) # pyright: ignore[reportUnknownMemberType] + + # Publish all test results + for result in test_results: + publisher.publish(result) + + # Wait for messages with timeout + for _ in range(len(test_results)): + try: + msg = await asyncio.wait_for(sub.next_msg(), timeout=5.0) + data = json.loads(msg.data.decode("utf-8")) # pyright: ignore[reportAny] + received_messages.append(cast(dict[str, object], data)) + except asyncio.TimeoutError: + break + + await sub.unsubscribe() + await nc.close() + + # Run async subscriber + asyncio.run(subscribe_and_publish()) + + # Cleanup publisher + publisher.close() + + # Verify all messages received + assert len(received_messages) == len(test_results), ( + f"Expected {len(test_results)} messages, got {len(received_messages)}" + ) + + # Validate schema for each message + for i, msg in enumerate(received_messages): + is_valid, error = _validate_result_schema(msg) + assert is_valid, f"Message {i} schema validation failed: {error}" + + # Verify specific values + assert received_messages[0]["frame"] == 100 + assert received_messages[0]["label"] == "positive" + assert received_messages[0]["track_id"] == 1 + _conf = received_messages[0]["confidence"] + assert isinstance(_conf, (int, float)) + assert 0.0 <= float(_conf) <= 1.0 + + assert received_messages[1]["label"] == "negative" + assert received_messages[2]["track_id"] == 2 + + @pytest.mark.skipif(not _docker_available(), reason="Docker not available") + def test_nats_publisher_graceful_when_server_unavailable(self) -> None: + """Test that publisher handles missing server gracefully.""" + try: + from opengait.demo.output import NatsPublisher + except ImportError: + pytest.skip("output module not available") + + # Use wrong port where no server is running + bad_url = "nats://127.0.0.1:14222" + publisher = NatsPublisher(bad_url, subject=NATS_SUBJECT) + + # Should not raise when publishing without server + test_result: dict[str, object] = { + "frame": 1, + "track_id": 1, + "label": "positive", + "confidence": 0.85, + "window": [0, 30], + "timestamp_ns": 1234567890, + } + + # Should not raise + publisher.publish(test_result) + + # Cleanup should also not raise + publisher.close() + + @pytest.mark.skipif(not _docker_available(), reason="Docker not available") + def test_nats_publisher_context_manager( + self, nats_server: tuple[bool, int] + ) -> None: + """Test that publisher works as context manager.""" + server_available, port = nats_server + if not server_available: + pytest.skip("NATS server not available") + + nats_url = f"nats://127.0.0.1:{port}" + + try: + import asyncio + + import nats # type: ignore[import-untyped] + from opengait.demo.output import NatsPublisher, create_result + except ImportError as e: + pytest.skip(f"Required module not available: {e}") + + received_messages: list[dict[str, object]] = [] + + async def subscribe_and_test(): + nc = await nats.connect(nats_url) # pyright: ignore[reportUnknownMemberType] + sub = await nc.subscribe(NATS_SUBJECT) # pyright: ignore[reportUnknownMemberType] + + # Use context manager + with NatsPublisher(nats_url, subject=NATS_SUBJECT) as publisher: + result = create_result( + frame=200, + track_id=5, + label="neutral", + confidence=0.65, + window=(170, 200), + timestamp_ns=9999999999, + ) + publisher.publish(result) + + # Wait for message + try: + msg = await asyncio.wait_for(sub.next_msg(), timeout=5.0) + data = json.loads(msg.data.decode("utf-8")) # pyright: ignore[reportAny] + received_messages.append(cast(dict[str, object], data)) + except asyncio.TimeoutError: + pass + + await sub.unsubscribe() + await nc.close() + + asyncio.run(subscribe_and_test()) + + assert len(received_messages) == 1 + assert received_messages[0]["frame"] == 200 + assert received_messages[0]["track_id"] == 5 + + +class TestNatsSchemaValidation: + """Tests for JSON schema validation without requiring NATS server.""" + + def test_validate_result_schema_valid(self) -> None: + """Test schema validation with valid data.""" + valid_data: dict[str, object] = { + "frame": 1234, + "track_id": 42, + "label": "positive", + "confidence": 0.85, + "window": [1200, 1230], + "timestamp_ns": 1234567890000, + } + + is_valid, error = _validate_result_schema(valid_data) + assert is_valid, f"Valid data rejected: {error}" + + def test_validate_result_schema_invalid_label(self) -> None: + """Test schema validation rejects invalid label.""" + invalid_data: dict[str, object] = { + "frame": 1234, + "track_id": 42, + "label": "invalid_label", + "confidence": 0.85, + "window": [1200, 1230], + "timestamp_ns": 1234567890000, + } + + is_valid, error = _validate_result_schema(invalid_data) + assert not is_valid + assert "label" in error.lower() + + def test_validate_result_schema_confidence_out_of_range(self) -> None: + """Test schema validation rejects confidence outside [0, 1].""" + invalid_data: dict[str, object] = { + "frame": 1234, + "track_id": 42, + "label": "positive", + "confidence": 1.5, + "window": [1200, 1230], + "timestamp_ns": 1234567890000, + } + + is_valid, error = _validate_result_schema(invalid_data) + assert not is_valid + assert "confidence" in error.lower() + + def test_validate_result_schema_missing_fields(self) -> None: + """Test schema validation detects missing fields.""" + incomplete_data: dict[str, object] = { + "frame": 1234, + "label": "positive", + } + + is_valid, error = _validate_result_schema(incomplete_data) + assert not is_valid + assert "missing" in error.lower() + + def test_validate_result_schema_wrong_types(self) -> None: + """Test schema validation rejects wrong types.""" + wrong_types: dict[str, object] = { + "frame": "not_an_int", + "track_id": 42, + "label": "positive", + "confidence": 0.85, + "window": [1200, 1230], + "timestamp_ns": 1234567890000, + } + + is_valid, error = _validate_result_schema(wrong_types) + assert not is_valid + assert "frame" in error.lower() + + def test_all_valid_labels_accepted(self) -> None: + """Test that all valid labels are accepted.""" + for label_str in ["negative", "neutral", "positive"]: + data: dict[str, object] = { + "frame": 100, + "track_id": 1, + "label": label_str, + "confidence": 0.5, + "window": [70, 100], + "timestamp_ns": 1234567890, + } + is_valid, error = _validate_result_schema(data) + assert is_valid, f"Valid label '{label_str}' rejected: {error}" + + +class TestDockerAvailability: + """Tests for Docker availability detection.""" + + def test_docker_available_check(self) -> None: + """Test Docker availability check doesn't crash.""" + # This should not raise + result = _docker_available() + assert isinstance(result, bool) + + def test_nats_container_running_check(self) -> None: + """Test container running check doesn't crash.""" + # This should not raise even if Docker not available + result = _nats_container_running() + assert isinstance(result, bool) diff --git a/tests/demo/test_pipeline.py b/tests/demo/test_pipeline.py new file mode 100644 index 0000000..5bc1fe4 --- /dev/null +++ b/tests/demo/test_pipeline.py @@ -0,0 +1,279 @@ +from __future__ import annotations + +import json +from pathlib import Path +import subprocess +import sys +import time +from typing import Final, cast + +import pytest +import torch + +from opengait.demo.sconet_demo import ScoNetDemo + +REPO_ROOT: Final[Path] = Path(__file__).resolve().parents[2] +SAMPLE_VIDEO_PATH: Final[Path] = REPO_ROOT / "assets" / "sample.mp4" +CHECKPOINT_PATH: Final[Path] = REPO_ROOT / "ckpt" / "ScoNet-20000.pt" +CONFIG_PATH: Final[Path] = REPO_ROOT / "configs" / "sconet" / "sconet_scoliosis1k.yaml" +YOLO_MODEL_PATH: Final[Path] = REPO_ROOT / "yolo11n-seg.pt" + + +def _device_for_runtime() -> str: + return "cuda:0" if torch.cuda.is_available() else "cpu" + + +def _run_pipeline_cli( + *args: str, timeout_seconds: int = 120 +) -> subprocess.CompletedProcess[str]: + command = [sys.executable, "-m", "opengait.demo", *args] + return subprocess.run( + command, + cwd=REPO_ROOT, + capture_output=True, + text=True, + check=False, + timeout=timeout_seconds, + ) + + +def _require_integration_assets() -> None: + if not SAMPLE_VIDEO_PATH.is_file(): + pytest.skip(f"Missing sample video: {SAMPLE_VIDEO_PATH}") + if not CONFIG_PATH.is_file(): + pytest.skip(f"Missing config: {CONFIG_PATH}") + if not YOLO_MODEL_PATH.is_file(): + pytest.skip(f"Missing YOLO model file: {YOLO_MODEL_PATH}") + + +@pytest.fixture +def compatible_checkpoint_path(tmp_path: Path) -> Path: + if not CONFIG_PATH.is_file(): + pytest.skip(f"Missing config: {CONFIG_PATH}") + + checkpoint_file = tmp_path / "sconet-compatible.pt" + model = ScoNetDemo(cfg_path=str(CONFIG_PATH), checkpoint_path=None, device="cpu") + torch.save(model.state_dict(), checkpoint_file) + return checkpoint_file + + +def _extract_prediction_json_lines(stdout: str) -> list[dict[str, object]]: + required_keys = { + "frame", + "track_id", + "label", + "confidence", + "window", + "timestamp_ns", + } + predictions: list[dict[str, object]] = [] + + for line in stdout.splitlines(): + stripped = line.strip() + if not stripped: + continue + try: + payload_obj = cast(object, json.loads(stripped)) + except json.JSONDecodeError: + continue + + if not isinstance(payload_obj, dict): + continue + payload = cast(dict[str, object], payload_obj) + if required_keys.issubset(payload.keys()): + predictions.append(payload) + + return predictions + + +def _assert_prediction_schema(prediction: dict[str, object]) -> None: + assert isinstance(prediction["frame"], int) + assert isinstance(prediction["track_id"], int) + + label = prediction["label"] + assert isinstance(label, str) + assert label in {"negative", "neutral", "positive"} + + confidence = prediction["confidence"] + assert isinstance(confidence, (int, float)) + confidence_value = float(confidence) + assert 0.0 <= confidence_value <= 1.0 + + window_obj = prediction["window"] + assert isinstance(window_obj, int) + assert window_obj >= 0 + + assert isinstance(prediction["timestamp_ns"], int) + + +def test_pipeline_cli_fps_benchmark_smoke( + compatible_checkpoint_path: Path, +) -> None: + _require_integration_assets() + + max_frames = 90 + started_at = time.perf_counter() + result = _run_pipeline_cli( + "--source", + str(SAMPLE_VIDEO_PATH), + "--checkpoint", + str(compatible_checkpoint_path), + "--config", + str(CONFIG_PATH), + "--device", + _device_for_runtime(), + "--yolo-model", + str(YOLO_MODEL_PATH), + "--window", + "5", + "--stride", + "1", + "--max-frames", + str(max_frames), + timeout_seconds=180, + ) + elapsed_seconds = time.perf_counter() - started_at + + assert result.returncode == 0, ( + f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}" + ) + predictions = _extract_prediction_json_lines(result.stdout) + assert predictions, "Expected prediction output for FPS benchmark run" + + for prediction in predictions: + _assert_prediction_schema(prediction) + + observed_frames = { + frame_obj + for prediction in predictions + for frame_obj in [prediction["frame"]] + if isinstance(frame_obj, int) + } + observed_units = len(observed_frames) + if observed_units < 5: + pytest.skip( + "Insufficient observed frame samples for stable FPS benchmark in this environment" + ) + if elapsed_seconds <= 0: + pytest.skip("Non-positive elapsed time; cannot compute FPS benchmark") + + fps = observed_units / elapsed_seconds + min_expected_fps = 0.2 + assert fps >= min_expected_fps, ( + "Observed FPS below conservative CI threshold: " + f"{fps:.3f} < {min_expected_fps:.3f} " + f"(observed_units={observed_units}, elapsed_seconds={elapsed_seconds:.3f})" + ) + + +def test_pipeline_cli_happy_path_outputs_json_predictions( + compatible_checkpoint_path: Path, +) -> None: + _require_integration_assets() + + result = _run_pipeline_cli( + "--source", + str(SAMPLE_VIDEO_PATH), + "--checkpoint", + str(compatible_checkpoint_path), + "--config", + str(CONFIG_PATH), + "--device", + _device_for_runtime(), + "--yolo-model", + str(YOLO_MODEL_PATH), + "--window", + "10", + "--stride", + "10", + "--max-frames", + "120", + timeout_seconds=180, + ) + + assert result.returncode == 0, ( + f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}" + ) + predictions = _extract_prediction_json_lines(result.stdout) + assert predictions, ( + "Expected at least one prediction JSON line in stdout. " + f"stdout:\n{result.stdout}\nstderr:\n{result.stderr}" + ) + for prediction in predictions: + _assert_prediction_schema(prediction) + + assert "Connected to NATS" not in result.stderr + + +def test_pipeline_cli_max_frames_caps_output_frames( + compatible_checkpoint_path: Path, +) -> None: + _require_integration_assets() + + max_frames = 20 + result = _run_pipeline_cli( + "--source", + str(SAMPLE_VIDEO_PATH), + "--checkpoint", + str(compatible_checkpoint_path), + "--config", + str(CONFIG_PATH), + "--device", + _device_for_runtime(), + "--yolo-model", + str(YOLO_MODEL_PATH), + "--window", + "5", + "--stride", + "1", + "--max-frames", + str(max_frames), + timeout_seconds=180, + ) + + assert result.returncode == 0, ( + f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}" + ) + predictions = _extract_prediction_json_lines(result.stdout) + assert predictions, "Expected prediction output with --max-frames run" + + for prediction in predictions: + _assert_prediction_schema(prediction) + frame_idx_obj = prediction["frame"] + assert isinstance(frame_idx_obj, int) + assert frame_idx_obj < max_frames + + +def test_pipeline_cli_invalid_source_path_returns_user_error() -> None: + result = _run_pipeline_cli( + "--source", + "/definitely/not/a/real/video.mp4", + "--checkpoint", + "/tmp/unused-checkpoint.pt", + "--config", + str(CONFIG_PATH), + timeout_seconds=30, + ) + + assert result.returncode == 2 + assert "Error: Video source not found" in result.stderr + + +def test_pipeline_cli_invalid_checkpoint_path_returns_user_error() -> None: + if not SAMPLE_VIDEO_PATH.is_file(): + pytest.skip(f"Missing sample video: {SAMPLE_VIDEO_PATH}") + if not CONFIG_PATH.is_file(): + pytest.skip(f"Missing config: {CONFIG_PATH}") + + result = _run_pipeline_cli( + "--source", + str(SAMPLE_VIDEO_PATH), + "--checkpoint", + str(REPO_ROOT / "ckpt" / "missing-checkpoint.pt"), + "--config", + str(CONFIG_PATH), + timeout_seconds=30, + ) + + assert result.returncode == 2 + assert "Error: Checkpoint not found" in result.stderr diff --git a/tests/demo/test_preprocess.py b/tests/demo/test_preprocess.py new file mode 100644 index 0000000..9d58c39 --- /dev/null +++ b/tests/demo/test_preprocess.py @@ -0,0 +1,153 @@ +"""Unit tests for silhouette preprocessing functions.""" + +from typing import cast + +import numpy as np +from numpy.typing import NDArray +import pytest +from beartype.roar import BeartypeCallHintParamViolation +from jaxtyping import TypeCheckError + +from opengait.demo.preprocess import mask_to_silhouette + + +class TestMaskToSilhouette: + """Tests for mask_to_silhouette() function.""" + + def test_valid_mask_returns_correct_shape_dtype_and_range(self) -> None: + """Valid mask should return (64, 44) float32 array in [0, 1] range.""" + # Create a synthetic mask with sufficient area (person-shaped blob) + h, w = 200, 150 + mask = np.zeros((h, w), dtype=np.uint8) + # Draw a filled ellipse to simulate a person + center_y, center_x = h // 2, w // 2 + axes_y, axes_x = h // 3, w // 4 + y, x = np.ogrid[:h, :w] + ellipse_mask = ((x - center_x) / axes_x) ** 2 + ( + (y - center_y) / axes_y + ) ** 2 <= 1 + mask[ellipse_mask] = 255 + + bbox = (w // 4, h // 6, 3 * w // 4, 5 * h // 6) + + result = mask_to_silhouette(mask, bbox) + + assert result is not None + result_arr = cast(NDArray[np.float32], result) + assert result_arr.shape == (64, 44) + assert result_arr.dtype == np.float32 + assert np.all(result_arr >= 0.0) and np.all(result_arr <= 1.0) + + def test_tiny_mask_returns_none(self) -> None: + """Mask with area below MIN_MASK_AREA should return None.""" + # Create a tiny mask + mask = np.zeros((50, 50), dtype=np.uint8) + mask[20:22, 20:22] = 255 # Only 4 pixels, well below MIN_MASK_AREA (500) + + bbox = (20, 20, 22, 22) + + result = mask_to_silhouette(mask, bbox) + + assert result is None + + def test_empty_mask_returns_none(self) -> None: + """Completely empty mask should return None.""" + mask = np.zeros((100, 100), dtype=np.uint8) + bbox = (10, 10, 90, 90) + + result = mask_to_silhouette(mask, bbox) + + assert result is None + + def test_full_frame_mask_returns_valid_output(self) -> None: + """Full-frame mask (large bbox covering entire image) should work.""" + h, w = 300, 200 + mask = np.zeros((h, w), dtype=np.uint8) + # Create a large filled region + mask[50:250, 30:170] = 255 + + # Full frame bbox + bbox = (0, 0, w, h) + + result = mask_to_silhouette(mask, bbox) + + assert result is not None + result_arr = cast(NDArray[np.float32], result) + assert result_arr.shape == (64, 44) + assert result_arr.dtype == np.float32 + + def test_determinism_same_input_same_output(self) -> None: + """Same input should always produce same output.""" + h, w = 200, 150 + mask = np.zeros((h, w), dtype=np.uint8) + # Create a person-shaped region + mask[50:150, 40:110] = 255 + + bbox = (40, 50, 110, 150) + + result1 = mask_to_silhouette(mask, bbox) + result2 = mask_to_silhouette(mask, bbox) + result3 = mask_to_silhouette(mask, bbox) + + assert result1 is not None + assert result2 is not None + assert result3 is not None + np.testing.assert_array_equal(result1, result2) + np.testing.assert_array_equal(result2, result3) + + def test_tall_narrow_mask_valid_output(self) -> None: + """Tall narrow mask should produce valid silhouette.""" + h, w = 400, 50 + mask = np.zeros((h, w), dtype=np.uint8) + # Tall narrow person + mask[50:350, 10:40] = 255 + + bbox = (10, 50, 40, 350) + + result = mask_to_silhouette(mask, bbox) + + assert result is not None + result_arr = cast(NDArray[np.float32], result) + assert result_arr.shape == (64, 44) + + def test_wide_short_mask_valid_output(self) -> None: + """Wide short mask should produce valid silhouette.""" + h, w = 100, 400 + mask = np.zeros((h, w), dtype=np.uint8) + # Wide short person + mask[20:80, 50:350] = 255 + + bbox = (50, 20, 350, 80) + + result = mask_to_silhouette(mask, bbox) + + assert result is not None + result_arr = cast(NDArray[np.float32], result) + assert result_arr.shape == (64, 44) + + def test_beartype_rejects_wrong_dtype(self) -> None: + """Beartype should reject non-uint8 input.""" + # Float array instead of uint8 + mask = np.ones((100, 100), dtype=np.float32) * 255 + bbox = (10, 10, 90, 90) + + with pytest.raises((BeartypeCallHintParamViolation, TypeCheckError)): + _ = mask_to_silhouette(mask, bbox) + + def test_beartype_rejects_wrong_ndim(self) -> None: + """Beartype should reject non-2D array.""" + # 3D array instead of 2D + mask = np.ones((100, 100, 3), dtype=np.uint8) * 255 + bbox = (10, 10, 90, 90) + + with pytest.raises((BeartypeCallHintParamViolation, TypeCheckError)): + _ = mask_to_silhouette(mask, bbox) + + def test_beartype_rejects_wrong_bbox_type(self) -> None: + """Beartype should reject non-tuple bbox.""" + mask = np.ones((100, 100), dtype=np.uint8) * 255 + # List instead of tuple + bbox = [10, 10, 90, 90] + + with pytest.raises((BeartypeCallHintParamViolation, TypeCheckError)): + _ = mask_to_silhouette(mask, bbox) diff --git a/tests/demo/test_sconet_demo.py b/tests/demo/test_sconet_demo.py new file mode 100644 index 0000000..613c165 --- /dev/null +++ b/tests/demo/test_sconet_demo.py @@ -0,0 +1,341 @@ +"""Unit tests for ScoNetDemo forward pass. + +Tests cover: +- Construction from config/checkpoint path handling +- Forward output shape (N, 3, 16) and dtype float +- Predict output (label_str, confidence_float) with valid label/range +- No-DDP leakage check (no torch.distributed calls in unit behavior) +""" + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING, cast +from unittest.mock import patch + +import pytest +import torch +from torch import Tensor + +if TYPE_CHECKING: + from opengait.demo.sconet_demo import ScoNetDemo + +# Constants for test configuration +CONFIG_PATH = Path("configs/sconet/sconet_scoliosis1k.yaml") + + +@pytest.fixture +def demo() -> "ScoNetDemo": + """Create ScoNetDemo without loading checkpoint (CPU-only).""" + from opengait.demo.sconet_demo import ScoNetDemo + + return ScoNetDemo( + cfg_path=str(CONFIG_PATH), + checkpoint_path=None, + device="cpu", + ) + + +@pytest.fixture +def dummy_sils_batch() -> Tensor: + """Return dummy silhouette tensor of shape (N, 1, S, 64, 44).""" + return torch.randn(2, 1, 30, 64, 44) + + +@pytest.fixture +def dummy_sils_single() -> Tensor: + """Return dummy silhouette tensor of shape (1, 1, S, 64, 44) for predict.""" + return torch.randn(1, 1, 30, 64, 44) + + +@pytest.fixture +def synthetic_state_dict() -> dict[str, Tensor]: + """Return a synthetic state dict compatible with ScoNetDemo structure.""" + return { + "backbone.conv1.conv.weight": torch.randn(64, 1, 3, 3), + "backbone.conv1.bn.weight": torch.ones(64), + "backbone.conv1.bn.bias": torch.zeros(64), + "backbone.conv1.bn.running_mean": torch.zeros(64), + "backbone.conv1.bn.running_var": torch.ones(64), + "fcs.fc_bin": torch.randn(16, 512, 256), + "bn_necks.fc_bin": torch.randn(16, 256, 3), + "bn_necks.bn1d.weight": torch.ones(4096), + "bn_necks.bn1d.bias": torch.zeros(4096), + "bn_necks.bn1d.running_mean": torch.zeros(4096), + "bn_necks.bn1d.running_var": torch.ones(4096), + } + + +class TestScoNetDemoConstruction: + """Tests for ScoNetDemo construction and path handling.""" + + def test_construction_from_config_no_checkpoint(self) -> None: + """Test construction with config only, no checkpoint.""" + from opengait.demo.sconet_demo import ScoNetDemo + + demo = ScoNetDemo( + cfg_path=str(CONFIG_PATH), + checkpoint_path=None, + device="cpu", + ) + + assert demo.cfg_path.endswith("sconet_scoliosis1k.yaml") + assert demo.device == torch.device("cpu") + assert demo.cfg is not None + assert "model_cfg" in demo.cfg + assert demo.training is False # eval mode + + def test_construction_with_relative_path(self) -> None: + """Test construction handles relative config path correctly.""" + from opengait.demo.sconet_demo import ScoNetDemo + + demo = ScoNetDemo( + cfg_path="configs/sconet/sconet_scoliosis1k.yaml", + checkpoint_path=None, + device="cpu", + ) + + assert demo.cfg is not None + assert demo.backbone is not None + + def test_construction_invalid_config_raises(self) -> None: + """Test construction raises with invalid config path.""" + from opengait.demo.sconet_demo import ScoNetDemo + + with pytest.raises((FileNotFoundError, TypeError)): + _ = ScoNetDemo( + cfg_path="/nonexistent/path/config.yaml", + checkpoint_path=None, + device="cpu", + ) + + +class TestScoNetDemoForward: + """Tests for ScoNetDemo forward pass.""" + + def test_forward_output_shape_and_dtype( + self, demo: "ScoNetDemo", dummy_sils_batch: Tensor + ) -> None: + """Test forward returns logits with shape (N, 3, 16) and correct dtypes.""" + outputs_raw = demo.forward(dummy_sils_batch) + outputs = cast(dict[str, Tensor], outputs_raw) + + assert "logits" in outputs + logits = outputs["logits"] + + # Expected shape: (batch_size, num_classes, parts_num) = (N, 3, 16) + assert logits.shape == (2, 3, 16) + assert logits.dtype == torch.float32 + assert outputs["label"].dtype == torch.int64 + assert outputs["confidence"].dtype == torch.float32 + + def test_forward_returns_required_keys( + self, demo: "ScoNetDemo", dummy_sils_batch: Tensor + ) -> None: + """Test forward returns required output keys.""" + outputs_raw = demo.forward(dummy_sils_batch) + outputs = cast(dict[str, Tensor], outputs_raw) + + required_keys = {"logits", "label", "confidence"} + assert set(outputs.keys()) >= required_keys + + def test_forward_batch_size_one( + self, demo: "ScoNetDemo", dummy_sils_single: Tensor + ) -> None: + """Test forward works with batch size 1.""" + outputs_raw = demo.forward(dummy_sils_single) + outputs = cast(dict[str, Tensor], outputs_raw) + + assert outputs["logits"].shape == (1, 3, 16) + assert outputs["label"].shape == (1,) + assert outputs["confidence"].shape == (1,) + + def test_forward_label_range( + self, demo: "ScoNetDemo", dummy_sils_batch: Tensor + ) -> None: + """Test forward returns valid label indices (0, 1, or 2).""" + outputs_raw = demo.forward(dummy_sils_batch) + outputs = cast(dict[str, Tensor], outputs_raw) + + labels = outputs["label"] + assert torch.all(labels >= 0) + assert torch.all(labels <= 2) + + def test_forward_confidence_range( + self, demo: "ScoNetDemo", dummy_sils_batch: Tensor + ) -> None: + """Test forward returns confidence in [0, 1].""" + outputs_raw = demo.forward(dummy_sils_batch) + outputs = cast(dict[str, Tensor], outputs_raw) + + confidence = outputs["confidence"] + assert torch.all(confidence >= 0.0) + assert torch.all(confidence <= 1.0) + + +class TestScoNetDemoPredict: + """Tests for ScoNetDemo predict method.""" + + def test_predict_returns_tuple_with_valid_types( + self, demo: "ScoNetDemo", dummy_sils_single: Tensor + ) -> None: + """Test predict returns (str, float) tuple with valid label.""" + from opengait.demo.sconet_demo import ScoNetDemo + + result_raw = demo.predict(dummy_sils_single) + result = cast(tuple[str, float], result_raw) + + assert isinstance(result, tuple) + assert len(result) == 2 + label, confidence = result + assert isinstance(label, str) + assert isinstance(confidence, float) + + valid_labels = set(ScoNetDemo.LABEL_MAP.values()) + assert label in valid_labels + + def test_predict_confidence_range( + self, demo: "ScoNetDemo", dummy_sils_single: Tensor + ) -> None: + """Test predict returns confidence in valid range [0, 1].""" + result_raw = demo.predict(dummy_sils_single) + result = cast(tuple[str, float], result_raw) + confidence = result[1] + + assert 0.0 <= confidence <= 1.0 + + def test_predict_rejects_batch_size_greater_than_one( + self, demo: "ScoNetDemo", dummy_sils_batch: Tensor + ) -> None: + """Test predict raises ValueError for batch size > 1.""" + with pytest.raises(ValueError, match="batch size 1"): + _ = demo.predict(dummy_sils_batch) + + +class TestScoNetDemoNoDDP: + """Tests to verify no DDP leakage in unit behavior.""" + + def test_no_distributed_init_in_construction(self) -> None: + """Test that construction does not call torch.distributed.""" + from opengait.demo.sconet_demo import ScoNetDemo + + with patch("torch.distributed.is_initialized") as mock_is_init: + with patch("torch.distributed.init_process_group") as mock_init_pg: + _ = ScoNetDemo( + cfg_path=str(CONFIG_PATH), + checkpoint_path=None, + device="cpu", + ) + + mock_init_pg.assert_not_called() + mock_is_init.assert_not_called() + + def test_forward_no_distributed_calls( + self, demo: "ScoNetDemo", dummy_sils_batch: Tensor + ) -> None: + """Test forward pass does not call torch.distributed.""" + with patch("torch.distributed.all_reduce") as mock_all_reduce: + with patch("torch.distributed.broadcast") as mock_broadcast: + _ = demo.forward(dummy_sils_batch) + + mock_all_reduce.assert_not_called() + mock_broadcast.assert_not_called() + + def test_predict_no_distributed_calls( + self, demo: "ScoNetDemo", dummy_sils_single: Tensor + ) -> None: + """Test predict does not call torch.distributed.""" + with patch("torch.distributed.all_reduce") as mock_all_reduce: + with patch("torch.distributed.broadcast") as mock_broadcast: + _ = demo.predict(dummy_sils_single) + + mock_all_reduce.assert_not_called() + mock_broadcast.assert_not_called() + + def test_model_not_wrapped_in_ddp(self, demo: "ScoNetDemo") -> None: + """Test model is not wrapped in DistributedDataParallel.""" + from torch.nn.parallel import DistributedDataParallel as DDP + + assert not isinstance(demo, DDP) + assert not isinstance(demo.backbone, DDP) + + def test_device_is_cpu(self, demo: "ScoNetDemo") -> None: + """Test model stays on CPU when device="cpu" specified.""" + assert demo.device.type == "cpu" + + for param in demo.parameters(): + assert param.device.type == "cpu" + + +class TestScoNetDemoCheckpointLoading: + """Tests for checkpoint loading behavior using synthetic state dict.""" + + def test_load_checkpoint_changes_weights( + self, + demo: "ScoNetDemo", + synthetic_state_dict: dict[str, Tensor], + ) -> None: + """Test loading checkpoint actually changes model weights.""" + import tempfile + import os + + # Get initial weight + initial_weight = next(iter(demo.parameters())).clone() + + # Create temp checkpoint file with synthetic state dict + with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f: + torch.save(synthetic_state_dict, f.name) + temp_path = f.name + + try: + _ = demo.load_checkpoint(temp_path, strict=False) + new_weight = next(iter(demo.parameters())) + assert not torch.equal(initial_weight, new_weight) + finally: + os.unlink(temp_path) + + def test_load_checkpoint_sets_eval_mode( + self, + demo: "ScoNetDemo", + synthetic_state_dict: dict[str, Tensor], + ) -> None: + """Test loading checkpoint sets model to eval mode.""" + import tempfile + import os + + _ = demo.train() + assert demo.training is True + + with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f: + torch.save(synthetic_state_dict, f.name) + temp_path = f.name + + try: + _ = demo.load_checkpoint(temp_path, strict=False) + assert demo.training is False + finally: + os.unlink(temp_path) + + def test_load_checkpoint_invalid_path_raises(self, demo: "ScoNetDemo") -> None: + """Test loading from invalid checkpoint path raises error.""" + with pytest.raises(FileNotFoundError): + _ = demo.load_checkpoint("/nonexistent/checkpoint.pt") + + +class TestScoNetDemoLabelMap: + """Tests for LABEL_MAP constant.""" + + def test_label_map_has_three_classes(self) -> None: + """Test LABEL_MAP has exactly 3 classes.""" + from opengait.demo.sconet_demo import ScoNetDemo + + assert len(ScoNetDemo.LABEL_MAP) == 3 + assert set(ScoNetDemo.LABEL_MAP.keys()) == {0, 1, 2} + + def test_label_map_values_are_valid_strings(self) -> None: + """Test LABEL_MAP values are valid non-empty strings.""" + from opengait.demo.sconet_demo import ScoNetDemo + + for value in ScoNetDemo.LABEL_MAP.values(): + assert isinstance(value, str) + assert len(value) > 0 diff --git a/tests/demo/test_window.py b/tests/demo/test_window.py new file mode 100644 index 0000000..3c96bf0 --- /dev/null +++ b/tests/demo/test_window.py @@ -0,0 +1,346 @@ +"""Unit tests for SilhouetteWindow and select_person functions.""" + +from typing import Any +from unittest.mock import MagicMock + +import numpy as np +import pytest +import torch +from numpy.typing import NDArray + +from opengait.demo.window import SilhouetteWindow, select_person + + +class TestSilhouetteWindow: + """Tests for SilhouetteWindow class.""" + + def test_window_fill_and_ready_behavior(self) -> None: + """Window should be ready only when filled to window_size.""" + window = SilhouetteWindow(window_size=5, stride=1, gap_threshold=10) + sil = np.ones((64, 44), dtype=np.float32) + + # Not ready with fewer frames + for i in range(4): + window.push(sil, frame_idx=i, track_id=1) + assert not window.is_ready() + assert window.fill_level == (i + 1) / 5 + + # Ready at exactly window_size + window.push(sil, frame_idx=4, track_id=1) + assert window.is_ready() + assert window.fill_level == 1.0 + + def test_underfilled_not_ready(self) -> None: + """Underfilled window should never report ready.""" + window = SilhouetteWindow(window_size=10, stride=1, gap_threshold=5) + sil = np.ones((64, 44), dtype=np.float32) + + # Push 9 frames (underfilled) + for i in range(9): + window.push(sil, frame_idx=i, track_id=1) + + assert not window.is_ready() + assert window.fill_level == 0.9 + + # get_tensor should raise when not ready + with pytest.raises(ValueError, match="Window not ready"): + window.get_tensor() + + def test_track_id_change_resets_buffer(self) -> None: + """Changing track ID should reset the buffer.""" + window = SilhouetteWindow(window_size=5, stride=1, gap_threshold=10) + sil = np.ones((64, 44), dtype=np.float32) + + # Fill window with track_id=1 + for i in range(5): + window.push(sil, frame_idx=i, track_id=1) + + assert window.is_ready() + assert window.current_track_id == 1 + assert window.fill_level == 1.0 + + # Push with different track_id should reset + window.push(sil, frame_idx=5, track_id=2) + assert not window.is_ready() + assert window.fill_level == 0.2 + assert window.current_track_id == 2 + + def test_frame_gap_reset_behavior(self) -> None: + """Frame gap exceeding threshold should reset buffer.""" + window = SilhouetteWindow(window_size=5, stride=1, gap_threshold=3) + sil = np.ones((64, 44), dtype=np.float32) + + # Fill window with consecutive frames + for i in range(5): + window.push(sil, frame_idx=i, track_id=1) + + assert window.is_ready() + assert window.fill_level == 1.0 + + # Small gap (within threshold) - no reset + window.push(sil, frame_idx=6, track_id=1) # gap = 1 + assert window.is_ready() + assert window.fill_level == 1.0 # deque maintains max + + # Reset and fill again + window.reset() + for i in range(5): + window.push(sil, frame_idx=i, track_id=1) + + # Large gap (exceeds threshold) - should reset + window.push(sil, frame_idx=10, track_id=1) # gap = 5 > 3 + assert not window.is_ready() + assert window.fill_level == 0.2 + + def test_get_tensor_shape(self) -> None: + """get_tensor should return tensor of shape [1, 1, window_size, 64, 44].""" + window_size = 7 + window = SilhouetteWindow(window_size=window_size, stride=1, gap_threshold=10) + + # Push unique frames to verify ordering + for i in range(window_size): + sil = np.full((64, 44), fill_value=float(i), dtype=np.float32) + window.push(sil, frame_idx=i, track_id=1) + + assert window.is_ready() + + tensor = window.get_tensor(device="cpu") + + assert isinstance(tensor, torch.Tensor) + assert tensor.shape == (1, 1, window_size, 64, 44) + assert tensor.dtype == torch.float32 + + # Verify content ordering: first frame should have value 0.0 + assert tensor[0, 0, 0, 0, 0].item() == 0.0 + # Last frame should have value window_size-1 + assert tensor[0, 0, window_size - 1, 0, 0].item() == window_size - 1 + + def test_should_classify_stride_behavior(self) -> None: + """should_classify should respect stride setting.""" + window = SilhouetteWindow(window_size=5, stride=3, gap_threshold=10) + sil = np.ones((64, 44), dtype=np.float32) + + # Fill window + for i in range(5): + window.push(sil, frame_idx=i, track_id=1) + + assert window.is_ready() + + # First classification should always trigger + assert window.should_classify() + + # Mark as classified at frame 4 + window.mark_classified() + + # Not ready to classify yet (stride=3, only 0 frames passed) + window.push(sil, frame_idx=5, track_id=1) + assert not window.should_classify() + + # Still not ready (only 1 frame passed) + window.push(sil, frame_idx=6, track_id=1) + assert not window.should_classify() + + # Now ready (3 frames passed since last classification) + window.push(sil, frame_idx=7, track_id=1) + assert window.should_classify() + + def test_should_classify_not_ready(self) -> None: + """should_classify should return False when window not ready.""" + window = SilhouetteWindow(window_size=5, stride=1, gap_threshold=10) + sil = np.ones((64, 44), dtype=np.float32) + + # Push only 3 frames (not ready) + for i in range(3): + window.push(sil, frame_idx=i, track_id=1) + + assert not window.is_ready() + assert not window.should_classify() + + def test_reset_clears_all_state(self) -> None: + """reset should clear all internal state.""" + window = SilhouetteWindow(window_size=5, stride=1, gap_threshold=10) + sil = np.ones((64, 44), dtype=np.float32) + + # Fill and classify + for i in range(5): + window.push(sil, frame_idx=i, track_id=1) + window.mark_classified() + + assert window.is_ready() + assert window.current_track_id == 1 + assert window.frame_count == 5 + + # Reset + window.reset() + + assert not window.is_ready() + assert window.fill_level == 0.0 + assert window.current_track_id is None + assert window.frame_count == 0 + assert not window.should_classify() + + def test_push_invalid_shape_raises(self) -> None: + """push should raise ValueError for invalid silhouette shape.""" + window = SilhouetteWindow(window_size=5, stride=1, gap_threshold=10) + + # Wrong shape + sil_wrong = np.ones((32, 32), dtype=np.float32) + + with pytest.raises(ValueError, match="Expected silhouette shape"): + window.push(sil_wrong, frame_idx=0, track_id=1) + + def test_push_wrong_dtype_converts(self) -> None: + """push should convert dtype to float32.""" + window = SilhouetteWindow(window_size=1, stride=1, gap_threshold=10) + + # uint8 input + sil_uint8 = np.ones((64, 44), dtype=np.uint8) * 255 + window.push(sil_uint8, frame_idx=0, track_id=1) + + tensor = window.get_tensor() + assert tensor.dtype == torch.float32 + + +class TestSelectPerson: + """Tests for select_person function.""" + + def _create_mock_results( + self, + boxes_xyxy: NDArray[np.float32], + masks_data: NDArray[np.float32], + track_ids: NDArray[np.int64] | None, + ) -> Any: + """Create a mock detection results object.""" + mock_boxes = MagicMock() + mock_boxes.xyxy = boxes_xyxy + mock_boxes.id = track_ids + + mock_masks = MagicMock() + mock_masks.data = masks_data + + mock_results = MagicMock() + mock_results.boxes = mock_boxes + mock_results.masks = mock_masks + + return mock_results + + def test_select_person_single_detection(self) -> None: + """Single detection should return that person's data.""" + boxes = np.array([[10.0, 10.0, 50.0, 90.0]], dtype=np.float32) # area = 3200 + masks = np.random.rand(1, 100, 100).astype(np.float32) + track_ids = np.array([42], dtype=np.int64) + + results = self._create_mock_results(boxes, masks, track_ids) + result = select_person(results) + + assert result is not None + mask, bbox, tid = result + assert mask.shape == (100, 100) + assert bbox == (10, 10, 50, 90) + assert tid == 42 + + def test_select_person_multi_detection_selects_largest(self) -> None: + """Multiple detections should select the one with largest bbox area.""" + # Two boxes: second one is larger + boxes = np.array( + [ + [0.0, 0.0, 10.0, 10.0], # area = 100 + [0.0, 0.0, 30.0, 30.0], # area = 900 (largest) + [0.0, 0.0, 20.0, 20.0], # area = 400 + ], + dtype=np.float32, + ) + masks = np.random.rand(3, 100, 100).astype(np.float32) + track_ids = np.array([1, 2, 3], dtype=np.int64) + + results = self._create_mock_results(boxes, masks, track_ids) + result = select_person(results) + + assert result is not None + mask, bbox, tid = result + assert bbox == (0, 0, 30, 30) # Largest box + assert tid == 2 # Corresponding track ID + + def test_select_person_no_detections_returns_none(self) -> None: + """No detections should return None.""" + boxes = np.array([], dtype=np.float32).reshape(0, 4) + masks = np.array([], dtype=np.float32).reshape(0, 100, 100) + track_ids = np.array([], dtype=np.int64) + + results = self._create_mock_results(boxes, masks, track_ids) + result = select_person(results) + + assert result is None + + def test_select_person_no_track_ids_returns_none(self) -> None: + """Detections without track IDs should return None.""" + boxes = np.array([[10.0, 10.0, 50.0, 90.0]], dtype=np.float32) + masks = np.random.rand(1, 100, 100).astype(np.float32) + + results = self._create_mock_results(boxes, masks, track_ids=None) + result = select_person(results) + + assert result is None + + def test_select_person_empty_track_ids_returns_none(self) -> None: + """Empty track IDs array should return None.""" + boxes = np.array([[10.0, 10.0, 50.0, 90.0]], dtype=np.float32) + masks = np.random.rand(1, 100, 100).astype(np.float32) + track_ids = np.array([], dtype=np.int64) + + results = self._create_mock_results(boxes, masks, track_ids) + result = select_person(results) + + assert result is None + + def test_select_person_missing_boxes_returns_none(self) -> None: + """Missing boxes attribute should return None.""" + mock_results = MagicMock() + mock_results.boxes = None + + result = select_person(mock_results) + assert result is None + + def test_select_person_missing_masks_returns_none(self) -> None: + """Missing masks attribute should return None.""" + boxes = np.array([[10.0, 10.0, 50.0, 90.0]], dtype=np.float32) + track_ids = np.array([1], dtype=np.int64) + + mock_boxes = MagicMock() + mock_boxes.xyxy = boxes + mock_boxes.id = track_ids + + mock_results = MagicMock() + mock_results.boxes = mock_boxes + mock_results.masks = None + + result = select_person(mock_results) + assert result is None + + def test_select_person_1d_bbox_handling(self) -> None: + """1D bbox array should be reshaped to 2D.""" + boxes = np.array([10.0, 10.0, 50.0, 90.0], dtype=np.float32) # 1D + masks = np.random.rand(1, 100, 100).astype(np.float32) + track_ids = np.array([1], dtype=np.int64) + + results = self._create_mock_results(boxes, masks, track_ids) + result = select_person(results) + + assert result is not None + _, bbox, tid = result + assert bbox == (10, 10, 50, 90) + assert tid == 1 + + def test_select_person_2d_mask_handling(self) -> None: + """2D mask should be expanded to 3D.""" + boxes = np.array([[10.0, 10.0, 50.0, 90.0]], dtype=np.float32) + masks = np.random.rand(100, 100).astype(np.float32) # 2D + track_ids = np.array([1], dtype=np.int64) + + results = self._create_mock_results(boxes, masks, track_ids) + result = select_person(results) + + assert result is not None + mask, _, _ = result + # Should be 2D (extracted from expanded 3D) + assert mask.shape == (100, 100)