chore: update demo runtime, tests, and agent docs

This commit is contained in:
2026-03-02 12:33:17 +08:00
parent 1f8f959ad7
commit cbb3284c13
14 changed files with 1491 additions and 236 deletions
+53 -22
View File
@@ -3,8 +3,16 @@ from __future__ import annotations
import argparse
import logging
import sys
from typing import cast
from .pipeline import ScoliosisPipeline
from .pipeline import ScoliosisPipeline, WindowMode, resolve_stride
def _positive_float(value: str) -> float:
parsed = float(value)
if parsed <= 0:
raise argparse.ArgumentTypeError("target-fps must be positive")
return parsed
if __name__ == "__main__":
@@ -29,6 +37,24 @@ if __name__ == "__main__":
"--window", type=int, default=30, help="Window size for classification"
)
parser.add_argument("--stride", type=int, default=30, help="Stride for window")
parser.add_argument(
"--target-fps",
type=_positive_float,
default=15.0,
help="Target FPS for temporal downsampling before windowing",
)
parser.add_argument(
"--window-mode",
type=str,
choices=["manual", "sliding", "chunked"],
default="manual",
help="Window scheduling mode: manual uses --stride; sliding uses stride=1; chunked uses stride=window",
)
parser.add_argument(
"--no-target-fps",
action="store_true",
help="Disable temporal downsampling and use all frames",
)
parser.add_argument(
"--nats-url", type=str, default=None, help="NATS URL for result publishing"
)
@@ -88,27 +114,32 @@ if __name__ == "__main__":
source=args.source, checkpoint=args.checkpoint, config=args.config
)
# Build kwargs based on what ScoliosisPipeline accepts
pipeline_kwargs = {
"source": args.source,
"checkpoint": args.checkpoint,
"config": args.config,
"device": args.device,
"yolo_model": args.yolo_model,
"window": args.window,
"stride": args.stride,
"nats_url": args.nats_url,
"nats_subject": args.nats_subject,
"max_frames": args.max_frames,
"preprocess_only": args.preprocess_only,
"silhouette_export_path": args.silhouette_export_path,
"silhouette_export_format": args.silhouette_export_format,
"silhouette_visualize_dir": args.silhouette_visualize_dir,
"result_export_path": args.result_export_path,
"result_export_format": args.result_export_format,
"visualize": args.visualize,
}
pipeline = ScoliosisPipeline(**pipeline_kwargs)
effective_stride = resolve_stride(
window=cast(int, args.window),
stride=cast(int, args.stride),
window_mode=cast(WindowMode, args.window_mode),
)
pipeline = ScoliosisPipeline(
source=cast(str, args.source),
checkpoint=cast(str, args.checkpoint),
config=cast(str, args.config),
device=cast(str, args.device),
yolo_model=cast(str, args.yolo_model),
window=cast(int, args.window),
stride=effective_stride,
target_fps=(None if args.no_target_fps else cast(float, args.target_fps)),
nats_url=cast(str | None, args.nats_url),
nats_subject=cast(str, args.nats_subject),
max_frames=cast(int | None, args.max_frames),
preprocess_only=cast(bool, args.preprocess_only),
silhouette_export_path=cast(str | None, args.silhouette_export_path),
silhouette_export_format=cast(str, args.silhouette_export_format),
silhouette_visualize_dir=cast(str | None, args.silhouette_visualize_dir),
result_export_path=cast(str | None, args.result_export_path),
result_export_format=cast(str, args.result_export_format),
visualize=cast(bool, args.visualize),
)
raise SystemExit(pipeline.run())
except ValueError as err:
print(f"Error: {err}", file=sys.stderr)
+18 -3
View File
@@ -18,6 +18,7 @@ logger = logging.getLogger(__name__)
# Type alias for frame stream: (frame_array, metadata_dict)
FrameStream = Iterable[tuple[np.ndarray, dict[str, object]]]
# Protocol for cv-mmap metadata (needed at runtime for nested function annotation)
class _FrameMetadata(Protocol):
frame_count: int
@@ -58,6 +59,13 @@ def opencv_source(
if not cap.isOpened():
raise RuntimeError(f"Failed to open video source: {path}")
is_file_source = isinstance(path, str)
source_fps = float(cap.get(cv2.CAP_PROP_FPS)) if is_file_source else 0.0
fps_valid = source_fps > 0.0 and np.isfinite(source_fps)
fallback_fps = source_fps if fps_valid else 30.0
fallback_interval_ns = int(1_000_000_000 / fallback_fps)
start_ns = time.monotonic_ns()
frame_idx = 0
try:
while max_frames is None or frame_idx < max_frames:
@@ -66,14 +74,22 @@ def opencv_source(
# End of stream
break
# Get timestamp if available (some backends support this)
timestamp_ns = time.monotonic_ns()
if is_file_source:
pos_msec = float(cap.get(cv2.CAP_PROP_POS_MSEC))
if np.isfinite(pos_msec) and pos_msec > 0.0:
timestamp_ns = start_ns + int(pos_msec * 1_000_000)
else:
timestamp_ns = start_ns + frame_idx * fallback_interval_ns
else:
timestamp_ns = time.monotonic_ns()
metadata: dict[str, object] = {
"frame_count": frame_idx,
"timestamp_ns": timestamp_ns,
"source": path,
}
if fps_valid:
metadata["source_fps"] = source_fps
yield frame, metadata
frame_idx += 1
@@ -118,7 +134,6 @@ def cvmmap_source(
# Import cvmmap only when function is called
# Use try/except for runtime import check
try:
from cvmmap import CvMmapClient as _CvMmapClientReal # pyright: ignore[reportMissingTypeStubs]
except ImportError as e:
raise ImportError(
+105 -5
View File
@@ -5,7 +5,7 @@ from contextlib import suppress
import logging
from pathlib import Path
import time
from typing import TYPE_CHECKING, Protocol, cast
from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, cast
from beartype import beartype
import click
@@ -31,6 +31,16 @@ JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
JaxtypedFactory = Callable[..., JaxtypedDecorator]
jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
WindowMode: TypeAlias = Literal["manual", "sliding", "chunked"]
def resolve_stride(window: int, stride: int, window_mode: WindowMode) -> int:
if window_mode == "manual":
return stride
if window_mode == "sliding":
return 1
return window
class _BoxesLike(Protocol):
@property
@@ -65,6 +75,27 @@ class _TrackCallable(Protocol):
) -> object: ...
class _FramePacer:
_interval_ns: int
_next_emit_ns: int | None
def __init__(self, target_fps: float) -> None:
if target_fps <= 0:
raise ValueError(f"target_fps must be positive, got {target_fps}")
self._interval_ns = int(1_000_000_000 / target_fps)
self._next_emit_ns = None
def should_emit(self, timestamp_ns: int) -> bool:
if self._next_emit_ns is None:
self._next_emit_ns = timestamp_ns + self._interval_ns
return True
if timestamp_ns >= self._next_emit_ns:
while self._next_emit_ns <= timestamp_ns:
self._next_emit_ns += self._interval_ns
return True
return False
class ScoliosisPipeline:
_detector: object
_source: FrameStream
@@ -83,6 +114,7 @@ class ScoliosisPipeline:
_result_buffer: list[DemoResult]
_visualizer: OpenCVVisualizer | None
_last_viz_payload: dict[str, object] | None
_frame_pacer: _FramePacer | None
def __init__(
self,
@@ -104,6 +136,7 @@ class ScoliosisPipeline:
result_export_path: str | None = None,
result_export_format: str = "json",
visualize: bool = False,
target_fps: float | None = 15.0,
) -> None:
self._detector = YOLO(yolo_model)
self._source = create_source(source, max_frames=max_frames)
@@ -140,6 +173,7 @@ class ScoliosisPipeline:
else:
self._visualizer = None
self._last_viz_payload = None
self._frame_pacer = _FramePacer(target_fps) if target_fps is not None else None
@staticmethod
def _extract_int(meta: dict[str, object], key: str, fallback: int) -> int:
@@ -177,6 +211,7 @@ class ScoliosisPipeline:
Float[ndarray, "64 44"],
UInt8[ndarray, "h w"],
BBoxXYXY,
BBoxXYXY,
int,
]
| None
@@ -189,7 +224,7 @@ class ScoliosisPipeline:
mask_to_silhouette(self._to_mask_u8(mask_raw), bbox_mask),
)
if silhouette is not None:
return silhouette, mask_raw, bbox_frame, int(track_id)
return silhouette, mask_raw, bbox_frame, bbox_mask, int(track_id)
fallback = cast(
tuple[UInt8[ndarray, "h w"], BBoxXYXY] | None,
@@ -231,7 +266,7 @@ class ScoliosisPipeline:
# Fallback: use mask-space bbox if orig_shape unavailable
bbox_frame = bbox_mask
# For fallback case, mask_raw is the same as mask_u8
return silhouette, mask_u8, bbox_frame, 0
return silhouette, mask_u8, bbox_frame, bbox_mask, 0
@jaxtyped(typechecker=beartype)
def process_frame(
@@ -262,7 +297,7 @@ class ScoliosisPipeline:
if selected is None:
return None
silhouette, mask_raw, bbox, track_id = selected
silhouette, mask_raw, bbox, bbox_mask, track_id = selected
# Store silhouette for export if in preprocess-only mode or if export requested
if self._silhouette_export_path is not None or self._preprocess_only:
@@ -284,20 +319,39 @@ class ScoliosisPipeline:
return {
"mask_raw": mask_raw,
"bbox": bbox,
"bbox_mask": bbox_mask,
"silhouette": silhouette,
"segmentation_input": None,
"track_id": track_id,
"label": None,
"confidence": None,
}
if self._frame_pacer is not None and not self._frame_pacer.should_emit(
timestamp_ns
):
return {
"mask_raw": mask_raw,
"bbox": bbox,
"bbox_mask": bbox_mask,
"silhouette": silhouette,
"segmentation_input": self._window.buffered_silhouettes,
"track_id": track_id,
"label": None,
"confidence": None,
}
self._window.push(silhouette, frame_idx=frame_idx, track_id=track_id)
segmentation_input = self._window.buffered_silhouettes
if not self._window.should_classify():
# Return visualization payload even when not classifying yet
return {
"mask_raw": mask_raw,
"bbox": bbox,
"bbox_mask": bbox_mask,
"silhouette": silhouette,
"segmentation_input": segmentation_input,
"track_id": track_id,
"label": None,
"confidence": None,
@@ -330,7 +384,9 @@ class ScoliosisPipeline:
"result": result,
"mask_raw": mask_raw,
"bbox": bbox,
"bbox_mask": bbox_mask,
"silhouette": silhouette,
"segmentation_input": segmentation_input,
"track_id": track_id,
"label": label,
"confidence": confidence,
@@ -400,7 +456,9 @@ class ScoliosisPipeline:
viz_dict = cast(dict[str, object], viz_data)
mask_raw_obj = viz_dict.get("mask_raw")
bbox_obj = viz_dict.get("bbox")
bbox_mask_obj = viz_dict.get("bbox_mask")
silhouette_obj = viz_dict.get("silhouette")
segmentation_input_obj = viz_dict.get("segmentation_input")
track_id_val = viz_dict.get("track_id", 0)
track_id = track_id_val if isinstance(track_id_val, int) else 0
label_obj = viz_dict.get("label")
@@ -409,24 +467,33 @@ class ScoliosisPipeline:
# Cast extracted values to expected types
mask_raw = cast(NDArray[np.uint8] | None, mask_raw_obj)
bbox = cast(BBoxXYXY | None, bbox_obj)
bbox_mask = cast(BBoxXYXY | None, bbox_mask_obj)
silhouette = cast(NDArray[np.float32] | None, silhouette_obj)
segmentation_input = cast(
NDArray[np.float32] | None,
segmentation_input_obj,
)
label = cast(str | None, label_obj)
confidence = cast(float | None, confidence_obj)
else:
# No detection and no cache - use default values
mask_raw = None
bbox = None
bbox_mask = None
track_id = 0
silhouette = None
segmentation_input = None
label = None
confidence = None
keep_running = self._visualizer.update(
frame_u8,
bbox,
bbox_mask,
track_id,
mask_raw,
silhouette,
segmentation_input,
label,
confidence,
ema_fps,
@@ -671,6 +738,23 @@ def validate_runtime_inputs(source: str, checkpoint: str, config: str) -> None:
)
@click.option("--window", type=click.IntRange(min=1), default=30, show_default=True)
@click.option("--stride", type=click.IntRange(min=1), default=30, show_default=True)
@click.option(
"--window-mode",
type=click.Choice(["manual", "sliding", "chunked"], case_sensitive=False),
default="manual",
show_default=True,
help=(
"Window scheduling mode: manual uses --stride; "
"sliding forces stride=1; chunked forces stride=window"
),
)
@click.option(
"--target-fps",
type=click.FloatRange(min=0.1),
default=15.0,
show_default=True,
)
@click.option("--no-target-fps", is_flag=True, default=False)
@click.option("--nats-url", type=str, default=None)
@click.option(
"--nats-subject",
@@ -725,6 +809,9 @@ def main(
yolo_model: str,
window: int,
stride: int,
window_mode: str,
target_fps: float | None,
no_target_fps: bool,
nats_url: str | None,
nats_subject: str,
max_frames: int | None,
@@ -748,6 +835,18 @@ def main(
try:
validate_runtime_inputs(source=source, checkpoint=checkpoint, config=config)
effective_stride = resolve_stride(
window=window,
stride=stride,
window_mode=cast(WindowMode, window_mode.lower()),
)
if effective_stride != stride:
logger.info(
"window_mode=%s overrides stride=%d -> effective_stride=%d",
window_mode,
stride,
effective_stride,
)
pipeline = ScoliosisPipeline(
source=source,
checkpoint=checkpoint,
@@ -755,7 +854,8 @@ def main(
device=device,
yolo_model=yolo_model,
window=window,
stride=stride,
stride=effective_stride,
target_fps=None if no_target_fps else target_fps,
nats_url=nats_url,
nats_subject=nats_subject,
max_frames=max_frames,
+3 -1
View File
@@ -23,8 +23,10 @@ jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
UInt8Array = NDArray[np.uint8]
Float32Array = NDArray[np.float32]
#: Bounding box in XYXY format: (x1, y1, x2, y2) where (x1,y1) is top-left and (x2,y2) is bottom-right.
BBoxXYXY = tuple[int, int, int, int]
"""
Bounding box in XYXY format: (x1, y1, x2, y2) where (x1,y1) is top-left and (x2,y2) is bottom-right.
"""
def _read_attr(container: object, key: str) -> object | None:
+246 -122
View File
@@ -20,7 +20,9 @@ logger = logging.getLogger(__name__)
# Window names
MAIN_WINDOW = "Scoliosis Detection"
SEG_WINDOW = "Segmentation"
SEG_WINDOW = "Normalized Silhouette"
RAW_WINDOW = "Raw Mask"
WINDOW_SEG_INPUT = "Segmentation Input"
# Silhouette dimensions (from preprocess.py)
SIL_HEIGHT = 64
@@ -29,43 +31,45 @@ SIL_WIDTH = 44
# Display dimensions for upscaled silhouette
DISPLAY_HEIGHT = 256
DISPLAY_WIDTH = 176
RAW_STATS_PAD = 54
MODE_LABEL_PAD = 26
# Colors (BGR)
COLOR_GREEN = (0, 255, 0)
COLOR_WHITE = (255, 255, 255)
COLOR_BLACK = (0, 0, 0)
COLOR_DARK_GRAY = (56, 56, 56)
COLOR_RED = (0, 0, 255)
COLOR_YELLOW = (0, 255, 255)
# Mode labels
MODE_LABELS = ["Both", "Raw Mask", "Normalized"]
# Type alias for image arrays (NDArray or cv2.Mat)
ImageArray = NDArray[np.uint8]
class OpenCVVisualizer:
"""Real-time visualizer for gait analysis demo.
Displays two windows:
- Main stream: Original frame with bounding box and metadata overlay
- Segmentation: Raw mask, normalized silhouette, or side-by-side view
Supports interactive mode switching via keyboard.
"""
def __init__(self) -> None:
"""Initialize visualizer with default mask mode."""
self.mask_mode: int = 0 # 0: Both, 1: Raw, 2: Normalized
self.show_raw_window: bool = False
self.show_raw_debug: bool = False
self._windows_created: bool = False
self._raw_window_created: bool = False
def _ensure_windows(self) -> None:
"""Create OpenCV windows if not already created."""
if not self._windows_created:
cv2.namedWindow(MAIN_WINDOW, cv2.WINDOW_NORMAL)
cv2.namedWindow(SEG_WINDOW, cv2.WINDOW_NORMAL)
cv2.namedWindow(WINDOW_SEG_INPUT, cv2.WINDOW_NORMAL)
self._windows_created = True
def _ensure_raw_window(self) -> None:
if not self._raw_window_created:
cv2.namedWindow(RAW_WINDOW, cv2.WINDOW_NORMAL)
self._raw_window_created = True
def _hide_raw_window(self) -> None:
if self._raw_window_created:
cv2.destroyWindow(RAW_WINDOW)
self._raw_window_created = False
def _draw_bbox(
self,
frame: ImageArray,
@@ -215,33 +219,181 @@ class OpenCVVisualizer:
return upscaled
def _normalize_mask_for_display(self, mask: NDArray[np.generic]) -> ImageArray:
mask_array = np.asarray(mask)
if mask_array.dtype == np.bool_:
bool_scaled = np.where(mask_array, np.uint8(255), np.uint8(0)).astype(
np.uint8
)
return cast(ImageArray, bool_scaled)
if mask_array.dtype == np.uint8:
mask_array = cast(ImageArray, mask_array)
max_u8 = int(np.max(mask_array)) if mask_array.size > 0 else 0
if max_u8 <= 1:
scaled_u8 = np.where(mask_array > 0, np.uint8(255), np.uint8(0)).astype(
np.uint8
)
return cast(ImageArray, scaled_u8)
return cast(ImageArray, mask_array)
if np.issubdtype(mask_array.dtype, np.integer):
max_int = float(np.max(mask_array)) if mask_array.size > 0 else 0.0
if max_int <= 1.0:
return cast(
ImageArray, (mask_array.astype(np.float32) * 255.0).astype(np.uint8)
)
clipped = np.clip(mask_array, 0, 255).astype(np.uint8)
return cast(ImageArray, clipped)
mask_float = np.asarray(mask_array, dtype=np.float32)
max_val = float(np.max(mask_float)) if mask_float.size > 0 else 0.0
if max_val <= 0.0:
return np.zeros(mask_float.shape, dtype=np.uint8)
normalized = np.clip((mask_float / max_val) * 255.0, 0.0, 255.0).astype(
np.uint8
)
return cast(ImageArray, normalized)
def _draw_raw_stats(self, image: ImageArray, mask_raw: ImageArray | None) -> None:
if mask_raw is None:
return
mask = np.asarray(mask_raw)
if mask.size == 0:
return
stats = [
f"raw: {mask.dtype}",
f"min/max: {float(mask.min()):.3f}/{float(mask.max()):.3f}",
f"nnz: {int(np.count_nonzero(mask))}",
]
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.45
thickness = 1
line_h = 18
x0 = 8
y0 = 20
for i, txt in enumerate(stats):
y = y0 + i * line_h
(tw, th), _ = cv2.getTextSize(txt, font, font_scale, thickness)
_ = cv2.rectangle(
image, (x0 - 4, y - th - 4), (x0 + tw + 4, y + 4), COLOR_BLACK, -1
)
_ = cv2.putText(
image, txt, (x0, y), font, font_scale, COLOR_YELLOW, thickness
)
def _prepare_segmentation_view(
self,
mask_raw: ImageArray | None,
silhouette: NDArray[np.float32] | None,
bbox: BBoxXYXY | None,
) -> ImageArray:
"""Prepare segmentation window content based on current mode.
_ = mask_raw
_ = bbox
return self._prepare_normalized_view(silhouette)
Args:
mask_raw: Raw binary mask (H, W) uint8 or None
silhouette: Normalized silhouette (64, 44) float32 or None
def _fit_gray_to_display(
self,
gray: ImageArray,
out_h: int = DISPLAY_HEIGHT,
out_w: int = DISPLAY_WIDTH,
) -> ImageArray:
src_h, src_w = gray.shape[:2]
if src_h <= 0 or src_w <= 0:
return np.zeros((out_h, out_w), dtype=np.uint8)
Returns:
Displayable image (H, W, 3) uint8
"""
if self.mask_mode == 0:
# Mode 0: Both (side by side)
return self._prepare_both_view(mask_raw, silhouette)
elif self.mask_mode == 1:
# Mode 1: Raw mask only
return self._prepare_raw_view(mask_raw)
else:
# Mode 2: Normalized silhouette only
return self._prepare_normalized_view(silhouette)
scale = min(out_w / src_w, out_h / src_h)
new_w = max(1, int(round(src_w * scale)))
new_h = max(1, int(round(src_h * scale)))
resized = cast(
ImageArray,
cv2.resize(gray, (new_w, new_h), interpolation=cv2.INTER_NEAREST),
)
canvas = np.zeros((out_h, out_w), dtype=np.uint8)
x0 = (out_w - new_w) // 2
y0 = (out_h - new_h) // 2
canvas[y0 : y0 + new_h, x0 : x0 + new_w] = resized
return cast(ImageArray, canvas)
def _crop_mask_to_bbox(
self,
mask_gray: ImageArray,
bbox: BBoxXYXY | None,
) -> ImageArray:
if bbox is None:
return mask_gray
h, w = mask_gray.shape[:2]
x1, y1, x2, y2 = bbox
x1c = max(0, min(w, int(x1)))
x2c = max(0, min(w, int(x2)))
y1c = max(0, min(h, int(y1)))
y2c = max(0, min(h, int(y2)))
if x2c <= x1c or y2c <= y1c:
return mask_gray
cropped = mask_gray[y1c:y2c, x1c:x2c]
if cropped.size == 0:
return mask_gray
return cast(ImageArray, cropped)
def _prepare_segmentation_input_view(
self,
silhouettes: NDArray[np.float32] | None,
) -> ImageArray:
if silhouettes is None or silhouettes.size == 0:
placeholder = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH, 3), dtype=np.uint8)
self._draw_mode_indicator(placeholder, "Input Silhouettes (No Data)")
return placeholder
n_frames = int(silhouettes.shape[0])
tiles_per_row = int(np.ceil(np.sqrt(n_frames)))
rows = int(np.ceil(n_frames / tiles_per_row))
tile_h = DISPLAY_HEIGHT
tile_w = DISPLAY_WIDTH
grid = np.zeros((rows * tile_h, tiles_per_row * tile_w), dtype=np.uint8)
for idx in range(n_frames):
sil = silhouettes[idx]
tile = self._upscale_silhouette(sil)
r = idx // tiles_per_row
c = idx % tiles_per_row
y0, y1 = r * tile_h, (r + 1) * tile_h
x0, x1 = c * tile_w, (c + 1) * tile_w
grid[y0:y1, x0:x1] = tile
grid_bgr = cast(ImageArray, cv2.cvtColor(grid, cv2.COLOR_GRAY2BGR))
for idx in range(n_frames):
r = idx // tiles_per_row
c = idx % tiles_per_row
y0 = r * tile_h
x0 = c * tile_w
cv2.putText(
grid_bgr,
str(idx),
(x0 + 8, y0 + 22),
cv2.FONT_HERSHEY_SIMPLEX,
0.6,
(0, 255, 255),
2,
cv2.LINE_AA,
)
return grid_bgr
def _prepare_raw_view(
self,
mask_raw: ImageArray | None,
bbox: BBoxXYXY | None = None,
) -> ImageArray:
"""Prepare raw mask view.
@@ -261,20 +413,23 @@ class OpenCVVisualizer:
if len(mask_raw.shape) == 3:
mask_gray = cast(ImageArray, cv2.cvtColor(mask_raw, cv2.COLOR_BGR2GRAY))
else:
mask_gray = mask_raw
mask_gray = cast(ImageArray, mask_raw)
# Resize to display size
mask_resized = cast(
ImageArray,
cv2.resize(
mask_gray,
(DISPLAY_WIDTH, DISPLAY_HEIGHT),
interpolation=cv2.INTER_NEAREST,
),
mask_gray = self._normalize_mask_for_display(mask_gray)
mask_gray = self._crop_mask_to_bbox(mask_gray, bbox)
debug_pad = RAW_STATS_PAD if self.show_raw_debug else 0
content_h = max(1, DISPLAY_HEIGHT - debug_pad - MODE_LABEL_PAD)
mask_resized = self._fit_gray_to_display(
mask_gray, out_h=content_h, out_w=DISPLAY_WIDTH
)
full_mask = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH), dtype=np.uint8)
full_mask[debug_pad : debug_pad + content_h, :] = mask_resized
# Convert to BGR for display
mask_bgr = cast(ImageArray, cv2.cvtColor(mask_resized, cv2.COLOR_GRAY2BGR))
mask_bgr = cast(ImageArray, cv2.cvtColor(full_mask, cv2.COLOR_GRAY2BGR))
if self.show_raw_debug:
self._draw_raw_stats(mask_bgr, mask_raw)
self._draw_mode_indicator(mask_bgr, "Raw Mask")
return mask_bgr
@@ -299,80 +454,21 @@ class OpenCVVisualizer:
# Upscale and convert
upscaled = self._upscale_silhouette(silhouette)
sil_bgr = cast(ImageArray, cv2.cvtColor(upscaled, cv2.COLOR_GRAY2BGR))
content_h = max(1, DISPLAY_HEIGHT - MODE_LABEL_PAD)
sil_compact = self._fit_gray_to_display(
upscaled, out_h=content_h, out_w=DISPLAY_WIDTH
)
sil_canvas = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH), dtype=np.uint8)
sil_canvas[:content_h, :] = sil_compact
sil_bgr = cast(ImageArray, cv2.cvtColor(sil_canvas, cv2.COLOR_GRAY2BGR))
self._draw_mode_indicator(sil_bgr, "Normalized")
return sil_bgr
def _prepare_both_view(
self,
mask_raw: ImageArray | None,
silhouette: NDArray[np.float32] | None,
) -> ImageArray:
"""Prepare side-by-side view of both masks.
Args:
mask_raw: Raw binary mask or None
silhouette: Normalized silhouette or None
Returns:
Displayable side-by-side image
"""
# Prepare individual views without mode indicators (will be drawn on combined)
# Raw view preparation (without indicator)
if mask_raw is None:
raw_gray = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH), dtype=np.uint8)
else:
if len(mask_raw.shape) == 3:
mask_gray = cast(ImageArray, cv2.cvtColor(mask_raw, cv2.COLOR_BGR2GRAY))
else:
mask_gray = mask_raw
# Normalize to uint8 [0,255] for display (handles both float [0,1] and uint8 inputs)
if mask_gray.dtype == np.float32 or mask_gray.dtype == np.float64:
mask_gray = (mask_gray * 255).astype(np.uint8)
raw_gray = cast(
ImageArray,
cv2.resize(
mask_gray,
(DISPLAY_WIDTH, DISPLAY_HEIGHT),
interpolation=cv2.INTER_NEAREST,
),
)
# Normalized view preparation (without indicator)
if silhouette is None:
norm_gray = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH), dtype=np.uint8)
else:
upscaled = self._upscale_silhouette(silhouette)
norm_gray = upscaled
# Stack horizontally
combined = np.hstack([raw_gray, norm_gray])
# Convert back to BGR
combined_bgr = cast(ImageArray, cv2.cvtColor(combined, cv2.COLOR_GRAY2BGR))
# Add mode indicator
self._draw_mode_indicator(combined_bgr, "Both: Raw | Normalized")
return combined_bgr
def _draw_mode_indicator(
self,
image: ImageArray,
label: str,
) -> None:
"""Draw mode indicator text on image.
Args:
image: Image to draw on (modified in place)
label: Mode label text
"""
def _draw_mode_indicator(self, image: ImageArray, label: str) -> None:
h, w = image.shape[:2]
# Mode text at bottom
mode_text = f"Mode: {MODE_LABELS[self.mask_mode]} ({self.mask_mode}) - {label}"
mode_text = label
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.5
@@ -383,15 +479,22 @@ class OpenCVVisualizer:
mode_text, font, font_scale, thickness
)
# Draw background at bottom center
x_pos = (w - text_width) // 2
y_pos = h - 10
x_pos = 14
y_pos = h - 8
y_top = max(0, h - MODE_LABEL_PAD)
_ = cv2.rectangle(
image,
(x_pos - 5, y_pos - text_height - 5),
(x_pos + text_width + 5, y_pos + 5),
COLOR_BLACK,
(0, y_top),
(w, h),
COLOR_DARK_GRAY,
-1,
)
_ = cv2.rectangle(
image,
(x_pos - 6, y_pos - text_height - 6),
(x_pos + text_width + 8, y_pos + 6),
COLOR_DARK_GRAY,
-1,
)
@@ -410,9 +513,11 @@ class OpenCVVisualizer:
self,
frame: ImageArray,
bbox: BBoxXYXY | None,
bbox_mask: BBoxXYXY | None,
track_id: int,
mask_raw: ImageArray | None,
silhouette: NDArray[np.float32] | None,
segmentation_input: NDArray[np.float32] | None,
label: str | None,
confidence: float | None,
fps: float,
@@ -441,23 +546,42 @@ class OpenCVVisualizer:
cv2.imshow(MAIN_WINDOW, main_display)
# Prepare and show segmentation window
seg_display = self._prepare_segmentation_view(mask_raw, silhouette)
seg_display = self._prepare_segmentation_view(mask_raw, silhouette, bbox)
cv2.imshow(SEG_WINDOW, seg_display)
if self.show_raw_window:
self._ensure_raw_window()
raw_display = self._prepare_raw_view(mask_raw, bbox_mask)
cv2.imshow(RAW_WINDOW, raw_display)
seg_input_display = self._prepare_segmentation_input_view(segmentation_input)
cv2.imshow(WINDOW_SEG_INPUT, seg_input_display)
# Handle keyboard input
key = cv2.waitKey(1) & 0xFF
if key == ord("q"):
return False
elif key == ord("m"):
# Cycle through modes: 0 -> 1 -> 2 -> 0
self.mask_mode = (self.mask_mode + 1) % 3
logger.debug("Switched to mask mode: %s", MODE_LABELS[self.mask_mode])
elif key == ord("r"):
self.show_raw_window = not self.show_raw_window
if self.show_raw_window:
self._ensure_raw_window()
logger.debug("Raw mask window enabled")
else:
self._hide_raw_window()
logger.debug("Raw mask window disabled")
elif key == ord("d"):
self.show_raw_debug = not self.show_raw_debug
logger.debug(
"Raw mask debug overlay %s",
"enabled" if self.show_raw_debug else "disabled",
)
return True
def close(self) -> None:
"""Close all OpenCV windows and cleanup."""
if self._windows_created:
self._hide_raw_window()
cv2.destroyAllWindows()
self._windows_created = False
self._raw_window_created = False
+9
View File
@@ -216,6 +216,15 @@ class SilhouetteWindow:
raise ValueError("Window is empty")
return int(self._frame_indices[0])
@property
def buffered_silhouettes(self) -> Float[ndarray, "n 64 44"]:
if not self._buffer:
return np.empty((0, SIL_HEIGHT, SIL_WIDTH), dtype=np.float32)
return cast(
Float[ndarray, "n 64 44"],
np.stack(list(self._buffer), axis=0).astype(np.float32, copy=True),
)
def _to_numpy(obj: _ArrayLike) -> ndarray:
"""Safely convert array-like object to numpy array.