Files
OpenGait/tests/demo/test_pipeline.py
T
crosstyan 7f073179d7 fix(demo): stabilize visualizer bbox and mask rendering
Align bbox coordinate handling across primary and fallback paths, normalize Both-mode raw mask rendering, and tighten demo result typing to reduce runtime/display inconsistencies.
2026-02-28 18:05:33 +08:00

895 lines
28 KiB
Python

from __future__ import annotations
import importlib.util
import json
import pickle
from pathlib import Path
import subprocess
import sys
import time
from typing import Final, cast
from unittest import mock
import numpy as np
from numpy.typing import NDArray
import pytest
import torch
from opengait.demo.sconet_demo import ScoNetDemo
REPO_ROOT: Final[Path] = Path(__file__).resolve().parents[2]
SAMPLE_VIDEO_PATH: Final[Path] = REPO_ROOT / "assets" / "sample.mp4"
CHECKPOINT_PATH: Final[Path] = REPO_ROOT / "ckpt" / "ScoNet-20000.pt"
CONFIG_PATH: Final[Path] = REPO_ROOT / "configs" / "sconet" / "sconet_scoliosis1k.yaml"
YOLO_MODEL_PATH: Final[Path] = REPO_ROOT / "ckpt" / "yolo11n-seg.pt"
def _device_for_runtime() -> str:
return "cuda:0" if torch.cuda.is_available() else "cpu"
def _run_pipeline_cli(
*args: str, timeout_seconds: int = 120
) -> subprocess.CompletedProcess[str]:
command = [sys.executable, "-m", "opengait.demo", *args]
return subprocess.run(
command,
cwd=REPO_ROOT,
capture_output=True,
text=True,
check=False,
timeout=timeout_seconds,
)
def _require_integration_assets() -> None:
if not SAMPLE_VIDEO_PATH.is_file():
pytest.skip(f"Missing sample video: {SAMPLE_VIDEO_PATH}")
if not CONFIG_PATH.is_file():
pytest.skip(f"Missing config: {CONFIG_PATH}")
if not YOLO_MODEL_PATH.is_file():
pytest.skip(f"Missing YOLO model file: {YOLO_MODEL_PATH}")
@pytest.fixture
def compatible_checkpoint_path(tmp_path: Path) -> Path:
if not CONFIG_PATH.is_file():
pytest.skip(f"Missing config: {CONFIG_PATH}")
checkpoint_file = tmp_path / "sconet-compatible.pt"
model = ScoNetDemo(cfg_path=str(CONFIG_PATH), checkpoint_path=None, device="cpu")
torch.save(model.state_dict(), checkpoint_file)
return checkpoint_file
def _extract_prediction_json_lines(stdout: str) -> list[dict[str, object]]:
required_keys = {
"frame",
"track_id",
"label",
"confidence",
"window",
"timestamp_ns",
}
predictions: list[dict[str, object]] = []
for line in stdout.splitlines():
stripped = line.strip()
if not stripped:
continue
try:
payload_obj = cast(object, json.loads(stripped))
except json.JSONDecodeError:
continue
if not isinstance(payload_obj, dict):
continue
payload = cast(dict[str, object], payload_obj)
if required_keys.issubset(payload.keys()):
predictions.append(payload)
return predictions
def _assert_prediction_schema(prediction: dict[str, object]) -> None:
assert isinstance(prediction["frame"], int)
assert isinstance(prediction["track_id"], int)
label = prediction["label"]
assert isinstance(label, str)
assert label in {"negative", "neutral", "positive"}
confidence = prediction["confidence"]
assert isinstance(confidence, (int, float))
confidence_value = float(confidence)
assert 0.0 <= confidence_value <= 1.0
window_obj = prediction["window"]
assert isinstance(window_obj, int)
assert window_obj >= 0
assert isinstance(prediction["timestamp_ns"], int)
def test_pipeline_cli_fps_benchmark_smoke(
compatible_checkpoint_path: Path,
) -> None:
_require_integration_assets()
max_frames = 90
started_at = time.perf_counter()
result = _run_pipeline_cli(
"--source",
str(SAMPLE_VIDEO_PATH),
"--checkpoint",
str(compatible_checkpoint_path),
"--config",
str(CONFIG_PATH),
"--device",
_device_for_runtime(),
"--yolo-model",
str(YOLO_MODEL_PATH),
"--window",
"5",
"--stride",
"1",
"--max-frames",
str(max_frames),
timeout_seconds=180,
)
elapsed_seconds = time.perf_counter() - started_at
assert result.returncode == 0, (
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
)
predictions = _extract_prediction_json_lines(result.stdout)
assert predictions, "Expected prediction output for FPS benchmark run"
for prediction in predictions:
_assert_prediction_schema(prediction)
observed_frames = {
frame_obj
for prediction in predictions
for frame_obj in [prediction["frame"]]
if isinstance(frame_obj, int)
}
observed_units = len(observed_frames)
if observed_units < 5:
pytest.skip(
"Insufficient observed frame samples for stable FPS benchmark in this environment"
)
if elapsed_seconds <= 0:
pytest.skip("Non-positive elapsed time; cannot compute FPS benchmark")
fps = observed_units / elapsed_seconds
min_expected_fps = 0.2
assert fps >= min_expected_fps, (
"Observed FPS below conservative CI threshold: "
f"{fps:.3f} < {min_expected_fps:.3f} "
f"(observed_units={observed_units}, elapsed_seconds={elapsed_seconds:.3f})"
)
def test_pipeline_cli_happy_path_outputs_json_predictions(
compatible_checkpoint_path: Path,
) -> None:
_require_integration_assets()
result = _run_pipeline_cli(
"--source",
str(SAMPLE_VIDEO_PATH),
"--checkpoint",
str(compatible_checkpoint_path),
"--config",
str(CONFIG_PATH),
"--device",
_device_for_runtime(),
"--yolo-model",
str(YOLO_MODEL_PATH),
"--window",
"10",
"--stride",
"10",
"--max-frames",
"120",
timeout_seconds=180,
)
assert result.returncode == 0, (
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
)
predictions = _extract_prediction_json_lines(result.stdout)
assert predictions, (
"Expected at least one prediction JSON line in stdout. "
f"stdout:\n{result.stdout}\nstderr:\n{result.stderr}"
)
for prediction in predictions:
_assert_prediction_schema(prediction)
assert "Connected to NATS" not in result.stderr
def test_pipeline_cli_max_frames_caps_output_frames(
compatible_checkpoint_path: Path,
) -> None:
_require_integration_assets()
max_frames = 20
result = _run_pipeline_cli(
"--source",
str(SAMPLE_VIDEO_PATH),
"--checkpoint",
str(compatible_checkpoint_path),
"--config",
str(CONFIG_PATH),
"--device",
_device_for_runtime(),
"--yolo-model",
str(YOLO_MODEL_PATH),
"--window",
"5",
"--stride",
"1",
"--max-frames",
str(max_frames),
timeout_seconds=180,
)
assert result.returncode == 0, (
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
)
predictions = _extract_prediction_json_lines(result.stdout)
assert predictions, "Expected prediction output with --max-frames run"
for prediction in predictions:
_assert_prediction_schema(prediction)
frame_idx_obj = prediction["frame"]
assert isinstance(frame_idx_obj, int)
assert frame_idx_obj < max_frames
def test_pipeline_cli_invalid_source_path_returns_user_error() -> None:
result = _run_pipeline_cli(
"--source",
"/definitely/not/a/real/video.mp4",
"--checkpoint",
"/tmp/unused-checkpoint.pt",
"--config",
str(CONFIG_PATH),
timeout_seconds=30,
)
assert result.returncode == 2
assert "Error: Video source not found" in result.stderr
def test_pipeline_cli_invalid_checkpoint_path_returns_user_error() -> None:
if not SAMPLE_VIDEO_PATH.is_file():
pytest.skip(f"Missing sample video: {SAMPLE_VIDEO_PATH}")
if not CONFIG_PATH.is_file():
pytest.skip(f"Missing config: {CONFIG_PATH}")
result = _run_pipeline_cli(
"--source",
str(SAMPLE_VIDEO_PATH),
"--checkpoint",
str(REPO_ROOT / "ckpt" / "missing-checkpoint.pt"),
"--config",
str(CONFIG_PATH),
timeout_seconds=30,
)
assert result.returncode == 2
assert "Error: Checkpoint not found" in result.stderr
def test_pipeline_cli_preprocess_only_requires_export_path(
compatible_checkpoint_path: Path,
) -> None:
"""Test that --preprocess-only requires --silhouette-export-path."""
_require_integration_assets()
result = _run_pipeline_cli(
"--source",
str(SAMPLE_VIDEO_PATH),
"--checkpoint",
str(compatible_checkpoint_path),
"--config",
str(CONFIG_PATH),
"--device",
_device_for_runtime(),
"--yolo-model",
str(YOLO_MODEL_PATH),
"--preprocess-only",
"--max-frames",
"10",
timeout_seconds=30,
)
assert result.returncode == 2
assert "--silhouette-export-path is required" in result.stderr
def test_pipeline_cli_preprocess_only_exports_pickle(
compatible_checkpoint_path: Path,
tmp_path: Path,
) -> None:
"""Test preprocess-only mode exports silhouettes to pickle."""
_require_integration_assets()
export_path = tmp_path / "silhouettes.pkl"
result = _run_pipeline_cli(
"--source",
str(SAMPLE_VIDEO_PATH),
"--checkpoint",
str(compatible_checkpoint_path),
"--config",
str(CONFIG_PATH),
"--device",
_device_for_runtime(),
"--yolo-model",
str(YOLO_MODEL_PATH),
"--preprocess-only",
"--silhouette-export-path",
str(export_path),
"--silhouette-export-format",
"pickle",
"--max-frames",
"30",
timeout_seconds=180,
)
assert result.returncode == 0, (
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
)
# Verify export file exists and contains silhouettes
assert export_path.is_file(), f"Export file not found: {export_path}"
with open(export_path, "rb") as f:
silhouettes = pickle.load(f)
assert isinstance(silhouettes, list)
assert len(silhouettes) > 0, "Expected at least one silhouette"
# Verify silhouette schema
for item in silhouettes:
assert isinstance(item, dict)
assert "frame" in item
assert "track_id" in item
assert "timestamp_ns" in item
assert "silhouette" in item
assert isinstance(item["frame"], int)
assert isinstance(item["track_id"], int)
assert isinstance(item["timestamp_ns"], int)
def test_pipeline_cli_result_export_json(
compatible_checkpoint_path: Path,
tmp_path: Path,
) -> None:
"""Test that results can be exported to JSON file."""
_require_integration_assets()
export_path = tmp_path / "results.jsonl"
result = _run_pipeline_cli(
"--source",
str(SAMPLE_VIDEO_PATH),
"--checkpoint",
str(compatible_checkpoint_path),
"--config",
str(CONFIG_PATH),
"--device",
_device_for_runtime(),
"--yolo-model",
str(YOLO_MODEL_PATH),
"--window",
"10",
"--stride",
"10",
"--result-export-path",
str(export_path),
"--result-export-format",
"json",
"--max-frames",
"60",
timeout_seconds=180,
)
assert result.returncode == 0, (
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
)
# Verify export file exists
assert export_path.is_file(), f"Export file not found: {export_path}"
# Read and verify JSON lines
predictions: list[dict[str, object]] = []
with open(export_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
predictions.append(cast(dict[str, object], json.loads(line)))
assert len(predictions) > 0, "Expected at least one prediction in export"
for prediction in predictions:
_assert_prediction_schema(prediction)
def test_pipeline_cli_result_export_pickle(
compatible_checkpoint_path: Path,
tmp_path: Path,
) -> None:
"""Test that results can be exported to pickle file."""
_require_integration_assets()
export_path = tmp_path / "results.pkl"
result = _run_pipeline_cli(
"--source",
str(SAMPLE_VIDEO_PATH),
"--checkpoint",
str(compatible_checkpoint_path),
"--config",
str(CONFIG_PATH),
"--device",
_device_for_runtime(),
"--yolo-model",
str(YOLO_MODEL_PATH),
"--window",
"10",
"--stride",
"10",
"--result-export-path",
str(export_path),
"--result-export-format",
"pickle",
"--max-frames",
"60",
timeout_seconds=180,
)
assert result.returncode == 0, (
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
)
# Verify export file exists
assert export_path.is_file(), f"Export file not found: {export_path}"
# Read and verify pickle
with open(export_path, "rb") as f:
predictions = pickle.load(f)
assert isinstance(predictions, list)
assert len(predictions) > 0, "Expected at least one prediction in export"
for prediction in predictions:
_assert_prediction_schema(prediction)
def test_pipeline_cli_silhouette_and_result_export(
compatible_checkpoint_path: Path,
tmp_path: Path,
) -> None:
"""Test exporting both silhouettes and results simultaneously."""
_require_integration_assets()
silhouette_export = tmp_path / "silhouettes.pkl"
result_export = tmp_path / "results.jsonl"
result = _run_pipeline_cli(
"--source",
str(SAMPLE_VIDEO_PATH),
"--checkpoint",
str(compatible_checkpoint_path),
"--config",
str(CONFIG_PATH),
"--device",
_device_for_runtime(),
"--yolo-model",
str(YOLO_MODEL_PATH),
"--window",
"10",
"--stride",
"10",
"--silhouette-export-path",
str(silhouette_export),
"--silhouette-export-format",
"pickle",
"--result-export-path",
str(result_export),
"--result-export-format",
"json",
"--max-frames",
"60",
timeout_seconds=180,
)
assert result.returncode == 0, (
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
)
# Verify both export files exist
assert silhouette_export.is_file(), (
f"Silhouette export not found: {silhouette_export}"
)
assert result_export.is_file(), f"Result export not found: {result_export}"
# Verify silhouette export
with open(silhouette_export, "rb") as f:
silhouettes = pickle.load(f)
assert isinstance(silhouettes, list)
assert len(silhouettes) > 0
# Verify result export
with open(result_export, "r", encoding="utf-8") as f:
predictions = [
cast(dict[str, object], json.loads(line)) for line in f if line.strip()
]
assert len(predictions) > 0
def test_pipeline_cli_parquet_export_requires_pyarrow(
compatible_checkpoint_path: Path,
tmp_path: Path,
) -> None:
"""Test that parquet export fails gracefully when pyarrow is not available."""
_require_integration_assets()
# Skip if pyarrow is actually installed
if importlib.util.find_spec("pyarrow") is not None:
pytest.skip("pyarrow is installed, skipping missing dependency test")
try:
import pyarrow # noqa: F401
pytest.skip("pyarrow is installed, skipping missing dependency test")
except ImportError:
pass
export_path = tmp_path / "results.parquet"
result = _run_pipeline_cli(
"--source",
str(SAMPLE_VIDEO_PATH),
"--checkpoint",
str(compatible_checkpoint_path),
"--config",
str(CONFIG_PATH),
"--device",
_device_for_runtime(),
"--yolo-model",
str(YOLO_MODEL_PATH),
"--window",
"10",
"--stride",
"10",
"--result-export-path",
str(export_path),
"--result-export-format",
"parquet",
"--max-frames",
"30",
timeout_seconds=180,
)
# Should fail with RuntimeError about pyarrow
assert result.returncode == 1
assert "parquet" in result.stderr.lower() or "pyarrow" in result.stderr.lower()
def test_pipeline_cli_silhouette_visualization(
compatible_checkpoint_path: Path,
tmp_path: Path,
) -> None:
"""Test that silhouette visualization creates PNG files."""
_require_integration_assets()
visualize_dir = tmp_path / "silhouette_viz"
result = _run_pipeline_cli(
"--source",
str(SAMPLE_VIDEO_PATH),
"--checkpoint",
str(compatible_checkpoint_path),
"--config",
str(CONFIG_PATH),
"--device",
_device_for_runtime(),
"--yolo-model",
str(YOLO_MODEL_PATH),
"--window",
"10",
"--stride",
"10",
"--silhouette-visualize-dir",
str(visualize_dir),
"--max-frames",
"30",
timeout_seconds=180,
)
assert result.returncode == 0, (
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
)
# Verify visualization directory exists and contains PNG files
assert visualize_dir.is_dir(), f"Visualization directory not found: {visualize_dir}"
png_files = list(visualize_dir.glob("*.png"))
assert len(png_files) > 0, "Expected at least one PNG visualization file"
# Verify filenames contain frame and track info
for png_file in png_files:
assert "silhouette_frame" in png_file.name
assert "_track" in png_file.name
def test_pipeline_cli_preprocess_only_with_visualization(
compatible_checkpoint_path: Path,
tmp_path: Path,
) -> None:
"""Test preprocess-only mode with both export and visualization."""
_require_integration_assets()
export_path = tmp_path / "silhouettes.pkl"
visualize_dir = tmp_path / "silhouette_viz"
result = _run_pipeline_cli(
"--source",
str(SAMPLE_VIDEO_PATH),
"--checkpoint",
str(compatible_checkpoint_path),
"--config",
str(CONFIG_PATH),
"--device",
_device_for_runtime(),
"--yolo-model",
str(YOLO_MODEL_PATH),
"--preprocess-only",
"--silhouette-export-path",
str(export_path),
"--silhouette-visualize-dir",
str(visualize_dir),
"--max-frames",
"30",
timeout_seconds=180,
)
assert result.returncode == 0, (
f"Expected exit code 0, got {result.returncode}. stderr:\n{result.stderr}"
)
# Verify export file exists
assert export_path.is_file(), f"Export file not found: {export_path}"
# Verify visualization files exist
assert visualize_dir.is_dir(), f"Visualization directory not found: {visualize_dir}"
png_files = list(visualize_dir.glob("*.png"))
assert len(png_files) > 0, "Expected at least one PNG visualization file"
# Load and verify pickle export
with open(export_path, "rb") as f:
silhouettes = pickle.load(f)
assert isinstance(silhouettes, list)
assert len(silhouettes) > 0
# Number of exported silhouettes should match number of PNG files
assert len(silhouettes) == len(png_files), (
f"Mismatch: {len(silhouettes)} silhouettes exported but {len(png_files)} PNG files created"
)
class MockVisualizer:
"""Mock visualizer to track update calls."""
def __init__(self) -> None:
self.update_calls: list[dict[str, object]] = []
self.return_value: bool = True
def update(
self,
frame: NDArray[np.uint8],
bbox: tuple[int, int, int, int] | None,
track_id: int,
mask_raw: NDArray[np.uint8] | None,
silhouette: NDArray[np.float32] | None,
label: str | None,
confidence: float | None,
fps: float,
) -> bool:
self.update_calls.append(
{
"frame": frame,
"bbox": bbox,
"track_id": track_id,
"mask_raw": mask_raw,
"silhouette": silhouette,
"label": label,
"confidence": confidence,
"fps": fps,
}
)
return self.return_value
def close(self) -> None:
pass
def test_pipeline_visualizer_updates_on_no_detection() -> None:
"""Test that visualizer is still updated even when process_frame returns None.
This is a regression test for the window freeze issue when no person is detected.
The window should refresh every frame to prevent freezing.
"""
from opengait.demo.pipeline import ScoliosisPipeline
# Create a minimal pipeline with mocked dependencies
with (
mock.patch("opengait.demo.pipeline.YOLO") as mock_yolo,
mock.patch("opengait.demo.pipeline.create_source") as mock_source,
mock.patch("opengait.demo.pipeline.create_publisher") as mock_publisher,
mock.patch("opengait.demo.pipeline.ScoNetDemo") as mock_classifier,
):
# Setup mock detector that returns no detections (causing process_frame to return None)
mock_detector = mock.MagicMock()
mock_detector.track.return_value = [] # No detections
mock_yolo.return_value = mock_detector
# Setup mock source with 3 frames
mock_frame = np.zeros((480, 640, 3), dtype=np.uint8)
mock_source.return_value = [(mock_frame, {"frame_count": i}) for i in range(3)]
# Setup mock publisher and classifier
mock_publisher.return_value = mock.MagicMock()
mock_classifier.return_value = mock.MagicMock()
# Create pipeline with visualize enabled
pipeline = ScoliosisPipeline(
source="dummy.mp4",
checkpoint="dummy.pt",
config=str(CONFIG_PATH) if CONFIG_PATH.exists() else "dummy.yaml",
device="cpu",
yolo_model="dummy.pt",
window=30,
stride=30,
nats_url=None,
nats_subject="test",
max_frames=None,
visualize=True,
)
# Replace the visualizer with our mock
mock_viz = MockVisualizer()
pipeline._visualizer = mock_viz # type: ignore[assignment]
# Run pipeline
_ = pipeline.run()
# Verify visualizer was updated for all 3 frames even with no detections
assert len(mock_viz.update_calls) == 3, (
f"Expected visualizer.update() to be called 3 times (once per frame), "
f"but was called {len(mock_viz.update_calls)} times. "
f"Window would freeze if not updated on no-detection frames."
)
# Verify each call had the frame data
for call in mock_viz.update_calls:
assert call["track_id"] == 0 # Default track_id when no detection
assert call["bbox"] is None # No bbox when no detection
assert call["mask_raw"] is None # No mask when no detection
assert call["silhouette"] is None # No silhouette when no detection
assert call["label"] is None # No label when no detection
assert call["confidence"] is None # No confidence when no detection
def test_pipeline_visualizer_uses_cached_detection_on_no_detection() -> None:
"""Test that visualizer reuses last valid detection when current frame has no detection.
This is a regression test for the cache-reuse behavior when person temporarily
disappears from frame. The last valid silhouette/mask should be displayed.
"""
from opengait.demo.pipeline import ScoliosisPipeline
# Create a minimal pipeline with mocked dependencies
with (
mock.patch("opengait.demo.pipeline.YOLO") as mock_yolo,
mock.patch("opengait.demo.pipeline.create_source") as mock_source,
mock.patch("opengait.demo.pipeline.create_publisher") as mock_publisher,
mock.patch("opengait.demo.pipeline.ScoNetDemo") as mock_classifier,
mock.patch("opengait.demo.pipeline.select_person") as mock_select_person,
mock.patch("opengait.demo.pipeline.mask_to_silhouette") as mock_mask_to_sil,
):
# Create mock detection result for frames 0-1 (valid detection)
mock_box = mock.MagicMock()
mock_box.xyxy = np.array([[100, 100, 200, 300]], dtype=np.float32)
mock_box.id = np.array([1], dtype=np.int64)
mock_mask = mock.MagicMock()
mock_mask.data = np.random.rand(1, 480, 640).astype(np.float32)
mock_result = mock.MagicMock()
mock_result.boxes = mock_box
mock_result.masks = mock_mask
# Setup mock detector: detection for frames 0-1, then no detection for frames 2-3
mock_detector = mock.MagicMock()
mock_detector.track.side_effect = [
[mock_result], # Frame 0: valid detection
[mock_result], # Frame 1: valid detection
[], # Frame 2: no detection
[], # Frame 3: no detection
]
mock_yolo.return_value = mock_detector
# Setup mock source with 4 frames
mock_frame = np.zeros((480, 640, 3), dtype=np.uint8)
mock_source.return_value = [(mock_frame, {"frame_count": i}) for i in range(4)]
# Setup mock publisher and classifier
mock_publisher.return_value = mock.MagicMock()
mock_classifier.return_value = mock.MagicMock()
# Setup mock select_person to return valid mask and bbox
dummy_mask = np.random.randint(0, 256, (480, 640), dtype=np.uint8)
dummy_bbox_mask = (100, 100, 200, 300)
dummy_bbox_frame = (100, 100, 200, 300)
mock_select_person.return_value = (dummy_mask, dummy_bbox_mask, dummy_bbox_frame, 1)
# Setup mock mask_to_silhouette to return valid silhouette
dummy_silhouette = np.random.rand(64, 44).astype(np.float32)
mock_mask_to_sil.return_value = dummy_silhouette
# Create pipeline with visualize enabled
pipeline = ScoliosisPipeline(
source="dummy.mp4",
checkpoint="dummy.pt",
config=str(CONFIG_PATH) if CONFIG_PATH.exists() else "dummy.yaml",
device="cpu",
yolo_model="dummy.pt",
window=30,
stride=30,
nats_url=None,
nats_subject="test",
max_frames=None,
visualize=True,
)
# Replace the visualizer with our mock
mock_viz = MockVisualizer()
pipeline._visualizer = mock_viz # type: ignore[assignment]
# Run pipeline
_ = pipeline.run()
# Verify visualizer was updated for all 4 frames
assert len(mock_viz.update_calls) == 4, (
f"Expected visualizer.update() to be called 4 times, "
f"but was called {len(mock_viz.update_calls)} times."
)
# Extract the mask_raw values from each call
mask_raw_calls = [call["mask_raw"] for call in mock_viz.update_calls]
# Frames 0 and 1 should have valid masks (not None)
assert mask_raw_calls[0] is not None, "Frame 0 should have valid mask"
assert mask_raw_calls[1] is not None, "Frame 1 should have valid mask"
# Frames 2 and 3 should reuse the cached mask from frame 1 (not None)
assert mask_raw_calls[2] is not None, (
"Frame 2 (no detection) should display cached mask from last valid detection, "
"not None/blank"
)
assert mask_raw_calls[3] is not None, (
"Frame 3 (no detection) should display cached mask from last valid detection, "
"not None/blank"
)
# The cached masks should be copies (different objects) to prevent mutation issues
if mask_raw_calls[1] is not None and mask_raw_calls[2] is not None:
assert mask_raw_calls[1] is not mask_raw_calls[2], (
"Cached mask should be a copy, not the same object reference"
)