feat(demo): add realtime visualization pipeline flow
Integrate an opt-in OpenCV visualizer into the demo runtime so operators can monitor tracking, segmentation, and inference confidence in real time without changing the default non-visual execution path.
This commit is contained in:
+117
-2
@@ -1,7 +1,122 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from .pipeline import main
|
import argparse
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from .pipeline import ScoliosisPipeline
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
parser = argparse.ArgumentParser(description="Scoliosis Detection Pipeline")
|
||||||
|
parser.add_argument(
|
||||||
|
"--source", type=str, required=True, help="Video source path or camera ID"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint", type=str, required=True, help="Model checkpoint path"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--config",
|
||||||
|
type=str,
|
||||||
|
default="configs/sconet/sconet_scoliosis1k.yaml",
|
||||||
|
help="Config file path",
|
||||||
|
)
|
||||||
|
parser.add_argument("--device", type=str, default="cuda:0", help="Device to run on")
|
||||||
|
parser.add_argument(
|
||||||
|
"--yolo-model", type=str, default="ckpt/yolo11n-seg.pt", help="YOLO model name"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--window", type=int, default=30, help="Window size for classification"
|
||||||
|
)
|
||||||
|
parser.add_argument("--stride", type=int, default=30, help="Stride for window")
|
||||||
|
parser.add_argument(
|
||||||
|
"--nats-url", type=str, default=None, help="NATS URL for result publishing"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--nats-subject", type=str, default="scoliosis.result", help="NATS subject"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-frames", type=int, default=None, help="Maximum frames to process"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--preprocess-only", action="store_true", help="Only preprocess silhouettes"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--silhouette-export-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to export silhouettes",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--silhouette-export-format", type=str, default="pickle", help="Export format"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--silhouette-visualize-dir",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Directory for silhouette visualizations",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--result-export-path", type=str, default=None, help="Path to export results"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--result-export-format", type=str, default="json", help="Result export format"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--visualize", action="store_true", help="Enable real-time visualization"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate preprocess-only mode requires silhouette export path
|
||||||
|
if args.preprocess_only and not args.silhouette_export_path:
|
||||||
|
print(
|
||||||
|
"Error: --silhouette-export-path is required when using --preprocess-only",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
raise SystemExit(2)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from .pipeline import validate_runtime_inputs
|
||||||
|
|
||||||
|
validate_runtime_inputs(
|
||||||
|
source=args.source, checkpoint=args.checkpoint, config=args.config
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build kwargs based on what ScoliosisPipeline accepts
|
||||||
|
sig = inspect.signature(ScoliosisPipeline.__init__)
|
||||||
|
pipeline_kwargs = {
|
||||||
|
"source": args.source,
|
||||||
|
"checkpoint": args.checkpoint,
|
||||||
|
"config": args.config,
|
||||||
|
"device": args.device,
|
||||||
|
"yolo_model": args.yolo_model,
|
||||||
|
"window": args.window,
|
||||||
|
"stride": args.stride,
|
||||||
|
"nats_url": args.nats_url,
|
||||||
|
"nats_subject": args.nats_subject,
|
||||||
|
"max_frames": args.max_frames,
|
||||||
|
"preprocess_only": args.preprocess_only,
|
||||||
|
"silhouette_export_path": args.silhouette_export_path,
|
||||||
|
"silhouette_export_format": args.silhouette_export_format,
|
||||||
|
"silhouette_visualize_dir": args.silhouette_visualize_dir,
|
||||||
|
"result_export_path": args.result_export_path,
|
||||||
|
"result_export_format": args.result_export_format,
|
||||||
|
}
|
||||||
|
if "visualize" in sig.parameters:
|
||||||
|
pipeline_kwargs["visualize"] = args.visualize
|
||||||
|
|
||||||
|
pipeline = ScoliosisPipeline(**pipeline_kwargs)
|
||||||
|
raise SystemExit(pipeline.run())
|
||||||
|
except ValueError as err:
|
||||||
|
print(f"Error: {err}", file=sys.stderr)
|
||||||
|
raise SystemExit(2) from err
|
||||||
|
except RuntimeError as err:
|
||||||
|
print(f"Runtime error: {err}", file=sys.stderr)
|
||||||
|
raise SystemExit(1) from err
|
||||||
|
|||||||
+107
-9
@@ -78,6 +78,7 @@ class ScoliosisPipeline:
|
|||||||
_result_export_path: Path | None
|
_result_export_path: Path | None
|
||||||
_result_export_format: str
|
_result_export_format: str
|
||||||
_result_buffer: list[dict[str, object]]
|
_result_buffer: list[dict[str, object]]
|
||||||
|
_visualizer: object | None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -98,6 +99,7 @@ class ScoliosisPipeline:
|
|||||||
silhouette_visualize_dir: str | None = None,
|
silhouette_visualize_dir: str | None = None,
|
||||||
result_export_path: str | None = None,
|
result_export_path: str | None = None,
|
||||||
result_export_format: str = "json",
|
result_export_format: str = "json",
|
||||||
|
visualize: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._detector = YOLO(yolo_model)
|
self._detector = YOLO(yolo_model)
|
||||||
self._source = create_source(source, max_frames=max_frames)
|
self._source = create_source(source, max_frames=max_frames)
|
||||||
@@ -124,6 +126,12 @@ class ScoliosisPipeline:
|
|||||||
)
|
)
|
||||||
self._result_export_format = result_export_format
|
self._result_export_format = result_export_format
|
||||||
self._result_buffer = []
|
self._result_buffer = []
|
||||||
|
if visualize:
|
||||||
|
from .visualizer import OpenCVVisualizer
|
||||||
|
|
||||||
|
self._visualizer = OpenCVVisualizer()
|
||||||
|
else:
|
||||||
|
self._visualizer = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_int(meta: dict[str, object], key: str, fallback: int) -> int:
|
def _extract_int(meta: dict[str, object], key: str, fallback: int) -> int:
|
||||||
@@ -156,7 +164,15 @@ class ScoliosisPipeline:
|
|||||||
def _select_silhouette(
|
def _select_silhouette(
|
||||||
self,
|
self,
|
||||||
result: _DetectionResultsLike,
|
result: _DetectionResultsLike,
|
||||||
) -> tuple[Float[ndarray, "64 44"], int] | None:
|
) -> (
|
||||||
|
tuple[
|
||||||
|
Float[ndarray, "64 44"],
|
||||||
|
UInt8[ndarray, "h w"],
|
||||||
|
tuple[int, int, int, int],
|
||||||
|
int,
|
||||||
|
]
|
||||||
|
| None
|
||||||
|
):
|
||||||
selected = select_person(result)
|
selected = select_person(result)
|
||||||
if selected is not None:
|
if selected is not None:
|
||||||
mask_raw, bbox, track_id = selected
|
mask_raw, bbox, track_id = selected
|
||||||
@@ -165,7 +181,7 @@ class ScoliosisPipeline:
|
|||||||
mask_to_silhouette(self._to_mask_u8(mask_raw), bbox),
|
mask_to_silhouette(self._to_mask_u8(mask_raw), bbox),
|
||||||
)
|
)
|
||||||
if silhouette is not None:
|
if silhouette is not None:
|
||||||
return silhouette, int(track_id)
|
return silhouette, mask_raw, bbox, int(track_id)
|
||||||
|
|
||||||
fallback = cast(
|
fallback = cast(
|
||||||
tuple[UInt8[ndarray, "h w"], tuple[int, int, int, int]] | None,
|
tuple[UInt8[ndarray, "h w"], tuple[int, int, int, int]] | None,
|
||||||
@@ -181,7 +197,8 @@ class ScoliosisPipeline:
|
|||||||
)
|
)
|
||||||
if silhouette is None:
|
if silhouette is None:
|
||||||
return None
|
return None
|
||||||
return silhouette, 0
|
# For fallback case, mask_raw is the same as mask_u8
|
||||||
|
return silhouette, mask_u8, bbox, 0
|
||||||
|
|
||||||
@jaxtyped(typechecker=beartype)
|
@jaxtyped(typechecker=beartype)
|
||||||
def process_frame(
|
def process_frame(
|
||||||
@@ -212,7 +229,7 @@ class ScoliosisPipeline:
|
|||||||
if selected is None:
|
if selected is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
silhouette, track_id = selected
|
silhouette, mask_raw, bbox, track_id = selected
|
||||||
|
|
||||||
# Store silhouette for export if in preprocess-only mode or if export requested
|
# Store silhouette for export if in preprocess-only mode or if export requested
|
||||||
if self._silhouette_export_path is not None or self._preprocess_only:
|
if self._silhouette_export_path is not None or self._preprocess_only:
|
||||||
@@ -230,12 +247,28 @@ class ScoliosisPipeline:
|
|||||||
self._visualize_silhouette(silhouette, frame_idx, track_id)
|
self._visualize_silhouette(silhouette, frame_idx, track_id)
|
||||||
|
|
||||||
if self._preprocess_only:
|
if self._preprocess_only:
|
||||||
return None
|
# Return visualization payload for display even in preprocess-only mode
|
||||||
|
return {
|
||||||
|
"mask_raw": mask_raw,
|
||||||
|
"bbox": bbox,
|
||||||
|
"silhouette": silhouette,
|
||||||
|
"track_id": track_id,
|
||||||
|
"label": None,
|
||||||
|
"confidence": None,
|
||||||
|
}
|
||||||
|
|
||||||
self._window.push(silhouette, frame_idx=frame_idx, track_id=track_id)
|
self._window.push(silhouette, frame_idx=frame_idx, track_id=track_id)
|
||||||
|
|
||||||
if not self._window.should_classify():
|
if not self._window.should_classify():
|
||||||
return None
|
# Return visualization payload even when not classifying yet
|
||||||
|
return {
|
||||||
|
"mask_raw": mask_raw,
|
||||||
|
"bbox": bbox,
|
||||||
|
"silhouette": silhouette,
|
||||||
|
"track_id": track_id,
|
||||||
|
"label": None,
|
||||||
|
"confidence": None,
|
||||||
|
}
|
||||||
|
|
||||||
window_tensor = self._window.get_tensor(device=self._device)
|
window_tensor = self._window.get_tensor(device=self._device)
|
||||||
label, confidence = cast(
|
label, confidence = cast(
|
||||||
@@ -259,25 +292,82 @@ class ScoliosisPipeline:
|
|||||||
self._result_buffer.append(result)
|
self._result_buffer.append(result)
|
||||||
|
|
||||||
self._publisher.publish(result)
|
self._publisher.publish(result)
|
||||||
return result
|
# Return result with visualization payload
|
||||||
|
return {
|
||||||
|
"result": result,
|
||||||
|
"mask_raw": mask_raw,
|
||||||
|
"bbox": bbox,
|
||||||
|
"silhouette": silhouette,
|
||||||
|
"track_id": track_id,
|
||||||
|
"label": label,
|
||||||
|
"confidence": confidence,
|
||||||
|
}
|
||||||
|
|
||||||
def run(self) -> int:
|
def run(self) -> int:
|
||||||
frame_count = 0
|
frame_count = 0
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
|
# EMA FPS state (alpha=0.1 for smoothing)
|
||||||
|
ema_fps = 0.0
|
||||||
|
alpha = 0.1
|
||||||
|
prev_time = start_time
|
||||||
try:
|
try:
|
||||||
for item in self._source:
|
for item in self._source:
|
||||||
frame, metadata = item
|
frame, metadata = item
|
||||||
frame_u8 = np.asarray(frame, dtype=np.uint8)
|
frame_u8 = np.asarray(frame, dtype=np.uint8)
|
||||||
frame_idx = self._extract_int(metadata, "frame_count", fallback=0)
|
frame_idx = self._extract_int(metadata, "frame_count", fallback=0)
|
||||||
frame_count += 1
|
frame_count += 1
|
||||||
|
|
||||||
|
# Compute per-frame EMA FPS
|
||||||
|
curr_time = time.perf_counter()
|
||||||
|
delta = curr_time - prev_time
|
||||||
|
prev_time = curr_time
|
||||||
|
if delta > 0:
|
||||||
|
instant_fps = 1.0 / delta
|
||||||
|
if ema_fps == 0.0:
|
||||||
|
ema_fps = instant_fps
|
||||||
|
else:
|
||||||
|
ema_fps = alpha * instant_fps + (1 - alpha) * ema_fps
|
||||||
|
|
||||||
|
viz_payload = None
|
||||||
try:
|
try:
|
||||||
_ = self.process_frame(frame_u8, metadata)
|
viz_payload = self.process_frame(frame_u8, metadata)
|
||||||
except Exception as frame_error:
|
except Exception as frame_error:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Skipping frame %d due to processing error: %s",
|
"Skipping frame %d due to processing error: %s",
|
||||||
frame_idx,
|
frame_idx,
|
||||||
frame_error,
|
frame_error,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Update visualizer if enabled
|
||||||
|
if self._visualizer is not None and viz_payload is not None:
|
||||||
|
# Cast viz_payload to dict for type checking
|
||||||
|
viz_dict = cast(dict[str, object], viz_payload)
|
||||||
|
mask_raw = viz_dict.get("mask_raw")
|
||||||
|
bbox = viz_dict.get("bbox")
|
||||||
|
silhouette = viz_dict.get("silhouette")
|
||||||
|
track_id_val = viz_dict.get("track_id", 0)
|
||||||
|
track_id = track_id_val if isinstance(track_id_val, int) else 0
|
||||||
|
label = viz_dict.get("label")
|
||||||
|
confidence = viz_dict.get("confidence")
|
||||||
|
|
||||||
|
# Cast _visualizer to object with update method
|
||||||
|
visualizer = cast(object, self._visualizer)
|
||||||
|
update_fn = getattr(visualizer, "update", None)
|
||||||
|
if callable(update_fn):
|
||||||
|
keep_running = update_fn(
|
||||||
|
frame_u8,
|
||||||
|
bbox,
|
||||||
|
track_id,
|
||||||
|
mask_raw,
|
||||||
|
silhouette,
|
||||||
|
label,
|
||||||
|
confidence,
|
||||||
|
ema_fps,
|
||||||
|
)
|
||||||
|
if not keep_running:
|
||||||
|
logger.info("Visualization closed by user.")
|
||||||
|
break
|
||||||
|
|
||||||
if frame_count % 100 == 0:
|
if frame_count % 100 == 0:
|
||||||
elapsed = time.perf_counter() - start_time
|
elapsed = time.perf_counter() - start_time
|
||||||
fps = frame_count / elapsed if elapsed > 0 else 0.0
|
fps = frame_count / elapsed if elapsed > 0 else 0.0
|
||||||
@@ -293,6 +383,14 @@ class ScoliosisPipeline:
|
|||||||
if self._closed:
|
if self._closed:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Close visualizer if enabled
|
||||||
|
if self._visualizer is not None:
|
||||||
|
visualizer = cast(object, self._visualizer)
|
||||||
|
close_viz = getattr(visualizer, "close", None)
|
||||||
|
if callable(close_viz):
|
||||||
|
with suppress(Exception):
|
||||||
|
_ = close_viz()
|
||||||
|
|
||||||
# Export silhouettes if requested
|
# Export silhouettes if requested
|
||||||
if self._silhouette_export_path is not None and self._silhouette_buffer:
|
if self._silhouette_export_path is not None and self._silhouette_buffer:
|
||||||
self._export_silhouettes()
|
self._export_silhouettes()
|
||||||
@@ -504,7 +602,7 @@ def validate_runtime_inputs(source: str, checkpoint: str, config: str) -> None:
|
|||||||
show_default=True,
|
show_default=True,
|
||||||
)
|
)
|
||||||
@click.option("--device", type=str, default="cuda:0", show_default=True)
|
@click.option("--device", type=str, default="cuda:0", show_default=True)
|
||||||
@click.option("--yolo-model", type=str, default="yolo11n-seg.pt", show_default=True)
|
@click.option("--yolo-model", type=str, default="ckpt/yolo11n-seg.pt", show_default=True)
|
||||||
@click.option("--window", type=click.IntRange(min=1), default=30, show_default=True)
|
@click.option("--window", type=click.IntRange(min=1), default=30, show_default=True)
|
||||||
@click.option("--stride", type=click.IntRange(min=1), default=30, show_default=True)
|
@click.option("--stride", type=click.IntRange(min=1), default=30, show_default=True)
|
||||||
@click.option("--nats-url", type=str, default=None)
|
@click.option("--nats-url", type=str, default=None)
|
||||||
|
|||||||
@@ -0,0 +1,446 @@
|
|||||||
|
"""OpenCV-based visualizer for demo pipeline.
|
||||||
|
|
||||||
|
Provides real-time visualization of detection, segmentation, and classification results
|
||||||
|
with interactive mode switching for mask display.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Window names
|
||||||
|
MAIN_WINDOW = "Scoliosis Detection"
|
||||||
|
SEG_WINDOW = "Segmentation"
|
||||||
|
|
||||||
|
# Silhouette dimensions (from preprocess.py)
|
||||||
|
SIL_HEIGHT = 64
|
||||||
|
SIL_WIDTH = 44
|
||||||
|
|
||||||
|
# Display dimensions for upscaled silhouette
|
||||||
|
DISPLAY_HEIGHT = 256
|
||||||
|
DISPLAY_WIDTH = 176
|
||||||
|
|
||||||
|
# Colors (BGR)
|
||||||
|
COLOR_GREEN = (0, 255, 0)
|
||||||
|
COLOR_WHITE = (255, 255, 255)
|
||||||
|
COLOR_BLACK = (0, 0, 0)
|
||||||
|
COLOR_RED = (0, 0, 255)
|
||||||
|
COLOR_YELLOW = (0, 255, 255)
|
||||||
|
|
||||||
|
# Mode labels
|
||||||
|
MODE_LABELS = ["Both", "Raw Mask", "Normalized"]
|
||||||
|
|
||||||
|
# Type alias for image arrays (NDArray or cv2.Mat)
|
||||||
|
ImageArray = NDArray[np.uint8]
|
||||||
|
|
||||||
|
|
||||||
|
class OpenCVVisualizer:
|
||||||
|
"""Real-time visualizer for gait analysis demo.
|
||||||
|
|
||||||
|
Displays two windows:
|
||||||
|
- Main stream: Original frame with bounding box and metadata overlay
|
||||||
|
- Segmentation: Raw mask, normalized silhouette, or side-by-side view
|
||||||
|
|
||||||
|
Supports interactive mode switching via keyboard.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize visualizer with default mask mode."""
|
||||||
|
self.mask_mode: int = 0 # 0: Both, 1: Raw, 2: Normalized
|
||||||
|
self._windows_created: bool = False
|
||||||
|
|
||||||
|
def _ensure_windows(self) -> None:
|
||||||
|
"""Create OpenCV windows if not already created."""
|
||||||
|
if not self._windows_created:
|
||||||
|
cv2.namedWindow(MAIN_WINDOW, cv2.WINDOW_NORMAL)
|
||||||
|
cv2.namedWindow(SEG_WINDOW, cv2.WINDOW_NORMAL)
|
||||||
|
self._windows_created = True
|
||||||
|
|
||||||
|
def _draw_bbox(
|
||||||
|
self,
|
||||||
|
frame: ImageArray,
|
||||||
|
bbox: tuple[int, int, int, int] | None,
|
||||||
|
) -> None:
|
||||||
|
"""Draw bounding box on frame if present.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame: Input frame (H, W, 3) uint8 - modified in place
|
||||||
|
bbox: Bounding box as (x1, y1, x2, y2) or None
|
||||||
|
"""
|
||||||
|
if bbox is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
x1, y1, x2, y2 = bbox
|
||||||
|
# Draw rectangle with green color, thickness 2
|
||||||
|
_ = cv2.rectangle(frame, (x1, y1), (x2, y2), COLOR_GREEN, 2)
|
||||||
|
|
||||||
|
def _draw_text_overlay(
|
||||||
|
self,
|
||||||
|
frame: ImageArray,
|
||||||
|
track_id: int,
|
||||||
|
fps: float,
|
||||||
|
label: str | None,
|
||||||
|
confidence: float | None,
|
||||||
|
) -> None:
|
||||||
|
"""Draw text overlay with track info, FPS, label, and confidence.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame: Input frame (H, W, 3) uint8 - modified in place
|
||||||
|
track_id: Tracking ID
|
||||||
|
fps: Current FPS
|
||||||
|
label: Classification label or None
|
||||||
|
confidence: Classification confidence or None
|
||||||
|
"""
|
||||||
|
# Prepare text lines
|
||||||
|
lines: list[str] = []
|
||||||
|
lines.append(f"ID: {track_id}")
|
||||||
|
lines.append(f"FPS: {fps:.1f}")
|
||||||
|
|
||||||
|
if label is not None:
|
||||||
|
if confidence is not None:
|
||||||
|
lines.append(f"{label}: {confidence:.2%}")
|
||||||
|
else:
|
||||||
|
lines.append(label)
|
||||||
|
|
||||||
|
# Draw text with background for readability
|
||||||
|
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||||
|
font_scale = 0.6
|
||||||
|
thickness = 1
|
||||||
|
line_height = 25
|
||||||
|
margin = 10
|
||||||
|
|
||||||
|
for i, text in enumerate(lines):
|
||||||
|
y_pos = margin + (i + 1) * line_height
|
||||||
|
|
||||||
|
# Draw background rectangle
|
||||||
|
(text_width, text_height), _ = cv2.getTextSize(
|
||||||
|
text, font, font_scale, thickness
|
||||||
|
)
|
||||||
|
_ = cv2.rectangle(
|
||||||
|
frame,
|
||||||
|
(margin, y_pos - text_height - 5),
|
||||||
|
(margin + text_width + 10, y_pos + 5),
|
||||||
|
COLOR_BLACK,
|
||||||
|
-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Draw text
|
||||||
|
_ = cv2.putText(
|
||||||
|
frame,
|
||||||
|
text,
|
||||||
|
(margin + 5, y_pos),
|
||||||
|
font,
|
||||||
|
font_scale,
|
||||||
|
COLOR_WHITE,
|
||||||
|
thickness,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _prepare_main_frame(
|
||||||
|
self,
|
||||||
|
frame: ImageArray,
|
||||||
|
bbox: tuple[int, int, int, int] | None,
|
||||||
|
track_id: int,
|
||||||
|
fps: float,
|
||||||
|
label: str | None,
|
||||||
|
confidence: float | None,
|
||||||
|
) -> ImageArray:
|
||||||
|
"""Prepare main display frame with bbox and text overlay.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame: Input frame (H, W, C) uint8
|
||||||
|
bbox: Bounding box or None
|
||||||
|
track_id: Tracking ID
|
||||||
|
fps: Current FPS
|
||||||
|
label: Classification label or None
|
||||||
|
confidence: Classification confidence or None
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed frame ready for display
|
||||||
|
"""
|
||||||
|
# Ensure BGR format (convert grayscale if needed)
|
||||||
|
if len(frame.shape) == 2:
|
||||||
|
display_frame = cast(ImageArray, cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR))
|
||||||
|
elif frame.shape[2] == 1:
|
||||||
|
display_frame = cast(ImageArray, cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR))
|
||||||
|
elif frame.shape[2] == 3:
|
||||||
|
display_frame = frame.copy()
|
||||||
|
elif frame.shape[2] == 4:
|
||||||
|
display_frame = cast(ImageArray, cv2.cvtColor(frame, cv2.COLOR_BGRA2BGR))
|
||||||
|
else:
|
||||||
|
display_frame = frame.copy()
|
||||||
|
|
||||||
|
# Draw bbox and text (modifies in place)
|
||||||
|
self._draw_bbox(display_frame, bbox)
|
||||||
|
self._draw_text_overlay(display_frame, track_id, fps, label, confidence)
|
||||||
|
|
||||||
|
return display_frame
|
||||||
|
|
||||||
|
def _upscale_silhouette(
|
||||||
|
self,
|
||||||
|
silhouette: NDArray[np.float32] | NDArray[np.uint8],
|
||||||
|
) -> ImageArray:
|
||||||
|
"""Upscale silhouette to display size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
silhouette: Input silhouette (64, 44) float32 [0,1] or uint8 [0,255]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Upscaled silhouette (256, 176) uint8
|
||||||
|
"""
|
||||||
|
# Normalize to uint8 if needed
|
||||||
|
if silhouette.dtype == np.float32 or silhouette.dtype == np.float64:
|
||||||
|
sil_u8 = (silhouette * 255).astype(np.uint8)
|
||||||
|
else:
|
||||||
|
sil_u8 = silhouette.astype(np.uint8)
|
||||||
|
|
||||||
|
# Upscale using nearest neighbor to preserve pixelation
|
||||||
|
upscaled = cast(
|
||||||
|
ImageArray,
|
||||||
|
cv2.resize(
|
||||||
|
sil_u8,
|
||||||
|
(DISPLAY_WIDTH, DISPLAY_HEIGHT),
|
||||||
|
interpolation=cv2.INTER_NEAREST,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return upscaled
|
||||||
|
|
||||||
|
def _prepare_segmentation_view(
|
||||||
|
self,
|
||||||
|
mask_raw: ImageArray | None,
|
||||||
|
silhouette: NDArray[np.float32] | None,
|
||||||
|
) -> ImageArray:
|
||||||
|
"""Prepare segmentation window content based on current mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mask_raw: Raw binary mask (H, W) uint8 or None
|
||||||
|
silhouette: Normalized silhouette (64, 44) float32 or None
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Displayable image (H, W, 3) uint8
|
||||||
|
"""
|
||||||
|
if self.mask_mode == 0:
|
||||||
|
# Mode 0: Both (side by side)
|
||||||
|
return self._prepare_both_view(mask_raw, silhouette)
|
||||||
|
elif self.mask_mode == 1:
|
||||||
|
# Mode 1: Raw mask only
|
||||||
|
return self._prepare_raw_view(mask_raw)
|
||||||
|
else:
|
||||||
|
# Mode 2: Normalized silhouette only
|
||||||
|
return self._prepare_normalized_view(silhouette)
|
||||||
|
|
||||||
|
def _prepare_raw_view(
|
||||||
|
self,
|
||||||
|
mask_raw: ImageArray | None,
|
||||||
|
) -> ImageArray:
|
||||||
|
"""Prepare raw mask view.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mask_raw: Raw binary mask or None
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Displayable image with mode indicator
|
||||||
|
"""
|
||||||
|
if mask_raw is None:
|
||||||
|
# Create placeholder
|
||||||
|
placeholder = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH, 3), dtype=np.uint8)
|
||||||
|
self._draw_mode_indicator(placeholder, "Raw Mask (No Data)")
|
||||||
|
return placeholder
|
||||||
|
|
||||||
|
# Ensure single channel
|
||||||
|
if len(mask_raw.shape) == 3:
|
||||||
|
mask_gray = cast(ImageArray, cv2.cvtColor(mask_raw, cv2.COLOR_BGR2GRAY))
|
||||||
|
else:
|
||||||
|
mask_gray = mask_raw
|
||||||
|
|
||||||
|
# Resize to display size
|
||||||
|
mask_resized = cast(
|
||||||
|
ImageArray,
|
||||||
|
cv2.resize(
|
||||||
|
mask_gray,
|
||||||
|
(DISPLAY_WIDTH, DISPLAY_HEIGHT),
|
||||||
|
interpolation=cv2.INTER_NEAREST,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to BGR for display
|
||||||
|
mask_bgr = cast(ImageArray, cv2.cvtColor(mask_resized, cv2.COLOR_GRAY2BGR))
|
||||||
|
self._draw_mode_indicator(mask_bgr, "Raw Mask")
|
||||||
|
|
||||||
|
return mask_bgr
|
||||||
|
|
||||||
|
def _prepare_normalized_view(
|
||||||
|
self,
|
||||||
|
silhouette: NDArray[np.float32] | None,
|
||||||
|
) -> ImageArray:
|
||||||
|
"""Prepare normalized silhouette view.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
silhouette: Normalized silhouette (64, 44) or None
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Displayable image with mode indicator
|
||||||
|
"""
|
||||||
|
if silhouette is None:
|
||||||
|
# Create placeholder
|
||||||
|
placeholder = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH, 3), dtype=np.uint8)
|
||||||
|
self._draw_mode_indicator(placeholder, "Normalized (No Data)")
|
||||||
|
return placeholder
|
||||||
|
|
||||||
|
# Upscale and convert
|
||||||
|
upscaled = self._upscale_silhouette(silhouette)
|
||||||
|
sil_bgr = cast(ImageArray, cv2.cvtColor(upscaled, cv2.COLOR_GRAY2BGR))
|
||||||
|
self._draw_mode_indicator(sil_bgr, "Normalized")
|
||||||
|
|
||||||
|
return sil_bgr
|
||||||
|
|
||||||
|
def _prepare_both_view(
|
||||||
|
self,
|
||||||
|
mask_raw: ImageArray | None,
|
||||||
|
silhouette: NDArray[np.float32] | None,
|
||||||
|
) -> ImageArray:
|
||||||
|
"""Prepare side-by-side view of both masks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mask_raw: Raw binary mask or None
|
||||||
|
silhouette: Normalized silhouette or None
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Displayable side-by-side image
|
||||||
|
"""
|
||||||
|
# Prepare individual views
|
||||||
|
raw_view = self._prepare_raw_view(mask_raw)
|
||||||
|
norm_view = self._prepare_normalized_view(silhouette)
|
||||||
|
|
||||||
|
# Convert to grayscale for side-by-side composition
|
||||||
|
if len(raw_view.shape) == 3:
|
||||||
|
raw_gray = cast(ImageArray, cv2.cvtColor(raw_view, cv2.COLOR_BGR2GRAY))
|
||||||
|
else:
|
||||||
|
raw_gray = raw_view
|
||||||
|
|
||||||
|
if len(norm_view.shape) == 3:
|
||||||
|
norm_gray = cast(ImageArray, cv2.cvtColor(norm_view, cv2.COLOR_BGR2GRAY))
|
||||||
|
else:
|
||||||
|
norm_gray = norm_view
|
||||||
|
|
||||||
|
# Stack horizontally
|
||||||
|
combined = np.hstack([raw_gray, norm_gray])
|
||||||
|
|
||||||
|
# Convert back to BGR
|
||||||
|
combined_bgr = cast(ImageArray, cv2.cvtColor(combined, cv2.COLOR_GRAY2BGR))
|
||||||
|
|
||||||
|
# Add mode indicator
|
||||||
|
self._draw_mode_indicator(combined_bgr, "Both: Raw | Normalized")
|
||||||
|
|
||||||
|
return combined_bgr
|
||||||
|
|
||||||
|
def _draw_mode_indicator(
|
||||||
|
self,
|
||||||
|
image: ImageArray,
|
||||||
|
label: str,
|
||||||
|
) -> None:
|
||||||
|
"""Draw mode indicator text on image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: Image to draw on (modified in place)
|
||||||
|
label: Mode label text
|
||||||
|
"""
|
||||||
|
h, w = image.shape[:2]
|
||||||
|
|
||||||
|
# Mode text at bottom
|
||||||
|
mode_text = f"Mode: {MODE_LABELS[self.mask_mode]} ({self.mask_mode}) - {label}"
|
||||||
|
|
||||||
|
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||||
|
font_scale = 0.5
|
||||||
|
thickness = 1
|
||||||
|
|
||||||
|
# Get text size for background
|
||||||
|
(text_width, text_height), _ = cv2.getTextSize(
|
||||||
|
mode_text, font, font_scale, thickness
|
||||||
|
)
|
||||||
|
|
||||||
|
# Draw background at bottom center
|
||||||
|
x_pos = (w - text_width) // 2
|
||||||
|
y_pos = h - 10
|
||||||
|
|
||||||
|
_ = cv2.rectangle(
|
||||||
|
image,
|
||||||
|
(x_pos - 5, y_pos - text_height - 5),
|
||||||
|
(x_pos + text_width + 5, y_pos + 5),
|
||||||
|
COLOR_BLACK,
|
||||||
|
-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Draw text
|
||||||
|
_ = cv2.putText(
|
||||||
|
image,
|
||||||
|
mode_text,
|
||||||
|
(x_pos, y_pos),
|
||||||
|
font,
|
||||||
|
font_scale,
|
||||||
|
COLOR_YELLOW,
|
||||||
|
thickness,
|
||||||
|
)
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
frame: ImageArray,
|
||||||
|
bbox: tuple[int, int, int, int] | None,
|
||||||
|
track_id: int,
|
||||||
|
mask_raw: ImageArray | None,
|
||||||
|
silhouette: NDArray[np.float32] | None,
|
||||||
|
label: str | None,
|
||||||
|
confidence: float | None,
|
||||||
|
fps: float,
|
||||||
|
) -> bool:
|
||||||
|
"""Update visualization with new frame data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame: Input frame (H, W, C) uint8
|
||||||
|
bbox: Bounding box as (x1, y1, x2, y2) or None
|
||||||
|
track_id: Tracking ID
|
||||||
|
mask_raw: Raw binary mask (H, W) uint8 or None
|
||||||
|
silhouette: Normalized silhouette (64, 44) float32 [0,1] or None
|
||||||
|
label: Classification label or None
|
||||||
|
confidence: Classification confidence [0,1] or None
|
||||||
|
fps: Current FPS
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
False if user requested quit (pressed 'q'), True otherwise
|
||||||
|
"""
|
||||||
|
self._ensure_windows()
|
||||||
|
|
||||||
|
# Prepare and show main window
|
||||||
|
main_display = self._prepare_main_frame(
|
||||||
|
frame, bbox, track_id, fps, label, confidence
|
||||||
|
)
|
||||||
|
cv2.imshow(MAIN_WINDOW, main_display)
|
||||||
|
|
||||||
|
# Prepare and show segmentation window
|
||||||
|
seg_display = self._prepare_segmentation_view(mask_raw, silhouette)
|
||||||
|
cv2.imshow(SEG_WINDOW, seg_display)
|
||||||
|
|
||||||
|
# Handle keyboard input
|
||||||
|
key = cv2.waitKey(1) & 0xFF
|
||||||
|
|
||||||
|
if key == ord("q"):
|
||||||
|
return False
|
||||||
|
elif key == ord("m"):
|
||||||
|
# Cycle through modes: 0 -> 1 -> 2 -> 0
|
||||||
|
self.mask_mode = (self.mask_mode + 1) % 3
|
||||||
|
logger.debug("Switched to mask mode: %s", MODE_LABELS[self.mask_mode])
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close all OpenCV windows and cleanup."""
|
||||||
|
if self._windows_created:
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
self._windows_created = False
|
||||||
Reference in New Issue
Block a user