fix(demo): stabilize visualizer bbox and mask rendering

Align bbox coordinate handling across primary and fallback paths, normalize Both-mode raw mask rendering, and tighten demo result typing to reduce runtime/display inconsistencies.
This commit is contained in:
2026-02-28 18:05:33 +08:00
parent 06a6cd1ccf
commit 7f073179d7
7 changed files with 416 additions and 73 deletions
+10 -9
View File
@@ -7,7 +7,8 @@ Provides generator-based interfaces for video sources:
"""
from collections.abc import AsyncIterator, Generator, Iterable
from typing import TYPE_CHECKING, Protocol, cast
from typing import Protocol, cast
import logging
import numpy as np
@@ -17,15 +18,15 @@ logger = logging.getLogger(__name__)
# Type alias for frame stream: (frame_array, metadata_dict)
FrameStream = Iterable[tuple[np.ndarray, dict[str, object]]]
if TYPE_CHECKING:
# Protocol for cv-mmap metadata to avoid direct import
class _FrameMetadata(Protocol):
frame_count: int
timestamp_ns: int
# Protocol for cv-mmap metadata (needed at runtime for nested function annotation)
class _FrameMetadata(Protocol):
frame_count: int
timestamp_ns: int
# Protocol for cv-mmap client
class _CvMmapClient(Protocol):
def __aiter__(self) -> AsyncIterator[tuple[np.ndarray, _FrameMetadata]]: ...
# Protocol for cv-mmap client (needed at runtime for cast)
class _CvMmapClient(Protocol):
def __aiter__(self) -> AsyncIterator[tuple[np.ndarray, _FrameMetadata]]: ...
def opencv_source(
+24 -9
View File
@@ -14,7 +14,7 @@ import logging
import sys
import threading
import time
from typing import TYPE_CHECKING, Protocol, TextIO, cast, runtime_checkable
from typing import TYPE_CHECKING, Protocol, TextIO, TypedDict, cast, runtime_checkable
if TYPE_CHECKING:
from types import TracebackType
@@ -22,17 +22,31 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class DemoResult(TypedDict):
"""Typed result dictionary for demo pipeline output.
Contains classification result with frame metadata.
"""
frame: int
track_id: int
label: str
confidence: float
window: int
timestamp_ns: int
@runtime_checkable
class ResultPublisher(Protocol):
"""Protocol for result publishers."""
def publish(self, result: dict[str, object]) -> None:
def publish(self, result: DemoResult) -> None:
"""
Publish a result dictionary.
Parameters
----------
result : dict[str, object]
result : DemoResult
Result data with keys: frame, track_id, label, confidence, window, timestamp_ns
"""
...
@@ -54,13 +68,13 @@ class ConsolePublisher:
"""
self._output = output
def publish(self, result: dict[str, object]) -> None:
def publish(self, result: DemoResult) -> None:
"""
Publish result as JSON line.
Parameters
----------
result : dict[str, object]
result : DemoResult
Result data with keys: frame, track_id, label, confidence, window, timestamp_ns
"""
try:
@@ -214,13 +228,13 @@ class NatsPublisher:
logger.warning(f"Failed to connect to NATS at {self._nats_url}: {e}")
return False
def publish(self, result: dict[str, object]) -> None:
def publish(self, result: DemoResult) -> None:
"""
Publish result to NATS subject.
Parameters
----------
result : dict[str, object]
result : DemoResult
Result data with keys: frame, track_id, label, confidence, window, timestamp_ns
"""
if not self._ensure_connected():
@@ -239,6 +253,7 @@ class NatsPublisher:
).encode("utf-8")
_ = await self._nc.publish(self._subject, payload)
_ = await self._nc.flush()
# Run publish in background loop
future = asyncio.run_coroutine_threadsafe(
_publish(),
@@ -331,7 +346,7 @@ def create_result(
confidence: float,
window: int | tuple[int, int],
timestamp_ns: int | None = None,
) -> dict[str, object]:
) -> DemoResult:
"""
Create a standardized result dictionary.
@@ -353,7 +368,7 @@ def create_result(
Returns
-------
dict[str, object]
DemoResult
Standardized result dictionary
"""
return {
+77 -27
View File
@@ -17,8 +17,8 @@ from numpy.typing import NDArray
from ultralytics.models.yolo.model import YOLO
from .input import FrameStream, create_source
from .output import ResultPublisher, create_publisher, create_result
from .preprocess import frame_to_person_mask, mask_to_silhouette
from .output import DemoResult, ResultPublisher, create_publisher, create_result
from .preprocess import BBoxXYXY, frame_to_person_mask, mask_to_silhouette
from .sconet_demo import ScoNetDemo
from .window import SilhouetteWindow, select_person
@@ -53,6 +53,7 @@ class _DetectionResultsLike(Protocol):
def masks(self) -> _MasksLike: ...
class _TrackCallable(Protocol):
def __call__(
self,
@@ -80,8 +81,9 @@ class ScoliosisPipeline:
_silhouette_visualize_dir: Path | None
_result_export_path: Path | None
_result_export_format: str
_result_buffer: list[dict[str, object]]
_result_buffer: list[DemoResult]
_visualizer: OpenCVVisualizer | None
_last_viz_payload: dict[str, object] | None
def __init__(
self,
@@ -135,6 +137,7 @@ class ScoliosisPipeline:
self._visualizer = OpenCVVisualizer()
else:
self._visualizer = None
self._last_viz_payload = None
@staticmethod
def _extract_int(meta: dict[str, object], key: str, fallback: int) -> int:
@@ -171,37 +174,59 @@ class ScoliosisPipeline:
tuple[
Float[ndarray, "64 44"],
UInt8[ndarray, "h w"],
tuple[int, int, int, int],
BBoxXYXY,
int,
]
| None
):
selected = select_person(result)
if selected is not None:
mask_raw, bbox, track_id = selected
mask_raw, bbox_mask, bbox_frame, track_id = selected
silhouette = cast(
Float[ndarray, "64 44"] | None,
mask_to_silhouette(self._to_mask_u8(mask_raw), bbox),
mask_to_silhouette(self._to_mask_u8(mask_raw), bbox_mask),
)
if silhouette is not None:
return silhouette, mask_raw, bbox, int(track_id)
return silhouette, mask_raw, bbox_frame, int(track_id)
fallback = cast(
tuple[UInt8[ndarray, "h w"], tuple[int, int, int, int]] | None,
tuple[UInt8[ndarray, "h w"], BBoxXYXY] | None,
frame_to_person_mask(result),
)
if fallback is None:
return None
mask_u8, bbox = fallback
mask_u8, bbox_mask = fallback
silhouette = cast(
Float[ndarray, "64 44"] | None,
mask_to_silhouette(mask_u8, bbox),
mask_to_silhouette(mask_u8, bbox_mask),
)
if silhouette is None:
return None
# Convert mask-space bbox to frame-space for visualization
# Use result.orig_shape to get frame dimensions safely
orig_shape = getattr(result, "orig_shape", None)
if orig_shape is not None and isinstance(orig_shape, (tuple, list)) and len(orig_shape) >= 2:
frame_h, frame_w = int(orig_shape[0]), int(orig_shape[1])
mask_h, mask_w = mask_u8.shape[0], mask_u8.shape[1]
if mask_w > 0 and mask_h > 0 and frame_w > 0 and frame_h > 0:
scale_x = frame_w / mask_w
scale_y = frame_h / mask_h
bbox_frame = (
int(bbox_mask[0] * scale_x),
int(bbox_mask[1] * scale_y),
int(bbox_mask[2] * scale_x),
int(bbox_mask[3] * scale_y),
)
else:
# Fallback: use mask-space bbox if dimensions invalid
bbox_frame = bbox_mask
else:
# Fallback: use mask-space bbox if orig_shape unavailable
bbox_frame = bbox_mask
# For fallback case, mask_raw is the same as mask_u8
return silhouette, mask_u8, bbox, 0
return silhouette, mask_u8, bbox_frame, 0
@jaxtyped(typechecker=beartype)
def process_frame(
@@ -342,23 +367,48 @@ class ScoliosisPipeline:
)
# Update visualizer if enabled
if self._visualizer is not None and viz_payload is not None:
# Cast viz_payload to dict for type checking
viz_dict = cast(dict[str, object], viz_payload)
mask_raw_obj = viz_dict.get("mask_raw")
bbox_obj = viz_dict.get("bbox")
silhouette_obj = viz_dict.get("silhouette")
track_id_val = viz_dict.get("track_id", 0)
track_id = track_id_val if isinstance(track_id_val, int) else 0
label_obj = viz_dict.get("label")
confidence_obj = viz_dict.get("confidence")
if self._visualizer is not None:
# Cache valid payload for no-detection frames
if viz_payload is not None:
# Cache a copy to prevent mutation of original data
viz_payload_dict = cast(dict[str, object], viz_payload)
cached: dict[str, object] = {}
for k, v in viz_payload_dict.items():
copy_method = cast(Callable[[], object] | None, getattr(v, "copy", None))
if copy_method is not None:
cached[k] = copy_method()
else:
cached[k] = v
self._last_viz_payload = cached
# Use cached payload if current is None
viz_data = viz_payload if viz_payload is not None else self._last_viz_payload
if viz_data is not None:
# Cast viz_payload to dict for type checking
viz_dict = cast(dict[str, object], viz_data)
mask_raw_obj = viz_dict.get("mask_raw")
bbox_obj = viz_dict.get("bbox")
silhouette_obj = viz_dict.get("silhouette")
track_id_val = viz_dict.get("track_id", 0)
track_id = track_id_val if isinstance(track_id_val, int) else 0
label_obj = viz_dict.get("label")
confidence_obj = viz_dict.get("confidence")
# Cast extracted values to expected types
mask_raw = cast(NDArray[np.uint8] | None, mask_raw_obj)
bbox = cast(tuple[int, int, int, int] | None, bbox_obj)
silhouette = cast(NDArray[np.float32] | None, silhouette_obj)
label = cast(str | None, label_obj)
confidence = cast(float | None, confidence_obj)
# Cast extracted values to expected types
mask_raw = cast(NDArray[np.uint8] | None, mask_raw_obj)
bbox = cast(BBoxXYXY | None, bbox_obj)
silhouette = cast(NDArray[np.float32] | None, silhouette_obj)
label = cast(str | None, label_obj)
confidence = cast(float | None, confidence_obj)
else:
# No detection and no cache - use default values
mask_raw = None
bbox = None
track_id = 0
silhouette = None
label = None
confidence = None
keep_running = self._visualizer.update(
frame_u8,
+47 -8
View File
@@ -23,6 +23,9 @@ jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
UInt8Array = NDArray[np.uint8]
Float32Array = NDArray[np.float32]
#: Bounding box in XYXY format: (x1, y1, x2, y2) where (x1,y1) is top-left and (x2,y2) is bottom-right.
BBoxXYXY = tuple[int, int, int, int]
def _read_attr(container: object, key: str) -> object | None:
if isinstance(container, dict):
@@ -59,7 +62,15 @@ def _to_numpy_array(value: object) -> NDArray[np.generic]:
return cast(NDArray[np.generic], np.asarray(current))
def _bbox_from_mask(mask: UInt8[ndarray, "h w"]) -> tuple[int, int, int, int] | None:
def _bbox_from_mask(mask: UInt8[ndarray, "h w"]) -> BBoxXYXY | None:
"""Extract bounding box from binary mask in XYXY format.
Args:
mask: Binary mask array of shape (H, W) with dtype uint8.
Returns:
Bounding box as (x1, y1, x2, y2) in XYXY format, or None if mask is empty.
"""
mask_u8 = np.asarray(mask, dtype=np.uint8)
coords = np.argwhere(mask_u8 > 0)
if int(coords.size) == 0:
@@ -76,9 +87,17 @@ def _bbox_from_mask(mask: UInt8[ndarray, "h w"]) -> tuple[int, int, int, int] |
return (x1, y1, x2, y2)
def _sanitize_bbox(
bbox: tuple[int, int, int, int], height: int, width: int
) -> tuple[int, int, int, int] | None:
def _sanitize_bbox(bbox: BBoxXYXY, height: int, width: int) -> BBoxXYXY | None:
"""Sanitize bounding box to ensure it's within image bounds.
Args:
bbox: Bounding box in XYXY format (x1, y1, x2, y2).
height: Image height.
width: Image width.
Returns:
Sanitized bounding box in XYXY format, or None if invalid.
"""
x1, y1, x2, y2 = bbox
x1c = max(0, min(int(x1), width - 1))
y1c = max(0, min(int(y1), height - 1))
@@ -92,7 +111,17 @@ def _sanitize_bbox(
@jaxtyped(typechecker=beartype)
def frame_to_person_mask(
result: object, min_area: int = MIN_MASK_AREA
) -> tuple[UInt8[ndarray, "h w"], tuple[int, int, int, int]] | None:
) -> tuple[UInt8[ndarray, "h w"], BBoxXYXY] | None:
"""Extract person mask and bounding box from detection result.
Args:
result: Detection results object with boxes and masks attributes.
min_area: Minimum mask area to consider valid.
Returns:
Tuple of (mask, bbox) where bbox is in XYXY format (x1, y1, x2, y2),
or None if no valid detections.
"""
masks_obj = _read_attr(result, "masks")
if masks_obj is None:
return None
@@ -152,7 +181,7 @@ def frame_to_person_mask(
best_area = -1
best_mask: UInt8[ndarray, "h w"] | None = None
best_bbox: tuple[int, int, int, int] | None = None
best_bbox: BBoxXYXY | None = None
for idx in range(mask_count):
mask_float = np.asarray(masks_float[idx], dtype=np.float32)
@@ -167,7 +196,7 @@ def frame_to_person_mask(
if area < min_area:
continue
bbox: tuple[int, int, int, int] | None = None
bbox: BBoxXYXY | None = None
shape_2d = cast(tuple[int, int], mask_binary.shape)
h = int(shape_2d[0])
w = int(shape_2d[1])
@@ -204,8 +233,18 @@ def frame_to_person_mask(
@jaxtyped(typechecker=beartype)
def mask_to_silhouette(
mask: UInt8[ndarray, "h w"],
bbox: tuple[int, int, int, int],
bbox: BBoxXYXY,
) -> Float[ndarray, "64 44"] | None:
"""Convert mask to standardized silhouette using bounding box.
Args:
mask: Binary mask array of shape (H, W) with dtype uint8.
bbox: Bounding box in XYXY format (x1, y1, x2, y2).
Returns:
Standardized silhouette array of shape (64, 44) with dtype float32,
or None if conversion fails.
"""
mask_u8 = np.where(mask > 0, np.uint8(255), np.uint8(0)).astype(np.uint8)
if int(np.count_nonzero(mask_u8)) < MIN_MASK_AREA:
return None
+13 -6
View File
@@ -13,8 +13,11 @@ import cv2
import numpy as np
from numpy.typing import NDArray
from .preprocess import BBoxXYXY
logger = logging.getLogger(__name__)
# Window names
MAIN_WINDOW = "Scoliosis Detection"
SEG_WINDOW = "Segmentation"
@@ -66,13 +69,13 @@ class OpenCVVisualizer:
def _draw_bbox(
self,
frame: ImageArray,
bbox: tuple[int, int, int, int] | None,
bbox: BBoxXYXY | None,
) -> None:
"""Draw bounding box on frame if present.
Args:
frame: Input frame (H, W, 3) uint8 - modified in place
bbox: Bounding box as (x1, y1, x2, y2) or None
bbox: Bounding box in XYXY format as (x1, y1, x2, y2) or None
"""
if bbox is None:
return
@@ -145,7 +148,7 @@ class OpenCVVisualizer:
def _prepare_main_frame(
self,
frame: ImageArray,
bbox: tuple[int, int, int, int] | None,
bbox: BBoxXYXY | None,
track_id: int,
fps: float,
label: str | None,
@@ -155,7 +158,7 @@ class OpenCVVisualizer:
Args:
frame: Input frame (H, W, C) uint8
bbox: Bounding box or None
bbox: Bounding box in XYXY format (x1, y1, x2, y2) or None
track_id: Tracking ID
fps: Current FPS
label: Classification label or None
@@ -324,6 +327,9 @@ class OpenCVVisualizer:
mask_gray = cast(ImageArray, cv2.cvtColor(mask_raw, cv2.COLOR_BGR2GRAY))
else:
mask_gray = mask_raw
# Normalize to uint8 [0,255] for display (handles both float [0,1] and uint8 inputs)
if mask_gray.dtype == np.float32 or mask_gray.dtype == np.float64:
mask_gray = (mask_gray * 255).astype(np.uint8)
raw_gray = cast(
ImageArray,
cv2.resize(
@@ -333,6 +339,7 @@ class OpenCVVisualizer:
),
)
# Normalized view preparation (without indicator)
if silhouette is None:
norm_gray = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH), dtype=np.uint8)
@@ -402,7 +409,7 @@ class OpenCVVisualizer:
def update(
self,
frame: ImageArray,
bbox: tuple[int, int, int, int] | None,
bbox: BBoxXYXY | None,
track_id: int,
mask_raw: ImageArray | None,
silhouette: NDArray[np.float32] | None,
@@ -414,7 +421,7 @@ class OpenCVVisualizer:
Args:
frame: Input frame (H, W, C) uint8
bbox: Bounding box as (x1, y1, x2, y2) or None
bbox: Bounding box in XYXY format (x1, y1, x2, y2) or None
track_id: Tracking ID
mask_raw: Raw binary mask (H, W) uint8 or None
silhouette: Normalized silhouette (64, 44) float32 [0,1] or None
+22 -10
View File
@@ -12,10 +12,11 @@ import torch
from jaxtyping import Float
from numpy import ndarray
from .preprocess import BBoxXYXY
if TYPE_CHECKING:
from numpy.typing import NDArray
# Silhouette dimensions from preprocess.py
SIL_HEIGHT: int = 64
SIL_WIDTH: int = 44
@@ -239,19 +240,23 @@ def _to_numpy(obj: _ArrayLike) -> ndarray:
def select_person(
results: _DetectionResults,
) -> tuple[ndarray, tuple[int, int, int, int], int] | None:
) -> tuple[ndarray, BBoxXYXY, BBoxXYXY, int] | None:
"""Select the person with largest bounding box from detection results.
Args:
results: Detection results object with boxes and masks attributes.
Expected to have:
- boxes.xyxy: array of bounding boxes [N, 4]
- masks.data: array of masks [N, H, W]
- boxes.xyxy: array of bounding boxes [N, 4] in frame coordinates (XYXY format)
- masks.data: array of masks [N, H, W] in mask coordinates
- boxes.id: optional track IDs [N]
Returns:
Tuple of (mask, bbox, track_id) for the largest person,
Tuple of (mask, bbox_mask, bbox_frame, track_id) for the largest person,
or None if no valid detections or track IDs unavailable.
- mask: the person's segmentation mask
- bbox_mask: bounding box in mask coordinate space (XYXY format: x1, y1, x2, y2)
- bbox_frame: bounding box in frame coordinate space (XYXY format: x1, y1, x2, y2)
- track_id: the person's track ID
"""
# Check for track IDs
boxes_obj: _Boxes | object = getattr(results, "boxes", None)
@@ -329,20 +334,27 @@ def select_person(
# Scale bbox from frame space to mask space
scale_x = mask_w / frame_w if frame_w > 0 else 1.0
scale_y = mask_h / frame_h if frame_h > 0 else 1.0
bbox = (
bbox_mask = (
int(float(bboxes[best_idx][0]) * scale_x),
int(float(bboxes[best_idx][1]) * scale_y),
int(float(bboxes[best_idx][2]) * scale_x),
int(float(bboxes[best_idx][3]) * scale_y),
)
else:
# Fallback: use bbox as-is (assume same coordinate space)
bbox = (
bbox_frame = (
int(float(bboxes[best_idx][0])),
int(float(bboxes[best_idx][1])),
int(float(bboxes[best_idx][2])),
int(float(bboxes[best_idx][3])),
)
else:
# Fallback: use bbox as-is for both (assume same coordinate space)
bbox_mask = (
int(float(bboxes[best_idx][0])),
int(float(bboxes[best_idx][1])),
int(float(bboxes[best_idx][2])),
int(float(bboxes[best_idx][3])),
)
bbox_frame = bbox_mask
track_id = int(track_ids[best_idx]) if best_idx < len(track_ids) else best_idx
return mask, bbox, track_id
return mask, bbox_mask, bbox_frame, track_id