Files
OpenGait/opengait/demo/visualizer.py
T

588 lines
18 KiB
Python

"""OpenCV-based visualizer for demo pipeline.
Provides real-time visualization of detection, segmentation, and classification results
with interactive mode switching for mask display.
"""
from __future__ import annotations
import logging
from typing import cast
import cv2
import numpy as np
from numpy.typing import NDArray
from .preprocess import BBoxXYXY
logger = logging.getLogger(__name__)
# Window names
MAIN_WINDOW = "Scoliosis Detection"
SEG_WINDOW = "Normalized Silhouette"
RAW_WINDOW = "Raw Mask"
WINDOW_SEG_INPUT = "Segmentation Input"
# Silhouette dimensions (from preprocess.py)
SIL_HEIGHT = 64
SIL_WIDTH = 44
# Display dimensions for upscaled silhouette
DISPLAY_HEIGHT = 256
DISPLAY_WIDTH = 176
RAW_STATS_PAD = 54
MODE_LABEL_PAD = 26
# Colors (BGR)
COLOR_GREEN = (0, 255, 0)
COLOR_WHITE = (255, 255, 255)
COLOR_BLACK = (0, 0, 0)
COLOR_DARK_GRAY = (56, 56, 56)
COLOR_RED = (0, 0, 255)
COLOR_YELLOW = (0, 255, 255)
# Type alias for image arrays (NDArray or cv2.Mat)
ImageArray = NDArray[np.uint8]
class OpenCVVisualizer:
def __init__(self) -> None:
self.show_raw_window: bool = False
self.show_raw_debug: bool = False
self._windows_created: bool = False
self._raw_window_created: bool = False
def _ensure_windows(self) -> None:
if not self._windows_created:
cv2.namedWindow(MAIN_WINDOW, cv2.WINDOW_NORMAL)
cv2.namedWindow(SEG_WINDOW, cv2.WINDOW_NORMAL)
cv2.namedWindow(WINDOW_SEG_INPUT, cv2.WINDOW_NORMAL)
self._windows_created = True
def _ensure_raw_window(self) -> None:
if not self._raw_window_created:
cv2.namedWindow(RAW_WINDOW, cv2.WINDOW_NORMAL)
self._raw_window_created = True
def _hide_raw_window(self) -> None:
if self._raw_window_created:
cv2.destroyWindow(RAW_WINDOW)
self._raw_window_created = False
def _draw_bbox(
self,
frame: ImageArray,
bbox: BBoxXYXY | None,
) -> None:
"""Draw bounding box on frame if present.
Args:
frame: Input frame (H, W, 3) uint8 - modified in place
bbox: Bounding box in XYXY format as (x1, y1, x2, y2) or None
"""
if bbox is None:
return
x1, y1, x2, y2 = bbox
# Draw rectangle with green color, thickness 2
_ = cv2.rectangle(frame, (x1, y1), (x2, y2), COLOR_GREEN, 2)
def _draw_text_overlay(
self,
frame: ImageArray,
track_id: int,
fps: float,
label: str | None,
confidence: float | None,
) -> None:
"""Draw text overlay with track info, FPS, label, and confidence.
Args:
frame: Input frame (H, W, 3) uint8 - modified in place
track_id: Tracking ID
fps: Current FPS
label: Classification label or None
confidence: Classification confidence or None
"""
# Prepare text lines
lines: list[str] = []
lines.append(f"ID: {track_id}")
lines.append(f"FPS: {fps:.1f}")
if label is not None:
if confidence is not None:
lines.append(f"{label}: {confidence:.2%}")
else:
lines.append(label)
# Draw text with background for readability
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.6
thickness = 1
line_height = 25
margin = 10
for i, text in enumerate(lines):
y_pos = margin + (i + 1) * line_height
# Draw background rectangle
(text_width, text_height), _ = cv2.getTextSize(
text, font, font_scale, thickness
)
_ = cv2.rectangle(
frame,
(margin, y_pos - text_height - 5),
(margin + text_width + 10, y_pos + 5),
COLOR_BLACK,
-1,
)
# Draw text
_ = cv2.putText(
frame,
text,
(margin + 5, y_pos),
font,
font_scale,
COLOR_WHITE,
thickness,
)
def _prepare_main_frame(
self,
frame: ImageArray,
bbox: BBoxXYXY | None,
track_id: int,
fps: float,
label: str | None,
confidence: float | None,
) -> ImageArray:
"""Prepare main display frame with bbox and text overlay.
Args:
frame: Input frame (H, W, C) uint8
bbox: Bounding box in XYXY format (x1, y1, x2, y2) or None
track_id: Tracking ID
fps: Current FPS
label: Classification label or None
confidence: Classification confidence or None
Returns:
Processed frame ready for display
"""
# Ensure BGR format (convert grayscale if needed)
if len(frame.shape) == 2:
display_frame = cast(ImageArray, cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR))
elif frame.shape[2] == 1:
display_frame = cast(ImageArray, cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR))
elif frame.shape[2] == 3:
display_frame = frame.copy()
elif frame.shape[2] == 4:
display_frame = cast(ImageArray, cv2.cvtColor(frame, cv2.COLOR_BGRA2BGR))
else:
display_frame = frame.copy()
# Draw bbox and text (modifies in place)
self._draw_bbox(display_frame, bbox)
self._draw_text_overlay(display_frame, track_id, fps, label, confidence)
return display_frame
def _upscale_silhouette(
self,
silhouette: NDArray[np.float32] | NDArray[np.uint8],
) -> ImageArray:
"""Upscale silhouette to display size.
Args:
silhouette: Input silhouette (64, 44) float32 [0,1] or uint8 [0,255]
Returns:
Upscaled silhouette (256, 176) uint8
"""
# Normalize to uint8 if needed
if silhouette.dtype == np.float32 or silhouette.dtype == np.float64:
sil_u8 = (silhouette * 255).astype(np.uint8)
else:
sil_u8 = silhouette.astype(np.uint8)
# Upscale using nearest neighbor to preserve pixelation
upscaled = cast(
ImageArray,
cv2.resize(
sil_u8,
(DISPLAY_WIDTH, DISPLAY_HEIGHT),
interpolation=cv2.INTER_NEAREST,
),
)
return upscaled
def _normalize_mask_for_display(self, mask: NDArray[np.generic]) -> ImageArray:
mask_array = np.asarray(mask)
if mask_array.dtype == np.bool_:
bool_scaled = np.where(mask_array, np.uint8(255), np.uint8(0)).astype(
np.uint8
)
return cast(ImageArray, bool_scaled)
if mask_array.dtype == np.uint8:
mask_array = cast(ImageArray, mask_array)
max_u8 = int(np.max(mask_array)) if mask_array.size > 0 else 0
if max_u8 <= 1:
scaled_u8 = np.where(mask_array > 0, np.uint8(255), np.uint8(0)).astype(
np.uint8
)
return cast(ImageArray, scaled_u8)
return cast(ImageArray, mask_array)
if np.issubdtype(mask_array.dtype, np.integer):
max_int = float(np.max(mask_array)) if mask_array.size > 0 else 0.0
if max_int <= 1.0:
return cast(
ImageArray, (mask_array.astype(np.float32) * 255.0).astype(np.uint8)
)
clipped = np.clip(mask_array, 0, 255).astype(np.uint8)
return cast(ImageArray, clipped)
mask_float = np.asarray(mask_array, dtype=np.float32)
max_val = float(np.max(mask_float)) if mask_float.size > 0 else 0.0
if max_val <= 0.0:
return np.zeros(mask_float.shape, dtype=np.uint8)
normalized = np.clip((mask_float / max_val) * 255.0, 0.0, 255.0).astype(
np.uint8
)
return cast(ImageArray, normalized)
def _draw_raw_stats(self, image: ImageArray, mask_raw: ImageArray | None) -> None:
if mask_raw is None:
return
mask = np.asarray(mask_raw)
if mask.size == 0:
return
stats = [
f"raw: {mask.dtype}",
f"min/max: {float(mask.min()):.3f}/{float(mask.max()):.3f}",
f"nnz: {int(np.count_nonzero(mask))}",
]
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.45
thickness = 1
line_h = 18
x0 = 8
y0 = 20
for i, txt in enumerate(stats):
y = y0 + i * line_h
(tw, th), _ = cv2.getTextSize(txt, font, font_scale, thickness)
_ = cv2.rectangle(
image, (x0 - 4, y - th - 4), (x0 + tw + 4, y + 4), COLOR_BLACK, -1
)
_ = cv2.putText(
image, txt, (x0, y), font, font_scale, COLOR_YELLOW, thickness
)
def _prepare_segmentation_view(
self,
mask_raw: ImageArray | None,
silhouette: NDArray[np.float32] | None,
bbox: BBoxXYXY | None,
) -> ImageArray:
_ = mask_raw
_ = bbox
return self._prepare_normalized_view(silhouette)
def _fit_gray_to_display(
self,
gray: ImageArray,
out_h: int = DISPLAY_HEIGHT,
out_w: int = DISPLAY_WIDTH,
) -> ImageArray:
src_h, src_w = gray.shape[:2]
if src_h <= 0 or src_w <= 0:
return np.zeros((out_h, out_w), dtype=np.uint8)
scale = min(out_w / src_w, out_h / src_h)
new_w = max(1, int(round(src_w * scale)))
new_h = max(1, int(round(src_h * scale)))
resized = cast(
ImageArray,
cv2.resize(gray, (new_w, new_h), interpolation=cv2.INTER_NEAREST),
)
canvas = np.zeros((out_h, out_w), dtype=np.uint8)
x0 = (out_w - new_w) // 2
y0 = (out_h - new_h) // 2
canvas[y0 : y0 + new_h, x0 : x0 + new_w] = resized
return cast(ImageArray, canvas)
def _crop_mask_to_bbox(
self,
mask_gray: ImageArray,
bbox: BBoxXYXY | None,
) -> ImageArray:
if bbox is None:
return mask_gray
h, w = mask_gray.shape[:2]
x1, y1, x2, y2 = bbox
x1c = max(0, min(w, int(x1)))
x2c = max(0, min(w, int(x2)))
y1c = max(0, min(h, int(y1)))
y2c = max(0, min(h, int(y2)))
if x2c <= x1c or y2c <= y1c:
return mask_gray
cropped = mask_gray[y1c:y2c, x1c:x2c]
if cropped.size == 0:
return mask_gray
return cast(ImageArray, cropped)
def _prepare_segmentation_input_view(
self,
silhouettes: NDArray[np.float32] | None,
) -> ImageArray:
if silhouettes is None or silhouettes.size == 0:
placeholder = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH, 3), dtype=np.uint8)
self._draw_mode_indicator(placeholder, "Input Silhouettes (No Data)")
return placeholder
n_frames = int(silhouettes.shape[0])
tiles_per_row = int(np.ceil(np.sqrt(n_frames)))
rows = int(np.ceil(n_frames / tiles_per_row))
tile_h = DISPLAY_HEIGHT
tile_w = DISPLAY_WIDTH
grid = np.zeros((rows * tile_h, tiles_per_row * tile_w), dtype=np.uint8)
for idx in range(n_frames):
sil = silhouettes[idx]
tile = self._upscale_silhouette(sil)
r = idx // tiles_per_row
c = idx % tiles_per_row
y0, y1 = r * tile_h, (r + 1) * tile_h
x0, x1 = c * tile_w, (c + 1) * tile_w
grid[y0:y1, x0:x1] = tile
grid_bgr = cast(ImageArray, cv2.cvtColor(grid, cv2.COLOR_GRAY2BGR))
for idx in range(n_frames):
r = idx // tiles_per_row
c = idx % tiles_per_row
y0 = r * tile_h
x0 = c * tile_w
cv2.putText(
grid_bgr,
str(idx),
(x0 + 8, y0 + 22),
cv2.FONT_HERSHEY_SIMPLEX,
0.6,
(0, 255, 255),
2,
cv2.LINE_AA,
)
return grid_bgr
def _prepare_raw_view(
self,
mask_raw: ImageArray | None,
bbox: BBoxXYXY | None = None,
) -> ImageArray:
"""Prepare raw mask view.
Args:
mask_raw: Raw binary mask or None
Returns:
Displayable image with mode indicator
"""
if mask_raw is None:
# Create placeholder
placeholder = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH, 3), dtype=np.uint8)
self._draw_mode_indicator(placeholder, "Raw Mask (No Data)")
return placeholder
# Ensure single channel
if len(mask_raw.shape) == 3:
mask_gray = cast(ImageArray, cv2.cvtColor(mask_raw, cv2.COLOR_BGR2GRAY))
else:
mask_gray = cast(ImageArray, mask_raw)
mask_gray = self._normalize_mask_for_display(mask_gray)
mask_gray = self._crop_mask_to_bbox(mask_gray, bbox)
debug_pad = RAW_STATS_PAD if self.show_raw_debug else 0
content_h = max(1, DISPLAY_HEIGHT - debug_pad - MODE_LABEL_PAD)
mask_resized = self._fit_gray_to_display(
mask_gray, out_h=content_h, out_w=DISPLAY_WIDTH
)
full_mask = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH), dtype=np.uint8)
full_mask[debug_pad : debug_pad + content_h, :] = mask_resized
# Convert to BGR for display
mask_bgr = cast(ImageArray, cv2.cvtColor(full_mask, cv2.COLOR_GRAY2BGR))
if self.show_raw_debug:
self._draw_raw_stats(mask_bgr, mask_raw)
self._draw_mode_indicator(mask_bgr, "Raw Mask")
return mask_bgr
def _prepare_normalized_view(
self,
silhouette: NDArray[np.float32] | None,
) -> ImageArray:
"""Prepare normalized silhouette view.
Args:
silhouette: Normalized silhouette (64, 44) or None
Returns:
Displayable image with mode indicator
"""
if silhouette is None:
# Create placeholder
placeholder = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH, 3), dtype=np.uint8)
self._draw_mode_indicator(placeholder, "Normalized (No Data)")
return placeholder
# Upscale and convert
upscaled = self._upscale_silhouette(silhouette)
content_h = max(1, DISPLAY_HEIGHT - MODE_LABEL_PAD)
sil_compact = self._fit_gray_to_display(
upscaled, out_h=content_h, out_w=DISPLAY_WIDTH
)
sil_canvas = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH), dtype=np.uint8)
sil_canvas[:content_h, :] = sil_compact
sil_bgr = cast(ImageArray, cv2.cvtColor(sil_canvas, cv2.COLOR_GRAY2BGR))
self._draw_mode_indicator(sil_bgr, "Normalized")
return sil_bgr
def _draw_mode_indicator(self, image: ImageArray, label: str) -> None:
h, w = image.shape[:2]
mode_text = label
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.5
thickness = 1
# Get text size for background
(text_width, text_height), _ = cv2.getTextSize(
mode_text, font, font_scale, thickness
)
x_pos = 14
y_pos = h - 8
y_top = max(0, h - MODE_LABEL_PAD)
_ = cv2.rectangle(
image,
(0, y_top),
(w, h),
COLOR_DARK_GRAY,
-1,
)
_ = cv2.rectangle(
image,
(x_pos - 6, y_pos - text_height - 6),
(x_pos + text_width + 8, y_pos + 6),
COLOR_DARK_GRAY,
-1,
)
# Draw text
_ = cv2.putText(
image,
mode_text,
(x_pos, y_pos),
font,
font_scale,
COLOR_YELLOW,
thickness,
)
def update(
self,
frame: ImageArray,
bbox: BBoxXYXY | None,
bbox_mask: BBoxXYXY | None,
track_id: int,
mask_raw: ImageArray | None,
silhouette: NDArray[np.float32] | None,
segmentation_input: NDArray[np.float32] | None,
label: str | None,
confidence: float | None,
fps: float,
) -> bool:
"""Update visualization with new frame data.
Args:
frame: Input frame (H, W, C) uint8
bbox: Bounding box in XYXY format (x1, y1, x2, y2) or None
track_id: Tracking ID
mask_raw: Raw binary mask (H, W) uint8 or None
silhouette: Normalized silhouette (64, 44) float32 [0,1] or None
label: Classification label or None
confidence: Classification confidence [0,1] or None
fps: Current FPS
Returns:
False if user requested quit (pressed 'q'), True otherwise
"""
self._ensure_windows()
# Prepare and show main window
main_display = self._prepare_main_frame(
frame, bbox, track_id, fps, label, confidence
)
cv2.imshow(MAIN_WINDOW, main_display)
# Prepare and show segmentation window
seg_display = self._prepare_segmentation_view(mask_raw, silhouette, bbox)
cv2.imshow(SEG_WINDOW, seg_display)
if self.show_raw_window:
self._ensure_raw_window()
raw_display = self._prepare_raw_view(mask_raw, bbox_mask)
cv2.imshow(RAW_WINDOW, raw_display)
seg_input_display = self._prepare_segmentation_input_view(segmentation_input)
cv2.imshow(WINDOW_SEG_INPUT, seg_input_display)
# Handle keyboard input
key = cv2.waitKey(1) & 0xFF
if key == ord("q"):
return False
elif key == ord("r"):
self.show_raw_window = not self.show_raw_window
if self.show_raw_window:
self._ensure_raw_window()
logger.debug("Raw mask window enabled")
else:
self._hide_raw_window()
logger.debug("Raw mask window disabled")
elif key == ord("d"):
self.show_raw_debug = not self.show_raw_debug
logger.debug(
"Raw mask debug overlay %s",
"enabled" if self.show_raw_debug else "disabled",
)
return True
def close(self) -> None:
if self._windows_created:
self._hide_raw_window()
cv2.destroyAllWindows()
self._windows_created = False
self._raw_window_created = False