fix(demo): stabilize visualizer bbox and mask rendering
Align bbox coordinate handling across primary and fallback paths, normalize Both-mode raw mask rendering, and tighten demo result typing to reduce runtime/display inconsistencies.
This commit is contained in:
+223
-4
@@ -8,7 +8,10 @@ import subprocess
|
||||
import sys
|
||||
import time
|
||||
from typing import Final, cast
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@@ -107,6 +110,7 @@ def _assert_prediction_schema(prediction: dict[str, object]) -> None:
|
||||
|
||||
assert isinstance(prediction["timestamp_ns"], int)
|
||||
|
||||
|
||||
def test_pipeline_cli_fps_benchmark_smoke(
|
||||
compatible_checkpoint_path: Path,
|
||||
) -> None:
|
||||
@@ -280,7 +284,6 @@ def test_pipeline_cli_invalid_checkpoint_path_returns_user_error() -> None:
|
||||
assert "Error: Checkpoint not found" in result.stderr
|
||||
|
||||
|
||||
|
||||
def test_pipeline_cli_preprocess_only_requires_export_path(
|
||||
compatible_checkpoint_path: Path,
|
||||
) -> None:
|
||||
@@ -511,7 +514,9 @@ def test_pipeline_cli_silhouette_and_result_export(
|
||||
)
|
||||
|
||||
# Verify both export files exist
|
||||
assert silhouette_export.is_file(), f"Silhouette export not found: {silhouette_export}"
|
||||
assert silhouette_export.is_file(), (
|
||||
f"Silhouette export not found: {silhouette_export}"
|
||||
)
|
||||
assert result_export.is_file(), f"Result export not found: {result_export}"
|
||||
|
||||
# Verify silhouette export
|
||||
@@ -522,7 +527,9 @@ def test_pipeline_cli_silhouette_and_result_export(
|
||||
|
||||
# Verify result export
|
||||
with open(result_export, "r", encoding="utf-8") as f:
|
||||
predictions = [cast(dict[str, object], json.loads(line)) for line in f if line.strip()]
|
||||
predictions = [
|
||||
cast(dict[str, object], json.loads(line)) for line in f if line.strip()
|
||||
]
|
||||
assert len(predictions) > 0
|
||||
|
||||
|
||||
@@ -538,6 +545,7 @@ def test_pipeline_cli_parquet_export_requires_pyarrow(
|
||||
pytest.skip("pyarrow is installed, skipping missing dependency test")
|
||||
try:
|
||||
import pyarrow # noqa: F401
|
||||
|
||||
pytest.skip("pyarrow is installed, skipping missing dependency test")
|
||||
except ImportError:
|
||||
pass
|
||||
@@ -573,7 +581,6 @@ def test_pipeline_cli_parquet_export_requires_pyarrow(
|
||||
assert "parquet" in result.stderr.lower() or "pyarrow" in result.stderr.lower()
|
||||
|
||||
|
||||
|
||||
def test_pipeline_cli_silhouette_visualization(
|
||||
compatible_checkpoint_path: Path,
|
||||
tmp_path: Path,
|
||||
@@ -673,3 +680,215 @@ def test_pipeline_cli_preprocess_only_with_visualization(
|
||||
assert len(silhouettes) == len(png_files), (
|
||||
f"Mismatch: {len(silhouettes)} silhouettes exported but {len(png_files)} PNG files created"
|
||||
)
|
||||
|
||||
|
||||
class MockVisualizer:
|
||||
"""Mock visualizer to track update calls."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.update_calls: list[dict[str, object]] = []
|
||||
self.return_value: bool = True
|
||||
|
||||
def update(
|
||||
self,
|
||||
frame: NDArray[np.uint8],
|
||||
bbox: tuple[int, int, int, int] | None,
|
||||
track_id: int,
|
||||
mask_raw: NDArray[np.uint8] | None,
|
||||
silhouette: NDArray[np.float32] | None,
|
||||
label: str | None,
|
||||
confidence: float | None,
|
||||
fps: float,
|
||||
) -> bool:
|
||||
self.update_calls.append(
|
||||
{
|
||||
"frame": frame,
|
||||
"bbox": bbox,
|
||||
"track_id": track_id,
|
||||
"mask_raw": mask_raw,
|
||||
"silhouette": silhouette,
|
||||
"label": label,
|
||||
"confidence": confidence,
|
||||
"fps": fps,
|
||||
}
|
||||
)
|
||||
return self.return_value
|
||||
|
||||
def close(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def test_pipeline_visualizer_updates_on_no_detection() -> None:
|
||||
"""Test that visualizer is still updated even when process_frame returns None.
|
||||
|
||||
This is a regression test for the window freeze issue when no person is detected.
|
||||
The window should refresh every frame to prevent freezing.
|
||||
"""
|
||||
from opengait.demo.pipeline import ScoliosisPipeline
|
||||
|
||||
# Create a minimal pipeline with mocked dependencies
|
||||
with (
|
||||
mock.patch("opengait.demo.pipeline.YOLO") as mock_yolo,
|
||||
mock.patch("opengait.demo.pipeline.create_source") as mock_source,
|
||||
mock.patch("opengait.demo.pipeline.create_publisher") as mock_publisher,
|
||||
mock.patch("opengait.demo.pipeline.ScoNetDemo") as mock_classifier,
|
||||
):
|
||||
# Setup mock detector that returns no detections (causing process_frame to return None)
|
||||
mock_detector = mock.MagicMock()
|
||||
mock_detector.track.return_value = [] # No detections
|
||||
mock_yolo.return_value = mock_detector
|
||||
|
||||
# Setup mock source with 3 frames
|
||||
mock_frame = np.zeros((480, 640, 3), dtype=np.uint8)
|
||||
mock_source.return_value = [(mock_frame, {"frame_count": i}) for i in range(3)]
|
||||
|
||||
# Setup mock publisher and classifier
|
||||
mock_publisher.return_value = mock.MagicMock()
|
||||
mock_classifier.return_value = mock.MagicMock()
|
||||
|
||||
# Create pipeline with visualize enabled
|
||||
pipeline = ScoliosisPipeline(
|
||||
source="dummy.mp4",
|
||||
checkpoint="dummy.pt",
|
||||
config=str(CONFIG_PATH) if CONFIG_PATH.exists() else "dummy.yaml",
|
||||
device="cpu",
|
||||
yolo_model="dummy.pt",
|
||||
window=30,
|
||||
stride=30,
|
||||
nats_url=None,
|
||||
nats_subject="test",
|
||||
max_frames=None,
|
||||
visualize=True,
|
||||
)
|
||||
|
||||
# Replace the visualizer with our mock
|
||||
mock_viz = MockVisualizer()
|
||||
pipeline._visualizer = mock_viz # type: ignore[assignment]
|
||||
|
||||
# Run pipeline
|
||||
_ = pipeline.run()
|
||||
|
||||
# Verify visualizer was updated for all 3 frames even with no detections
|
||||
assert len(mock_viz.update_calls) == 3, (
|
||||
f"Expected visualizer.update() to be called 3 times (once per frame), "
|
||||
f"but was called {len(mock_viz.update_calls)} times. "
|
||||
f"Window would freeze if not updated on no-detection frames."
|
||||
)
|
||||
|
||||
# Verify each call had the frame data
|
||||
for call in mock_viz.update_calls:
|
||||
assert call["track_id"] == 0 # Default track_id when no detection
|
||||
assert call["bbox"] is None # No bbox when no detection
|
||||
assert call["mask_raw"] is None # No mask when no detection
|
||||
assert call["silhouette"] is None # No silhouette when no detection
|
||||
assert call["label"] is None # No label when no detection
|
||||
assert call["confidence"] is None # No confidence when no detection
|
||||
|
||||
|
||||
|
||||
def test_pipeline_visualizer_uses_cached_detection_on_no_detection() -> None:
|
||||
"""Test that visualizer reuses last valid detection when current frame has no detection.
|
||||
|
||||
This is a regression test for the cache-reuse behavior when person temporarily
|
||||
disappears from frame. The last valid silhouette/mask should be displayed.
|
||||
"""
|
||||
from opengait.demo.pipeline import ScoliosisPipeline
|
||||
|
||||
# Create a minimal pipeline with mocked dependencies
|
||||
with (
|
||||
mock.patch("opengait.demo.pipeline.YOLO") as mock_yolo,
|
||||
mock.patch("opengait.demo.pipeline.create_source") as mock_source,
|
||||
mock.patch("opengait.demo.pipeline.create_publisher") as mock_publisher,
|
||||
mock.patch("opengait.demo.pipeline.ScoNetDemo") as mock_classifier,
|
||||
mock.patch("opengait.demo.pipeline.select_person") as mock_select_person,
|
||||
mock.patch("opengait.demo.pipeline.mask_to_silhouette") as mock_mask_to_sil,
|
||||
):
|
||||
# Create mock detection result for frames 0-1 (valid detection)
|
||||
mock_box = mock.MagicMock()
|
||||
mock_box.xyxy = np.array([[100, 100, 200, 300]], dtype=np.float32)
|
||||
mock_box.id = np.array([1], dtype=np.int64)
|
||||
mock_mask = mock.MagicMock()
|
||||
mock_mask.data = np.random.rand(1, 480, 640).astype(np.float32)
|
||||
mock_result = mock.MagicMock()
|
||||
mock_result.boxes = mock_box
|
||||
mock_result.masks = mock_mask
|
||||
|
||||
# Setup mock detector: detection for frames 0-1, then no detection for frames 2-3
|
||||
mock_detector = mock.MagicMock()
|
||||
mock_detector.track.side_effect = [
|
||||
[mock_result], # Frame 0: valid detection
|
||||
[mock_result], # Frame 1: valid detection
|
||||
[], # Frame 2: no detection
|
||||
[], # Frame 3: no detection
|
||||
]
|
||||
mock_yolo.return_value = mock_detector
|
||||
|
||||
# Setup mock source with 4 frames
|
||||
mock_frame = np.zeros((480, 640, 3), dtype=np.uint8)
|
||||
mock_source.return_value = [(mock_frame, {"frame_count": i}) for i in range(4)]
|
||||
|
||||
# Setup mock publisher and classifier
|
||||
mock_publisher.return_value = mock.MagicMock()
|
||||
mock_classifier.return_value = mock.MagicMock()
|
||||
|
||||
# Setup mock select_person to return valid mask and bbox
|
||||
dummy_mask = np.random.randint(0, 256, (480, 640), dtype=np.uint8)
|
||||
dummy_bbox_mask = (100, 100, 200, 300)
|
||||
dummy_bbox_frame = (100, 100, 200, 300)
|
||||
mock_select_person.return_value = (dummy_mask, dummy_bbox_mask, dummy_bbox_frame, 1)
|
||||
|
||||
# Setup mock mask_to_silhouette to return valid silhouette
|
||||
dummy_silhouette = np.random.rand(64, 44).astype(np.float32)
|
||||
mock_mask_to_sil.return_value = dummy_silhouette
|
||||
|
||||
# Create pipeline with visualize enabled
|
||||
pipeline = ScoliosisPipeline(
|
||||
source="dummy.mp4",
|
||||
checkpoint="dummy.pt",
|
||||
config=str(CONFIG_PATH) if CONFIG_PATH.exists() else "dummy.yaml",
|
||||
device="cpu",
|
||||
yolo_model="dummy.pt",
|
||||
window=30,
|
||||
stride=30,
|
||||
nats_url=None,
|
||||
nats_subject="test",
|
||||
max_frames=None,
|
||||
visualize=True,
|
||||
)
|
||||
|
||||
# Replace the visualizer with our mock
|
||||
mock_viz = MockVisualizer()
|
||||
pipeline._visualizer = mock_viz # type: ignore[assignment]
|
||||
|
||||
# Run pipeline
|
||||
_ = pipeline.run()
|
||||
|
||||
# Verify visualizer was updated for all 4 frames
|
||||
assert len(mock_viz.update_calls) == 4, (
|
||||
f"Expected visualizer.update() to be called 4 times, "
|
||||
f"but was called {len(mock_viz.update_calls)} times."
|
||||
)
|
||||
|
||||
# Extract the mask_raw values from each call
|
||||
mask_raw_calls = [call["mask_raw"] for call in mock_viz.update_calls]
|
||||
|
||||
# Frames 0 and 1 should have valid masks (not None)
|
||||
assert mask_raw_calls[0] is not None, "Frame 0 should have valid mask"
|
||||
assert mask_raw_calls[1] is not None, "Frame 1 should have valid mask"
|
||||
|
||||
# Frames 2 and 3 should reuse the cached mask from frame 1 (not None)
|
||||
assert mask_raw_calls[2] is not None, (
|
||||
"Frame 2 (no detection) should display cached mask from last valid detection, "
|
||||
"not None/blank"
|
||||
)
|
||||
assert mask_raw_calls[3] is not None, (
|
||||
"Frame 3 (no detection) should display cached mask from last valid detection, "
|
||||
"not None/blank"
|
||||
)
|
||||
|
||||
# The cached masks should be copies (different objects) to prevent mutation issues
|
||||
if mask_raw_calls[1] is not None and mask_raw_calls[2] is not None:
|
||||
assert mask_raw_calls[1] is not mask_raw_calls[2], (
|
||||
"Cached mask should be a copy, not the same object reference"
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user