refactor(demo): simplify visualizer wiring and typing

Apply Oracle-guided cleanup to make the demo pipeline contract explicit and remove defensive runtime indirection while preserving existing visualization behavior.
This commit is contained in:
2026-02-27 22:15:30 +08:00
parent e90e53ffaf
commit 433e673807
3 changed files with 54 additions and 45 deletions
+1 -5
View File
@@ -1,7 +1,6 @@
from __future__ import annotations
import argparse
import inspect
import logging
import sys
@@ -90,7 +89,6 @@ if __name__ == "__main__":
)
# Build kwargs based on what ScoliosisPipeline accepts
sig = inspect.signature(ScoliosisPipeline.__init__)
pipeline_kwargs = {
"source": args.source,
"checkpoint": args.checkpoint,
@@ -108,10 +106,8 @@ if __name__ == "__main__":
"silhouette_visualize_dir": args.silhouette_visualize_dir,
"result_export_path": args.result_export_path,
"result_export_format": args.result_export_format,
"visualize": args.visualize,
}
if "visualize" in sig.parameters:
pipeline_kwargs["visualize"] = args.visualize
pipeline = ScoliosisPipeline(**pipeline_kwargs)
raise SystemExit(pipeline.run())
except ValueError as err:
+32 -29
View File
@@ -5,7 +5,7 @@ from contextlib import suppress
import logging
from pathlib import Path
import time
from typing import Protocol, cast
from typing import TYPE_CHECKING, Protocol, cast
from beartype import beartype
import click
@@ -22,6 +22,9 @@ from .preprocess import frame_to_person_mask, mask_to_silhouette
from .sconet_demo import ScoNetDemo
from .window import SilhouetteWindow, select_person
if TYPE_CHECKING:
from .visualizer import OpenCVVisualizer
logger = logging.getLogger(__name__)
JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
@@ -78,7 +81,7 @@ class ScoliosisPipeline:
_result_export_path: Path | None
_result_export_format: str
_result_buffer: list[dict[str, object]]
_visualizer: object | None
_visualizer: OpenCVVisualizer | None
def __init__(
self,
@@ -342,31 +345,34 @@ class ScoliosisPipeline:
if self._visualizer is not None and viz_payload is not None:
# Cast viz_payload to dict for type checking
viz_dict = cast(dict[str, object], viz_payload)
mask_raw = viz_dict.get("mask_raw")
bbox = viz_dict.get("bbox")
silhouette = viz_dict.get("silhouette")
mask_raw_obj = viz_dict.get("mask_raw")
bbox_obj = viz_dict.get("bbox")
silhouette_obj = viz_dict.get("silhouette")
track_id_val = viz_dict.get("track_id", 0)
track_id = track_id_val if isinstance(track_id_val, int) else 0
label = viz_dict.get("label")
confidence = viz_dict.get("confidence")
label_obj = viz_dict.get("label")
confidence_obj = viz_dict.get("confidence")
# Cast _visualizer to object with update method
visualizer = cast(object, self._visualizer)
update_fn = getattr(visualizer, "update", None)
if callable(update_fn):
keep_running = update_fn(
frame_u8,
bbox,
track_id,
mask_raw,
silhouette,
label,
confidence,
ema_fps,
)
if not keep_running:
logger.info("Visualization closed by user.")
break
# Cast extracted values to expected types
mask_raw = cast(NDArray[np.uint8] | None, mask_raw_obj)
bbox = cast(tuple[int, int, int, int] | None, bbox_obj)
silhouette = cast(NDArray[np.float32] | None, silhouette_obj)
label = cast(str | None, label_obj)
confidence = cast(float | None, confidence_obj)
keep_running = self._visualizer.update(
frame_u8,
bbox,
track_id,
mask_raw,
silhouette,
label,
confidence,
ema_fps,
)
if not keep_running:
logger.info("Visualization closed by user.")
break
if frame_count % 100 == 0:
elapsed = time.perf_counter() - start_time
@@ -385,11 +391,8 @@ class ScoliosisPipeline:
# Close visualizer if enabled
if self._visualizer is not None:
visualizer = cast(object, self._visualizer)
close_viz = getattr(visualizer, "close", None)
if callable(close_viz):
with suppress(Exception):
_ = close_viz()
with suppress(Exception):
self._visualizer.close()
# Export silhouettes if requested
if self._silhouette_export_path is not None and self._silhouette_buffer:
+21 -11
View File
@@ -315,20 +315,30 @@ class OpenCVVisualizer:
Returns:
Displayable side-by-side image
"""
# Prepare individual views
raw_view = self._prepare_raw_view(mask_raw)
norm_view = self._prepare_normalized_view(silhouette)
# Convert to grayscale for side-by-side composition
if len(raw_view.shape) == 3:
raw_gray = cast(ImageArray, cv2.cvtColor(raw_view, cv2.COLOR_BGR2GRAY))
# Prepare individual views without mode indicators (will be drawn on combined)
# Raw view preparation (without indicator)
if mask_raw is None:
raw_gray = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH), dtype=np.uint8)
else:
raw_gray = raw_view
if len(mask_raw.shape) == 3:
mask_gray = cast(ImageArray, cv2.cvtColor(mask_raw, cv2.COLOR_BGR2GRAY))
else:
mask_gray = mask_raw
raw_gray = cast(
ImageArray,
cv2.resize(
mask_gray,
(DISPLAY_WIDTH, DISPLAY_HEIGHT),
interpolation=cv2.INTER_NEAREST,
),
)
if len(norm_view.shape) == 3:
norm_gray = cast(ImageArray, cv2.cvtColor(norm_view, cv2.COLOR_BGR2GRAY))
# Normalized view preparation (without indicator)
if silhouette is None:
norm_gray = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH), dtype=np.uint8)
else:
norm_gray = norm_view
upscaled = self._upscale_silhouette(silhouette)
norm_gray = upscaled
# Stack horizontally
combined = np.hstack([raw_gray, norm_gray])