refactor(demo): simplify visualizer wiring and typing

Apply Oracle-guided cleanup to make the demo pipeline contract explicit and remove defensive runtime indirection while preserving existing visualization behavior.
This commit is contained in:
2026-02-27 22:15:30 +08:00
parent e90e53ffaf
commit 433e673807
3 changed files with 54 additions and 45 deletions
+32 -29
View File
@@ -5,7 +5,7 @@ from contextlib import suppress
import logging
from pathlib import Path
import time
from typing import Protocol, cast
from typing import TYPE_CHECKING, Protocol, cast
from beartype import beartype
import click
@@ -22,6 +22,9 @@ from .preprocess import frame_to_person_mask, mask_to_silhouette
from .sconet_demo import ScoNetDemo
from .window import SilhouetteWindow, select_person
if TYPE_CHECKING:
from .visualizer import OpenCVVisualizer
logger = logging.getLogger(__name__)
JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
@@ -78,7 +81,7 @@ class ScoliosisPipeline:
_result_export_path: Path | None
_result_export_format: str
_result_buffer: list[dict[str, object]]
_visualizer: object | None
_visualizer: OpenCVVisualizer | None
def __init__(
self,
@@ -342,31 +345,34 @@ class ScoliosisPipeline:
if self._visualizer is not None and viz_payload is not None:
# Cast viz_payload to dict for type checking
viz_dict = cast(dict[str, object], viz_payload)
mask_raw = viz_dict.get("mask_raw")
bbox = viz_dict.get("bbox")
silhouette = viz_dict.get("silhouette")
mask_raw_obj = viz_dict.get("mask_raw")
bbox_obj = viz_dict.get("bbox")
silhouette_obj = viz_dict.get("silhouette")
track_id_val = viz_dict.get("track_id", 0)
track_id = track_id_val if isinstance(track_id_val, int) else 0
label = viz_dict.get("label")
confidence = viz_dict.get("confidence")
label_obj = viz_dict.get("label")
confidence_obj = viz_dict.get("confidence")
# Cast _visualizer to object with update method
visualizer = cast(object, self._visualizer)
update_fn = getattr(visualizer, "update", None)
if callable(update_fn):
keep_running = update_fn(
frame_u8,
bbox,
track_id,
mask_raw,
silhouette,
label,
confidence,
ema_fps,
)
if not keep_running:
logger.info("Visualization closed by user.")
break
# Cast extracted values to expected types
mask_raw = cast(NDArray[np.uint8] | None, mask_raw_obj)
bbox = cast(tuple[int, int, int, int] | None, bbox_obj)
silhouette = cast(NDArray[np.float32] | None, silhouette_obj)
label = cast(str | None, label_obj)
confidence = cast(float | None, confidence_obj)
keep_running = self._visualizer.update(
frame_u8,
bbox,
track_id,
mask_raw,
silhouette,
label,
confidence,
ema_fps,
)
if not keep_running:
logger.info("Visualization closed by user.")
break
if frame_count % 100 == 0:
elapsed = time.perf_counter() - start_time
@@ -385,11 +391,8 @@ class ScoliosisPipeline:
# Close visualizer if enabled
if self._visualizer is not None:
visualizer = cast(object, self._visualizer)
close_viz = getattr(visualizer, "close", None)
if callable(close_viz):
with suppress(Exception):
_ = close_viz()
with suppress(Exception):
self._visualizer.close()
# Export silhouettes if requested
if self._silhouette_export_path is not None and self._silhouette_buffer: