d6fd6c03e6
Introduce focused unit, integration, and NATS-path tests for demo modules, and align assertions with final schema and temporal contracts (window int, seq=30, fill-level ratio). This commit isolates validation logic from runtime changes and provides reproducible QA for pipeline behavior and failure modes.
280 lines
7.9 KiB
Python
280 lines
7.9 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from pathlib import Path
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
from typing import Final, cast
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from opengait.demo.sconet_demo import ScoNetDemo
|
|
|
|
REPO_ROOT: Final[Path] = Path(__file__).resolve().parents[2]
|
|
SAMPLE_VIDEO_PATH: Final[Path] = REPO_ROOT / "assets" / "sample.mp4"
|
|
CHECKPOINT_PATH: Final[Path] = REPO_ROOT / "ckpt" / "ScoNet-20000.pt"
|
|
CONFIG_PATH: Final[Path] = REPO_ROOT / "configs" / "sconet" / "sconet_scoliosis1k.yaml"
|
|
YOLO_MODEL_PATH: Final[Path] = REPO_ROOT / "yolo11n-seg.pt"
|
|
|
|
|
|
def _device_for_runtime() -> str:
|
|
return "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
def _run_pipeline_cli(
|
|
*args: str, timeout_seconds: int = 120
|
|
) -> subprocess.CompletedProcess[str]:
|
|
command = [sys.executable, "-m", "opengait.demo", *args]
|
|
return subprocess.run(
|
|
command,
|
|
cwd=REPO_ROOT,
|
|
capture_output=True,
|
|
text=True,
|
|
check=False,
|
|
timeout=timeout_seconds,
|
|
)
|
|
|
|
|
|
def _require_integration_assets() -> None:
|
|
if not SAMPLE_VIDEO_PATH.is_file():
|
|
pytest.skip(f"Missing sample video: {SAMPLE_VIDEO_PATH}")
|
|
if not CONFIG_PATH.is_file():
|
|
pytest.skip(f"Missing config: {CONFIG_PATH}")
|
|
if not YOLO_MODEL_PATH.is_file():
|
|
pytest.skip(f"Missing YOLO model file: {YOLO_MODEL_PATH}")
|
|
|
|
|
|
@pytest.fixture
|
|
def compatible_checkpoint_path(tmp_path: Path) -> Path:
|
|
if not CONFIG_PATH.is_file():
|
|
pytest.skip(f"Missing config: {CONFIG_PATH}")
|
|
|
|
checkpoint_file = tmp_path / "sconet-compatible.pt"
|
|
model = ScoNetDemo(cfg_path=str(CONFIG_PATH), checkpoint_path=None, device="cpu")
|
|
torch.save(model.state_dict(), checkpoint_file)
|
|
return checkpoint_file
|
|
|
|
|
|
def _extract_prediction_json_lines(stdout: str) -> list[dict[str, object]]:
|
|
required_keys = {
|
|
"frame",
|
|
"track_id",
|
|
"label",
|
|
"confidence",
|
|
"window",
|
|
"timestamp_ns",
|
|
}
|
|
predictions: list[dict[str, object]] = []
|
|
|
|
for line in stdout.splitlines():
|
|
stripped = line.strip()
|
|
if not stripped:
|
|
continue
|
|
try:
|
|
payload_obj = cast(object, json.loads(stripped))
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
if not isinstance(payload_obj, dict):
|
|
continue
|
|
payload = cast(dict[str, object], payload_obj)
|
|
if required_keys.issubset(payload.keys()):
|
|
predictions.append(payload)
|
|
|
|
return predictions
|
|
|
|
|
|
def _assert_prediction_schema(prediction: dict[str, object]) -> None:
|
|
assert isinstance(prediction["frame"], int)
|
|
assert isinstance(prediction["track_id"], int)
|
|
|
|
label = prediction["label"]
|
|
assert isinstance(label, str)
|
|
assert label in {"negative", "neutral", "positive"}
|
|
|
|
confidence = prediction["confidence"]
|
|
assert isinstance(confidence, (int, float))
|
|
confidence_value = float(confidence)
|
|
assert 0.0 <= confidence_value <= 1.0
|
|
|
|
window_obj = prediction["window"]
|
|
assert isinstance(window_obj, int)
|
|
assert window_obj >= 0
|
|
|
|
assert isinstance(prediction["timestamp_ns"], int)
|
|
|
|
|
|
def test_pipeline_cli_fps_benchmark_smoke(
|
|
compatible_checkpoint_path: Path,
|
|
) -> None:
|
|
_require_integration_assets()
|
|
|
|
max_frames = 90
|
|
started_at = time.perf_counter()
|
|
result = _run_pipeline_cli(
|
|
"--source",
|
|
str(SAMPLE_VIDEO_PATH),
|
|
"--checkpoint",
|
|
str(compatible_checkpoint_path),
|
|
"--config",
|
|
str(CONFIG_PATH),
|
|
"--device",
|
|
_device_for_runtime(),
|
|
"--yolo-model",
|
|
str(YOLO_MODEL_PATH),
|
|
"--window",
|
|
"5",
|
|
"--stride",
|
|
"1",
|
|
"--max-frames",
|
|
str(max_frames),
|
|
timeout_seconds=180,
|
|
)
|
|
elapsed_seconds = time.perf_counter() - started_at
|
|
|
|
assert result.returncode == 0, (
|
|
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
|
|
)
|
|
predictions = _extract_prediction_json_lines(result.stdout)
|
|
assert predictions, "Expected prediction output for FPS benchmark run"
|
|
|
|
for prediction in predictions:
|
|
_assert_prediction_schema(prediction)
|
|
|
|
observed_frames = {
|
|
frame_obj
|
|
for prediction in predictions
|
|
for frame_obj in [prediction["frame"]]
|
|
if isinstance(frame_obj, int)
|
|
}
|
|
observed_units = len(observed_frames)
|
|
if observed_units < 5:
|
|
pytest.skip(
|
|
"Insufficient observed frame samples for stable FPS benchmark in this environment"
|
|
)
|
|
if elapsed_seconds <= 0:
|
|
pytest.skip("Non-positive elapsed time; cannot compute FPS benchmark")
|
|
|
|
fps = observed_units / elapsed_seconds
|
|
min_expected_fps = 0.2
|
|
assert fps >= min_expected_fps, (
|
|
"Observed FPS below conservative CI threshold: "
|
|
f"{fps:.3f} < {min_expected_fps:.3f} "
|
|
f"(observed_units={observed_units}, elapsed_seconds={elapsed_seconds:.3f})"
|
|
)
|
|
|
|
|
|
def test_pipeline_cli_happy_path_outputs_json_predictions(
|
|
compatible_checkpoint_path: Path,
|
|
) -> None:
|
|
_require_integration_assets()
|
|
|
|
result = _run_pipeline_cli(
|
|
"--source",
|
|
str(SAMPLE_VIDEO_PATH),
|
|
"--checkpoint",
|
|
str(compatible_checkpoint_path),
|
|
"--config",
|
|
str(CONFIG_PATH),
|
|
"--device",
|
|
_device_for_runtime(),
|
|
"--yolo-model",
|
|
str(YOLO_MODEL_PATH),
|
|
"--window",
|
|
"10",
|
|
"--stride",
|
|
"10",
|
|
"--max-frames",
|
|
"120",
|
|
timeout_seconds=180,
|
|
)
|
|
|
|
assert result.returncode == 0, (
|
|
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
|
|
)
|
|
predictions = _extract_prediction_json_lines(result.stdout)
|
|
assert predictions, (
|
|
"Expected at least one prediction JSON line in stdout. "
|
|
f"stdout:\n{result.stdout}\nstderr:\n{result.stderr}"
|
|
)
|
|
for prediction in predictions:
|
|
_assert_prediction_schema(prediction)
|
|
|
|
assert "Connected to NATS" not in result.stderr
|
|
|
|
|
|
def test_pipeline_cli_max_frames_caps_output_frames(
|
|
compatible_checkpoint_path: Path,
|
|
) -> None:
|
|
_require_integration_assets()
|
|
|
|
max_frames = 20
|
|
result = _run_pipeline_cli(
|
|
"--source",
|
|
str(SAMPLE_VIDEO_PATH),
|
|
"--checkpoint",
|
|
str(compatible_checkpoint_path),
|
|
"--config",
|
|
str(CONFIG_PATH),
|
|
"--device",
|
|
_device_for_runtime(),
|
|
"--yolo-model",
|
|
str(YOLO_MODEL_PATH),
|
|
"--window",
|
|
"5",
|
|
"--stride",
|
|
"1",
|
|
"--max-frames",
|
|
str(max_frames),
|
|
timeout_seconds=180,
|
|
)
|
|
|
|
assert result.returncode == 0, (
|
|
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
|
|
)
|
|
predictions = _extract_prediction_json_lines(result.stdout)
|
|
assert predictions, "Expected prediction output with --max-frames run"
|
|
|
|
for prediction in predictions:
|
|
_assert_prediction_schema(prediction)
|
|
frame_idx_obj = prediction["frame"]
|
|
assert isinstance(frame_idx_obj, int)
|
|
assert frame_idx_obj < max_frames
|
|
|
|
|
|
def test_pipeline_cli_invalid_source_path_returns_user_error() -> None:
|
|
result = _run_pipeline_cli(
|
|
"--source",
|
|
"/definitely/not/a/real/video.mp4",
|
|
"--checkpoint",
|
|
"/tmp/unused-checkpoint.pt",
|
|
"--config",
|
|
str(CONFIG_PATH),
|
|
timeout_seconds=30,
|
|
)
|
|
|
|
assert result.returncode == 2
|
|
assert "Error: Video source not found" in result.stderr
|
|
|
|
|
|
def test_pipeline_cli_invalid_checkpoint_path_returns_user_error() -> None:
|
|
if not SAMPLE_VIDEO_PATH.is_file():
|
|
pytest.skip(f"Missing sample video: {SAMPLE_VIDEO_PATH}")
|
|
if not CONFIG_PATH.is_file():
|
|
pytest.skip(f"Missing config: {CONFIG_PATH}")
|
|
|
|
result = _run_pipeline_cli(
|
|
"--source",
|
|
str(SAMPLE_VIDEO_PATH),
|
|
"--checkpoint",
|
|
str(REPO_ROOT / "ckpt" / "missing-checkpoint.pt"),
|
|
"--config",
|
|
str(CONFIG_PATH),
|
|
timeout_seconds=30,
|
|
)
|
|
|
|
assert result.returncode == 2
|
|
assert "Error: Checkpoint not found" in result.stderr
|