diff --git a/opengait/demo/input.py b/opengait/demo/input.py index a96f540..7b72d50 100644 --- a/opengait/demo/input.py +++ b/opengait/demo/input.py @@ -7,7 +7,8 @@ Provides generator-based interfaces for video sources: """ from collections.abc import AsyncIterator, Generator, Iterable -from typing import TYPE_CHECKING, Protocol, cast +from typing import Protocol, cast + import logging import numpy as np @@ -17,15 +18,15 @@ logger = logging.getLogger(__name__) # Type alias for frame stream: (frame_array, metadata_dict) FrameStream = Iterable[tuple[np.ndarray, dict[str, object]]] -if TYPE_CHECKING: - # Protocol for cv-mmap metadata to avoid direct import - class _FrameMetadata(Protocol): - frame_count: int - timestamp_ns: int +# Protocol for cv-mmap metadata (needed at runtime for nested function annotation) +class _FrameMetadata(Protocol): + frame_count: int + timestamp_ns: int - # Protocol for cv-mmap client - class _CvMmapClient(Protocol): - def __aiter__(self) -> AsyncIterator[tuple[np.ndarray, _FrameMetadata]]: ... + +# Protocol for cv-mmap client (needed at runtime for cast) +class _CvMmapClient(Protocol): + def __aiter__(self) -> AsyncIterator[tuple[np.ndarray, _FrameMetadata]]: ... def opencv_source( diff --git a/opengait/demo/output.py b/opengait/demo/output.py index 5c9180e..261a16c 100644 --- a/opengait/demo/output.py +++ b/opengait/demo/output.py @@ -14,7 +14,7 @@ import logging import sys import threading import time -from typing import TYPE_CHECKING, Protocol, TextIO, cast, runtime_checkable +from typing import TYPE_CHECKING, Protocol, TextIO, TypedDict, cast, runtime_checkable if TYPE_CHECKING: from types import TracebackType @@ -22,17 +22,31 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +class DemoResult(TypedDict): + """Typed result dictionary for demo pipeline output. + + Contains classification result with frame metadata. + """ + + frame: int + track_id: int + label: str + confidence: float + window: int + timestamp_ns: int + + @runtime_checkable class ResultPublisher(Protocol): """Protocol for result publishers.""" - def publish(self, result: dict[str, object]) -> None: + def publish(self, result: DemoResult) -> None: """ Publish a result dictionary. Parameters ---------- - result : dict[str, object] + result : DemoResult Result data with keys: frame, track_id, label, confidence, window, timestamp_ns """ ... @@ -54,13 +68,13 @@ class ConsolePublisher: """ self._output = output - def publish(self, result: dict[str, object]) -> None: + def publish(self, result: DemoResult) -> None: """ Publish result as JSON line. Parameters ---------- - result : dict[str, object] + result : DemoResult Result data with keys: frame, track_id, label, confidence, window, timestamp_ns """ try: @@ -214,13 +228,13 @@ class NatsPublisher: logger.warning(f"Failed to connect to NATS at {self._nats_url}: {e}") return False - def publish(self, result: dict[str, object]) -> None: + def publish(self, result: DemoResult) -> None: """ Publish result to NATS subject. Parameters ---------- - result : dict[str, object] + result : DemoResult Result data with keys: frame, track_id, label, confidence, window, timestamp_ns """ if not self._ensure_connected(): @@ -239,6 +253,7 @@ class NatsPublisher: ).encode("utf-8") _ = await self._nc.publish(self._subject, payload) _ = await self._nc.flush() + # Run publish in background loop future = asyncio.run_coroutine_threadsafe( _publish(), @@ -331,7 +346,7 @@ def create_result( confidence: float, window: int | tuple[int, int], timestamp_ns: int | None = None, -) -> dict[str, object]: +) -> DemoResult: """ Create a standardized result dictionary. @@ -353,7 +368,7 @@ def create_result( Returns ------- - dict[str, object] + DemoResult Standardized result dictionary """ return { diff --git a/opengait/demo/pipeline.py b/opengait/demo/pipeline.py index 2fa0b2f..04d68bc 100644 --- a/opengait/demo/pipeline.py +++ b/opengait/demo/pipeline.py @@ -17,8 +17,8 @@ from numpy.typing import NDArray from ultralytics.models.yolo.model import YOLO from .input import FrameStream, create_source -from .output import ResultPublisher, create_publisher, create_result -from .preprocess import frame_to_person_mask, mask_to_silhouette +from .output import DemoResult, ResultPublisher, create_publisher, create_result +from .preprocess import BBoxXYXY, frame_to_person_mask, mask_to_silhouette from .sconet_demo import ScoNetDemo from .window import SilhouetteWindow, select_person @@ -53,6 +53,7 @@ class _DetectionResultsLike(Protocol): def masks(self) -> _MasksLike: ... + class _TrackCallable(Protocol): def __call__( self, @@ -80,8 +81,9 @@ class ScoliosisPipeline: _silhouette_visualize_dir: Path | None _result_export_path: Path | None _result_export_format: str - _result_buffer: list[dict[str, object]] + _result_buffer: list[DemoResult] _visualizer: OpenCVVisualizer | None + _last_viz_payload: dict[str, object] | None def __init__( self, @@ -135,6 +137,7 @@ class ScoliosisPipeline: self._visualizer = OpenCVVisualizer() else: self._visualizer = None + self._last_viz_payload = None @staticmethod def _extract_int(meta: dict[str, object], key: str, fallback: int) -> int: @@ -171,37 +174,59 @@ class ScoliosisPipeline: tuple[ Float[ndarray, "64 44"], UInt8[ndarray, "h w"], - tuple[int, int, int, int], + BBoxXYXY, int, ] | None ): selected = select_person(result) if selected is not None: - mask_raw, bbox, track_id = selected + mask_raw, bbox_mask, bbox_frame, track_id = selected silhouette = cast( Float[ndarray, "64 44"] | None, - mask_to_silhouette(self._to_mask_u8(mask_raw), bbox), + mask_to_silhouette(self._to_mask_u8(mask_raw), bbox_mask), ) if silhouette is not None: - return silhouette, mask_raw, bbox, int(track_id) + return silhouette, mask_raw, bbox_frame, int(track_id) fallback = cast( - tuple[UInt8[ndarray, "h w"], tuple[int, int, int, int]] | None, + tuple[UInt8[ndarray, "h w"], BBoxXYXY] | None, frame_to_person_mask(result), ) if fallback is None: return None - mask_u8, bbox = fallback + mask_u8, bbox_mask = fallback silhouette = cast( Float[ndarray, "64 44"] | None, - mask_to_silhouette(mask_u8, bbox), + mask_to_silhouette(mask_u8, bbox_mask), ) if silhouette is None: return None + # Convert mask-space bbox to frame-space for visualization + # Use result.orig_shape to get frame dimensions safely + orig_shape = getattr(result, "orig_shape", None) + if orig_shape is not None and isinstance(orig_shape, (tuple, list)) and len(orig_shape) >= 2: + frame_h, frame_w = int(orig_shape[0]), int(orig_shape[1]) + mask_h, mask_w = mask_u8.shape[0], mask_u8.shape[1] + if mask_w > 0 and mask_h > 0 and frame_w > 0 and frame_h > 0: + scale_x = frame_w / mask_w + scale_y = frame_h / mask_h + bbox_frame = ( + int(bbox_mask[0] * scale_x), + int(bbox_mask[1] * scale_y), + int(bbox_mask[2] * scale_x), + int(bbox_mask[3] * scale_y), + ) + else: + # Fallback: use mask-space bbox if dimensions invalid + bbox_frame = bbox_mask + else: + # Fallback: use mask-space bbox if orig_shape unavailable + bbox_frame = bbox_mask # For fallback case, mask_raw is the same as mask_u8 - return silhouette, mask_u8, bbox, 0 + return silhouette, mask_u8, bbox_frame, 0 + @jaxtyped(typechecker=beartype) def process_frame( @@ -342,23 +367,48 @@ class ScoliosisPipeline: ) # Update visualizer if enabled - if self._visualizer is not None and viz_payload is not None: - # Cast viz_payload to dict for type checking - viz_dict = cast(dict[str, object], viz_payload) - mask_raw_obj = viz_dict.get("mask_raw") - bbox_obj = viz_dict.get("bbox") - silhouette_obj = viz_dict.get("silhouette") - track_id_val = viz_dict.get("track_id", 0) - track_id = track_id_val if isinstance(track_id_val, int) else 0 - label_obj = viz_dict.get("label") - confidence_obj = viz_dict.get("confidence") + if self._visualizer is not None: + # Cache valid payload for no-detection frames + if viz_payload is not None: + # Cache a copy to prevent mutation of original data + viz_payload_dict = cast(dict[str, object], viz_payload) + cached: dict[str, object] = {} + for k, v in viz_payload_dict.items(): + copy_method = cast(Callable[[], object] | None, getattr(v, "copy", None)) + if copy_method is not None: + cached[k] = copy_method() + else: + cached[k] = v + self._last_viz_payload = cached + + # Use cached payload if current is None + viz_data = viz_payload if viz_payload is not None else self._last_viz_payload + + if viz_data is not None: + # Cast viz_payload to dict for type checking + viz_dict = cast(dict[str, object], viz_data) + mask_raw_obj = viz_dict.get("mask_raw") + bbox_obj = viz_dict.get("bbox") + silhouette_obj = viz_dict.get("silhouette") + track_id_val = viz_dict.get("track_id", 0) + track_id = track_id_val if isinstance(track_id_val, int) else 0 + label_obj = viz_dict.get("label") + confidence_obj = viz_dict.get("confidence") - # Cast extracted values to expected types - mask_raw = cast(NDArray[np.uint8] | None, mask_raw_obj) - bbox = cast(tuple[int, int, int, int] | None, bbox_obj) - silhouette = cast(NDArray[np.float32] | None, silhouette_obj) - label = cast(str | None, label_obj) - confidence = cast(float | None, confidence_obj) + # Cast extracted values to expected types + mask_raw = cast(NDArray[np.uint8] | None, mask_raw_obj) + bbox = cast(BBoxXYXY | None, bbox_obj) + silhouette = cast(NDArray[np.float32] | None, silhouette_obj) + label = cast(str | None, label_obj) + confidence = cast(float | None, confidence_obj) + else: + # No detection and no cache - use default values + mask_raw = None + bbox = None + track_id = 0 + silhouette = None + label = None + confidence = None keep_running = self._visualizer.update( frame_u8, diff --git a/opengait/demo/preprocess.py b/opengait/demo/preprocess.py index 74b9df9..7db214f 100644 --- a/opengait/demo/preprocess.py +++ b/opengait/demo/preprocess.py @@ -23,6 +23,9 @@ jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped) UInt8Array = NDArray[np.uint8] Float32Array = NDArray[np.float32] +#: Bounding box in XYXY format: (x1, y1, x2, y2) where (x1,y1) is top-left and (x2,y2) is bottom-right. +BBoxXYXY = tuple[int, int, int, int] + def _read_attr(container: object, key: str) -> object | None: if isinstance(container, dict): @@ -59,7 +62,15 @@ def _to_numpy_array(value: object) -> NDArray[np.generic]: return cast(NDArray[np.generic], np.asarray(current)) -def _bbox_from_mask(mask: UInt8[ndarray, "h w"]) -> tuple[int, int, int, int] | None: +def _bbox_from_mask(mask: UInt8[ndarray, "h w"]) -> BBoxXYXY | None: + """Extract bounding box from binary mask in XYXY format. + + Args: + mask: Binary mask array of shape (H, W) with dtype uint8. + + Returns: + Bounding box as (x1, y1, x2, y2) in XYXY format, or None if mask is empty. + """ mask_u8 = np.asarray(mask, dtype=np.uint8) coords = np.argwhere(mask_u8 > 0) if int(coords.size) == 0: @@ -76,9 +87,17 @@ def _bbox_from_mask(mask: UInt8[ndarray, "h w"]) -> tuple[int, int, int, int] | return (x1, y1, x2, y2) -def _sanitize_bbox( - bbox: tuple[int, int, int, int], height: int, width: int -) -> tuple[int, int, int, int] | None: +def _sanitize_bbox(bbox: BBoxXYXY, height: int, width: int) -> BBoxXYXY | None: + """Sanitize bounding box to ensure it's within image bounds. + + Args: + bbox: Bounding box in XYXY format (x1, y1, x2, y2). + height: Image height. + width: Image width. + + Returns: + Sanitized bounding box in XYXY format, or None if invalid. + """ x1, y1, x2, y2 = bbox x1c = max(0, min(int(x1), width - 1)) y1c = max(0, min(int(y1), height - 1)) @@ -92,7 +111,17 @@ def _sanitize_bbox( @jaxtyped(typechecker=beartype) def frame_to_person_mask( result: object, min_area: int = MIN_MASK_AREA -) -> tuple[UInt8[ndarray, "h w"], tuple[int, int, int, int]] | None: +) -> tuple[UInt8[ndarray, "h w"], BBoxXYXY] | None: + """Extract person mask and bounding box from detection result. + + Args: + result: Detection results object with boxes and masks attributes. + min_area: Minimum mask area to consider valid. + + Returns: + Tuple of (mask, bbox) where bbox is in XYXY format (x1, y1, x2, y2), + or None if no valid detections. + """ masks_obj = _read_attr(result, "masks") if masks_obj is None: return None @@ -152,7 +181,7 @@ def frame_to_person_mask( best_area = -1 best_mask: UInt8[ndarray, "h w"] | None = None - best_bbox: tuple[int, int, int, int] | None = None + best_bbox: BBoxXYXY | None = None for idx in range(mask_count): mask_float = np.asarray(masks_float[idx], dtype=np.float32) @@ -167,7 +196,7 @@ def frame_to_person_mask( if area < min_area: continue - bbox: tuple[int, int, int, int] | None = None + bbox: BBoxXYXY | None = None shape_2d = cast(tuple[int, int], mask_binary.shape) h = int(shape_2d[0]) w = int(shape_2d[1]) @@ -204,8 +233,18 @@ def frame_to_person_mask( @jaxtyped(typechecker=beartype) def mask_to_silhouette( mask: UInt8[ndarray, "h w"], - bbox: tuple[int, int, int, int], + bbox: BBoxXYXY, ) -> Float[ndarray, "64 44"] | None: + """Convert mask to standardized silhouette using bounding box. + + Args: + mask: Binary mask array of shape (H, W) with dtype uint8. + bbox: Bounding box in XYXY format (x1, y1, x2, y2). + + Returns: + Standardized silhouette array of shape (64, 44) with dtype float32, + or None if conversion fails. + """ mask_u8 = np.where(mask > 0, np.uint8(255), np.uint8(0)).astype(np.uint8) if int(np.count_nonzero(mask_u8)) < MIN_MASK_AREA: return None diff --git a/opengait/demo/visualizer.py b/opengait/demo/visualizer.py index 01d26c4..e399bdf 100644 --- a/opengait/demo/visualizer.py +++ b/opengait/demo/visualizer.py @@ -13,8 +13,11 @@ import cv2 import numpy as np from numpy.typing import NDArray +from .preprocess import BBoxXYXY + logger = logging.getLogger(__name__) + # Window names MAIN_WINDOW = "Scoliosis Detection" SEG_WINDOW = "Segmentation" @@ -66,13 +69,13 @@ class OpenCVVisualizer: def _draw_bbox( self, frame: ImageArray, - bbox: tuple[int, int, int, int] | None, + bbox: BBoxXYXY | None, ) -> None: """Draw bounding box on frame if present. Args: frame: Input frame (H, W, 3) uint8 - modified in place - bbox: Bounding box as (x1, y1, x2, y2) or None + bbox: Bounding box in XYXY format as (x1, y1, x2, y2) or None """ if bbox is None: return @@ -145,7 +148,7 @@ class OpenCVVisualizer: def _prepare_main_frame( self, frame: ImageArray, - bbox: tuple[int, int, int, int] | None, + bbox: BBoxXYXY | None, track_id: int, fps: float, label: str | None, @@ -155,7 +158,7 @@ class OpenCVVisualizer: Args: frame: Input frame (H, W, C) uint8 - bbox: Bounding box or None + bbox: Bounding box in XYXY format (x1, y1, x2, y2) or None track_id: Tracking ID fps: Current FPS label: Classification label or None @@ -324,6 +327,9 @@ class OpenCVVisualizer: mask_gray = cast(ImageArray, cv2.cvtColor(mask_raw, cv2.COLOR_BGR2GRAY)) else: mask_gray = mask_raw + # Normalize to uint8 [0,255] for display (handles both float [0,1] and uint8 inputs) + if mask_gray.dtype == np.float32 or mask_gray.dtype == np.float64: + mask_gray = (mask_gray * 255).astype(np.uint8) raw_gray = cast( ImageArray, cv2.resize( @@ -333,6 +339,7 @@ class OpenCVVisualizer: ), ) + # Normalized view preparation (without indicator) if silhouette is None: norm_gray = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH), dtype=np.uint8) @@ -402,7 +409,7 @@ class OpenCVVisualizer: def update( self, frame: ImageArray, - bbox: tuple[int, int, int, int] | None, + bbox: BBoxXYXY | None, track_id: int, mask_raw: ImageArray | None, silhouette: NDArray[np.float32] | None, @@ -414,7 +421,7 @@ class OpenCVVisualizer: Args: frame: Input frame (H, W, C) uint8 - bbox: Bounding box as (x1, y1, x2, y2) or None + bbox: Bounding box in XYXY format (x1, y1, x2, y2) or None track_id: Tracking ID mask_raw: Raw binary mask (H, W) uint8 or None silhouette: Normalized silhouette (64, 44) float32 [0,1] or None diff --git a/opengait/demo/window.py b/opengait/demo/window.py index 836ae0a..6443271 100644 --- a/opengait/demo/window.py +++ b/opengait/demo/window.py @@ -12,10 +12,11 @@ import torch from jaxtyping import Float from numpy import ndarray +from .preprocess import BBoxXYXY + if TYPE_CHECKING: from numpy.typing import NDArray - # Silhouette dimensions from preprocess.py SIL_HEIGHT: int = 64 SIL_WIDTH: int = 44 @@ -239,19 +240,23 @@ def _to_numpy(obj: _ArrayLike) -> ndarray: def select_person( results: _DetectionResults, -) -> tuple[ndarray, tuple[int, int, int, int], int] | None: +) -> tuple[ndarray, BBoxXYXY, BBoxXYXY, int] | None: """Select the person with largest bounding box from detection results. Args: results: Detection results object with boxes and masks attributes. Expected to have: - - boxes.xyxy: array of bounding boxes [N, 4] - - masks.data: array of masks [N, H, W] + - boxes.xyxy: array of bounding boxes [N, 4] in frame coordinates (XYXY format) + - masks.data: array of masks [N, H, W] in mask coordinates - boxes.id: optional track IDs [N] Returns: - Tuple of (mask, bbox, track_id) for the largest person, + Tuple of (mask, bbox_mask, bbox_frame, track_id) for the largest person, or None if no valid detections or track IDs unavailable. + - mask: the person's segmentation mask + - bbox_mask: bounding box in mask coordinate space (XYXY format: x1, y1, x2, y2) + - bbox_frame: bounding box in frame coordinate space (XYXY format: x1, y1, x2, y2) + - track_id: the person's track ID """ # Check for track IDs boxes_obj: _Boxes | object = getattr(results, "boxes", None) @@ -329,20 +334,27 @@ def select_person( # Scale bbox from frame space to mask space scale_x = mask_w / frame_w if frame_w > 0 else 1.0 scale_y = mask_h / frame_h if frame_h > 0 else 1.0 - bbox = ( + bbox_mask = ( int(float(bboxes[best_idx][0]) * scale_x), int(float(bboxes[best_idx][1]) * scale_y), int(float(bboxes[best_idx][2]) * scale_x), int(float(bboxes[best_idx][3]) * scale_y), ) - else: - # Fallback: use bbox as-is (assume same coordinate space) - bbox = ( + bbox_frame = ( int(float(bboxes[best_idx][0])), int(float(bboxes[best_idx][1])), int(float(bboxes[best_idx][2])), int(float(bboxes[best_idx][3])), ) + else: + # Fallback: use bbox as-is for both (assume same coordinate space) + bbox_mask = ( + int(float(bboxes[best_idx][0])), + int(float(bboxes[best_idx][1])), + int(float(bboxes[best_idx][2])), + int(float(bboxes[best_idx][3])), + ) + bbox_frame = bbox_mask track_id = int(track_ids[best_idx]) if best_idx < len(track_ids) else best_idx - return mask, bbox, track_id + return mask, bbox_mask, bbox_frame, track_id diff --git a/tests/demo/test_pipeline.py b/tests/demo/test_pipeline.py index ceb13e5..da06460 100644 --- a/tests/demo/test_pipeline.py +++ b/tests/demo/test_pipeline.py @@ -8,7 +8,10 @@ import subprocess import sys import time from typing import Final, cast +from unittest import mock +import numpy as np +from numpy.typing import NDArray import pytest import torch @@ -107,6 +110,7 @@ def _assert_prediction_schema(prediction: dict[str, object]) -> None: assert isinstance(prediction["timestamp_ns"], int) + def test_pipeline_cli_fps_benchmark_smoke( compatible_checkpoint_path: Path, ) -> None: @@ -280,7 +284,6 @@ def test_pipeline_cli_invalid_checkpoint_path_returns_user_error() -> None: assert "Error: Checkpoint not found" in result.stderr - def test_pipeline_cli_preprocess_only_requires_export_path( compatible_checkpoint_path: Path, ) -> None: @@ -511,7 +514,9 @@ def test_pipeline_cli_silhouette_and_result_export( ) # Verify both export files exist - assert silhouette_export.is_file(), f"Silhouette export not found: {silhouette_export}" + assert silhouette_export.is_file(), ( + f"Silhouette export not found: {silhouette_export}" + ) assert result_export.is_file(), f"Result export not found: {result_export}" # Verify silhouette export @@ -522,7 +527,9 @@ def test_pipeline_cli_silhouette_and_result_export( # Verify result export with open(result_export, "r", encoding="utf-8") as f: - predictions = [cast(dict[str, object], json.loads(line)) for line in f if line.strip()] + predictions = [ + cast(dict[str, object], json.loads(line)) for line in f if line.strip() + ] assert len(predictions) > 0 @@ -538,6 +545,7 @@ def test_pipeline_cli_parquet_export_requires_pyarrow( pytest.skip("pyarrow is installed, skipping missing dependency test") try: import pyarrow # noqa: F401 + pytest.skip("pyarrow is installed, skipping missing dependency test") except ImportError: pass @@ -573,7 +581,6 @@ def test_pipeline_cli_parquet_export_requires_pyarrow( assert "parquet" in result.stderr.lower() or "pyarrow" in result.stderr.lower() - def test_pipeline_cli_silhouette_visualization( compatible_checkpoint_path: Path, tmp_path: Path, @@ -673,3 +680,215 @@ def test_pipeline_cli_preprocess_only_with_visualization( assert len(silhouettes) == len(png_files), ( f"Mismatch: {len(silhouettes)} silhouettes exported but {len(png_files)} PNG files created" ) + + +class MockVisualizer: + """Mock visualizer to track update calls.""" + + def __init__(self) -> None: + self.update_calls: list[dict[str, object]] = [] + self.return_value: bool = True + + def update( + self, + frame: NDArray[np.uint8], + bbox: tuple[int, int, int, int] | None, + track_id: int, + mask_raw: NDArray[np.uint8] | None, + silhouette: NDArray[np.float32] | None, + label: str | None, + confidence: float | None, + fps: float, + ) -> bool: + self.update_calls.append( + { + "frame": frame, + "bbox": bbox, + "track_id": track_id, + "mask_raw": mask_raw, + "silhouette": silhouette, + "label": label, + "confidence": confidence, + "fps": fps, + } + ) + return self.return_value + + def close(self) -> None: + pass + + +def test_pipeline_visualizer_updates_on_no_detection() -> None: + """Test that visualizer is still updated even when process_frame returns None. + + This is a regression test for the window freeze issue when no person is detected. + The window should refresh every frame to prevent freezing. + """ + from opengait.demo.pipeline import ScoliosisPipeline + + # Create a minimal pipeline with mocked dependencies + with ( + mock.patch("opengait.demo.pipeline.YOLO") as mock_yolo, + mock.patch("opengait.demo.pipeline.create_source") as mock_source, + mock.patch("opengait.demo.pipeline.create_publisher") as mock_publisher, + mock.patch("opengait.demo.pipeline.ScoNetDemo") as mock_classifier, + ): + # Setup mock detector that returns no detections (causing process_frame to return None) + mock_detector = mock.MagicMock() + mock_detector.track.return_value = [] # No detections + mock_yolo.return_value = mock_detector + + # Setup mock source with 3 frames + mock_frame = np.zeros((480, 640, 3), dtype=np.uint8) + mock_source.return_value = [(mock_frame, {"frame_count": i}) for i in range(3)] + + # Setup mock publisher and classifier + mock_publisher.return_value = mock.MagicMock() + mock_classifier.return_value = mock.MagicMock() + + # Create pipeline with visualize enabled + pipeline = ScoliosisPipeline( + source="dummy.mp4", + checkpoint="dummy.pt", + config=str(CONFIG_PATH) if CONFIG_PATH.exists() else "dummy.yaml", + device="cpu", + yolo_model="dummy.pt", + window=30, + stride=30, + nats_url=None, + nats_subject="test", + max_frames=None, + visualize=True, + ) + + # Replace the visualizer with our mock + mock_viz = MockVisualizer() + pipeline._visualizer = mock_viz # type: ignore[assignment] + + # Run pipeline + _ = pipeline.run() + + # Verify visualizer was updated for all 3 frames even with no detections + assert len(mock_viz.update_calls) == 3, ( + f"Expected visualizer.update() to be called 3 times (once per frame), " + f"but was called {len(mock_viz.update_calls)} times. " + f"Window would freeze if not updated on no-detection frames." + ) + + # Verify each call had the frame data + for call in mock_viz.update_calls: + assert call["track_id"] == 0 # Default track_id when no detection + assert call["bbox"] is None # No bbox when no detection + assert call["mask_raw"] is None # No mask when no detection + assert call["silhouette"] is None # No silhouette when no detection + assert call["label"] is None # No label when no detection + assert call["confidence"] is None # No confidence when no detection + + + +def test_pipeline_visualizer_uses_cached_detection_on_no_detection() -> None: + """Test that visualizer reuses last valid detection when current frame has no detection. + + This is a regression test for the cache-reuse behavior when person temporarily + disappears from frame. The last valid silhouette/mask should be displayed. + """ + from opengait.demo.pipeline import ScoliosisPipeline + + # Create a minimal pipeline with mocked dependencies + with ( + mock.patch("opengait.demo.pipeline.YOLO") as mock_yolo, + mock.patch("opengait.demo.pipeline.create_source") as mock_source, + mock.patch("opengait.demo.pipeline.create_publisher") as mock_publisher, + mock.patch("opengait.demo.pipeline.ScoNetDemo") as mock_classifier, + mock.patch("opengait.demo.pipeline.select_person") as mock_select_person, + mock.patch("opengait.demo.pipeline.mask_to_silhouette") as mock_mask_to_sil, + ): + # Create mock detection result for frames 0-1 (valid detection) + mock_box = mock.MagicMock() + mock_box.xyxy = np.array([[100, 100, 200, 300]], dtype=np.float32) + mock_box.id = np.array([1], dtype=np.int64) + mock_mask = mock.MagicMock() + mock_mask.data = np.random.rand(1, 480, 640).astype(np.float32) + mock_result = mock.MagicMock() + mock_result.boxes = mock_box + mock_result.masks = mock_mask + + # Setup mock detector: detection for frames 0-1, then no detection for frames 2-3 + mock_detector = mock.MagicMock() + mock_detector.track.side_effect = [ + [mock_result], # Frame 0: valid detection + [mock_result], # Frame 1: valid detection + [], # Frame 2: no detection + [], # Frame 3: no detection + ] + mock_yolo.return_value = mock_detector + + # Setup mock source with 4 frames + mock_frame = np.zeros((480, 640, 3), dtype=np.uint8) + mock_source.return_value = [(mock_frame, {"frame_count": i}) for i in range(4)] + + # Setup mock publisher and classifier + mock_publisher.return_value = mock.MagicMock() + mock_classifier.return_value = mock.MagicMock() + + # Setup mock select_person to return valid mask and bbox + dummy_mask = np.random.randint(0, 256, (480, 640), dtype=np.uint8) + dummy_bbox_mask = (100, 100, 200, 300) + dummy_bbox_frame = (100, 100, 200, 300) + mock_select_person.return_value = (dummy_mask, dummy_bbox_mask, dummy_bbox_frame, 1) + + # Setup mock mask_to_silhouette to return valid silhouette + dummy_silhouette = np.random.rand(64, 44).astype(np.float32) + mock_mask_to_sil.return_value = dummy_silhouette + + # Create pipeline with visualize enabled + pipeline = ScoliosisPipeline( + source="dummy.mp4", + checkpoint="dummy.pt", + config=str(CONFIG_PATH) if CONFIG_PATH.exists() else "dummy.yaml", + device="cpu", + yolo_model="dummy.pt", + window=30, + stride=30, + nats_url=None, + nats_subject="test", + max_frames=None, + visualize=True, + ) + + # Replace the visualizer with our mock + mock_viz = MockVisualizer() + pipeline._visualizer = mock_viz # type: ignore[assignment] + + # Run pipeline + _ = pipeline.run() + + # Verify visualizer was updated for all 4 frames + assert len(mock_viz.update_calls) == 4, ( + f"Expected visualizer.update() to be called 4 times, " + f"but was called {len(mock_viz.update_calls)} times." + ) + + # Extract the mask_raw values from each call + mask_raw_calls = [call["mask_raw"] for call in mock_viz.update_calls] + + # Frames 0 and 1 should have valid masks (not None) + assert mask_raw_calls[0] is not None, "Frame 0 should have valid mask" + assert mask_raw_calls[1] is not None, "Frame 1 should have valid mask" + + # Frames 2 and 3 should reuse the cached mask from frame 1 (not None) + assert mask_raw_calls[2] is not None, ( + "Frame 2 (no detection) should display cached mask from last valid detection, " + "not None/blank" + ) + assert mask_raw_calls[3] is not None, ( + "Frame 3 (no detection) should display cached mask from last valid detection, " + "not None/blank" + ) + + # The cached masks should be copies (different objects) to prevent mutation issues + if mask_raw_calls[1] is not None and mask_raw_calls[2] is not None: + assert mask_raw_calls[1] is not mask_raw_calls[2], ( + "Cached mask should be a copy, not the same object reference" + ) +