fix(demo): stabilize visualizer bbox and mask rendering
Align bbox coordinate handling across primary and fallback paths, normalize Both-mode raw mask rendering, and tighten demo result typing to reduce runtime/display inconsistencies.
This commit is contained in:
+10
-9
@@ -7,7 +7,8 @@ Provides generator-based interfaces for video sources:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from collections.abc import AsyncIterator, Generator, Iterable
|
from collections.abc import AsyncIterator, Generator, Iterable
|
||||||
from typing import TYPE_CHECKING, Protocol, cast
|
from typing import Protocol, cast
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -17,15 +18,15 @@ logger = logging.getLogger(__name__)
|
|||||||
# Type alias for frame stream: (frame_array, metadata_dict)
|
# Type alias for frame stream: (frame_array, metadata_dict)
|
||||||
FrameStream = Iterable[tuple[np.ndarray, dict[str, object]]]
|
FrameStream = Iterable[tuple[np.ndarray, dict[str, object]]]
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
# Protocol for cv-mmap metadata (needed at runtime for nested function annotation)
|
||||||
# Protocol for cv-mmap metadata to avoid direct import
|
class _FrameMetadata(Protocol):
|
||||||
class _FrameMetadata(Protocol):
|
frame_count: int
|
||||||
frame_count: int
|
timestamp_ns: int
|
||||||
timestamp_ns: int
|
|
||||||
|
|
||||||
# Protocol for cv-mmap client
|
|
||||||
class _CvMmapClient(Protocol):
|
# Protocol for cv-mmap client (needed at runtime for cast)
|
||||||
def __aiter__(self) -> AsyncIterator[tuple[np.ndarray, _FrameMetadata]]: ...
|
class _CvMmapClient(Protocol):
|
||||||
|
def __aiter__(self) -> AsyncIterator[tuple[np.ndarray, _FrameMetadata]]: ...
|
||||||
|
|
||||||
|
|
||||||
def opencv_source(
|
def opencv_source(
|
||||||
|
|||||||
+24
-9
@@ -14,7 +14,7 @@ import logging
|
|||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, Protocol, TextIO, cast, runtime_checkable
|
from typing import TYPE_CHECKING, Protocol, TextIO, TypedDict, cast, runtime_checkable
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from types import TracebackType
|
from types import TracebackType
|
||||||
@@ -22,17 +22,31 @@ if TYPE_CHECKING:
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DemoResult(TypedDict):
|
||||||
|
"""Typed result dictionary for demo pipeline output.
|
||||||
|
|
||||||
|
Contains classification result with frame metadata.
|
||||||
|
"""
|
||||||
|
|
||||||
|
frame: int
|
||||||
|
track_id: int
|
||||||
|
label: str
|
||||||
|
confidence: float
|
||||||
|
window: int
|
||||||
|
timestamp_ns: int
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class ResultPublisher(Protocol):
|
class ResultPublisher(Protocol):
|
||||||
"""Protocol for result publishers."""
|
"""Protocol for result publishers."""
|
||||||
|
|
||||||
def publish(self, result: dict[str, object]) -> None:
|
def publish(self, result: DemoResult) -> None:
|
||||||
"""
|
"""
|
||||||
Publish a result dictionary.
|
Publish a result dictionary.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
result : dict[str, object]
|
result : DemoResult
|
||||||
Result data with keys: frame, track_id, label, confidence, window, timestamp_ns
|
Result data with keys: frame, track_id, label, confidence, window, timestamp_ns
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
@@ -54,13 +68,13 @@ class ConsolePublisher:
|
|||||||
"""
|
"""
|
||||||
self._output = output
|
self._output = output
|
||||||
|
|
||||||
def publish(self, result: dict[str, object]) -> None:
|
def publish(self, result: DemoResult) -> None:
|
||||||
"""
|
"""
|
||||||
Publish result as JSON line.
|
Publish result as JSON line.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
result : dict[str, object]
|
result : DemoResult
|
||||||
Result data with keys: frame, track_id, label, confidence, window, timestamp_ns
|
Result data with keys: frame, track_id, label, confidence, window, timestamp_ns
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
@@ -214,13 +228,13 @@ class NatsPublisher:
|
|||||||
logger.warning(f"Failed to connect to NATS at {self._nats_url}: {e}")
|
logger.warning(f"Failed to connect to NATS at {self._nats_url}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def publish(self, result: dict[str, object]) -> None:
|
def publish(self, result: DemoResult) -> None:
|
||||||
"""
|
"""
|
||||||
Publish result to NATS subject.
|
Publish result to NATS subject.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
result : dict[str, object]
|
result : DemoResult
|
||||||
Result data with keys: frame, track_id, label, confidence, window, timestamp_ns
|
Result data with keys: frame, track_id, label, confidence, window, timestamp_ns
|
||||||
"""
|
"""
|
||||||
if not self._ensure_connected():
|
if not self._ensure_connected():
|
||||||
@@ -239,6 +253,7 @@ class NatsPublisher:
|
|||||||
).encode("utf-8")
|
).encode("utf-8")
|
||||||
_ = await self._nc.publish(self._subject, payload)
|
_ = await self._nc.publish(self._subject, payload)
|
||||||
_ = await self._nc.flush()
|
_ = await self._nc.flush()
|
||||||
|
|
||||||
# Run publish in background loop
|
# Run publish in background loop
|
||||||
future = asyncio.run_coroutine_threadsafe(
|
future = asyncio.run_coroutine_threadsafe(
|
||||||
_publish(),
|
_publish(),
|
||||||
@@ -331,7 +346,7 @@ def create_result(
|
|||||||
confidence: float,
|
confidence: float,
|
||||||
window: int | tuple[int, int],
|
window: int | tuple[int, int],
|
||||||
timestamp_ns: int | None = None,
|
timestamp_ns: int | None = None,
|
||||||
) -> dict[str, object]:
|
) -> DemoResult:
|
||||||
"""
|
"""
|
||||||
Create a standardized result dictionary.
|
Create a standardized result dictionary.
|
||||||
|
|
||||||
@@ -353,7 +368,7 @@ def create_result(
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
dict[str, object]
|
DemoResult
|
||||||
Standardized result dictionary
|
Standardized result dictionary
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
|
|||||||
+77
-27
@@ -17,8 +17,8 @@ from numpy.typing import NDArray
|
|||||||
from ultralytics.models.yolo.model import YOLO
|
from ultralytics.models.yolo.model import YOLO
|
||||||
|
|
||||||
from .input import FrameStream, create_source
|
from .input import FrameStream, create_source
|
||||||
from .output import ResultPublisher, create_publisher, create_result
|
from .output import DemoResult, ResultPublisher, create_publisher, create_result
|
||||||
from .preprocess import frame_to_person_mask, mask_to_silhouette
|
from .preprocess import BBoxXYXY, frame_to_person_mask, mask_to_silhouette
|
||||||
from .sconet_demo import ScoNetDemo
|
from .sconet_demo import ScoNetDemo
|
||||||
from .window import SilhouetteWindow, select_person
|
from .window import SilhouetteWindow, select_person
|
||||||
|
|
||||||
@@ -53,6 +53,7 @@ class _DetectionResultsLike(Protocol):
|
|||||||
def masks(self) -> _MasksLike: ...
|
def masks(self) -> _MasksLike: ...
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class _TrackCallable(Protocol):
|
class _TrackCallable(Protocol):
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@@ -80,8 +81,9 @@ class ScoliosisPipeline:
|
|||||||
_silhouette_visualize_dir: Path | None
|
_silhouette_visualize_dir: Path | None
|
||||||
_result_export_path: Path | None
|
_result_export_path: Path | None
|
||||||
_result_export_format: str
|
_result_export_format: str
|
||||||
_result_buffer: list[dict[str, object]]
|
_result_buffer: list[DemoResult]
|
||||||
_visualizer: OpenCVVisualizer | None
|
_visualizer: OpenCVVisualizer | None
|
||||||
|
_last_viz_payload: dict[str, object] | None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -135,6 +137,7 @@ class ScoliosisPipeline:
|
|||||||
self._visualizer = OpenCVVisualizer()
|
self._visualizer = OpenCVVisualizer()
|
||||||
else:
|
else:
|
||||||
self._visualizer = None
|
self._visualizer = None
|
||||||
|
self._last_viz_payload = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_int(meta: dict[str, object], key: str, fallback: int) -> int:
|
def _extract_int(meta: dict[str, object], key: str, fallback: int) -> int:
|
||||||
@@ -171,37 +174,59 @@ class ScoliosisPipeline:
|
|||||||
tuple[
|
tuple[
|
||||||
Float[ndarray, "64 44"],
|
Float[ndarray, "64 44"],
|
||||||
UInt8[ndarray, "h w"],
|
UInt8[ndarray, "h w"],
|
||||||
tuple[int, int, int, int],
|
BBoxXYXY,
|
||||||
int,
|
int,
|
||||||
]
|
]
|
||||||
| None
|
| None
|
||||||
):
|
):
|
||||||
selected = select_person(result)
|
selected = select_person(result)
|
||||||
if selected is not None:
|
if selected is not None:
|
||||||
mask_raw, bbox, track_id = selected
|
mask_raw, bbox_mask, bbox_frame, track_id = selected
|
||||||
silhouette = cast(
|
silhouette = cast(
|
||||||
Float[ndarray, "64 44"] | None,
|
Float[ndarray, "64 44"] | None,
|
||||||
mask_to_silhouette(self._to_mask_u8(mask_raw), bbox),
|
mask_to_silhouette(self._to_mask_u8(mask_raw), bbox_mask),
|
||||||
)
|
)
|
||||||
if silhouette is not None:
|
if silhouette is not None:
|
||||||
return silhouette, mask_raw, bbox, int(track_id)
|
return silhouette, mask_raw, bbox_frame, int(track_id)
|
||||||
|
|
||||||
fallback = cast(
|
fallback = cast(
|
||||||
tuple[UInt8[ndarray, "h w"], tuple[int, int, int, int]] | None,
|
tuple[UInt8[ndarray, "h w"], BBoxXYXY] | None,
|
||||||
frame_to_person_mask(result),
|
frame_to_person_mask(result),
|
||||||
)
|
)
|
||||||
if fallback is None:
|
if fallback is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
mask_u8, bbox = fallback
|
mask_u8, bbox_mask = fallback
|
||||||
silhouette = cast(
|
silhouette = cast(
|
||||||
Float[ndarray, "64 44"] | None,
|
Float[ndarray, "64 44"] | None,
|
||||||
mask_to_silhouette(mask_u8, bbox),
|
mask_to_silhouette(mask_u8, bbox_mask),
|
||||||
)
|
)
|
||||||
if silhouette is None:
|
if silhouette is None:
|
||||||
return None
|
return None
|
||||||
|
# Convert mask-space bbox to frame-space for visualization
|
||||||
|
# Use result.orig_shape to get frame dimensions safely
|
||||||
|
orig_shape = getattr(result, "orig_shape", None)
|
||||||
|
if orig_shape is not None and isinstance(orig_shape, (tuple, list)) and len(orig_shape) >= 2:
|
||||||
|
frame_h, frame_w = int(orig_shape[0]), int(orig_shape[1])
|
||||||
|
mask_h, mask_w = mask_u8.shape[0], mask_u8.shape[1]
|
||||||
|
if mask_w > 0 and mask_h > 0 and frame_w > 0 and frame_h > 0:
|
||||||
|
scale_x = frame_w / mask_w
|
||||||
|
scale_y = frame_h / mask_h
|
||||||
|
bbox_frame = (
|
||||||
|
int(bbox_mask[0] * scale_x),
|
||||||
|
int(bbox_mask[1] * scale_y),
|
||||||
|
int(bbox_mask[2] * scale_x),
|
||||||
|
int(bbox_mask[3] * scale_y),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Fallback: use mask-space bbox if dimensions invalid
|
||||||
|
bbox_frame = bbox_mask
|
||||||
|
else:
|
||||||
|
# Fallback: use mask-space bbox if orig_shape unavailable
|
||||||
|
bbox_frame = bbox_mask
|
||||||
# For fallback case, mask_raw is the same as mask_u8
|
# For fallback case, mask_raw is the same as mask_u8
|
||||||
return silhouette, mask_u8, bbox, 0
|
return silhouette, mask_u8, bbox_frame, 0
|
||||||
|
|
||||||
|
|
||||||
@jaxtyped(typechecker=beartype)
|
@jaxtyped(typechecker=beartype)
|
||||||
def process_frame(
|
def process_frame(
|
||||||
@@ -342,23 +367,48 @@ class ScoliosisPipeline:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Update visualizer if enabled
|
# Update visualizer if enabled
|
||||||
if self._visualizer is not None and viz_payload is not None:
|
if self._visualizer is not None:
|
||||||
# Cast viz_payload to dict for type checking
|
# Cache valid payload for no-detection frames
|
||||||
viz_dict = cast(dict[str, object], viz_payload)
|
if viz_payload is not None:
|
||||||
mask_raw_obj = viz_dict.get("mask_raw")
|
# Cache a copy to prevent mutation of original data
|
||||||
bbox_obj = viz_dict.get("bbox")
|
viz_payload_dict = cast(dict[str, object], viz_payload)
|
||||||
silhouette_obj = viz_dict.get("silhouette")
|
cached: dict[str, object] = {}
|
||||||
track_id_val = viz_dict.get("track_id", 0)
|
for k, v in viz_payload_dict.items():
|
||||||
track_id = track_id_val if isinstance(track_id_val, int) else 0
|
copy_method = cast(Callable[[], object] | None, getattr(v, "copy", None))
|
||||||
label_obj = viz_dict.get("label")
|
if copy_method is not None:
|
||||||
confidence_obj = viz_dict.get("confidence")
|
cached[k] = copy_method()
|
||||||
|
else:
|
||||||
|
cached[k] = v
|
||||||
|
self._last_viz_payload = cached
|
||||||
|
|
||||||
# Cast extracted values to expected types
|
# Use cached payload if current is None
|
||||||
mask_raw = cast(NDArray[np.uint8] | None, mask_raw_obj)
|
viz_data = viz_payload if viz_payload is not None else self._last_viz_payload
|
||||||
bbox = cast(tuple[int, int, int, int] | None, bbox_obj)
|
|
||||||
silhouette = cast(NDArray[np.float32] | None, silhouette_obj)
|
if viz_data is not None:
|
||||||
label = cast(str | None, label_obj)
|
# Cast viz_payload to dict for type checking
|
||||||
confidence = cast(float | None, confidence_obj)
|
viz_dict = cast(dict[str, object], viz_data)
|
||||||
|
mask_raw_obj = viz_dict.get("mask_raw")
|
||||||
|
bbox_obj = viz_dict.get("bbox")
|
||||||
|
silhouette_obj = viz_dict.get("silhouette")
|
||||||
|
track_id_val = viz_dict.get("track_id", 0)
|
||||||
|
track_id = track_id_val if isinstance(track_id_val, int) else 0
|
||||||
|
label_obj = viz_dict.get("label")
|
||||||
|
confidence_obj = viz_dict.get("confidence")
|
||||||
|
|
||||||
|
# Cast extracted values to expected types
|
||||||
|
mask_raw = cast(NDArray[np.uint8] | None, mask_raw_obj)
|
||||||
|
bbox = cast(BBoxXYXY | None, bbox_obj)
|
||||||
|
silhouette = cast(NDArray[np.float32] | None, silhouette_obj)
|
||||||
|
label = cast(str | None, label_obj)
|
||||||
|
confidence = cast(float | None, confidence_obj)
|
||||||
|
else:
|
||||||
|
# No detection and no cache - use default values
|
||||||
|
mask_raw = None
|
||||||
|
bbox = None
|
||||||
|
track_id = 0
|
||||||
|
silhouette = None
|
||||||
|
label = None
|
||||||
|
confidence = None
|
||||||
|
|
||||||
keep_running = self._visualizer.update(
|
keep_running = self._visualizer.update(
|
||||||
frame_u8,
|
frame_u8,
|
||||||
|
|||||||
@@ -23,6 +23,9 @@ jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
|
|||||||
UInt8Array = NDArray[np.uint8]
|
UInt8Array = NDArray[np.uint8]
|
||||||
Float32Array = NDArray[np.float32]
|
Float32Array = NDArray[np.float32]
|
||||||
|
|
||||||
|
#: Bounding box in XYXY format: (x1, y1, x2, y2) where (x1,y1) is top-left and (x2,y2) is bottom-right.
|
||||||
|
BBoxXYXY = tuple[int, int, int, int]
|
||||||
|
|
||||||
|
|
||||||
def _read_attr(container: object, key: str) -> object | None:
|
def _read_attr(container: object, key: str) -> object | None:
|
||||||
if isinstance(container, dict):
|
if isinstance(container, dict):
|
||||||
@@ -59,7 +62,15 @@ def _to_numpy_array(value: object) -> NDArray[np.generic]:
|
|||||||
return cast(NDArray[np.generic], np.asarray(current))
|
return cast(NDArray[np.generic], np.asarray(current))
|
||||||
|
|
||||||
|
|
||||||
def _bbox_from_mask(mask: UInt8[ndarray, "h w"]) -> tuple[int, int, int, int] | None:
|
def _bbox_from_mask(mask: UInt8[ndarray, "h w"]) -> BBoxXYXY | None:
|
||||||
|
"""Extract bounding box from binary mask in XYXY format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mask: Binary mask array of shape (H, W) with dtype uint8.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Bounding box as (x1, y1, x2, y2) in XYXY format, or None if mask is empty.
|
||||||
|
"""
|
||||||
mask_u8 = np.asarray(mask, dtype=np.uint8)
|
mask_u8 = np.asarray(mask, dtype=np.uint8)
|
||||||
coords = np.argwhere(mask_u8 > 0)
|
coords = np.argwhere(mask_u8 > 0)
|
||||||
if int(coords.size) == 0:
|
if int(coords.size) == 0:
|
||||||
@@ -76,9 +87,17 @@ def _bbox_from_mask(mask: UInt8[ndarray, "h w"]) -> tuple[int, int, int, int] |
|
|||||||
return (x1, y1, x2, y2)
|
return (x1, y1, x2, y2)
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_bbox(
|
def _sanitize_bbox(bbox: BBoxXYXY, height: int, width: int) -> BBoxXYXY | None:
|
||||||
bbox: tuple[int, int, int, int], height: int, width: int
|
"""Sanitize bounding box to ensure it's within image bounds.
|
||||||
) -> tuple[int, int, int, int] | None:
|
|
||||||
|
Args:
|
||||||
|
bbox: Bounding box in XYXY format (x1, y1, x2, y2).
|
||||||
|
height: Image height.
|
||||||
|
width: Image width.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sanitized bounding box in XYXY format, or None if invalid.
|
||||||
|
"""
|
||||||
x1, y1, x2, y2 = bbox
|
x1, y1, x2, y2 = bbox
|
||||||
x1c = max(0, min(int(x1), width - 1))
|
x1c = max(0, min(int(x1), width - 1))
|
||||||
y1c = max(0, min(int(y1), height - 1))
|
y1c = max(0, min(int(y1), height - 1))
|
||||||
@@ -92,7 +111,17 @@ def _sanitize_bbox(
|
|||||||
@jaxtyped(typechecker=beartype)
|
@jaxtyped(typechecker=beartype)
|
||||||
def frame_to_person_mask(
|
def frame_to_person_mask(
|
||||||
result: object, min_area: int = MIN_MASK_AREA
|
result: object, min_area: int = MIN_MASK_AREA
|
||||||
) -> tuple[UInt8[ndarray, "h w"], tuple[int, int, int, int]] | None:
|
) -> tuple[UInt8[ndarray, "h w"], BBoxXYXY] | None:
|
||||||
|
"""Extract person mask and bounding box from detection result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result: Detection results object with boxes and masks attributes.
|
||||||
|
min_area: Minimum mask area to consider valid.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (mask, bbox) where bbox is in XYXY format (x1, y1, x2, y2),
|
||||||
|
or None if no valid detections.
|
||||||
|
"""
|
||||||
masks_obj = _read_attr(result, "masks")
|
masks_obj = _read_attr(result, "masks")
|
||||||
if masks_obj is None:
|
if masks_obj is None:
|
||||||
return None
|
return None
|
||||||
@@ -152,7 +181,7 @@ def frame_to_person_mask(
|
|||||||
|
|
||||||
best_area = -1
|
best_area = -1
|
||||||
best_mask: UInt8[ndarray, "h w"] | None = None
|
best_mask: UInt8[ndarray, "h w"] | None = None
|
||||||
best_bbox: tuple[int, int, int, int] | None = None
|
best_bbox: BBoxXYXY | None = None
|
||||||
|
|
||||||
for idx in range(mask_count):
|
for idx in range(mask_count):
|
||||||
mask_float = np.asarray(masks_float[idx], dtype=np.float32)
|
mask_float = np.asarray(masks_float[idx], dtype=np.float32)
|
||||||
@@ -167,7 +196,7 @@ def frame_to_person_mask(
|
|||||||
if area < min_area:
|
if area < min_area:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
bbox: tuple[int, int, int, int] | None = None
|
bbox: BBoxXYXY | None = None
|
||||||
shape_2d = cast(tuple[int, int], mask_binary.shape)
|
shape_2d = cast(tuple[int, int], mask_binary.shape)
|
||||||
h = int(shape_2d[0])
|
h = int(shape_2d[0])
|
||||||
w = int(shape_2d[1])
|
w = int(shape_2d[1])
|
||||||
@@ -204,8 +233,18 @@ def frame_to_person_mask(
|
|||||||
@jaxtyped(typechecker=beartype)
|
@jaxtyped(typechecker=beartype)
|
||||||
def mask_to_silhouette(
|
def mask_to_silhouette(
|
||||||
mask: UInt8[ndarray, "h w"],
|
mask: UInt8[ndarray, "h w"],
|
||||||
bbox: tuple[int, int, int, int],
|
bbox: BBoxXYXY,
|
||||||
) -> Float[ndarray, "64 44"] | None:
|
) -> Float[ndarray, "64 44"] | None:
|
||||||
|
"""Convert mask to standardized silhouette using bounding box.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mask: Binary mask array of shape (H, W) with dtype uint8.
|
||||||
|
bbox: Bounding box in XYXY format (x1, y1, x2, y2).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Standardized silhouette array of shape (64, 44) with dtype float32,
|
||||||
|
or None if conversion fails.
|
||||||
|
"""
|
||||||
mask_u8 = np.where(mask > 0, np.uint8(255), np.uint8(0)).astype(np.uint8)
|
mask_u8 = np.where(mask > 0, np.uint8(255), np.uint8(0)).astype(np.uint8)
|
||||||
if int(np.count_nonzero(mask_u8)) < MIN_MASK_AREA:
|
if int(np.count_nonzero(mask_u8)) < MIN_MASK_AREA:
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -13,8 +13,11 @@ import cv2
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
|
from .preprocess import BBoxXYXY
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# Window names
|
# Window names
|
||||||
MAIN_WINDOW = "Scoliosis Detection"
|
MAIN_WINDOW = "Scoliosis Detection"
|
||||||
SEG_WINDOW = "Segmentation"
|
SEG_WINDOW = "Segmentation"
|
||||||
@@ -66,13 +69,13 @@ class OpenCVVisualizer:
|
|||||||
def _draw_bbox(
|
def _draw_bbox(
|
||||||
self,
|
self,
|
||||||
frame: ImageArray,
|
frame: ImageArray,
|
||||||
bbox: tuple[int, int, int, int] | None,
|
bbox: BBoxXYXY | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Draw bounding box on frame if present.
|
"""Draw bounding box on frame if present.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
frame: Input frame (H, W, 3) uint8 - modified in place
|
frame: Input frame (H, W, 3) uint8 - modified in place
|
||||||
bbox: Bounding box as (x1, y1, x2, y2) or None
|
bbox: Bounding box in XYXY format as (x1, y1, x2, y2) or None
|
||||||
"""
|
"""
|
||||||
if bbox is None:
|
if bbox is None:
|
||||||
return
|
return
|
||||||
@@ -145,7 +148,7 @@ class OpenCVVisualizer:
|
|||||||
def _prepare_main_frame(
|
def _prepare_main_frame(
|
||||||
self,
|
self,
|
||||||
frame: ImageArray,
|
frame: ImageArray,
|
||||||
bbox: tuple[int, int, int, int] | None,
|
bbox: BBoxXYXY | None,
|
||||||
track_id: int,
|
track_id: int,
|
||||||
fps: float,
|
fps: float,
|
||||||
label: str | None,
|
label: str | None,
|
||||||
@@ -155,7 +158,7 @@ class OpenCVVisualizer:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
frame: Input frame (H, W, C) uint8
|
frame: Input frame (H, W, C) uint8
|
||||||
bbox: Bounding box or None
|
bbox: Bounding box in XYXY format (x1, y1, x2, y2) or None
|
||||||
track_id: Tracking ID
|
track_id: Tracking ID
|
||||||
fps: Current FPS
|
fps: Current FPS
|
||||||
label: Classification label or None
|
label: Classification label or None
|
||||||
@@ -324,6 +327,9 @@ class OpenCVVisualizer:
|
|||||||
mask_gray = cast(ImageArray, cv2.cvtColor(mask_raw, cv2.COLOR_BGR2GRAY))
|
mask_gray = cast(ImageArray, cv2.cvtColor(mask_raw, cv2.COLOR_BGR2GRAY))
|
||||||
else:
|
else:
|
||||||
mask_gray = mask_raw
|
mask_gray = mask_raw
|
||||||
|
# Normalize to uint8 [0,255] for display (handles both float [0,1] and uint8 inputs)
|
||||||
|
if mask_gray.dtype == np.float32 or mask_gray.dtype == np.float64:
|
||||||
|
mask_gray = (mask_gray * 255).astype(np.uint8)
|
||||||
raw_gray = cast(
|
raw_gray = cast(
|
||||||
ImageArray,
|
ImageArray,
|
||||||
cv2.resize(
|
cv2.resize(
|
||||||
@@ -333,6 +339,7 @@ class OpenCVVisualizer:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Normalized view preparation (without indicator)
|
# Normalized view preparation (without indicator)
|
||||||
if silhouette is None:
|
if silhouette is None:
|
||||||
norm_gray = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH), dtype=np.uint8)
|
norm_gray = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH), dtype=np.uint8)
|
||||||
@@ -402,7 +409,7 @@ class OpenCVVisualizer:
|
|||||||
def update(
|
def update(
|
||||||
self,
|
self,
|
||||||
frame: ImageArray,
|
frame: ImageArray,
|
||||||
bbox: tuple[int, int, int, int] | None,
|
bbox: BBoxXYXY | None,
|
||||||
track_id: int,
|
track_id: int,
|
||||||
mask_raw: ImageArray | None,
|
mask_raw: ImageArray | None,
|
||||||
silhouette: NDArray[np.float32] | None,
|
silhouette: NDArray[np.float32] | None,
|
||||||
@@ -414,7 +421,7 @@ class OpenCVVisualizer:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
frame: Input frame (H, W, C) uint8
|
frame: Input frame (H, W, C) uint8
|
||||||
bbox: Bounding box as (x1, y1, x2, y2) or None
|
bbox: Bounding box in XYXY format (x1, y1, x2, y2) or None
|
||||||
track_id: Tracking ID
|
track_id: Tracking ID
|
||||||
mask_raw: Raw binary mask (H, W) uint8 or None
|
mask_raw: Raw binary mask (H, W) uint8 or None
|
||||||
silhouette: Normalized silhouette (64, 44) float32 [0,1] or None
|
silhouette: Normalized silhouette (64, 44) float32 [0,1] or None
|
||||||
|
|||||||
+22
-10
@@ -12,10 +12,11 @@ import torch
|
|||||||
from jaxtyping import Float
|
from jaxtyping import Float
|
||||||
from numpy import ndarray
|
from numpy import ndarray
|
||||||
|
|
||||||
|
from .preprocess import BBoxXYXY
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
|
|
||||||
# Silhouette dimensions from preprocess.py
|
# Silhouette dimensions from preprocess.py
|
||||||
SIL_HEIGHT: int = 64
|
SIL_HEIGHT: int = 64
|
||||||
SIL_WIDTH: int = 44
|
SIL_WIDTH: int = 44
|
||||||
@@ -239,19 +240,23 @@ def _to_numpy(obj: _ArrayLike) -> ndarray:
|
|||||||
|
|
||||||
def select_person(
|
def select_person(
|
||||||
results: _DetectionResults,
|
results: _DetectionResults,
|
||||||
) -> tuple[ndarray, tuple[int, int, int, int], int] | None:
|
) -> tuple[ndarray, BBoxXYXY, BBoxXYXY, int] | None:
|
||||||
"""Select the person with largest bounding box from detection results.
|
"""Select the person with largest bounding box from detection results.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
results: Detection results object with boxes and masks attributes.
|
results: Detection results object with boxes and masks attributes.
|
||||||
Expected to have:
|
Expected to have:
|
||||||
- boxes.xyxy: array of bounding boxes [N, 4]
|
- boxes.xyxy: array of bounding boxes [N, 4] in frame coordinates (XYXY format)
|
||||||
- masks.data: array of masks [N, H, W]
|
- masks.data: array of masks [N, H, W] in mask coordinates
|
||||||
- boxes.id: optional track IDs [N]
|
- boxes.id: optional track IDs [N]
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (mask, bbox, track_id) for the largest person,
|
Tuple of (mask, bbox_mask, bbox_frame, track_id) for the largest person,
|
||||||
or None if no valid detections or track IDs unavailable.
|
or None if no valid detections or track IDs unavailable.
|
||||||
|
- mask: the person's segmentation mask
|
||||||
|
- bbox_mask: bounding box in mask coordinate space (XYXY format: x1, y1, x2, y2)
|
||||||
|
- bbox_frame: bounding box in frame coordinate space (XYXY format: x1, y1, x2, y2)
|
||||||
|
- track_id: the person's track ID
|
||||||
"""
|
"""
|
||||||
# Check for track IDs
|
# Check for track IDs
|
||||||
boxes_obj: _Boxes | object = getattr(results, "boxes", None)
|
boxes_obj: _Boxes | object = getattr(results, "boxes", None)
|
||||||
@@ -329,20 +334,27 @@ def select_person(
|
|||||||
# Scale bbox from frame space to mask space
|
# Scale bbox from frame space to mask space
|
||||||
scale_x = mask_w / frame_w if frame_w > 0 else 1.0
|
scale_x = mask_w / frame_w if frame_w > 0 else 1.0
|
||||||
scale_y = mask_h / frame_h if frame_h > 0 else 1.0
|
scale_y = mask_h / frame_h if frame_h > 0 else 1.0
|
||||||
bbox = (
|
bbox_mask = (
|
||||||
int(float(bboxes[best_idx][0]) * scale_x),
|
int(float(bboxes[best_idx][0]) * scale_x),
|
||||||
int(float(bboxes[best_idx][1]) * scale_y),
|
int(float(bboxes[best_idx][1]) * scale_y),
|
||||||
int(float(bboxes[best_idx][2]) * scale_x),
|
int(float(bboxes[best_idx][2]) * scale_x),
|
||||||
int(float(bboxes[best_idx][3]) * scale_y),
|
int(float(bboxes[best_idx][3]) * scale_y),
|
||||||
)
|
)
|
||||||
else:
|
bbox_frame = (
|
||||||
# Fallback: use bbox as-is (assume same coordinate space)
|
|
||||||
bbox = (
|
|
||||||
int(float(bboxes[best_idx][0])),
|
int(float(bboxes[best_idx][0])),
|
||||||
int(float(bboxes[best_idx][1])),
|
int(float(bboxes[best_idx][1])),
|
||||||
int(float(bboxes[best_idx][2])),
|
int(float(bboxes[best_idx][2])),
|
||||||
int(float(bboxes[best_idx][3])),
|
int(float(bboxes[best_idx][3])),
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
# Fallback: use bbox as-is for both (assume same coordinate space)
|
||||||
|
bbox_mask = (
|
||||||
|
int(float(bboxes[best_idx][0])),
|
||||||
|
int(float(bboxes[best_idx][1])),
|
||||||
|
int(float(bboxes[best_idx][2])),
|
||||||
|
int(float(bboxes[best_idx][3])),
|
||||||
|
)
|
||||||
|
bbox_frame = bbox_mask
|
||||||
track_id = int(track_ids[best_idx]) if best_idx < len(track_ids) else best_idx
|
track_id = int(track_ids[best_idx]) if best_idx < len(track_ids) else best_idx
|
||||||
|
|
||||||
return mask, bbox, track_id
|
return mask, bbox_mask, bbox_frame, track_id
|
||||||
|
|||||||
+223
-4
@@ -8,7 +8,10 @@ import subprocess
|
|||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from typing import Final, cast
|
from typing import Final, cast
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from numpy.typing import NDArray
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -107,6 +110,7 @@ def _assert_prediction_schema(prediction: dict[str, object]) -> None:
|
|||||||
|
|
||||||
assert isinstance(prediction["timestamp_ns"], int)
|
assert isinstance(prediction["timestamp_ns"], int)
|
||||||
|
|
||||||
|
|
||||||
def test_pipeline_cli_fps_benchmark_smoke(
|
def test_pipeline_cli_fps_benchmark_smoke(
|
||||||
compatible_checkpoint_path: Path,
|
compatible_checkpoint_path: Path,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -280,7 +284,6 @@ def test_pipeline_cli_invalid_checkpoint_path_returns_user_error() -> None:
|
|||||||
assert "Error: Checkpoint not found" in result.stderr
|
assert "Error: Checkpoint not found" in result.stderr
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_pipeline_cli_preprocess_only_requires_export_path(
|
def test_pipeline_cli_preprocess_only_requires_export_path(
|
||||||
compatible_checkpoint_path: Path,
|
compatible_checkpoint_path: Path,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -511,7 +514,9 @@ def test_pipeline_cli_silhouette_and_result_export(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Verify both export files exist
|
# Verify both export files exist
|
||||||
assert silhouette_export.is_file(), f"Silhouette export not found: {silhouette_export}"
|
assert silhouette_export.is_file(), (
|
||||||
|
f"Silhouette export not found: {silhouette_export}"
|
||||||
|
)
|
||||||
assert result_export.is_file(), f"Result export not found: {result_export}"
|
assert result_export.is_file(), f"Result export not found: {result_export}"
|
||||||
|
|
||||||
# Verify silhouette export
|
# Verify silhouette export
|
||||||
@@ -522,7 +527,9 @@ def test_pipeline_cli_silhouette_and_result_export(
|
|||||||
|
|
||||||
# Verify result export
|
# Verify result export
|
||||||
with open(result_export, "r", encoding="utf-8") as f:
|
with open(result_export, "r", encoding="utf-8") as f:
|
||||||
predictions = [cast(dict[str, object], json.loads(line)) for line in f if line.strip()]
|
predictions = [
|
||||||
|
cast(dict[str, object], json.loads(line)) for line in f if line.strip()
|
||||||
|
]
|
||||||
assert len(predictions) > 0
|
assert len(predictions) > 0
|
||||||
|
|
||||||
|
|
||||||
@@ -538,6 +545,7 @@ def test_pipeline_cli_parquet_export_requires_pyarrow(
|
|||||||
pytest.skip("pyarrow is installed, skipping missing dependency test")
|
pytest.skip("pyarrow is installed, skipping missing dependency test")
|
||||||
try:
|
try:
|
||||||
import pyarrow # noqa: F401
|
import pyarrow # noqa: F401
|
||||||
|
|
||||||
pytest.skip("pyarrow is installed, skipping missing dependency test")
|
pytest.skip("pyarrow is installed, skipping missing dependency test")
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
@@ -573,7 +581,6 @@ def test_pipeline_cli_parquet_export_requires_pyarrow(
|
|||||||
assert "parquet" in result.stderr.lower() or "pyarrow" in result.stderr.lower()
|
assert "parquet" in result.stderr.lower() or "pyarrow" in result.stderr.lower()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_pipeline_cli_silhouette_visualization(
|
def test_pipeline_cli_silhouette_visualization(
|
||||||
compatible_checkpoint_path: Path,
|
compatible_checkpoint_path: Path,
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
@@ -673,3 +680,215 @@ def test_pipeline_cli_preprocess_only_with_visualization(
|
|||||||
assert len(silhouettes) == len(png_files), (
|
assert len(silhouettes) == len(png_files), (
|
||||||
f"Mismatch: {len(silhouettes)} silhouettes exported but {len(png_files)} PNG files created"
|
f"Mismatch: {len(silhouettes)} silhouettes exported but {len(png_files)} PNG files created"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MockVisualizer:
|
||||||
|
"""Mock visualizer to track update calls."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.update_calls: list[dict[str, object]] = []
|
||||||
|
self.return_value: bool = True
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
frame: NDArray[np.uint8],
|
||||||
|
bbox: tuple[int, int, int, int] | None,
|
||||||
|
track_id: int,
|
||||||
|
mask_raw: NDArray[np.uint8] | None,
|
||||||
|
silhouette: NDArray[np.float32] | None,
|
||||||
|
label: str | None,
|
||||||
|
confidence: float | None,
|
||||||
|
fps: float,
|
||||||
|
) -> bool:
|
||||||
|
self.update_calls.append(
|
||||||
|
{
|
||||||
|
"frame": frame,
|
||||||
|
"bbox": bbox,
|
||||||
|
"track_id": track_id,
|
||||||
|
"mask_raw": mask_raw,
|
||||||
|
"silhouette": silhouette,
|
||||||
|
"label": label,
|
||||||
|
"confidence": confidence,
|
||||||
|
"fps": fps,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return self.return_value
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipeline_visualizer_updates_on_no_detection() -> None:
|
||||||
|
"""Test that visualizer is still updated even when process_frame returns None.
|
||||||
|
|
||||||
|
This is a regression test for the window freeze issue when no person is detected.
|
||||||
|
The window should refresh every frame to prevent freezing.
|
||||||
|
"""
|
||||||
|
from opengait.demo.pipeline import ScoliosisPipeline
|
||||||
|
|
||||||
|
# Create a minimal pipeline with mocked dependencies
|
||||||
|
with (
|
||||||
|
mock.patch("opengait.demo.pipeline.YOLO") as mock_yolo,
|
||||||
|
mock.patch("opengait.demo.pipeline.create_source") as mock_source,
|
||||||
|
mock.patch("opengait.demo.pipeline.create_publisher") as mock_publisher,
|
||||||
|
mock.patch("opengait.demo.pipeline.ScoNetDemo") as mock_classifier,
|
||||||
|
):
|
||||||
|
# Setup mock detector that returns no detections (causing process_frame to return None)
|
||||||
|
mock_detector = mock.MagicMock()
|
||||||
|
mock_detector.track.return_value = [] # No detections
|
||||||
|
mock_yolo.return_value = mock_detector
|
||||||
|
|
||||||
|
# Setup mock source with 3 frames
|
||||||
|
mock_frame = np.zeros((480, 640, 3), dtype=np.uint8)
|
||||||
|
mock_source.return_value = [(mock_frame, {"frame_count": i}) for i in range(3)]
|
||||||
|
|
||||||
|
# Setup mock publisher and classifier
|
||||||
|
mock_publisher.return_value = mock.MagicMock()
|
||||||
|
mock_classifier.return_value = mock.MagicMock()
|
||||||
|
|
||||||
|
# Create pipeline with visualize enabled
|
||||||
|
pipeline = ScoliosisPipeline(
|
||||||
|
source="dummy.mp4",
|
||||||
|
checkpoint="dummy.pt",
|
||||||
|
config=str(CONFIG_PATH) if CONFIG_PATH.exists() else "dummy.yaml",
|
||||||
|
device="cpu",
|
||||||
|
yolo_model="dummy.pt",
|
||||||
|
window=30,
|
||||||
|
stride=30,
|
||||||
|
nats_url=None,
|
||||||
|
nats_subject="test",
|
||||||
|
max_frames=None,
|
||||||
|
visualize=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Replace the visualizer with our mock
|
||||||
|
mock_viz = MockVisualizer()
|
||||||
|
pipeline._visualizer = mock_viz # type: ignore[assignment]
|
||||||
|
|
||||||
|
# Run pipeline
|
||||||
|
_ = pipeline.run()
|
||||||
|
|
||||||
|
# Verify visualizer was updated for all 3 frames even with no detections
|
||||||
|
assert len(mock_viz.update_calls) == 3, (
|
||||||
|
f"Expected visualizer.update() to be called 3 times (once per frame), "
|
||||||
|
f"but was called {len(mock_viz.update_calls)} times. "
|
||||||
|
f"Window would freeze if not updated on no-detection frames."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify each call had the frame data
|
||||||
|
for call in mock_viz.update_calls:
|
||||||
|
assert call["track_id"] == 0 # Default track_id when no detection
|
||||||
|
assert call["bbox"] is None # No bbox when no detection
|
||||||
|
assert call["mask_raw"] is None # No mask when no detection
|
||||||
|
assert call["silhouette"] is None # No silhouette when no detection
|
||||||
|
assert call["label"] is None # No label when no detection
|
||||||
|
assert call["confidence"] is None # No confidence when no detection
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipeline_visualizer_uses_cached_detection_on_no_detection() -> None:
|
||||||
|
"""Test that visualizer reuses last valid detection when current frame has no detection.
|
||||||
|
|
||||||
|
This is a regression test for the cache-reuse behavior when person temporarily
|
||||||
|
disappears from frame. The last valid silhouette/mask should be displayed.
|
||||||
|
"""
|
||||||
|
from opengait.demo.pipeline import ScoliosisPipeline
|
||||||
|
|
||||||
|
# Create a minimal pipeline with mocked dependencies
|
||||||
|
with (
|
||||||
|
mock.patch("opengait.demo.pipeline.YOLO") as mock_yolo,
|
||||||
|
mock.patch("opengait.demo.pipeline.create_source") as mock_source,
|
||||||
|
mock.patch("opengait.demo.pipeline.create_publisher") as mock_publisher,
|
||||||
|
mock.patch("opengait.demo.pipeline.ScoNetDemo") as mock_classifier,
|
||||||
|
mock.patch("opengait.demo.pipeline.select_person") as mock_select_person,
|
||||||
|
mock.patch("opengait.demo.pipeline.mask_to_silhouette") as mock_mask_to_sil,
|
||||||
|
):
|
||||||
|
# Create mock detection result for frames 0-1 (valid detection)
|
||||||
|
mock_box = mock.MagicMock()
|
||||||
|
mock_box.xyxy = np.array([[100, 100, 200, 300]], dtype=np.float32)
|
||||||
|
mock_box.id = np.array([1], dtype=np.int64)
|
||||||
|
mock_mask = mock.MagicMock()
|
||||||
|
mock_mask.data = np.random.rand(1, 480, 640).astype(np.float32)
|
||||||
|
mock_result = mock.MagicMock()
|
||||||
|
mock_result.boxes = mock_box
|
||||||
|
mock_result.masks = mock_mask
|
||||||
|
|
||||||
|
# Setup mock detector: detection for frames 0-1, then no detection for frames 2-3
|
||||||
|
mock_detector = mock.MagicMock()
|
||||||
|
mock_detector.track.side_effect = [
|
||||||
|
[mock_result], # Frame 0: valid detection
|
||||||
|
[mock_result], # Frame 1: valid detection
|
||||||
|
[], # Frame 2: no detection
|
||||||
|
[], # Frame 3: no detection
|
||||||
|
]
|
||||||
|
mock_yolo.return_value = mock_detector
|
||||||
|
|
||||||
|
# Setup mock source with 4 frames
|
||||||
|
mock_frame = np.zeros((480, 640, 3), dtype=np.uint8)
|
||||||
|
mock_source.return_value = [(mock_frame, {"frame_count": i}) for i in range(4)]
|
||||||
|
|
||||||
|
# Setup mock publisher and classifier
|
||||||
|
mock_publisher.return_value = mock.MagicMock()
|
||||||
|
mock_classifier.return_value = mock.MagicMock()
|
||||||
|
|
||||||
|
# Setup mock select_person to return valid mask and bbox
|
||||||
|
dummy_mask = np.random.randint(0, 256, (480, 640), dtype=np.uint8)
|
||||||
|
dummy_bbox_mask = (100, 100, 200, 300)
|
||||||
|
dummy_bbox_frame = (100, 100, 200, 300)
|
||||||
|
mock_select_person.return_value = (dummy_mask, dummy_bbox_mask, dummy_bbox_frame, 1)
|
||||||
|
|
||||||
|
# Setup mock mask_to_silhouette to return valid silhouette
|
||||||
|
dummy_silhouette = np.random.rand(64, 44).astype(np.float32)
|
||||||
|
mock_mask_to_sil.return_value = dummy_silhouette
|
||||||
|
|
||||||
|
# Create pipeline with visualize enabled
|
||||||
|
pipeline = ScoliosisPipeline(
|
||||||
|
source="dummy.mp4",
|
||||||
|
checkpoint="dummy.pt",
|
||||||
|
config=str(CONFIG_PATH) if CONFIG_PATH.exists() else "dummy.yaml",
|
||||||
|
device="cpu",
|
||||||
|
yolo_model="dummy.pt",
|
||||||
|
window=30,
|
||||||
|
stride=30,
|
||||||
|
nats_url=None,
|
||||||
|
nats_subject="test",
|
||||||
|
max_frames=None,
|
||||||
|
visualize=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Replace the visualizer with our mock
|
||||||
|
mock_viz = MockVisualizer()
|
||||||
|
pipeline._visualizer = mock_viz # type: ignore[assignment]
|
||||||
|
|
||||||
|
# Run pipeline
|
||||||
|
_ = pipeline.run()
|
||||||
|
|
||||||
|
# Verify visualizer was updated for all 4 frames
|
||||||
|
assert len(mock_viz.update_calls) == 4, (
|
||||||
|
f"Expected visualizer.update() to be called 4 times, "
|
||||||
|
f"but was called {len(mock_viz.update_calls)} times."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract the mask_raw values from each call
|
||||||
|
mask_raw_calls = [call["mask_raw"] for call in mock_viz.update_calls]
|
||||||
|
|
||||||
|
# Frames 0 and 1 should have valid masks (not None)
|
||||||
|
assert mask_raw_calls[0] is not None, "Frame 0 should have valid mask"
|
||||||
|
assert mask_raw_calls[1] is not None, "Frame 1 should have valid mask"
|
||||||
|
|
||||||
|
# Frames 2 and 3 should reuse the cached mask from frame 1 (not None)
|
||||||
|
assert mask_raw_calls[2] is not None, (
|
||||||
|
"Frame 2 (no detection) should display cached mask from last valid detection, "
|
||||||
|
"not None/blank"
|
||||||
|
)
|
||||||
|
assert mask_raw_calls[3] is not None, (
|
||||||
|
"Frame 3 (no detection) should display cached mask from last valid detection, "
|
||||||
|
"not None/blank"
|
||||||
|
)
|
||||||
|
|
||||||
|
# The cached masks should be copies (different objects) to prevent mutation issues
|
||||||
|
if mask_raw_calls[1] is not None and mask_raw_calls[2] is not None:
|
||||||
|
assert mask_raw_calls[1] is not mask_raw_calls[2], (
|
||||||
|
"Cached mask should be a copy, not the same object reference"
|
||||||
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user