fix(demo): stabilize visualizer bbox and mask rendering

Align bbox coordinate handling across primary and fallback paths, normalize Both-mode raw mask rendering, and tighten demo result typing to reduce runtime/display inconsistencies.
This commit is contained in:
2026-02-28 18:05:33 +08:00
parent 06a6cd1ccf
commit 7f073179d7
7 changed files with 416 additions and 73 deletions
+10 -9
View File
@@ -7,7 +7,8 @@ Provides generator-based interfaces for video sources:
""" """
from collections.abc import AsyncIterator, Generator, Iterable from collections.abc import AsyncIterator, Generator, Iterable
from typing import TYPE_CHECKING, Protocol, cast from typing import Protocol, cast
import logging import logging
import numpy as np import numpy as np
@@ -17,15 +18,15 @@ logger = logging.getLogger(__name__)
# Type alias for frame stream: (frame_array, metadata_dict) # Type alias for frame stream: (frame_array, metadata_dict)
FrameStream = Iterable[tuple[np.ndarray, dict[str, object]]] FrameStream = Iterable[tuple[np.ndarray, dict[str, object]]]
if TYPE_CHECKING: # Protocol for cv-mmap metadata (needed at runtime for nested function annotation)
# Protocol for cv-mmap metadata to avoid direct import class _FrameMetadata(Protocol):
class _FrameMetadata(Protocol): frame_count: int
frame_count: int timestamp_ns: int
timestamp_ns: int
# Protocol for cv-mmap client
class _CvMmapClient(Protocol): # Protocol for cv-mmap client (needed at runtime for cast)
def __aiter__(self) -> AsyncIterator[tuple[np.ndarray, _FrameMetadata]]: ... class _CvMmapClient(Protocol):
def __aiter__(self) -> AsyncIterator[tuple[np.ndarray, _FrameMetadata]]: ...
def opencv_source( def opencv_source(
+24 -9
View File
@@ -14,7 +14,7 @@ import logging
import sys import sys
import threading import threading
import time import time
from typing import TYPE_CHECKING, Protocol, TextIO, cast, runtime_checkable from typing import TYPE_CHECKING, Protocol, TextIO, TypedDict, cast, runtime_checkable
if TYPE_CHECKING: if TYPE_CHECKING:
from types import TracebackType from types import TracebackType
@@ -22,17 +22,31 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DemoResult(TypedDict):
"""Typed result dictionary for demo pipeline output.
Contains classification result with frame metadata.
"""
frame: int
track_id: int
label: str
confidence: float
window: int
timestamp_ns: int
@runtime_checkable @runtime_checkable
class ResultPublisher(Protocol): class ResultPublisher(Protocol):
"""Protocol for result publishers.""" """Protocol for result publishers."""
def publish(self, result: dict[str, object]) -> None: def publish(self, result: DemoResult) -> None:
""" """
Publish a result dictionary. Publish a result dictionary.
Parameters Parameters
---------- ----------
result : dict[str, object] result : DemoResult
Result data with keys: frame, track_id, label, confidence, window, timestamp_ns Result data with keys: frame, track_id, label, confidence, window, timestamp_ns
""" """
... ...
@@ -54,13 +68,13 @@ class ConsolePublisher:
""" """
self._output = output self._output = output
def publish(self, result: dict[str, object]) -> None: def publish(self, result: DemoResult) -> None:
""" """
Publish result as JSON line. Publish result as JSON line.
Parameters Parameters
---------- ----------
result : dict[str, object] result : DemoResult
Result data with keys: frame, track_id, label, confidence, window, timestamp_ns Result data with keys: frame, track_id, label, confidence, window, timestamp_ns
""" """
try: try:
@@ -214,13 +228,13 @@ class NatsPublisher:
logger.warning(f"Failed to connect to NATS at {self._nats_url}: {e}") logger.warning(f"Failed to connect to NATS at {self._nats_url}: {e}")
return False return False
def publish(self, result: dict[str, object]) -> None: def publish(self, result: DemoResult) -> None:
""" """
Publish result to NATS subject. Publish result to NATS subject.
Parameters Parameters
---------- ----------
result : dict[str, object] result : DemoResult
Result data with keys: frame, track_id, label, confidence, window, timestamp_ns Result data with keys: frame, track_id, label, confidence, window, timestamp_ns
""" """
if not self._ensure_connected(): if not self._ensure_connected():
@@ -239,6 +253,7 @@ class NatsPublisher:
).encode("utf-8") ).encode("utf-8")
_ = await self._nc.publish(self._subject, payload) _ = await self._nc.publish(self._subject, payload)
_ = await self._nc.flush() _ = await self._nc.flush()
# Run publish in background loop # Run publish in background loop
future = asyncio.run_coroutine_threadsafe( future = asyncio.run_coroutine_threadsafe(
_publish(), _publish(),
@@ -331,7 +346,7 @@ def create_result(
confidence: float, confidence: float,
window: int | tuple[int, int], window: int | tuple[int, int],
timestamp_ns: int | None = None, timestamp_ns: int | None = None,
) -> dict[str, object]: ) -> DemoResult:
""" """
Create a standardized result dictionary. Create a standardized result dictionary.
@@ -353,7 +368,7 @@ def create_result(
Returns Returns
------- -------
dict[str, object] DemoResult
Standardized result dictionary Standardized result dictionary
""" """
return { return {
+77 -27
View File
@@ -17,8 +17,8 @@ from numpy.typing import NDArray
from ultralytics.models.yolo.model import YOLO from ultralytics.models.yolo.model import YOLO
from .input import FrameStream, create_source from .input import FrameStream, create_source
from .output import ResultPublisher, create_publisher, create_result from .output import DemoResult, ResultPublisher, create_publisher, create_result
from .preprocess import frame_to_person_mask, mask_to_silhouette from .preprocess import BBoxXYXY, frame_to_person_mask, mask_to_silhouette
from .sconet_demo import ScoNetDemo from .sconet_demo import ScoNetDemo
from .window import SilhouetteWindow, select_person from .window import SilhouetteWindow, select_person
@@ -53,6 +53,7 @@ class _DetectionResultsLike(Protocol):
def masks(self) -> _MasksLike: ... def masks(self) -> _MasksLike: ...
class _TrackCallable(Protocol): class _TrackCallable(Protocol):
def __call__( def __call__(
self, self,
@@ -80,8 +81,9 @@ class ScoliosisPipeline:
_silhouette_visualize_dir: Path | None _silhouette_visualize_dir: Path | None
_result_export_path: Path | None _result_export_path: Path | None
_result_export_format: str _result_export_format: str
_result_buffer: list[dict[str, object]] _result_buffer: list[DemoResult]
_visualizer: OpenCVVisualizer | None _visualizer: OpenCVVisualizer | None
_last_viz_payload: dict[str, object] | None
def __init__( def __init__(
self, self,
@@ -135,6 +137,7 @@ class ScoliosisPipeline:
self._visualizer = OpenCVVisualizer() self._visualizer = OpenCVVisualizer()
else: else:
self._visualizer = None self._visualizer = None
self._last_viz_payload = None
@staticmethod @staticmethod
def _extract_int(meta: dict[str, object], key: str, fallback: int) -> int: def _extract_int(meta: dict[str, object], key: str, fallback: int) -> int:
@@ -171,37 +174,59 @@ class ScoliosisPipeline:
tuple[ tuple[
Float[ndarray, "64 44"], Float[ndarray, "64 44"],
UInt8[ndarray, "h w"], UInt8[ndarray, "h w"],
tuple[int, int, int, int], BBoxXYXY,
int, int,
] ]
| None | None
): ):
selected = select_person(result) selected = select_person(result)
if selected is not None: if selected is not None:
mask_raw, bbox, track_id = selected mask_raw, bbox_mask, bbox_frame, track_id = selected
silhouette = cast( silhouette = cast(
Float[ndarray, "64 44"] | None, Float[ndarray, "64 44"] | None,
mask_to_silhouette(self._to_mask_u8(mask_raw), bbox), mask_to_silhouette(self._to_mask_u8(mask_raw), bbox_mask),
) )
if silhouette is not None: if silhouette is not None:
return silhouette, mask_raw, bbox, int(track_id) return silhouette, mask_raw, bbox_frame, int(track_id)
fallback = cast( fallback = cast(
tuple[UInt8[ndarray, "h w"], tuple[int, int, int, int]] | None, tuple[UInt8[ndarray, "h w"], BBoxXYXY] | None,
frame_to_person_mask(result), frame_to_person_mask(result),
) )
if fallback is None: if fallback is None:
return None return None
mask_u8, bbox = fallback mask_u8, bbox_mask = fallback
silhouette = cast( silhouette = cast(
Float[ndarray, "64 44"] | None, Float[ndarray, "64 44"] | None,
mask_to_silhouette(mask_u8, bbox), mask_to_silhouette(mask_u8, bbox_mask),
) )
if silhouette is None: if silhouette is None:
return None return None
# Convert mask-space bbox to frame-space for visualization
# Use result.orig_shape to get frame dimensions safely
orig_shape = getattr(result, "orig_shape", None)
if orig_shape is not None and isinstance(orig_shape, (tuple, list)) and len(orig_shape) >= 2:
frame_h, frame_w = int(orig_shape[0]), int(orig_shape[1])
mask_h, mask_w = mask_u8.shape[0], mask_u8.shape[1]
if mask_w > 0 and mask_h > 0 and frame_w > 0 and frame_h > 0:
scale_x = frame_w / mask_w
scale_y = frame_h / mask_h
bbox_frame = (
int(bbox_mask[0] * scale_x),
int(bbox_mask[1] * scale_y),
int(bbox_mask[2] * scale_x),
int(bbox_mask[3] * scale_y),
)
else:
# Fallback: use mask-space bbox if dimensions invalid
bbox_frame = bbox_mask
else:
# Fallback: use mask-space bbox if orig_shape unavailable
bbox_frame = bbox_mask
# For fallback case, mask_raw is the same as mask_u8 # For fallback case, mask_raw is the same as mask_u8
return silhouette, mask_u8, bbox, 0 return silhouette, mask_u8, bbox_frame, 0
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
def process_frame( def process_frame(
@@ -342,23 +367,48 @@ class ScoliosisPipeline:
) )
# Update visualizer if enabled # Update visualizer if enabled
if self._visualizer is not None and viz_payload is not None: if self._visualizer is not None:
# Cast viz_payload to dict for type checking # Cache valid payload for no-detection frames
viz_dict = cast(dict[str, object], viz_payload) if viz_payload is not None:
mask_raw_obj = viz_dict.get("mask_raw") # Cache a copy to prevent mutation of original data
bbox_obj = viz_dict.get("bbox") viz_payload_dict = cast(dict[str, object], viz_payload)
silhouette_obj = viz_dict.get("silhouette") cached: dict[str, object] = {}
track_id_val = viz_dict.get("track_id", 0) for k, v in viz_payload_dict.items():
track_id = track_id_val if isinstance(track_id_val, int) else 0 copy_method = cast(Callable[[], object] | None, getattr(v, "copy", None))
label_obj = viz_dict.get("label") if copy_method is not None:
confidence_obj = viz_dict.get("confidence") cached[k] = copy_method()
else:
cached[k] = v
self._last_viz_payload = cached
# Use cached payload if current is None
viz_data = viz_payload if viz_payload is not None else self._last_viz_payload
if viz_data is not None:
# Cast viz_payload to dict for type checking
viz_dict = cast(dict[str, object], viz_data)
mask_raw_obj = viz_dict.get("mask_raw")
bbox_obj = viz_dict.get("bbox")
silhouette_obj = viz_dict.get("silhouette")
track_id_val = viz_dict.get("track_id", 0)
track_id = track_id_val if isinstance(track_id_val, int) else 0
label_obj = viz_dict.get("label")
confidence_obj = viz_dict.get("confidence")
# Cast extracted values to expected types # Cast extracted values to expected types
mask_raw = cast(NDArray[np.uint8] | None, mask_raw_obj) mask_raw = cast(NDArray[np.uint8] | None, mask_raw_obj)
bbox = cast(tuple[int, int, int, int] | None, bbox_obj) bbox = cast(BBoxXYXY | None, bbox_obj)
silhouette = cast(NDArray[np.float32] | None, silhouette_obj) silhouette = cast(NDArray[np.float32] | None, silhouette_obj)
label = cast(str | None, label_obj) label = cast(str | None, label_obj)
confidence = cast(float | None, confidence_obj) confidence = cast(float | None, confidence_obj)
else:
# No detection and no cache - use default values
mask_raw = None
bbox = None
track_id = 0
silhouette = None
label = None
confidence = None
keep_running = self._visualizer.update( keep_running = self._visualizer.update(
frame_u8, frame_u8,
+47 -8
View File
@@ -23,6 +23,9 @@ jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
UInt8Array = NDArray[np.uint8] UInt8Array = NDArray[np.uint8]
Float32Array = NDArray[np.float32] Float32Array = NDArray[np.float32]
#: Bounding box in XYXY format: (x1, y1, x2, y2) where (x1,y1) is top-left and (x2,y2) is bottom-right.
BBoxXYXY = tuple[int, int, int, int]
def _read_attr(container: object, key: str) -> object | None: def _read_attr(container: object, key: str) -> object | None:
if isinstance(container, dict): if isinstance(container, dict):
@@ -59,7 +62,15 @@ def _to_numpy_array(value: object) -> NDArray[np.generic]:
return cast(NDArray[np.generic], np.asarray(current)) return cast(NDArray[np.generic], np.asarray(current))
def _bbox_from_mask(mask: UInt8[ndarray, "h w"]) -> tuple[int, int, int, int] | None: def _bbox_from_mask(mask: UInt8[ndarray, "h w"]) -> BBoxXYXY | None:
"""Extract bounding box from binary mask in XYXY format.
Args:
mask: Binary mask array of shape (H, W) with dtype uint8.
Returns:
Bounding box as (x1, y1, x2, y2) in XYXY format, or None if mask is empty.
"""
mask_u8 = np.asarray(mask, dtype=np.uint8) mask_u8 = np.asarray(mask, dtype=np.uint8)
coords = np.argwhere(mask_u8 > 0) coords = np.argwhere(mask_u8 > 0)
if int(coords.size) == 0: if int(coords.size) == 0:
@@ -76,9 +87,17 @@ def _bbox_from_mask(mask: UInt8[ndarray, "h w"]) -> tuple[int, int, int, int] |
return (x1, y1, x2, y2) return (x1, y1, x2, y2)
def _sanitize_bbox( def _sanitize_bbox(bbox: BBoxXYXY, height: int, width: int) -> BBoxXYXY | None:
bbox: tuple[int, int, int, int], height: int, width: int """Sanitize bounding box to ensure it's within image bounds.
) -> tuple[int, int, int, int] | None:
Args:
bbox: Bounding box in XYXY format (x1, y1, x2, y2).
height: Image height.
width: Image width.
Returns:
Sanitized bounding box in XYXY format, or None if invalid.
"""
x1, y1, x2, y2 = bbox x1, y1, x2, y2 = bbox
x1c = max(0, min(int(x1), width - 1)) x1c = max(0, min(int(x1), width - 1))
y1c = max(0, min(int(y1), height - 1)) y1c = max(0, min(int(y1), height - 1))
@@ -92,7 +111,17 @@ def _sanitize_bbox(
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
def frame_to_person_mask( def frame_to_person_mask(
result: object, min_area: int = MIN_MASK_AREA result: object, min_area: int = MIN_MASK_AREA
) -> tuple[UInt8[ndarray, "h w"], tuple[int, int, int, int]] | None: ) -> tuple[UInt8[ndarray, "h w"], BBoxXYXY] | None:
"""Extract person mask and bounding box from detection result.
Args:
result: Detection results object with boxes and masks attributes.
min_area: Minimum mask area to consider valid.
Returns:
Tuple of (mask, bbox) where bbox is in XYXY format (x1, y1, x2, y2),
or None if no valid detections.
"""
masks_obj = _read_attr(result, "masks") masks_obj = _read_attr(result, "masks")
if masks_obj is None: if masks_obj is None:
return None return None
@@ -152,7 +181,7 @@ def frame_to_person_mask(
best_area = -1 best_area = -1
best_mask: UInt8[ndarray, "h w"] | None = None best_mask: UInt8[ndarray, "h w"] | None = None
best_bbox: tuple[int, int, int, int] | None = None best_bbox: BBoxXYXY | None = None
for idx in range(mask_count): for idx in range(mask_count):
mask_float = np.asarray(masks_float[idx], dtype=np.float32) mask_float = np.asarray(masks_float[idx], dtype=np.float32)
@@ -167,7 +196,7 @@ def frame_to_person_mask(
if area < min_area: if area < min_area:
continue continue
bbox: tuple[int, int, int, int] | None = None bbox: BBoxXYXY | None = None
shape_2d = cast(tuple[int, int], mask_binary.shape) shape_2d = cast(tuple[int, int], mask_binary.shape)
h = int(shape_2d[0]) h = int(shape_2d[0])
w = int(shape_2d[1]) w = int(shape_2d[1])
@@ -204,8 +233,18 @@ def frame_to_person_mask(
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
def mask_to_silhouette( def mask_to_silhouette(
mask: UInt8[ndarray, "h w"], mask: UInt8[ndarray, "h w"],
bbox: tuple[int, int, int, int], bbox: BBoxXYXY,
) -> Float[ndarray, "64 44"] | None: ) -> Float[ndarray, "64 44"] | None:
"""Convert mask to standardized silhouette using bounding box.
Args:
mask: Binary mask array of shape (H, W) with dtype uint8.
bbox: Bounding box in XYXY format (x1, y1, x2, y2).
Returns:
Standardized silhouette array of shape (64, 44) with dtype float32,
or None if conversion fails.
"""
mask_u8 = np.where(mask > 0, np.uint8(255), np.uint8(0)).astype(np.uint8) mask_u8 = np.where(mask > 0, np.uint8(255), np.uint8(0)).astype(np.uint8)
if int(np.count_nonzero(mask_u8)) < MIN_MASK_AREA: if int(np.count_nonzero(mask_u8)) < MIN_MASK_AREA:
return None return None
+13 -6
View File
@@ -13,8 +13,11 @@ import cv2
import numpy as np import numpy as np
from numpy.typing import NDArray from numpy.typing import NDArray
from .preprocess import BBoxXYXY
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Window names # Window names
MAIN_WINDOW = "Scoliosis Detection" MAIN_WINDOW = "Scoliosis Detection"
SEG_WINDOW = "Segmentation" SEG_WINDOW = "Segmentation"
@@ -66,13 +69,13 @@ class OpenCVVisualizer:
def _draw_bbox( def _draw_bbox(
self, self,
frame: ImageArray, frame: ImageArray,
bbox: tuple[int, int, int, int] | None, bbox: BBoxXYXY | None,
) -> None: ) -> None:
"""Draw bounding box on frame if present. """Draw bounding box on frame if present.
Args: Args:
frame: Input frame (H, W, 3) uint8 - modified in place frame: Input frame (H, W, 3) uint8 - modified in place
bbox: Bounding box as (x1, y1, x2, y2) or None bbox: Bounding box in XYXY format as (x1, y1, x2, y2) or None
""" """
if bbox is None: if bbox is None:
return return
@@ -145,7 +148,7 @@ class OpenCVVisualizer:
def _prepare_main_frame( def _prepare_main_frame(
self, self,
frame: ImageArray, frame: ImageArray,
bbox: tuple[int, int, int, int] | None, bbox: BBoxXYXY | None,
track_id: int, track_id: int,
fps: float, fps: float,
label: str | None, label: str | None,
@@ -155,7 +158,7 @@ class OpenCVVisualizer:
Args: Args:
frame: Input frame (H, W, C) uint8 frame: Input frame (H, W, C) uint8
bbox: Bounding box or None bbox: Bounding box in XYXY format (x1, y1, x2, y2) or None
track_id: Tracking ID track_id: Tracking ID
fps: Current FPS fps: Current FPS
label: Classification label or None label: Classification label or None
@@ -324,6 +327,9 @@ class OpenCVVisualizer:
mask_gray = cast(ImageArray, cv2.cvtColor(mask_raw, cv2.COLOR_BGR2GRAY)) mask_gray = cast(ImageArray, cv2.cvtColor(mask_raw, cv2.COLOR_BGR2GRAY))
else: else:
mask_gray = mask_raw mask_gray = mask_raw
# Normalize to uint8 [0,255] for display (handles both float [0,1] and uint8 inputs)
if mask_gray.dtype == np.float32 or mask_gray.dtype == np.float64:
mask_gray = (mask_gray * 255).astype(np.uint8)
raw_gray = cast( raw_gray = cast(
ImageArray, ImageArray,
cv2.resize( cv2.resize(
@@ -333,6 +339,7 @@ class OpenCVVisualizer:
), ),
) )
# Normalized view preparation (without indicator) # Normalized view preparation (without indicator)
if silhouette is None: if silhouette is None:
norm_gray = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH), dtype=np.uint8) norm_gray = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH), dtype=np.uint8)
@@ -402,7 +409,7 @@ class OpenCVVisualizer:
def update( def update(
self, self,
frame: ImageArray, frame: ImageArray,
bbox: tuple[int, int, int, int] | None, bbox: BBoxXYXY | None,
track_id: int, track_id: int,
mask_raw: ImageArray | None, mask_raw: ImageArray | None,
silhouette: NDArray[np.float32] | None, silhouette: NDArray[np.float32] | None,
@@ -414,7 +421,7 @@ class OpenCVVisualizer:
Args: Args:
frame: Input frame (H, W, C) uint8 frame: Input frame (H, W, C) uint8
bbox: Bounding box as (x1, y1, x2, y2) or None bbox: Bounding box in XYXY format (x1, y1, x2, y2) or None
track_id: Tracking ID track_id: Tracking ID
mask_raw: Raw binary mask (H, W) uint8 or None mask_raw: Raw binary mask (H, W) uint8 or None
silhouette: Normalized silhouette (64, 44) float32 [0,1] or None silhouette: Normalized silhouette (64, 44) float32 [0,1] or None
+22 -10
View File
@@ -12,10 +12,11 @@ import torch
from jaxtyping import Float from jaxtyping import Float
from numpy import ndarray from numpy import ndarray
from .preprocess import BBoxXYXY
if TYPE_CHECKING: if TYPE_CHECKING:
from numpy.typing import NDArray from numpy.typing import NDArray
# Silhouette dimensions from preprocess.py # Silhouette dimensions from preprocess.py
SIL_HEIGHT: int = 64 SIL_HEIGHT: int = 64
SIL_WIDTH: int = 44 SIL_WIDTH: int = 44
@@ -239,19 +240,23 @@ def _to_numpy(obj: _ArrayLike) -> ndarray:
def select_person( def select_person(
results: _DetectionResults, results: _DetectionResults,
) -> tuple[ndarray, tuple[int, int, int, int], int] | None: ) -> tuple[ndarray, BBoxXYXY, BBoxXYXY, int] | None:
"""Select the person with largest bounding box from detection results. """Select the person with largest bounding box from detection results.
Args: Args:
results: Detection results object with boxes and masks attributes. results: Detection results object with boxes and masks attributes.
Expected to have: Expected to have:
- boxes.xyxy: array of bounding boxes [N, 4] - boxes.xyxy: array of bounding boxes [N, 4] in frame coordinates (XYXY format)
- masks.data: array of masks [N, H, W] - masks.data: array of masks [N, H, W] in mask coordinates
- boxes.id: optional track IDs [N] - boxes.id: optional track IDs [N]
Returns: Returns:
Tuple of (mask, bbox, track_id) for the largest person, Tuple of (mask, bbox_mask, bbox_frame, track_id) for the largest person,
or None if no valid detections or track IDs unavailable. or None if no valid detections or track IDs unavailable.
- mask: the person's segmentation mask
- bbox_mask: bounding box in mask coordinate space (XYXY format: x1, y1, x2, y2)
- bbox_frame: bounding box in frame coordinate space (XYXY format: x1, y1, x2, y2)
- track_id: the person's track ID
""" """
# Check for track IDs # Check for track IDs
boxes_obj: _Boxes | object = getattr(results, "boxes", None) boxes_obj: _Boxes | object = getattr(results, "boxes", None)
@@ -329,20 +334,27 @@ def select_person(
# Scale bbox from frame space to mask space # Scale bbox from frame space to mask space
scale_x = mask_w / frame_w if frame_w > 0 else 1.0 scale_x = mask_w / frame_w if frame_w > 0 else 1.0
scale_y = mask_h / frame_h if frame_h > 0 else 1.0 scale_y = mask_h / frame_h if frame_h > 0 else 1.0
bbox = ( bbox_mask = (
int(float(bboxes[best_idx][0]) * scale_x), int(float(bboxes[best_idx][0]) * scale_x),
int(float(bboxes[best_idx][1]) * scale_y), int(float(bboxes[best_idx][1]) * scale_y),
int(float(bboxes[best_idx][2]) * scale_x), int(float(bboxes[best_idx][2]) * scale_x),
int(float(bboxes[best_idx][3]) * scale_y), int(float(bboxes[best_idx][3]) * scale_y),
) )
else: bbox_frame = (
# Fallback: use bbox as-is (assume same coordinate space)
bbox = (
int(float(bboxes[best_idx][0])), int(float(bboxes[best_idx][0])),
int(float(bboxes[best_idx][1])), int(float(bboxes[best_idx][1])),
int(float(bboxes[best_idx][2])), int(float(bboxes[best_idx][2])),
int(float(bboxes[best_idx][3])), int(float(bboxes[best_idx][3])),
) )
else:
# Fallback: use bbox as-is for both (assume same coordinate space)
bbox_mask = (
int(float(bboxes[best_idx][0])),
int(float(bboxes[best_idx][1])),
int(float(bboxes[best_idx][2])),
int(float(bboxes[best_idx][3])),
)
bbox_frame = bbox_mask
track_id = int(track_ids[best_idx]) if best_idx < len(track_ids) else best_idx track_id = int(track_ids[best_idx]) if best_idx < len(track_ids) else best_idx
return mask, bbox, track_id return mask, bbox_mask, bbox_frame, track_id
+223 -4
View File
@@ -8,7 +8,10 @@ import subprocess
import sys import sys
import time import time
from typing import Final, cast from typing import Final, cast
from unittest import mock
import numpy as np
from numpy.typing import NDArray
import pytest import pytest
import torch import torch
@@ -107,6 +110,7 @@ def _assert_prediction_schema(prediction: dict[str, object]) -> None:
assert isinstance(prediction["timestamp_ns"], int) assert isinstance(prediction["timestamp_ns"], int)
def test_pipeline_cli_fps_benchmark_smoke( def test_pipeline_cli_fps_benchmark_smoke(
compatible_checkpoint_path: Path, compatible_checkpoint_path: Path,
) -> None: ) -> None:
@@ -280,7 +284,6 @@ def test_pipeline_cli_invalid_checkpoint_path_returns_user_error() -> None:
assert "Error: Checkpoint not found" in result.stderr assert "Error: Checkpoint not found" in result.stderr
def test_pipeline_cli_preprocess_only_requires_export_path( def test_pipeline_cli_preprocess_only_requires_export_path(
compatible_checkpoint_path: Path, compatible_checkpoint_path: Path,
) -> None: ) -> None:
@@ -511,7 +514,9 @@ def test_pipeline_cli_silhouette_and_result_export(
) )
# Verify both export files exist # Verify both export files exist
assert silhouette_export.is_file(), f"Silhouette export not found: {silhouette_export}" assert silhouette_export.is_file(), (
f"Silhouette export not found: {silhouette_export}"
)
assert result_export.is_file(), f"Result export not found: {result_export}" assert result_export.is_file(), f"Result export not found: {result_export}"
# Verify silhouette export # Verify silhouette export
@@ -522,7 +527,9 @@ def test_pipeline_cli_silhouette_and_result_export(
# Verify result export # Verify result export
with open(result_export, "r", encoding="utf-8") as f: with open(result_export, "r", encoding="utf-8") as f:
predictions = [cast(dict[str, object], json.loads(line)) for line in f if line.strip()] predictions = [
cast(dict[str, object], json.loads(line)) for line in f if line.strip()
]
assert len(predictions) > 0 assert len(predictions) > 0
@@ -538,6 +545,7 @@ def test_pipeline_cli_parquet_export_requires_pyarrow(
pytest.skip("pyarrow is installed, skipping missing dependency test") pytest.skip("pyarrow is installed, skipping missing dependency test")
try: try:
import pyarrow # noqa: F401 import pyarrow # noqa: F401
pytest.skip("pyarrow is installed, skipping missing dependency test") pytest.skip("pyarrow is installed, skipping missing dependency test")
except ImportError: except ImportError:
pass pass
@@ -573,7 +581,6 @@ def test_pipeline_cli_parquet_export_requires_pyarrow(
assert "parquet" in result.stderr.lower() or "pyarrow" in result.stderr.lower() assert "parquet" in result.stderr.lower() or "pyarrow" in result.stderr.lower()
def test_pipeline_cli_silhouette_visualization( def test_pipeline_cli_silhouette_visualization(
compatible_checkpoint_path: Path, compatible_checkpoint_path: Path,
tmp_path: Path, tmp_path: Path,
@@ -673,3 +680,215 @@ def test_pipeline_cli_preprocess_only_with_visualization(
assert len(silhouettes) == len(png_files), ( assert len(silhouettes) == len(png_files), (
f"Mismatch: {len(silhouettes)} silhouettes exported but {len(png_files)} PNG files created" f"Mismatch: {len(silhouettes)} silhouettes exported but {len(png_files)} PNG files created"
) )
class MockVisualizer:
"""Mock visualizer to track update calls."""
def __init__(self) -> None:
self.update_calls: list[dict[str, object]] = []
self.return_value: bool = True
def update(
self,
frame: NDArray[np.uint8],
bbox: tuple[int, int, int, int] | None,
track_id: int,
mask_raw: NDArray[np.uint8] | None,
silhouette: NDArray[np.float32] | None,
label: str | None,
confidence: float | None,
fps: float,
) -> bool:
self.update_calls.append(
{
"frame": frame,
"bbox": bbox,
"track_id": track_id,
"mask_raw": mask_raw,
"silhouette": silhouette,
"label": label,
"confidence": confidence,
"fps": fps,
}
)
return self.return_value
def close(self) -> None:
pass
def test_pipeline_visualizer_updates_on_no_detection() -> None:
"""Test that visualizer is still updated even when process_frame returns None.
This is a regression test for the window freeze issue when no person is detected.
The window should refresh every frame to prevent freezing.
"""
from opengait.demo.pipeline import ScoliosisPipeline
# Create a minimal pipeline with mocked dependencies
with (
mock.patch("opengait.demo.pipeline.YOLO") as mock_yolo,
mock.patch("opengait.demo.pipeline.create_source") as mock_source,
mock.patch("opengait.demo.pipeline.create_publisher") as mock_publisher,
mock.patch("opengait.demo.pipeline.ScoNetDemo") as mock_classifier,
):
# Setup mock detector that returns no detections (causing process_frame to return None)
mock_detector = mock.MagicMock()
mock_detector.track.return_value = [] # No detections
mock_yolo.return_value = mock_detector
# Setup mock source with 3 frames
mock_frame = np.zeros((480, 640, 3), dtype=np.uint8)
mock_source.return_value = [(mock_frame, {"frame_count": i}) for i in range(3)]
# Setup mock publisher and classifier
mock_publisher.return_value = mock.MagicMock()
mock_classifier.return_value = mock.MagicMock()
# Create pipeline with visualize enabled
pipeline = ScoliosisPipeline(
source="dummy.mp4",
checkpoint="dummy.pt",
config=str(CONFIG_PATH) if CONFIG_PATH.exists() else "dummy.yaml",
device="cpu",
yolo_model="dummy.pt",
window=30,
stride=30,
nats_url=None,
nats_subject="test",
max_frames=None,
visualize=True,
)
# Replace the visualizer with our mock
mock_viz = MockVisualizer()
pipeline._visualizer = mock_viz # type: ignore[assignment]
# Run pipeline
_ = pipeline.run()
# Verify visualizer was updated for all 3 frames even with no detections
assert len(mock_viz.update_calls) == 3, (
f"Expected visualizer.update() to be called 3 times (once per frame), "
f"but was called {len(mock_viz.update_calls)} times. "
f"Window would freeze if not updated on no-detection frames."
)
# Verify each call had the frame data
for call in mock_viz.update_calls:
assert call["track_id"] == 0 # Default track_id when no detection
assert call["bbox"] is None # No bbox when no detection
assert call["mask_raw"] is None # No mask when no detection
assert call["silhouette"] is None # No silhouette when no detection
assert call["label"] is None # No label when no detection
assert call["confidence"] is None # No confidence when no detection
def test_pipeline_visualizer_uses_cached_detection_on_no_detection() -> None:
"""Test that visualizer reuses last valid detection when current frame has no detection.
This is a regression test for the cache-reuse behavior when person temporarily
disappears from frame. The last valid silhouette/mask should be displayed.
"""
from opengait.demo.pipeline import ScoliosisPipeline
# Create a minimal pipeline with mocked dependencies
with (
mock.patch("opengait.demo.pipeline.YOLO") as mock_yolo,
mock.patch("opengait.demo.pipeline.create_source") as mock_source,
mock.patch("opengait.demo.pipeline.create_publisher") as mock_publisher,
mock.patch("opengait.demo.pipeline.ScoNetDemo") as mock_classifier,
mock.patch("opengait.demo.pipeline.select_person") as mock_select_person,
mock.patch("opengait.demo.pipeline.mask_to_silhouette") as mock_mask_to_sil,
):
# Create mock detection result for frames 0-1 (valid detection)
mock_box = mock.MagicMock()
mock_box.xyxy = np.array([[100, 100, 200, 300]], dtype=np.float32)
mock_box.id = np.array([1], dtype=np.int64)
mock_mask = mock.MagicMock()
mock_mask.data = np.random.rand(1, 480, 640).astype(np.float32)
mock_result = mock.MagicMock()
mock_result.boxes = mock_box
mock_result.masks = mock_mask
# Setup mock detector: detection for frames 0-1, then no detection for frames 2-3
mock_detector = mock.MagicMock()
mock_detector.track.side_effect = [
[mock_result], # Frame 0: valid detection
[mock_result], # Frame 1: valid detection
[], # Frame 2: no detection
[], # Frame 3: no detection
]
mock_yolo.return_value = mock_detector
# Setup mock source with 4 frames
mock_frame = np.zeros((480, 640, 3), dtype=np.uint8)
mock_source.return_value = [(mock_frame, {"frame_count": i}) for i in range(4)]
# Setup mock publisher and classifier
mock_publisher.return_value = mock.MagicMock()
mock_classifier.return_value = mock.MagicMock()
# Setup mock select_person to return valid mask and bbox
dummy_mask = np.random.randint(0, 256, (480, 640), dtype=np.uint8)
dummy_bbox_mask = (100, 100, 200, 300)
dummy_bbox_frame = (100, 100, 200, 300)
mock_select_person.return_value = (dummy_mask, dummy_bbox_mask, dummy_bbox_frame, 1)
# Setup mock mask_to_silhouette to return valid silhouette
dummy_silhouette = np.random.rand(64, 44).astype(np.float32)
mock_mask_to_sil.return_value = dummy_silhouette
# Create pipeline with visualize enabled
pipeline = ScoliosisPipeline(
source="dummy.mp4",
checkpoint="dummy.pt",
config=str(CONFIG_PATH) if CONFIG_PATH.exists() else "dummy.yaml",
device="cpu",
yolo_model="dummy.pt",
window=30,
stride=30,
nats_url=None,
nats_subject="test",
max_frames=None,
visualize=True,
)
# Replace the visualizer with our mock
mock_viz = MockVisualizer()
pipeline._visualizer = mock_viz # type: ignore[assignment]
# Run pipeline
_ = pipeline.run()
# Verify visualizer was updated for all 4 frames
assert len(mock_viz.update_calls) == 4, (
f"Expected visualizer.update() to be called 4 times, "
f"but was called {len(mock_viz.update_calls)} times."
)
# Extract the mask_raw values from each call
mask_raw_calls = [call["mask_raw"] for call in mock_viz.update_calls]
# Frames 0 and 1 should have valid masks (not None)
assert mask_raw_calls[0] is not None, "Frame 0 should have valid mask"
assert mask_raw_calls[1] is not None, "Frame 1 should have valid mask"
# Frames 2 and 3 should reuse the cached mask from frame 1 (not None)
assert mask_raw_calls[2] is not None, (
"Frame 2 (no detection) should display cached mask from last valid detection, "
"not None/blank"
)
assert mask_raw_calls[3] is not None, (
"Frame 3 (no detection) should display cached mask from last valid detection, "
"not None/blank"
)
# The cached masks should be copies (different objects) to prevent mutation issues
if mask_raw_calls[1] is not None and mask_raw_calls[2] is not None:
assert mask_raw_calls[1] is not mask_raw_calls[2], (
"Cached mask should be a copy, not the same object reference"
)