refactor(demo): simplify visualizer wiring and typing

Apply Oracle-guided cleanup to make the demo pipeline contract explicit and remove defensive runtime indirection while preserving existing visualization behavior.
This commit is contained in:
2026-02-27 22:15:30 +08:00
parent e90e53ffaf
commit 433e673807
3 changed files with 54 additions and 45 deletions
+1 -5
View File
@@ -1,7 +1,6 @@
from __future__ import annotations from __future__ import annotations
import argparse import argparse
import inspect
import logging import logging
import sys import sys
@@ -90,7 +89,6 @@ if __name__ == "__main__":
) )
# Build kwargs based on what ScoliosisPipeline accepts # Build kwargs based on what ScoliosisPipeline accepts
sig = inspect.signature(ScoliosisPipeline.__init__)
pipeline_kwargs = { pipeline_kwargs = {
"source": args.source, "source": args.source,
"checkpoint": args.checkpoint, "checkpoint": args.checkpoint,
@@ -108,10 +106,8 @@ if __name__ == "__main__":
"silhouette_visualize_dir": args.silhouette_visualize_dir, "silhouette_visualize_dir": args.silhouette_visualize_dir,
"result_export_path": args.result_export_path, "result_export_path": args.result_export_path,
"result_export_format": args.result_export_format, "result_export_format": args.result_export_format,
"visualize": args.visualize,
} }
if "visualize" in sig.parameters:
pipeline_kwargs["visualize"] = args.visualize
pipeline = ScoliosisPipeline(**pipeline_kwargs) pipeline = ScoliosisPipeline(**pipeline_kwargs)
raise SystemExit(pipeline.run()) raise SystemExit(pipeline.run())
except ValueError as err: except ValueError as err:
+32 -29
View File
@@ -5,7 +5,7 @@ from contextlib import suppress
import logging import logging
from pathlib import Path from pathlib import Path
import time import time
from typing import Protocol, cast from typing import TYPE_CHECKING, Protocol, cast
from beartype import beartype from beartype import beartype
import click import click
@@ -22,6 +22,9 @@ from .preprocess import frame_to_person_mask, mask_to_silhouette
from .sconet_demo import ScoNetDemo from .sconet_demo import ScoNetDemo
from .window import SilhouetteWindow, select_person from .window import SilhouetteWindow, select_person
if TYPE_CHECKING:
from .visualizer import OpenCVVisualizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]] JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
@@ -78,7 +81,7 @@ class ScoliosisPipeline:
_result_export_path: Path | None _result_export_path: Path | None
_result_export_format: str _result_export_format: str
_result_buffer: list[dict[str, object]] _result_buffer: list[dict[str, object]]
_visualizer: object | None _visualizer: OpenCVVisualizer | None
def __init__( def __init__(
self, self,
@@ -342,31 +345,34 @@ class ScoliosisPipeline:
if self._visualizer is not None and viz_payload is not None: if self._visualizer is not None and viz_payload is not None:
# Cast viz_payload to dict for type checking # Cast viz_payload to dict for type checking
viz_dict = cast(dict[str, object], viz_payload) viz_dict = cast(dict[str, object], viz_payload)
mask_raw = viz_dict.get("mask_raw") mask_raw_obj = viz_dict.get("mask_raw")
bbox = viz_dict.get("bbox") bbox_obj = viz_dict.get("bbox")
silhouette = viz_dict.get("silhouette") silhouette_obj = viz_dict.get("silhouette")
track_id_val = viz_dict.get("track_id", 0) track_id_val = viz_dict.get("track_id", 0)
track_id = track_id_val if isinstance(track_id_val, int) else 0 track_id = track_id_val if isinstance(track_id_val, int) else 0
label = viz_dict.get("label") label_obj = viz_dict.get("label")
confidence = viz_dict.get("confidence") confidence_obj = viz_dict.get("confidence")
# Cast _visualizer to object with update method # Cast extracted values to expected types
visualizer = cast(object, self._visualizer) mask_raw = cast(NDArray[np.uint8] | None, mask_raw_obj)
update_fn = getattr(visualizer, "update", None) bbox = cast(tuple[int, int, int, int] | None, bbox_obj)
if callable(update_fn): silhouette = cast(NDArray[np.float32] | None, silhouette_obj)
keep_running = update_fn( label = cast(str | None, label_obj)
frame_u8, confidence = cast(float | None, confidence_obj)
bbox,
track_id, keep_running = self._visualizer.update(
mask_raw, frame_u8,
silhouette, bbox,
label, track_id,
confidence, mask_raw,
ema_fps, silhouette,
) label,
if not keep_running: confidence,
logger.info("Visualization closed by user.") ema_fps,
break )
if not keep_running:
logger.info("Visualization closed by user.")
break
if frame_count % 100 == 0: if frame_count % 100 == 0:
elapsed = time.perf_counter() - start_time elapsed = time.perf_counter() - start_time
@@ -385,11 +391,8 @@ class ScoliosisPipeline:
# Close visualizer if enabled # Close visualizer if enabled
if self._visualizer is not None: if self._visualizer is not None:
visualizer = cast(object, self._visualizer) with suppress(Exception):
close_viz = getattr(visualizer, "close", None) self._visualizer.close()
if callable(close_viz):
with suppress(Exception):
_ = close_viz()
# Export silhouettes if requested # Export silhouettes if requested
if self._silhouette_export_path is not None and self._silhouette_buffer: if self._silhouette_export_path is not None and self._silhouette_buffer:
+21 -11
View File
@@ -315,20 +315,30 @@ class OpenCVVisualizer:
Returns: Returns:
Displayable side-by-side image Displayable side-by-side image
""" """
# Prepare individual views # Prepare individual views without mode indicators (will be drawn on combined)
raw_view = self._prepare_raw_view(mask_raw) # Raw view preparation (without indicator)
norm_view = self._prepare_normalized_view(silhouette) if mask_raw is None:
raw_gray = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH), dtype=np.uint8)
# Convert to grayscale for side-by-side composition
if len(raw_view.shape) == 3:
raw_gray = cast(ImageArray, cv2.cvtColor(raw_view, cv2.COLOR_BGR2GRAY))
else: else:
raw_gray = raw_view if len(mask_raw.shape) == 3:
mask_gray = cast(ImageArray, cv2.cvtColor(mask_raw, cv2.COLOR_BGR2GRAY))
else:
mask_gray = mask_raw
raw_gray = cast(
ImageArray,
cv2.resize(
mask_gray,
(DISPLAY_WIDTH, DISPLAY_HEIGHT),
interpolation=cv2.INTER_NEAREST,
),
)
if len(norm_view.shape) == 3: # Normalized view preparation (without indicator)
norm_gray = cast(ImageArray, cv2.cvtColor(norm_view, cv2.COLOR_BGR2GRAY)) if silhouette is None:
norm_gray = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH), dtype=np.uint8)
else: else:
norm_gray = norm_view upscaled = self._upscale_silhouette(silhouette)
norm_gray = upscaled
# Stack horizontally # Stack horizontally
combined = np.hstack([raw_gray, norm_gray]) combined = np.hstack([raw_gray, norm_gray])