d4e2a59ad2
Make the OpenGait-studio demo drop unpaced frames before they grow the silhouette window. Separate source-frame gap tracking from paced-frame stride tracking so runtime scheduling matches the documented demo-window-and-stride behavior. Add regressions for paced window growth and schedule-frame stride semantics.
1038 lines
33 KiB
Python
1038 lines
33 KiB
Python
from __future__ import annotations
|
|
|
|
import importlib.util
|
|
import json
|
|
import pickle
|
|
from pathlib import Path
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
from typing import Final, Literal, cast
|
|
from unittest import mock
|
|
|
|
import numpy as np
|
|
from numpy.typing import NDArray
|
|
import pytest
|
|
import torch
|
|
|
|
from opengait_studio.sconet_demo import ScoNetDemo
|
|
|
|
REPO_ROOT: Final[Path] = Path(__file__).resolve().parents[2]
|
|
SAMPLE_VIDEO_PATH: Final[Path] = REPO_ROOT / "assets" / "sample.mp4"
|
|
CHECKPOINT_PATH: Final[Path] = REPO_ROOT / "ckpt" / "ScoNet-20000.pt"
|
|
CONFIG_PATH: Final[Path] = REPO_ROOT / "configs" / "sconet" / "sconet_scoliosis1k.yaml"
|
|
YOLO_MODEL_PATH: Final[Path] = REPO_ROOT / "ckpt" / "yolo11n-seg.pt"
|
|
|
|
|
|
def _device_for_runtime() -> str:
|
|
return "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
def _run_pipeline_cli(
|
|
*args: str, timeout_seconds: int = 120
|
|
) -> subprocess.CompletedProcess[str]:
|
|
command = [sys.executable, "-m", "opengait_studio", *args]
|
|
return subprocess.run(
|
|
command,
|
|
cwd=REPO_ROOT,
|
|
capture_output=True,
|
|
text=True,
|
|
check=False,
|
|
timeout=timeout_seconds,
|
|
)
|
|
|
|
|
|
def _require_integration_assets() -> None:
|
|
if not SAMPLE_VIDEO_PATH.is_file():
|
|
pytest.skip(f"Missing sample video: {SAMPLE_VIDEO_PATH}")
|
|
if not CONFIG_PATH.is_file():
|
|
pytest.skip(f"Missing config: {CONFIG_PATH}")
|
|
if not YOLO_MODEL_PATH.is_file():
|
|
pytest.skip(f"Missing YOLO model file: {YOLO_MODEL_PATH}")
|
|
|
|
|
|
@pytest.fixture
|
|
def compatible_checkpoint_path(tmp_path: Path) -> Path:
|
|
if not CONFIG_PATH.is_file():
|
|
pytest.skip(f"Missing config: {CONFIG_PATH}")
|
|
|
|
checkpoint_file = tmp_path / "sconet-compatible.pt"
|
|
model = ScoNetDemo(cfg_path=str(CONFIG_PATH), checkpoint_path=None, device="cpu")
|
|
torch.save(model.state_dict(), checkpoint_file)
|
|
return checkpoint_file
|
|
|
|
|
|
def _extract_prediction_json_lines(stdout: str) -> list[dict[str, object]]:
|
|
required_keys = {
|
|
"frame",
|
|
"track_id",
|
|
"label",
|
|
"confidence",
|
|
"window",
|
|
"timestamp_ns",
|
|
}
|
|
predictions: list[dict[str, object]] = []
|
|
|
|
for line in stdout.splitlines():
|
|
stripped = line.strip()
|
|
if not stripped:
|
|
continue
|
|
try:
|
|
payload_obj = cast(object, json.loads(stripped))
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
if not isinstance(payload_obj, dict):
|
|
continue
|
|
payload = cast(dict[str, object], payload_obj)
|
|
if required_keys.issubset(payload.keys()):
|
|
predictions.append(payload)
|
|
|
|
return predictions
|
|
|
|
|
|
def _assert_prediction_schema(prediction: dict[str, object]) -> None:
|
|
assert isinstance(prediction["frame"], int)
|
|
assert isinstance(prediction["track_id"], int)
|
|
|
|
label = prediction["label"]
|
|
assert isinstance(label, str)
|
|
assert label in {"negative", "neutral", "positive"}
|
|
|
|
confidence = prediction["confidence"]
|
|
assert isinstance(confidence, (int, float))
|
|
confidence_value = float(confidence)
|
|
assert 0.0 <= confidence_value <= 1.0
|
|
|
|
window_obj = prediction["window"]
|
|
assert isinstance(window_obj, int)
|
|
assert window_obj >= 0
|
|
|
|
assert isinstance(prediction["timestamp_ns"], int)
|
|
|
|
|
|
def test_pipeline_cli_fps_benchmark_smoke(
|
|
compatible_checkpoint_path: Path,
|
|
) -> None:
|
|
_require_integration_assets()
|
|
|
|
max_frames = 90
|
|
started_at = time.perf_counter()
|
|
result = _run_pipeline_cli(
|
|
"--source",
|
|
str(SAMPLE_VIDEO_PATH),
|
|
"--checkpoint",
|
|
str(compatible_checkpoint_path),
|
|
"--config",
|
|
str(CONFIG_PATH),
|
|
"--device",
|
|
_device_for_runtime(),
|
|
"--yolo-model",
|
|
str(YOLO_MODEL_PATH),
|
|
"--window",
|
|
"5",
|
|
"--stride",
|
|
"1",
|
|
"--max-frames",
|
|
str(max_frames),
|
|
timeout_seconds=180,
|
|
)
|
|
elapsed_seconds = time.perf_counter() - started_at
|
|
|
|
assert result.returncode == 0, (
|
|
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
|
|
)
|
|
predictions = _extract_prediction_json_lines(result.stdout)
|
|
assert predictions, "Expected prediction output for FPS benchmark run"
|
|
|
|
for prediction in predictions:
|
|
_assert_prediction_schema(prediction)
|
|
|
|
observed_frames = {
|
|
frame_obj
|
|
for prediction in predictions
|
|
for frame_obj in [prediction["frame"]]
|
|
if isinstance(frame_obj, int)
|
|
}
|
|
observed_units = len(observed_frames)
|
|
if observed_units < 5:
|
|
pytest.skip(
|
|
"Insufficient observed frame samples for stable FPS benchmark in this environment"
|
|
)
|
|
if elapsed_seconds <= 0:
|
|
pytest.skip("Non-positive elapsed time; cannot compute FPS benchmark")
|
|
|
|
fps = observed_units / elapsed_seconds
|
|
min_expected_fps = 0.2
|
|
assert fps >= min_expected_fps, (
|
|
"Observed FPS below conservative CI threshold: "
|
|
f"{fps:.3f} < {min_expected_fps:.3f} "
|
|
f"(observed_units={observed_units}, elapsed_seconds={elapsed_seconds:.3f})"
|
|
)
|
|
|
|
|
|
def test_pipeline_cli_happy_path_outputs_json_predictions(
|
|
compatible_checkpoint_path: Path,
|
|
) -> None:
|
|
_require_integration_assets()
|
|
|
|
result = _run_pipeline_cli(
|
|
"--source",
|
|
str(SAMPLE_VIDEO_PATH),
|
|
"--checkpoint",
|
|
str(compatible_checkpoint_path),
|
|
"--config",
|
|
str(CONFIG_PATH),
|
|
"--device",
|
|
_device_for_runtime(),
|
|
"--yolo-model",
|
|
str(YOLO_MODEL_PATH),
|
|
"--window",
|
|
"10",
|
|
"--stride",
|
|
"10",
|
|
"--max-frames",
|
|
"120",
|
|
timeout_seconds=180,
|
|
)
|
|
|
|
assert result.returncode == 0, (
|
|
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
|
|
)
|
|
predictions = _extract_prediction_json_lines(result.stdout)
|
|
assert predictions, (
|
|
"Expected at least one prediction JSON line in stdout. "
|
|
f"stdout:\n{result.stdout}\nstderr:\n{result.stderr}"
|
|
)
|
|
for prediction in predictions:
|
|
_assert_prediction_schema(prediction)
|
|
|
|
assert "Connected to NATS" not in result.stderr
|
|
|
|
|
|
def test_pipeline_cli_max_frames_caps_output_frames(
|
|
compatible_checkpoint_path: Path,
|
|
) -> None:
|
|
_require_integration_assets()
|
|
|
|
max_frames = 20
|
|
result = _run_pipeline_cli(
|
|
"--source",
|
|
str(SAMPLE_VIDEO_PATH),
|
|
"--checkpoint",
|
|
str(compatible_checkpoint_path),
|
|
"--config",
|
|
str(CONFIG_PATH),
|
|
"--device",
|
|
_device_for_runtime(),
|
|
"--yolo-model",
|
|
str(YOLO_MODEL_PATH),
|
|
"--window",
|
|
"5",
|
|
"--stride",
|
|
"1",
|
|
"--max-frames",
|
|
str(max_frames),
|
|
timeout_seconds=180,
|
|
)
|
|
|
|
assert result.returncode == 0, (
|
|
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
|
|
)
|
|
predictions = _extract_prediction_json_lines(result.stdout)
|
|
assert predictions, "Expected prediction output with --max-frames run"
|
|
|
|
for prediction in predictions:
|
|
_assert_prediction_schema(prediction)
|
|
frame_idx_obj = prediction["frame"]
|
|
assert isinstance(frame_idx_obj, int)
|
|
assert frame_idx_obj < max_frames
|
|
|
|
|
|
def test_pipeline_cli_invalid_source_path_returns_user_error() -> None:
|
|
result = _run_pipeline_cli(
|
|
"--source",
|
|
"/definitely/not/a/real/video.mp4",
|
|
"--checkpoint",
|
|
"/tmp/unused-checkpoint.pt",
|
|
"--config",
|
|
str(CONFIG_PATH),
|
|
timeout_seconds=30,
|
|
)
|
|
|
|
assert result.returncode == 2
|
|
assert "Error: Video source not found" in result.stderr
|
|
|
|
|
|
def test_pipeline_cli_invalid_checkpoint_path_returns_user_error() -> None:
|
|
if not SAMPLE_VIDEO_PATH.is_file():
|
|
pytest.skip(f"Missing sample video: {SAMPLE_VIDEO_PATH}")
|
|
if not CONFIG_PATH.is_file():
|
|
pytest.skip(f"Missing config: {CONFIG_PATH}")
|
|
|
|
result = _run_pipeline_cli(
|
|
"--source",
|
|
str(SAMPLE_VIDEO_PATH),
|
|
"--checkpoint",
|
|
str(REPO_ROOT / "ckpt" / "missing-checkpoint.pt"),
|
|
"--config",
|
|
str(CONFIG_PATH),
|
|
timeout_seconds=30,
|
|
)
|
|
|
|
assert result.returncode == 2
|
|
assert "Error: Checkpoint not found" in result.stderr
|
|
|
|
|
|
def test_pipeline_cli_preprocess_only_requires_export_path(
|
|
compatible_checkpoint_path: Path,
|
|
) -> None:
|
|
"""Test that --preprocess-only requires --silhouette-export-path."""
|
|
_require_integration_assets()
|
|
|
|
result = _run_pipeline_cli(
|
|
"--source",
|
|
str(SAMPLE_VIDEO_PATH),
|
|
"--checkpoint",
|
|
str(compatible_checkpoint_path),
|
|
"--config",
|
|
str(CONFIG_PATH),
|
|
"--device",
|
|
_device_for_runtime(),
|
|
"--yolo-model",
|
|
str(YOLO_MODEL_PATH),
|
|
"--preprocess-only",
|
|
"--max-frames",
|
|
"10",
|
|
timeout_seconds=30,
|
|
)
|
|
|
|
assert result.returncode == 2
|
|
assert "--silhouette-export-path is required" in result.stderr
|
|
|
|
|
|
def test_pipeline_cli_preprocess_only_exports_pickle(
|
|
compatible_checkpoint_path: Path,
|
|
tmp_path: Path,
|
|
) -> None:
|
|
"""Test preprocess-only mode exports silhouettes to pickle."""
|
|
_require_integration_assets()
|
|
|
|
export_path = tmp_path / "silhouettes.pkl"
|
|
|
|
result = _run_pipeline_cli(
|
|
"--source",
|
|
str(SAMPLE_VIDEO_PATH),
|
|
"--checkpoint",
|
|
str(compatible_checkpoint_path),
|
|
"--config",
|
|
str(CONFIG_PATH),
|
|
"--device",
|
|
_device_for_runtime(),
|
|
"--yolo-model",
|
|
str(YOLO_MODEL_PATH),
|
|
"--preprocess-only",
|
|
"--silhouette-export-path",
|
|
str(export_path),
|
|
"--silhouette-export-format",
|
|
"pickle",
|
|
"--max-frames",
|
|
"30",
|
|
timeout_seconds=180,
|
|
)
|
|
|
|
assert result.returncode == 0, (
|
|
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
|
|
)
|
|
|
|
# Verify export file exists and contains silhouettes
|
|
assert export_path.is_file(), f"Export file not found: {export_path}"
|
|
|
|
with open(export_path, "rb") as f:
|
|
silhouettes = pickle.load(f)
|
|
|
|
assert isinstance(silhouettes, list)
|
|
assert len(silhouettes) > 0, "Expected at least one silhouette"
|
|
|
|
# Verify silhouette schema
|
|
for item in silhouettes:
|
|
assert isinstance(item, dict)
|
|
assert "frame" in item
|
|
assert "track_id" in item
|
|
assert "timestamp_ns" in item
|
|
assert "silhouette" in item
|
|
assert isinstance(item["frame"], int)
|
|
assert isinstance(item["track_id"], int)
|
|
assert isinstance(item["timestamp_ns"], int)
|
|
|
|
|
|
def test_pipeline_cli_result_export_json(
|
|
compatible_checkpoint_path: Path,
|
|
tmp_path: Path,
|
|
) -> None:
|
|
"""Test that results can be exported to JSON file."""
|
|
_require_integration_assets()
|
|
|
|
export_path = tmp_path / "results.jsonl"
|
|
|
|
result = _run_pipeline_cli(
|
|
"--source",
|
|
str(SAMPLE_VIDEO_PATH),
|
|
"--checkpoint",
|
|
str(compatible_checkpoint_path),
|
|
"--config",
|
|
str(CONFIG_PATH),
|
|
"--device",
|
|
_device_for_runtime(),
|
|
"--yolo-model",
|
|
str(YOLO_MODEL_PATH),
|
|
"--window",
|
|
"10",
|
|
"--stride",
|
|
"10",
|
|
"--result-export-path",
|
|
str(export_path),
|
|
"--result-export-format",
|
|
"json",
|
|
"--max-frames",
|
|
"60",
|
|
timeout_seconds=180,
|
|
)
|
|
|
|
assert result.returncode == 0, (
|
|
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
|
|
)
|
|
|
|
# Verify export file exists
|
|
assert export_path.is_file(), f"Export file not found: {export_path}"
|
|
|
|
# Read and verify JSON lines
|
|
predictions: list[dict[str, object]] = []
|
|
with open(export_path, "r", encoding="utf-8") as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if line:
|
|
predictions.append(cast(dict[str, object], json.loads(line)))
|
|
|
|
assert len(predictions) > 0, "Expected at least one prediction in export"
|
|
|
|
for prediction in predictions:
|
|
_assert_prediction_schema(prediction)
|
|
|
|
|
|
def test_pipeline_cli_result_export_pickle(
|
|
compatible_checkpoint_path: Path,
|
|
tmp_path: Path,
|
|
) -> None:
|
|
"""Test that results can be exported to pickle file."""
|
|
_require_integration_assets()
|
|
|
|
export_path = tmp_path / "results.pkl"
|
|
|
|
result = _run_pipeline_cli(
|
|
"--source",
|
|
str(SAMPLE_VIDEO_PATH),
|
|
"--checkpoint",
|
|
str(compatible_checkpoint_path),
|
|
"--config",
|
|
str(CONFIG_PATH),
|
|
"--device",
|
|
_device_for_runtime(),
|
|
"--yolo-model",
|
|
str(YOLO_MODEL_PATH),
|
|
"--window",
|
|
"10",
|
|
"--stride",
|
|
"10",
|
|
"--result-export-path",
|
|
str(export_path),
|
|
"--result-export-format",
|
|
"pickle",
|
|
"--max-frames",
|
|
"60",
|
|
timeout_seconds=180,
|
|
)
|
|
|
|
assert result.returncode == 0, (
|
|
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
|
|
)
|
|
|
|
# Verify export file exists
|
|
assert export_path.is_file(), f"Export file not found: {export_path}"
|
|
|
|
# Read and verify pickle
|
|
with open(export_path, "rb") as f:
|
|
predictions = pickle.load(f)
|
|
|
|
assert isinstance(predictions, list)
|
|
assert len(predictions) > 0, "Expected at least one prediction in export"
|
|
|
|
for prediction in predictions:
|
|
_assert_prediction_schema(prediction)
|
|
|
|
|
|
def test_pipeline_cli_silhouette_and_result_export(
|
|
compatible_checkpoint_path: Path,
|
|
tmp_path: Path,
|
|
) -> None:
|
|
"""Test exporting both silhouettes and results simultaneously."""
|
|
_require_integration_assets()
|
|
|
|
silhouette_export = tmp_path / "silhouettes.pkl"
|
|
result_export = tmp_path / "results.jsonl"
|
|
|
|
result = _run_pipeline_cli(
|
|
"--source",
|
|
str(SAMPLE_VIDEO_PATH),
|
|
"--checkpoint",
|
|
str(compatible_checkpoint_path),
|
|
"--config",
|
|
str(CONFIG_PATH),
|
|
"--device",
|
|
_device_for_runtime(),
|
|
"--yolo-model",
|
|
str(YOLO_MODEL_PATH),
|
|
"--window",
|
|
"10",
|
|
"--stride",
|
|
"10",
|
|
"--silhouette-export-path",
|
|
str(silhouette_export),
|
|
"--silhouette-export-format",
|
|
"pickle",
|
|
"--result-export-path",
|
|
str(result_export),
|
|
"--result-export-format",
|
|
"json",
|
|
"--max-frames",
|
|
"60",
|
|
timeout_seconds=180,
|
|
)
|
|
|
|
assert result.returncode == 0, (
|
|
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
|
|
)
|
|
|
|
# Verify both export files exist
|
|
assert silhouette_export.is_file(), (
|
|
f"Silhouette export not found: {silhouette_export}"
|
|
)
|
|
assert result_export.is_file(), f"Result export not found: {result_export}"
|
|
|
|
# Verify silhouette export
|
|
with open(silhouette_export, "rb") as f:
|
|
silhouettes = pickle.load(f)
|
|
assert isinstance(silhouettes, list)
|
|
assert len(silhouettes) > 0
|
|
|
|
# Verify result export
|
|
with open(result_export, "r", encoding="utf-8") as f:
|
|
predictions = [
|
|
cast(dict[str, object], json.loads(line)) for line in f if line.strip()
|
|
]
|
|
assert len(predictions) > 0
|
|
|
|
|
|
def test_pipeline_cli_parquet_export_requires_pyarrow(
|
|
compatible_checkpoint_path: Path,
|
|
tmp_path: Path,
|
|
) -> None:
|
|
"""Test that parquet export fails gracefully when pyarrow is not available."""
|
|
_require_integration_assets()
|
|
|
|
# Skip if pyarrow is actually installed
|
|
if importlib.util.find_spec("pyarrow") is not None:
|
|
pytest.skip("pyarrow is installed, skipping missing dependency test")
|
|
try:
|
|
import pyarrow # noqa: F401
|
|
|
|
pytest.skip("pyarrow is installed, skipping missing dependency test")
|
|
except ImportError:
|
|
pass
|
|
|
|
export_path = tmp_path / "results.parquet"
|
|
|
|
result = _run_pipeline_cli(
|
|
"--source",
|
|
str(SAMPLE_VIDEO_PATH),
|
|
"--checkpoint",
|
|
str(compatible_checkpoint_path),
|
|
"--config",
|
|
str(CONFIG_PATH),
|
|
"--device",
|
|
_device_for_runtime(),
|
|
"--yolo-model",
|
|
str(YOLO_MODEL_PATH),
|
|
"--window",
|
|
"10",
|
|
"--stride",
|
|
"10",
|
|
"--result-export-path",
|
|
str(export_path),
|
|
"--result-export-format",
|
|
"parquet",
|
|
"--max-frames",
|
|
"30",
|
|
timeout_seconds=180,
|
|
)
|
|
|
|
# Should fail with RuntimeError about pyarrow
|
|
assert result.returncode == 1
|
|
assert "parquet" in result.stderr.lower() or "pyarrow" in result.stderr.lower()
|
|
|
|
|
|
def test_pipeline_cli_silhouette_visualization(
|
|
compatible_checkpoint_path: Path,
|
|
tmp_path: Path,
|
|
) -> None:
|
|
"""Test that silhouette visualization creates PNG files."""
|
|
_require_integration_assets()
|
|
|
|
visualize_dir = tmp_path / "silhouette_viz"
|
|
|
|
result = _run_pipeline_cli(
|
|
"--source",
|
|
str(SAMPLE_VIDEO_PATH),
|
|
"--checkpoint",
|
|
str(compatible_checkpoint_path),
|
|
"--config",
|
|
str(CONFIG_PATH),
|
|
"--device",
|
|
_device_for_runtime(),
|
|
"--yolo-model",
|
|
str(YOLO_MODEL_PATH),
|
|
"--window",
|
|
"10",
|
|
"--stride",
|
|
"10",
|
|
"--silhouette-visualize-dir",
|
|
str(visualize_dir),
|
|
"--max-frames",
|
|
"30",
|
|
timeout_seconds=180,
|
|
)
|
|
|
|
assert result.returncode == 0, (
|
|
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
|
|
)
|
|
|
|
# Verify visualization directory exists and contains PNG files
|
|
assert visualize_dir.is_dir(), f"Visualization directory not found: {visualize_dir}"
|
|
|
|
png_files = list(visualize_dir.glob("*.png"))
|
|
assert len(png_files) > 0, "Expected at least one PNG visualization file"
|
|
|
|
# Verify filenames contain frame and track info
|
|
for png_file in png_files:
|
|
assert "silhouette_frame" in png_file.name
|
|
assert "_track" in png_file.name
|
|
|
|
|
|
def test_pipeline_cli_preprocess_only_with_visualization(
|
|
compatible_checkpoint_path: Path,
|
|
tmp_path: Path,
|
|
) -> None:
|
|
"""Test preprocess-only mode with both export and visualization."""
|
|
_require_integration_assets()
|
|
|
|
export_path = tmp_path / "silhouettes.pkl"
|
|
visualize_dir = tmp_path / "silhouette_viz"
|
|
|
|
result = _run_pipeline_cli(
|
|
"--source",
|
|
str(SAMPLE_VIDEO_PATH),
|
|
"--checkpoint",
|
|
str(compatible_checkpoint_path),
|
|
"--config",
|
|
str(CONFIG_PATH),
|
|
"--device",
|
|
_device_for_runtime(),
|
|
"--yolo-model",
|
|
str(YOLO_MODEL_PATH),
|
|
"--preprocess-only",
|
|
"--silhouette-export-path",
|
|
str(export_path),
|
|
"--silhouette-visualize-dir",
|
|
str(visualize_dir),
|
|
"--max-frames",
|
|
"30",
|
|
timeout_seconds=180,
|
|
)
|
|
|
|
assert result.returncode == 0, (
|
|
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
|
|
)
|
|
|
|
# Verify export file exists
|
|
assert export_path.is_file(), f"Export file not found: {export_path}"
|
|
|
|
# Verify visualization files exist
|
|
assert visualize_dir.is_dir(), f"Visualization directory not found: {visualize_dir}"
|
|
png_files = list(visualize_dir.glob("*.png"))
|
|
assert len(png_files) > 0, "Expected at least one PNG visualization file"
|
|
|
|
# Load and verify pickle export
|
|
with open(export_path, "rb") as f:
|
|
silhouettes = pickle.load(f)
|
|
assert isinstance(silhouettes, list)
|
|
assert len(silhouettes) > 0
|
|
# Number of exported silhouettes should match number of PNG files
|
|
assert len(silhouettes) == len(png_files), (
|
|
f"Mismatch: {len(silhouettes)} silhouettes exported but {len(png_files)} PNG files created"
|
|
)
|
|
|
|
|
|
class MockVisualizer:
|
|
"""Mock visualizer to track update calls."""
|
|
|
|
def __init__(self) -> None:
|
|
self.update_calls: list[dict[str, object]] = []
|
|
self.return_value: bool = True
|
|
|
|
def update(
|
|
self,
|
|
frame: NDArray[np.uint8],
|
|
bbox: tuple[int, int, int, int] | None,
|
|
bbox_mask: tuple[int, int, int, int] | None,
|
|
track_id: int,
|
|
mask_raw: NDArray[np.uint8] | None,
|
|
silhouette: NDArray[np.float32] | None,
|
|
segmentation_input: NDArray[np.float32] | None,
|
|
label: str | None,
|
|
confidence: float | None,
|
|
fps: float,
|
|
) -> bool:
|
|
self.update_calls.append(
|
|
{
|
|
"frame": frame,
|
|
"bbox": bbox,
|
|
"bbox_mask": bbox_mask,
|
|
"track_id": track_id,
|
|
"mask_raw": mask_raw,
|
|
"silhouette": silhouette,
|
|
"segmentation_input": segmentation_input,
|
|
"label": label,
|
|
"confidence": confidence,
|
|
"fps": fps,
|
|
}
|
|
)
|
|
return self.return_value
|
|
|
|
def close(self) -> None:
|
|
pass
|
|
|
|
|
|
def test_pipeline_visualizer_updates_on_no_detection() -> None:
|
|
"""Test that visualizer is still updated even when process_frame returns None.
|
|
|
|
This is a regression test for the window freeze issue when no person is detected.
|
|
The window should refresh every frame to prevent freezing.
|
|
"""
|
|
from opengait_studio.pipeline import ScoliosisPipeline
|
|
|
|
# Create a minimal pipeline with mocked dependencies
|
|
with (
|
|
mock.patch("opengait_studio.pipeline.YOLO") as mock_yolo,
|
|
mock.patch("opengait_studio.pipeline.create_source") as mock_source,
|
|
mock.patch("opengait_studio.pipeline.create_publisher") as mock_publisher,
|
|
mock.patch("opengait_studio.pipeline.ScoNetDemo") as mock_classifier,
|
|
):
|
|
# Setup mock detector that returns no detections (causing process_frame to return None)
|
|
mock_detector = mock.MagicMock()
|
|
mock_detector.track.return_value = [] # No detections
|
|
mock_yolo.return_value = mock_detector
|
|
|
|
# Setup mock source with 3 frames
|
|
mock_frame = np.zeros((480, 640, 3), dtype=np.uint8)
|
|
mock_source.return_value = [(mock_frame, {"frame_count": i}) for i in range(3)]
|
|
|
|
# Setup mock publisher and classifier
|
|
mock_publisher.return_value = mock.MagicMock()
|
|
mock_classifier.return_value = mock.MagicMock()
|
|
|
|
# Create pipeline with visualize enabled
|
|
pipeline = ScoliosisPipeline(
|
|
source="dummy.mp4",
|
|
checkpoint="dummy.pt",
|
|
config=str(CONFIG_PATH) if CONFIG_PATH.exists() else "dummy.yaml",
|
|
device="cpu",
|
|
yolo_model="dummy.pt",
|
|
window=30,
|
|
stride=30,
|
|
nats_url=None,
|
|
nats_subject="test",
|
|
max_frames=None,
|
|
visualize=True,
|
|
)
|
|
|
|
mock_viz = MockVisualizer()
|
|
setattr(pipeline, "_visualizer", mock_viz)
|
|
|
|
# Run pipeline
|
|
_ = pipeline.run()
|
|
|
|
# Verify visualizer was updated for all 3 frames even with no detections
|
|
assert len(mock_viz.update_calls) == 3, (
|
|
f"Expected visualizer.update() to be called 3 times (once per frame), "
|
|
f"but was called {len(mock_viz.update_calls)} times. "
|
|
f"Window would freeze if not updated on no-detection frames."
|
|
)
|
|
|
|
# Verify each call had the frame data
|
|
for call in mock_viz.update_calls:
|
|
assert call["track_id"] == 0 # Default track_id when no detection
|
|
assert call["bbox"] is None # No bbox when no detection
|
|
assert call["bbox_mask"] is None
|
|
assert call["mask_raw"] is None # No mask when no detection
|
|
assert call["silhouette"] is None # No silhouette when no detection
|
|
assert call["segmentation_input"] is None
|
|
assert call["label"] is None # No label when no detection
|
|
assert call["confidence"] is None # No confidence when no detection
|
|
|
|
|
|
def test_pipeline_visualizer_clears_bbox_on_no_detection() -> None:
|
|
from opengait_studio.pipeline import ScoliosisPipeline
|
|
|
|
# Create a minimal pipeline with mocked dependencies
|
|
with (
|
|
mock.patch("opengait_studio.pipeline.YOLO") as mock_yolo,
|
|
mock.patch("opengait_studio.pipeline.create_source") as mock_source,
|
|
mock.patch("opengait_studio.pipeline.create_publisher") as mock_publisher,
|
|
mock.patch("opengait_studio.pipeline.ScoNetDemo") as mock_classifier,
|
|
mock.patch("opengait_studio.pipeline.select_person") as mock_select_person,
|
|
mock.patch("opengait_studio.pipeline.mask_to_silhouette") as mock_mask_to_sil,
|
|
):
|
|
# Create mock detection result for frames 0-1 (valid detection)
|
|
mock_box = mock.MagicMock()
|
|
mock_box.xyxy = np.array([[100, 100, 200, 300]], dtype=np.float32)
|
|
mock_box.id = np.array([1], dtype=np.int64)
|
|
mock_mask = mock.MagicMock()
|
|
mock_mask.data = np.random.rand(1, 480, 640).astype(np.float32)
|
|
mock_result = mock.MagicMock()
|
|
mock_result.boxes = mock_box
|
|
mock_result.masks = mock_mask
|
|
|
|
# Setup mock detector: detection for frames 0-1, then no detection for frames 2-3
|
|
mock_detector = mock.MagicMock()
|
|
mock_detector.track.side_effect = [
|
|
[mock_result], # Frame 0: valid detection
|
|
[mock_result], # Frame 1: valid detection
|
|
[], # Frame 2: no detection
|
|
[], # Frame 3: no detection
|
|
]
|
|
mock_yolo.return_value = mock_detector
|
|
|
|
# Setup mock source with 4 frames
|
|
mock_frame = np.zeros((480, 640, 3), dtype=np.uint8)
|
|
mock_source.return_value = [(mock_frame, {"frame_count": i}) for i in range(4)]
|
|
|
|
# Setup mock publisher and classifier
|
|
mock_publisher.return_value = mock.MagicMock()
|
|
mock_classifier.return_value = mock.MagicMock()
|
|
|
|
# Setup mock select_person to return valid mask and bbox
|
|
dummy_mask = np.random.randint(0, 256, (480, 640), dtype=np.uint8)
|
|
dummy_bbox_mask = (100, 100, 200, 300)
|
|
dummy_bbox_frame = (100, 100, 200, 300)
|
|
mock_select_person.return_value = (
|
|
dummy_mask,
|
|
dummy_bbox_mask,
|
|
dummy_bbox_frame,
|
|
1,
|
|
)
|
|
|
|
# Setup mock mask_to_silhouette to return valid silhouette
|
|
dummy_silhouette = np.random.rand(64, 44).astype(np.float32)
|
|
mock_mask_to_sil.return_value = dummy_silhouette
|
|
|
|
# Create pipeline with visualize enabled
|
|
pipeline = ScoliosisPipeline(
|
|
source="dummy.mp4",
|
|
checkpoint="dummy.pt",
|
|
config=str(CONFIG_PATH) if CONFIG_PATH.exists() else "dummy.yaml",
|
|
device="cpu",
|
|
yolo_model="dummy.pt",
|
|
window=30,
|
|
stride=30,
|
|
nats_url=None,
|
|
nats_subject="test",
|
|
max_frames=None,
|
|
visualize=True,
|
|
)
|
|
|
|
mock_viz = MockVisualizer()
|
|
setattr(pipeline, "_visualizer", mock_viz)
|
|
|
|
# Run pipeline
|
|
_ = pipeline.run()
|
|
|
|
# Verify visualizer was updated for all 4 frames
|
|
assert len(mock_viz.update_calls) == 4, (
|
|
f"Expected visualizer.update() to be called 4 times, "
|
|
f"but was called {len(mock_viz.update_calls)} times."
|
|
)
|
|
|
|
# Extract the mask_raw values from each call
|
|
mask_raw_calls = [call["mask_raw"] for call in mock_viz.update_calls]
|
|
|
|
# Frames 0 and 1 should have valid masks (not None)
|
|
assert mask_raw_calls[0] is not None, "Frame 0 should have valid mask"
|
|
assert mask_raw_calls[1] is not None, "Frame 1 should have valid mask"
|
|
|
|
assert mask_raw_calls[2] is not None, (
|
|
"Frame 2 (no detection) should display cached mask from last valid detection, "
|
|
"not None/blank"
|
|
)
|
|
assert mask_raw_calls[3] is not None, (
|
|
"Frame 3 (no detection) should display cached mask from last valid detection, "
|
|
"not None/blank"
|
|
)
|
|
|
|
segmentation_inputs = [
|
|
call["segmentation_input"] for call in mock_viz.update_calls
|
|
]
|
|
bbox_mask_calls = [call["bbox_mask"] for call in mock_viz.update_calls]
|
|
assert segmentation_inputs[0] is not None
|
|
assert segmentation_inputs[1] is not None
|
|
assert segmentation_inputs[2] is not None
|
|
assert segmentation_inputs[3] is not None
|
|
bbox_calls = [call["bbox"] for call in mock_viz.update_calls]
|
|
assert bbox_calls[0] == dummy_bbox_frame
|
|
assert bbox_calls[1] == dummy_bbox_frame
|
|
assert bbox_calls[2] is None
|
|
assert bbox_calls[3] is None
|
|
assert bbox_mask_calls[0] == dummy_bbox_mask
|
|
assert bbox_mask_calls[1] == dummy_bbox_mask
|
|
assert bbox_mask_calls[2] is None
|
|
assert bbox_mask_calls[3] is None
|
|
label_calls = [call["label"] for call in mock_viz.update_calls]
|
|
confidence_calls = [call["confidence"] for call in mock_viz.update_calls]
|
|
assert label_calls[2] is None
|
|
assert label_calls[3] is None
|
|
assert confidence_calls[2] is None
|
|
assert confidence_calls[3] is None
|
|
|
|
if mask_raw_calls[1] is not None and mask_raw_calls[2] is not None:
|
|
assert mask_raw_calls[1] is not mask_raw_calls[2], (
|
|
"Cached mask should be a copy, not the same object reference"
|
|
)
|
|
|
|
|
|
def test_frame_pacer_emission_count_24_to_15() -> None:
|
|
from opengait_studio.pipeline import _FramePacer
|
|
|
|
pacer = _FramePacer(15.0)
|
|
interval_ns = int(1_000_000_000 / 24)
|
|
emitted = sum(pacer.should_emit(i * interval_ns) for i in range(100))
|
|
assert 60 <= emitted <= 65
|
|
|
|
|
|
def test_pipeline_pacing_skips_window_growth_until_emitted() -> None:
|
|
from opengait_studio.pipeline import ScoliosisPipeline
|
|
|
|
with (
|
|
mock.patch("opengait_studio.pipeline.YOLO") as mock_yolo,
|
|
mock.patch("opengait_studio.pipeline.create_source") as mock_source,
|
|
mock.patch("opengait_studio.pipeline.create_publisher") as mock_publisher,
|
|
mock.patch("opengait_studio.pipeline.ScoNetDemo") as mock_classifier,
|
|
mock.patch("opengait_studio.pipeline.select_person") as mock_select_person,
|
|
mock.patch("opengait_studio.pipeline.mask_to_silhouette") as mock_mask_to_sil,
|
|
):
|
|
mock_detector = mock.MagicMock()
|
|
mock_box = mock.MagicMock()
|
|
mock_box.xyxy = np.array([[100, 100, 200, 300]], dtype=np.float32)
|
|
mock_box.id = np.array([1], dtype=np.int64)
|
|
mock_mask = mock.MagicMock()
|
|
mock_mask.data = np.random.rand(1, 480, 640).astype(np.float32)
|
|
mock_result = mock.MagicMock()
|
|
mock_result.boxes = mock_box
|
|
mock_result.masks = mock_mask
|
|
mock_detector.track.return_value = [mock_result]
|
|
mock_yolo.return_value = mock_detector
|
|
mock_source.return_value = []
|
|
mock_publisher.return_value = mock.MagicMock()
|
|
|
|
mock_model = mock.MagicMock()
|
|
mock_model.predict.return_value = ("neutral", 0.7)
|
|
mock_classifier.return_value = mock_model
|
|
|
|
dummy_mask = np.random.randint(0, 256, (480, 640), dtype=np.uint8)
|
|
dummy_bbox_mask = (100, 100, 200, 300)
|
|
dummy_bbox_frame = (100, 100, 200, 300)
|
|
dummy_silhouette = np.random.rand(64, 44).astype(np.float32)
|
|
mock_select_person.return_value = (
|
|
dummy_mask,
|
|
dummy_bbox_mask,
|
|
dummy_bbox_frame,
|
|
1,
|
|
)
|
|
mock_mask_to_sil.return_value = dummy_silhouette
|
|
|
|
pipeline = ScoliosisPipeline(
|
|
source="dummy.mp4",
|
|
checkpoint="dummy.pt",
|
|
config=str(CONFIG_PATH) if CONFIG_PATH.exists() else "dummy.yaml",
|
|
device="cpu",
|
|
yolo_model="dummy.pt",
|
|
window=2,
|
|
stride=1,
|
|
nats_url=None,
|
|
nats_subject="test",
|
|
max_frames=None,
|
|
target_fps=15.0,
|
|
)
|
|
frame = np.zeros((480, 640, 3), dtype=np.uint8)
|
|
|
|
first = pipeline.process_frame(
|
|
frame,
|
|
{"frame_count": 0, "timestamp_ns": 1_000_000_000},
|
|
)
|
|
second = pipeline.process_frame(
|
|
frame,
|
|
{"frame_count": 1, "timestamp_ns": 1_033_000_000},
|
|
)
|
|
third = pipeline.process_frame(
|
|
frame,
|
|
{"frame_count": 2, "timestamp_ns": 1_067_000_000},
|
|
)
|
|
|
|
assert first is not None
|
|
assert second is not None
|
|
assert third is not None
|
|
assert first["segmentation_input"] is not None
|
|
assert second["segmentation_input"] is not None
|
|
assert third["segmentation_input"] is not None
|
|
assert first["segmentation_input"].shape[0] == 1
|
|
assert second["segmentation_input"].shape[0] == 1
|
|
assert second["label"] is None
|
|
assert third["segmentation_input"].shape[0] == 2
|
|
assert third["label"] == "neutral"
|
|
|
|
|
|
def test_frame_pacer_requires_positive_target_fps() -> None:
|
|
from opengait_studio.pipeline import _FramePacer
|
|
|
|
with pytest.raises(ValueError, match="target_fps must be positive"):
|
|
_FramePacer(0.0)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("window", "stride", "mode", "expected"),
|
|
[
|
|
(30, 30, "manual", 30),
|
|
(30, 7, "manual", 7),
|
|
(30, 30, "sliding", 1),
|
|
(30, 1, "chunked", 30),
|
|
(15, 3, "chunked", 15),
|
|
],
|
|
)
|
|
def test_resolve_stride_modes(
|
|
window: int,
|
|
stride: int,
|
|
mode: Literal["manual", "sliding", "chunked"],
|
|
expected: int,
|
|
) -> None:
|
|
from opengait_studio.pipeline import resolve_stride
|
|
|
|
assert resolve_stride(window, stride, mode) == expected
|