refactor(demo): simplify visualizer wiring and typing
Apply Oracle-guided cleanup to make the demo pipeline contract explicit and remove defensive runtime indirection while preserving existing visualization behavior.
This commit is contained in:
+32
-29
@@ -5,7 +5,7 @@ from contextlib import suppress
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import time
|
||||
from typing import Protocol, cast
|
||||
from typing import TYPE_CHECKING, Protocol, cast
|
||||
|
||||
from beartype import beartype
|
||||
import click
|
||||
@@ -22,6 +22,9 @@ from .preprocess import frame_to_person_mask, mask_to_silhouette
|
||||
from .sconet_demo import ScoNetDemo
|
||||
from .window import SilhouetteWindow, select_person
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .visualizer import OpenCVVisualizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
|
||||
@@ -78,7 +81,7 @@ class ScoliosisPipeline:
|
||||
_result_export_path: Path | None
|
||||
_result_export_format: str
|
||||
_result_buffer: list[dict[str, object]]
|
||||
_visualizer: object | None
|
||||
_visualizer: OpenCVVisualizer | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -342,31 +345,34 @@ class ScoliosisPipeline:
|
||||
if self._visualizer is not None and viz_payload is not None:
|
||||
# Cast viz_payload to dict for type checking
|
||||
viz_dict = cast(dict[str, object], viz_payload)
|
||||
mask_raw = viz_dict.get("mask_raw")
|
||||
bbox = viz_dict.get("bbox")
|
||||
silhouette = viz_dict.get("silhouette")
|
||||
mask_raw_obj = viz_dict.get("mask_raw")
|
||||
bbox_obj = viz_dict.get("bbox")
|
||||
silhouette_obj = viz_dict.get("silhouette")
|
||||
track_id_val = viz_dict.get("track_id", 0)
|
||||
track_id = track_id_val if isinstance(track_id_val, int) else 0
|
||||
label = viz_dict.get("label")
|
||||
confidence = viz_dict.get("confidence")
|
||||
label_obj = viz_dict.get("label")
|
||||
confidence_obj = viz_dict.get("confidence")
|
||||
|
||||
# Cast _visualizer to object with update method
|
||||
visualizer = cast(object, self._visualizer)
|
||||
update_fn = getattr(visualizer, "update", None)
|
||||
if callable(update_fn):
|
||||
keep_running = update_fn(
|
||||
frame_u8,
|
||||
bbox,
|
||||
track_id,
|
||||
mask_raw,
|
||||
silhouette,
|
||||
label,
|
||||
confidence,
|
||||
ema_fps,
|
||||
)
|
||||
if not keep_running:
|
||||
logger.info("Visualization closed by user.")
|
||||
break
|
||||
# Cast extracted values to expected types
|
||||
mask_raw = cast(NDArray[np.uint8] | None, mask_raw_obj)
|
||||
bbox = cast(tuple[int, int, int, int] | None, bbox_obj)
|
||||
silhouette = cast(NDArray[np.float32] | None, silhouette_obj)
|
||||
label = cast(str | None, label_obj)
|
||||
confidence = cast(float | None, confidence_obj)
|
||||
|
||||
keep_running = self._visualizer.update(
|
||||
frame_u8,
|
||||
bbox,
|
||||
track_id,
|
||||
mask_raw,
|
||||
silhouette,
|
||||
label,
|
||||
confidence,
|
||||
ema_fps,
|
||||
)
|
||||
if not keep_running:
|
||||
logger.info("Visualization closed by user.")
|
||||
break
|
||||
|
||||
if frame_count % 100 == 0:
|
||||
elapsed = time.perf_counter() - start_time
|
||||
@@ -385,11 +391,8 @@ class ScoliosisPipeline:
|
||||
|
||||
# Close visualizer if enabled
|
||||
if self._visualizer is not None:
|
||||
visualizer = cast(object, self._visualizer)
|
||||
close_viz = getattr(visualizer, "close", None)
|
||||
if callable(close_viz):
|
||||
with suppress(Exception):
|
||||
_ = close_viz()
|
||||
with suppress(Exception):
|
||||
self._visualizer.close()
|
||||
|
||||
# Export silhouettes if requested
|
||||
if self._silhouette_export_path is not None and self._silhouette_buffer:
|
||||
|
||||
Reference in New Issue
Block a user