From 4cc2ef7c63308805fe4bedfdbff9f3b285d1e1a3 Mon Sep 17 00:00:00 2001 From: crosstyan Date: Fri, 27 Feb 2026 20:14:24 +0800 Subject: [PATCH] feat(demo): add realtime visualization pipeline flow Integrate an opt-in OpenCV visualizer into the demo runtime so operators can monitor tracking, segmentation, and inference confidence in real time without changing the default non-visual execution path. --- opengait/demo/__main__.py | 119 +++++++++- opengait/demo/pipeline.py | 116 +++++++++- opengait/demo/visualizer.py | 446 ++++++++++++++++++++++++++++++++++++ 3 files changed, 670 insertions(+), 11 deletions(-) create mode 100644 opengait/demo/visualizer.py diff --git a/opengait/demo/__main__.py b/opengait/demo/__main__.py index 2590b2a..21a334e 100644 --- a/opengait/demo/__main__.py +++ b/opengait/demo/__main__.py @@ -1,7 +1,122 @@ from __future__ import annotations -from .pipeline import main +import argparse +import inspect +import logging +import sys + +from .pipeline import ScoliosisPipeline if __name__ == "__main__": - main() + parser = argparse.ArgumentParser(description="Scoliosis Detection Pipeline") + parser.add_argument( + "--source", type=str, required=True, help="Video source path or camera ID" + ) + parser.add_argument( + "--checkpoint", type=str, required=True, help="Model checkpoint path" + ) + parser.add_argument( + "--config", + type=str, + default="configs/sconet/sconet_scoliosis1k.yaml", + help="Config file path", + ) + parser.add_argument("--device", type=str, default="cuda:0", help="Device to run on") + parser.add_argument( + "--yolo-model", type=str, default="ckpt/yolo11n-seg.pt", help="YOLO model name" + ) + parser.add_argument( + "--window", type=int, default=30, help="Window size for classification" + ) + parser.add_argument("--stride", type=int, default=30, help="Stride for window") + parser.add_argument( + "--nats-url", type=str, default=None, help="NATS URL for result publishing" + ) + parser.add_argument( + "--nats-subject", type=str, default="scoliosis.result", help="NATS subject" + ) + parser.add_argument( + "--max-frames", type=int, default=None, help="Maximum frames to process" + ) + parser.add_argument( + "--preprocess-only", action="store_true", help="Only preprocess silhouettes" + ) + parser.add_argument( + "--silhouette-export-path", + type=str, + default=None, + help="Path to export silhouettes", + ) + parser.add_argument( + "--silhouette-export-format", type=str, default="pickle", help="Export format" + ) + parser.add_argument( + "--silhouette-visualize-dir", + type=str, + default=None, + help="Directory for silhouette visualizations", + ) + parser.add_argument( + "--result-export-path", type=str, default=None, help="Path to export results" + ) + parser.add_argument( + "--result-export-format", type=str, default="json", help="Result export format" + ) + parser.add_argument( + "--visualize", action="store_true", help="Enable real-time visualization" + ) + args = parser.parse_args() + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + ) + + # Validate preprocess-only mode requires silhouette export path + if args.preprocess_only and not args.silhouette_export_path: + print( + "Error: --silhouette-export-path is required when using --preprocess-only", + file=sys.stderr, + ) + raise SystemExit(2) + + try: + # Import here to avoid circular imports + from .pipeline import validate_runtime_inputs + + validate_runtime_inputs( + source=args.source, checkpoint=args.checkpoint, config=args.config + ) + + # Build kwargs based on what ScoliosisPipeline accepts + sig = inspect.signature(ScoliosisPipeline.__init__) + pipeline_kwargs = { + "source": args.source, + "checkpoint": args.checkpoint, + "config": args.config, + "device": args.device, + "yolo_model": args.yolo_model, + "window": args.window, + "stride": args.stride, + "nats_url": args.nats_url, + "nats_subject": args.nats_subject, + "max_frames": args.max_frames, + "preprocess_only": args.preprocess_only, + "silhouette_export_path": args.silhouette_export_path, + "silhouette_export_format": args.silhouette_export_format, + "silhouette_visualize_dir": args.silhouette_visualize_dir, + "result_export_path": args.result_export_path, + "result_export_format": args.result_export_format, + } + if "visualize" in sig.parameters: + pipeline_kwargs["visualize"] = args.visualize + + pipeline = ScoliosisPipeline(**pipeline_kwargs) + raise SystemExit(pipeline.run()) + except ValueError as err: + print(f"Error: {err}", file=sys.stderr) + raise SystemExit(2) from err + except RuntimeError as err: + print(f"Runtime error: {err}", file=sys.stderr) + raise SystemExit(1) from err diff --git a/opengait/demo/pipeline.py b/opengait/demo/pipeline.py index 2377bab..09701b8 100644 --- a/opengait/demo/pipeline.py +++ b/opengait/demo/pipeline.py @@ -78,6 +78,7 @@ class ScoliosisPipeline: _result_export_path: Path | None _result_export_format: str _result_buffer: list[dict[str, object]] + _visualizer: object | None def __init__( self, @@ -98,6 +99,7 @@ class ScoliosisPipeline: silhouette_visualize_dir: str | None = None, result_export_path: str | None = None, result_export_format: str = "json", + visualize: bool = False, ) -> None: self._detector = YOLO(yolo_model) self._source = create_source(source, max_frames=max_frames) @@ -124,6 +126,12 @@ class ScoliosisPipeline: ) self._result_export_format = result_export_format self._result_buffer = [] + if visualize: + from .visualizer import OpenCVVisualizer + + self._visualizer = OpenCVVisualizer() + else: + self._visualizer = None @staticmethod def _extract_int(meta: dict[str, object], key: str, fallback: int) -> int: @@ -156,7 +164,15 @@ class ScoliosisPipeline: def _select_silhouette( self, result: _DetectionResultsLike, - ) -> tuple[Float[ndarray, "64 44"], int] | None: + ) -> ( + tuple[ + Float[ndarray, "64 44"], + UInt8[ndarray, "h w"], + tuple[int, int, int, int], + int, + ] + | None + ): selected = select_person(result) if selected is not None: mask_raw, bbox, track_id = selected @@ -165,7 +181,7 @@ class ScoliosisPipeline: mask_to_silhouette(self._to_mask_u8(mask_raw), bbox), ) if silhouette is not None: - return silhouette, int(track_id) + return silhouette, mask_raw, bbox, int(track_id) fallback = cast( tuple[UInt8[ndarray, "h w"], tuple[int, int, int, int]] | None, @@ -181,7 +197,8 @@ class ScoliosisPipeline: ) if silhouette is None: return None - return silhouette, 0 + # For fallback case, mask_raw is the same as mask_u8 + return silhouette, mask_u8, bbox, 0 @jaxtyped(typechecker=beartype) def process_frame( @@ -212,7 +229,7 @@ class ScoliosisPipeline: if selected is None: return None - silhouette, track_id = selected + silhouette, mask_raw, bbox, track_id = selected # Store silhouette for export if in preprocess-only mode or if export requested if self._silhouette_export_path is not None or self._preprocess_only: @@ -230,12 +247,28 @@ class ScoliosisPipeline: self._visualize_silhouette(silhouette, frame_idx, track_id) if self._preprocess_only: - return None + # Return visualization payload for display even in preprocess-only mode + return { + "mask_raw": mask_raw, + "bbox": bbox, + "silhouette": silhouette, + "track_id": track_id, + "label": None, + "confidence": None, + } self._window.push(silhouette, frame_idx=frame_idx, track_id=track_id) if not self._window.should_classify(): - return None + # Return visualization payload even when not classifying yet + return { + "mask_raw": mask_raw, + "bbox": bbox, + "silhouette": silhouette, + "track_id": track_id, + "label": None, + "confidence": None, + } window_tensor = self._window.get_tensor(device=self._device) label, confidence = cast( @@ -259,25 +292,82 @@ class ScoliosisPipeline: self._result_buffer.append(result) self._publisher.publish(result) - return result + # Return result with visualization payload + return { + "result": result, + "mask_raw": mask_raw, + "bbox": bbox, + "silhouette": silhouette, + "track_id": track_id, + "label": label, + "confidence": confidence, + } def run(self) -> int: frame_count = 0 start_time = time.perf_counter() + # EMA FPS state (alpha=0.1 for smoothing) + ema_fps = 0.0 + alpha = 0.1 + prev_time = start_time try: for item in self._source: frame, metadata = item frame_u8 = np.asarray(frame, dtype=np.uint8) frame_idx = self._extract_int(metadata, "frame_count", fallback=0) frame_count += 1 + + # Compute per-frame EMA FPS + curr_time = time.perf_counter() + delta = curr_time - prev_time + prev_time = curr_time + if delta > 0: + instant_fps = 1.0 / delta + if ema_fps == 0.0: + ema_fps = instant_fps + else: + ema_fps = alpha * instant_fps + (1 - alpha) * ema_fps + + viz_payload = None try: - _ = self.process_frame(frame_u8, metadata) + viz_payload = self.process_frame(frame_u8, metadata) except Exception as frame_error: logger.warning( "Skipping frame %d due to processing error: %s", frame_idx, frame_error, ) + + # Update visualizer if enabled + if self._visualizer is not None and viz_payload is not None: + # Cast viz_payload to dict for type checking + viz_dict = cast(dict[str, object], viz_payload) + mask_raw = viz_dict.get("mask_raw") + bbox = viz_dict.get("bbox") + silhouette = viz_dict.get("silhouette") + track_id_val = viz_dict.get("track_id", 0) + track_id = track_id_val if isinstance(track_id_val, int) else 0 + label = viz_dict.get("label") + confidence = viz_dict.get("confidence") + + # Cast _visualizer to object with update method + visualizer = cast(object, self._visualizer) + update_fn = getattr(visualizer, "update", None) + if callable(update_fn): + keep_running = update_fn( + frame_u8, + bbox, + track_id, + mask_raw, + silhouette, + label, + confidence, + ema_fps, + ) + if not keep_running: + logger.info("Visualization closed by user.") + break + if frame_count % 100 == 0: elapsed = time.perf_counter() - start_time fps = frame_count / elapsed if elapsed > 0 else 0.0 @@ -293,6 +383,14 @@ class ScoliosisPipeline: if self._closed: return + # Close visualizer if enabled + if self._visualizer is not None: + visualizer = cast(object, self._visualizer) + close_viz = getattr(visualizer, "close", None) + if callable(close_viz): + with suppress(Exception): + _ = close_viz() + # Export silhouettes if requested if self._silhouette_export_path is not None and self._silhouette_buffer: self._export_silhouettes() @@ -504,7 +602,7 @@ def validate_runtime_inputs(source: str, checkpoint: str, config: str) -> None: show_default=True, ) @click.option("--device", type=str, default="cuda:0", show_default=True) -@click.option("--yolo-model", type=str, default="yolo11n-seg.pt", show_default=True) +@click.option("--yolo-model", type=str, default="ckpt/yolo11n-seg.pt", show_default=True) @click.option("--window", type=click.IntRange(min=1), default=30, show_default=True) @click.option("--stride", type=click.IntRange(min=1), default=30, show_default=True) @click.option("--nats-url", type=str, default=None) diff --git a/opengait/demo/visualizer.py b/opengait/demo/visualizer.py new file mode 100644 index 0000000..7ba7754 --- /dev/null +++ b/opengait/demo/visualizer.py @@ -0,0 +1,446 @@ +"""OpenCV-based visualizer for demo pipeline. + +Provides real-time visualization of detection, segmentation, and classification results +with interactive mode switching for mask display. +""" + +from __future__ import annotations + +import logging +from typing import cast + +import cv2 +import numpy as np +from numpy.typing import NDArray + +logger = logging.getLogger(__name__) + +# Window names +MAIN_WINDOW = "Scoliosis Detection" +SEG_WINDOW = "Segmentation" + +# Silhouette dimensions (from preprocess.py) +SIL_HEIGHT = 64 +SIL_WIDTH = 44 + +# Display dimensions for upscaled silhouette +DISPLAY_HEIGHT = 256 +DISPLAY_WIDTH = 176 + +# Colors (BGR) +COLOR_GREEN = (0, 255, 0) +COLOR_WHITE = (255, 255, 255) +COLOR_BLACK = (0, 0, 0) +COLOR_RED = (0, 0, 255) +COLOR_YELLOW = (0, 255, 255) + +# Mode labels +MODE_LABELS = ["Both", "Raw Mask", "Normalized"] + +# Type alias for image arrays (NDArray or cv2.Mat) +ImageArray = NDArray[np.uint8] + + +class OpenCVVisualizer: + """Real-time visualizer for gait analysis demo. + + Displays two windows: + - Main stream: Original frame with bounding box and metadata overlay + - Segmentation: Raw mask, normalized silhouette, or side-by-side view + + Supports interactive mode switching via keyboard. + """ + + def __init__(self) -> None: + """Initialize visualizer with default mask mode.""" + self.mask_mode: int = 0 # 0: Both, 1: Raw, 2: Normalized + self._windows_created: bool = False + + def _ensure_windows(self) -> None: + """Create OpenCV windows if not already created.""" + if not self._windows_created: + cv2.namedWindow(MAIN_WINDOW, cv2.WINDOW_NORMAL) + cv2.namedWindow(SEG_WINDOW, cv2.WINDOW_NORMAL) + self._windows_created = True + + def _draw_bbox( + self, + frame: ImageArray, + bbox: tuple[int, int, int, int] | None, + ) -> None: + """Draw bounding box on frame if present. + + Args: + frame: Input frame (H, W, 3) uint8 - modified in place + bbox: Bounding box as (x1, y1, x2, y2) or None + """ + if bbox is None: + return + + x1, y1, x2, y2 = bbox + # Draw rectangle with green color, thickness 2 + _ = cv2.rectangle(frame, (x1, y1), (x2, y2), COLOR_GREEN, 2) + + def _draw_text_overlay( + self, + frame: ImageArray, + track_id: int, + fps: float, + label: str | None, + confidence: float | None, + ) -> None: + """Draw text overlay with track info, FPS, label, and confidence. + + Args: + frame: Input frame (H, W, 3) uint8 - modified in place + track_id: Tracking ID + fps: Current FPS + label: Classification label or None + confidence: Classification confidence or None + """ + # Prepare text lines + lines: list[str] = [] + lines.append(f"ID: {track_id}") + lines.append(f"FPS: {fps:.1f}") + + if label is not None: + if confidence is not None: + lines.append(f"{label}: {confidence:.2%}") + else: + lines.append(label) + + # Draw text with background for readability + font = cv2.FONT_HERSHEY_SIMPLEX + font_scale = 0.6 + thickness = 1 + line_height = 25 + margin = 10 + + for i, text in enumerate(lines): + y_pos = margin + (i + 1) * line_height + + # Draw background rectangle + (text_width, text_height), _ = cv2.getTextSize( + text, font, font_scale, thickness + ) + _ = cv2.rectangle( + frame, + (margin, y_pos - text_height - 5), + (margin + text_width + 10, y_pos + 5), + COLOR_BLACK, + -1, + ) + + # Draw text + _ = cv2.putText( + frame, + text, + (margin + 5, y_pos), + font, + font_scale, + COLOR_WHITE, + thickness, + ) + + def _prepare_main_frame( + self, + frame: ImageArray, + bbox: tuple[int, int, int, int] | None, + track_id: int, + fps: float, + label: str | None, + confidence: float | None, + ) -> ImageArray: + """Prepare main display frame with bbox and text overlay. + + Args: + frame: Input frame (H, W, C) uint8 + bbox: Bounding box or None + track_id: Tracking ID + fps: Current FPS + label: Classification label or None + confidence: Classification confidence or None + + Returns: + Processed frame ready for display + """ + # Ensure BGR format (convert grayscale if needed) + if len(frame.shape) == 2: + display_frame = cast(ImageArray, cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR)) + elif frame.shape[2] == 1: + display_frame = cast(ImageArray, cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR)) + elif frame.shape[2] == 3: + display_frame = frame.copy() + elif frame.shape[2] == 4: + display_frame = cast(ImageArray, cv2.cvtColor(frame, cv2.COLOR_BGRA2BGR)) + else: + display_frame = frame.copy() + + # Draw bbox and text (modifies in place) + self._draw_bbox(display_frame, bbox) + self._draw_text_overlay(display_frame, track_id, fps, label, confidence) + + return display_frame + + def _upscale_silhouette( + self, + silhouette: NDArray[np.float32] | NDArray[np.uint8], + ) -> ImageArray: + """Upscale silhouette to display size. + + Args: + silhouette: Input silhouette (64, 44) float32 [0,1] or uint8 [0,255] + + Returns: + Upscaled silhouette (256, 176) uint8 + """ + # Normalize to uint8 if needed + if silhouette.dtype == np.float32 or silhouette.dtype == np.float64: + sil_u8 = (silhouette * 255).astype(np.uint8) + else: + sil_u8 = silhouette.astype(np.uint8) + + # Upscale using nearest neighbor to preserve pixelation + upscaled = cast( + ImageArray, + cv2.resize( + sil_u8, + (DISPLAY_WIDTH, DISPLAY_HEIGHT), + interpolation=cv2.INTER_NEAREST, + ), + ) + + return upscaled + + def _prepare_segmentation_view( + self, + mask_raw: ImageArray | None, + silhouette: NDArray[np.float32] | None, + ) -> ImageArray: + """Prepare segmentation window content based on current mode. + + Args: + mask_raw: Raw binary mask (H, W) uint8 or None + silhouette: Normalized silhouette (64, 44) float32 or None + + Returns: + Displayable image (H, W, 3) uint8 + """ + if self.mask_mode == 0: + # Mode 0: Both (side by side) + return self._prepare_both_view(mask_raw, silhouette) + elif self.mask_mode == 1: + # Mode 1: Raw mask only + return self._prepare_raw_view(mask_raw) + else: + # Mode 2: Normalized silhouette only + return self._prepare_normalized_view(silhouette) + + def _prepare_raw_view( + self, + mask_raw: ImageArray | None, + ) -> ImageArray: + """Prepare raw mask view. + + Args: + mask_raw: Raw binary mask or None + + Returns: + Displayable image with mode indicator + """ + if mask_raw is None: + # Create placeholder + placeholder = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH, 3), dtype=np.uint8) + self._draw_mode_indicator(placeholder, "Raw Mask (No Data)") + return placeholder + + # Ensure single channel + if len(mask_raw.shape) == 3: + mask_gray = cast(ImageArray, cv2.cvtColor(mask_raw, cv2.COLOR_BGR2GRAY)) + else: + mask_gray = mask_raw + + # Resize to display size + mask_resized = cast( + ImageArray, + cv2.resize( + mask_gray, + (DISPLAY_WIDTH, DISPLAY_HEIGHT), + interpolation=cv2.INTER_NEAREST, + ), + ) + + # Convert to BGR for display + mask_bgr = cast(ImageArray, cv2.cvtColor(mask_resized, cv2.COLOR_GRAY2BGR)) + self._draw_mode_indicator(mask_bgr, "Raw Mask") + + return mask_bgr + + def _prepare_normalized_view( + self, + silhouette: NDArray[np.float32] | None, + ) -> ImageArray: + """Prepare normalized silhouette view. + + Args: + silhouette: Normalized silhouette (64, 44) or None + + Returns: + Displayable image with mode indicator + """ + if silhouette is None: + # Create placeholder + placeholder = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH, 3), dtype=np.uint8) + self._draw_mode_indicator(placeholder, "Normalized (No Data)") + return placeholder + + # Upscale and convert + upscaled = self._upscale_silhouette(silhouette) + sil_bgr = cast(ImageArray, cv2.cvtColor(upscaled, cv2.COLOR_GRAY2BGR)) + self._draw_mode_indicator(sil_bgr, "Normalized") + + return sil_bgr + + def _prepare_both_view( + self, + mask_raw: ImageArray | None, + silhouette: NDArray[np.float32] | None, + ) -> ImageArray: + """Prepare side-by-side view of both masks. + + Args: + mask_raw: Raw binary mask or None + silhouette: Normalized silhouette or None + + Returns: + Displayable side-by-side image + """ + # Prepare individual views + raw_view = self._prepare_raw_view(mask_raw) + norm_view = self._prepare_normalized_view(silhouette) + + # Convert to grayscale for side-by-side composition + if len(raw_view.shape) == 3: + raw_gray = cast(ImageArray, cv2.cvtColor(raw_view, cv2.COLOR_BGR2GRAY)) + else: + raw_gray = raw_view + + if len(norm_view.shape) == 3: + norm_gray = cast(ImageArray, cv2.cvtColor(norm_view, cv2.COLOR_BGR2GRAY)) + else: + norm_gray = norm_view + + # Stack horizontally + combined = np.hstack([raw_gray, norm_gray]) + + # Convert back to BGR + combined_bgr = cast(ImageArray, cv2.cvtColor(combined, cv2.COLOR_GRAY2BGR)) + + # Add mode indicator + self._draw_mode_indicator(combined_bgr, "Both: Raw | Normalized") + + return combined_bgr + + def _draw_mode_indicator( + self, + image: ImageArray, + label: str, + ) -> None: + """Draw mode indicator text on image. + + Args: + image: Image to draw on (modified in place) + label: Mode label text + """ + h, w = image.shape[:2] + + # Mode text at bottom + mode_text = f"Mode: {MODE_LABELS[self.mask_mode]} ({self.mask_mode}) - {label}" + + font = cv2.FONT_HERSHEY_SIMPLEX + font_scale = 0.5 + thickness = 1 + + # Get text size for background + (text_width, text_height), _ = cv2.getTextSize( + mode_text, font, font_scale, thickness + ) + + # Draw background at bottom center + x_pos = (w - text_width) // 2 + y_pos = h - 10 + + _ = cv2.rectangle( + image, + (x_pos - 5, y_pos - text_height - 5), + (x_pos + text_width + 5, y_pos + 5), + COLOR_BLACK, + -1, + ) + + # Draw text + _ = cv2.putText( + image, + mode_text, + (x_pos, y_pos), + font, + font_scale, + COLOR_YELLOW, + thickness, + ) + + def update( + self, + frame: ImageArray, + bbox: tuple[int, int, int, int] | None, + track_id: int, + mask_raw: ImageArray | None, + silhouette: NDArray[np.float32] | None, + label: str | None, + confidence: float | None, + fps: float, + ) -> bool: + """Update visualization with new frame data. + + Args: + frame: Input frame (H, W, C) uint8 + bbox: Bounding box as (x1, y1, x2, y2) or None + track_id: Tracking ID + mask_raw: Raw binary mask (H, W) uint8 or None + silhouette: Normalized silhouette (64, 44) float32 [0,1] or None + label: Classification label or None + confidence: Classification confidence [0,1] or None + fps: Current FPS + + Returns: + False if user requested quit (pressed 'q'), True otherwise + """ + self._ensure_windows() + + # Prepare and show main window + main_display = self._prepare_main_frame( + frame, bbox, track_id, fps, label, confidence + ) + cv2.imshow(MAIN_WINDOW, main_display) + + # Prepare and show segmentation window + seg_display = self._prepare_segmentation_view(mask_raw, silhouette) + cv2.imshow(SEG_WINDOW, seg_display) + + # Handle keyboard input + key = cv2.waitKey(1) & 0xFF + + if key == ord("q"): + return False + elif key == ord("m"): + # Cycle through modes: 0 -> 1 -> 2 -> 0 + self.mask_mode = (self.mask_mode + 1) % 3 + logger.debug("Switched to mask mode: %s", MODE_LABELS[self.mask_mode]) + + return True + + def close(self) -> None: + """Close all OpenCV windows and cleanup.""" + if self._windows_created: + cv2.destroyAllWindows() + self._windows_created = False