Files
OpenGait/tests/demo/test_pipeline.py
T
crosstyan d6fd6c03e6 test(demo): add unit and integration coverage for pipeline
Introduce focused unit, integration, and NATS-path tests for demo modules, and align assertions with final schema and temporal contracts (window int, seq=30, fill-level ratio). This commit isolates validation logic from runtime changes and provides reproducible QA for pipeline behavior and failure modes.
2026-02-27 09:59:14 +08:00

280 lines
7.9 KiB
Python

from __future__ import annotations
import json
from pathlib import Path
import subprocess
import sys
import time
from typing import Final, cast
import pytest
import torch
from opengait.demo.sconet_demo import ScoNetDemo
REPO_ROOT: Final[Path] = Path(__file__).resolve().parents[2]
SAMPLE_VIDEO_PATH: Final[Path] = REPO_ROOT / "assets" / "sample.mp4"
CHECKPOINT_PATH: Final[Path] = REPO_ROOT / "ckpt" / "ScoNet-20000.pt"
CONFIG_PATH: Final[Path] = REPO_ROOT / "configs" / "sconet" / "sconet_scoliosis1k.yaml"
YOLO_MODEL_PATH: Final[Path] = REPO_ROOT / "yolo11n-seg.pt"
def _device_for_runtime() -> str:
return "cuda:0" if torch.cuda.is_available() else "cpu"
def _run_pipeline_cli(
*args: str, timeout_seconds: int = 120
) -> subprocess.CompletedProcess[str]:
command = [sys.executable, "-m", "opengait.demo", *args]
return subprocess.run(
command,
cwd=REPO_ROOT,
capture_output=True,
text=True,
check=False,
timeout=timeout_seconds,
)
def _require_integration_assets() -> None:
if not SAMPLE_VIDEO_PATH.is_file():
pytest.skip(f"Missing sample video: {SAMPLE_VIDEO_PATH}")
if not CONFIG_PATH.is_file():
pytest.skip(f"Missing config: {CONFIG_PATH}")
if not YOLO_MODEL_PATH.is_file():
pytest.skip(f"Missing YOLO model file: {YOLO_MODEL_PATH}")
@pytest.fixture
def compatible_checkpoint_path(tmp_path: Path) -> Path:
if not CONFIG_PATH.is_file():
pytest.skip(f"Missing config: {CONFIG_PATH}")
checkpoint_file = tmp_path / "sconet-compatible.pt"
model = ScoNetDemo(cfg_path=str(CONFIG_PATH), checkpoint_path=None, device="cpu")
torch.save(model.state_dict(), checkpoint_file)
return checkpoint_file
def _extract_prediction_json_lines(stdout: str) -> list[dict[str, object]]:
required_keys = {
"frame",
"track_id",
"label",
"confidence",
"window",
"timestamp_ns",
}
predictions: list[dict[str, object]] = []
for line in stdout.splitlines():
stripped = line.strip()
if not stripped:
continue
try:
payload_obj = cast(object, json.loads(stripped))
except json.JSONDecodeError:
continue
if not isinstance(payload_obj, dict):
continue
payload = cast(dict[str, object], payload_obj)
if required_keys.issubset(payload.keys()):
predictions.append(payload)
return predictions
def _assert_prediction_schema(prediction: dict[str, object]) -> None:
assert isinstance(prediction["frame"], int)
assert isinstance(prediction["track_id"], int)
label = prediction["label"]
assert isinstance(label, str)
assert label in {"negative", "neutral", "positive"}
confidence = prediction["confidence"]
assert isinstance(confidence, (int, float))
confidence_value = float(confidence)
assert 0.0 <= confidence_value <= 1.0
window_obj = prediction["window"]
assert isinstance(window_obj, int)
assert window_obj >= 0
assert isinstance(prediction["timestamp_ns"], int)
def test_pipeline_cli_fps_benchmark_smoke(
compatible_checkpoint_path: Path,
) -> None:
_require_integration_assets()
max_frames = 90
started_at = time.perf_counter()
result = _run_pipeline_cli(
"--source",
str(SAMPLE_VIDEO_PATH),
"--checkpoint",
str(compatible_checkpoint_path),
"--config",
str(CONFIG_PATH),
"--device",
_device_for_runtime(),
"--yolo-model",
str(YOLO_MODEL_PATH),
"--window",
"5",
"--stride",
"1",
"--max-frames",
str(max_frames),
timeout_seconds=180,
)
elapsed_seconds = time.perf_counter() - started_at
assert result.returncode == 0, (
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
)
predictions = _extract_prediction_json_lines(result.stdout)
assert predictions, "Expected prediction output for FPS benchmark run"
for prediction in predictions:
_assert_prediction_schema(prediction)
observed_frames = {
frame_obj
for prediction in predictions
for frame_obj in [prediction["frame"]]
if isinstance(frame_obj, int)
}
observed_units = len(observed_frames)
if observed_units < 5:
pytest.skip(
"Insufficient observed frame samples for stable FPS benchmark in this environment"
)
if elapsed_seconds <= 0:
pytest.skip("Non-positive elapsed time; cannot compute FPS benchmark")
fps = observed_units / elapsed_seconds
min_expected_fps = 0.2
assert fps >= min_expected_fps, (
"Observed FPS below conservative CI threshold: "
f"{fps:.3f} < {min_expected_fps:.3f} "
f"(observed_units={observed_units}, elapsed_seconds={elapsed_seconds:.3f})"
)
def test_pipeline_cli_happy_path_outputs_json_predictions(
compatible_checkpoint_path: Path,
) -> None:
_require_integration_assets()
result = _run_pipeline_cli(
"--source",
str(SAMPLE_VIDEO_PATH),
"--checkpoint",
str(compatible_checkpoint_path),
"--config",
str(CONFIG_PATH),
"--device",
_device_for_runtime(),
"--yolo-model",
str(YOLO_MODEL_PATH),
"--window",
"10",
"--stride",
"10",
"--max-frames",
"120",
timeout_seconds=180,
)
assert result.returncode == 0, (
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
)
predictions = _extract_prediction_json_lines(result.stdout)
assert predictions, (
"Expected at least one prediction JSON line in stdout. "
f"stdout:\n{result.stdout}\nstderr:\n{result.stderr}"
)
for prediction in predictions:
_assert_prediction_schema(prediction)
assert "Connected to NATS" not in result.stderr
def test_pipeline_cli_max_frames_caps_output_frames(
compatible_checkpoint_path: Path,
) -> None:
_require_integration_assets()
max_frames = 20
result = _run_pipeline_cli(
"--source",
str(SAMPLE_VIDEO_PATH),
"--checkpoint",
str(compatible_checkpoint_path),
"--config",
str(CONFIG_PATH),
"--device",
_device_for_runtime(),
"--yolo-model",
str(YOLO_MODEL_PATH),
"--window",
"5",
"--stride",
"1",
"--max-frames",
str(max_frames),
timeout_seconds=180,
)
assert result.returncode == 0, (
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
)
predictions = _extract_prediction_json_lines(result.stdout)
assert predictions, "Expected prediction output with --max-frames run"
for prediction in predictions:
_assert_prediction_schema(prediction)
frame_idx_obj = prediction["frame"]
assert isinstance(frame_idx_obj, int)
assert frame_idx_obj < max_frames
def test_pipeline_cli_invalid_source_path_returns_user_error() -> None:
result = _run_pipeline_cli(
"--source",
"/definitely/not/a/real/video.mp4",
"--checkpoint",
"/tmp/unused-checkpoint.pt",
"--config",
str(CONFIG_PATH),
timeout_seconds=30,
)
assert result.returncode == 2
assert "Error: Video source not found" in result.stderr
def test_pipeline_cli_invalid_checkpoint_path_returns_user_error() -> None:
if not SAMPLE_VIDEO_PATH.is_file():
pytest.skip(f"Missing sample video: {SAMPLE_VIDEO_PATH}")
if not CONFIG_PATH.is_file():
pytest.skip(f"Missing config: {CONFIG_PATH}")
result = _run_pipeline_cli(
"--source",
str(SAMPLE_VIDEO_PATH),
"--checkpoint",
str(REPO_ROOT / "ckpt" / "missing-checkpoint.pt"),
"--config",
str(CONFIG_PATH),
timeout_seconds=30,
)
assert result.returncode == 2
assert "Error: Checkpoint not found" in result.stderr