refactor(demo): simplify visualizer wiring and typing
Apply Oracle-guided cleanup to make the demo pipeline contract explicit and remove defensive runtime indirection while preserving existing visualization behavior.
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import inspect
|
||||
import logging
|
||||
import sys
|
||||
|
||||
@@ -90,7 +89,6 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
# Build kwargs based on what ScoliosisPipeline accepts
|
||||
sig = inspect.signature(ScoliosisPipeline.__init__)
|
||||
pipeline_kwargs = {
|
||||
"source": args.source,
|
||||
"checkpoint": args.checkpoint,
|
||||
@@ -108,10 +106,8 @@ if __name__ == "__main__":
|
||||
"silhouette_visualize_dir": args.silhouette_visualize_dir,
|
||||
"result_export_path": args.result_export_path,
|
||||
"result_export_format": args.result_export_format,
|
||||
"visualize": args.visualize,
|
||||
}
|
||||
if "visualize" in sig.parameters:
|
||||
pipeline_kwargs["visualize"] = args.visualize
|
||||
|
||||
pipeline = ScoliosisPipeline(**pipeline_kwargs)
|
||||
raise SystemExit(pipeline.run())
|
||||
except ValueError as err:
|
||||
|
||||
+19
-16
@@ -5,7 +5,7 @@ from contextlib import suppress
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import time
|
||||
from typing import Protocol, cast
|
||||
from typing import TYPE_CHECKING, Protocol, cast
|
||||
|
||||
from beartype import beartype
|
||||
import click
|
||||
@@ -22,6 +22,9 @@ from .preprocess import frame_to_person_mask, mask_to_silhouette
|
||||
from .sconet_demo import ScoNetDemo
|
||||
from .window import SilhouetteWindow, select_person
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .visualizer import OpenCVVisualizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
|
||||
@@ -78,7 +81,7 @@ class ScoliosisPipeline:
|
||||
_result_export_path: Path | None
|
||||
_result_export_format: str
|
||||
_result_buffer: list[dict[str, object]]
|
||||
_visualizer: object | None
|
||||
_visualizer: OpenCVVisualizer | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -342,19 +345,22 @@ class ScoliosisPipeline:
|
||||
if self._visualizer is not None and viz_payload is not None:
|
||||
# Cast viz_payload to dict for type checking
|
||||
viz_dict = cast(dict[str, object], viz_payload)
|
||||
mask_raw = viz_dict.get("mask_raw")
|
||||
bbox = viz_dict.get("bbox")
|
||||
silhouette = viz_dict.get("silhouette")
|
||||
mask_raw_obj = viz_dict.get("mask_raw")
|
||||
bbox_obj = viz_dict.get("bbox")
|
||||
silhouette_obj = viz_dict.get("silhouette")
|
||||
track_id_val = viz_dict.get("track_id", 0)
|
||||
track_id = track_id_val if isinstance(track_id_val, int) else 0
|
||||
label = viz_dict.get("label")
|
||||
confidence = viz_dict.get("confidence")
|
||||
label_obj = viz_dict.get("label")
|
||||
confidence_obj = viz_dict.get("confidence")
|
||||
|
||||
# Cast _visualizer to object with update method
|
||||
visualizer = cast(object, self._visualizer)
|
||||
update_fn = getattr(visualizer, "update", None)
|
||||
if callable(update_fn):
|
||||
keep_running = update_fn(
|
||||
# Cast extracted values to expected types
|
||||
mask_raw = cast(NDArray[np.uint8] | None, mask_raw_obj)
|
||||
bbox = cast(tuple[int, int, int, int] | None, bbox_obj)
|
||||
silhouette = cast(NDArray[np.float32] | None, silhouette_obj)
|
||||
label = cast(str | None, label_obj)
|
||||
confidence = cast(float | None, confidence_obj)
|
||||
|
||||
keep_running = self._visualizer.update(
|
||||
frame_u8,
|
||||
bbox,
|
||||
track_id,
|
||||
@@ -385,11 +391,8 @@ class ScoliosisPipeline:
|
||||
|
||||
# Close visualizer if enabled
|
||||
if self._visualizer is not None:
|
||||
visualizer = cast(object, self._visualizer)
|
||||
close_viz = getattr(visualizer, "close", None)
|
||||
if callable(close_viz):
|
||||
with suppress(Exception):
|
||||
_ = close_viz()
|
||||
self._visualizer.close()
|
||||
|
||||
# Export silhouettes if requested
|
||||
if self._silhouette_export_path is not None and self._silhouette_buffer:
|
||||
|
||||
+22
-12
@@ -315,20 +315,30 @@ class OpenCVVisualizer:
|
||||
Returns:
|
||||
Displayable side-by-side image
|
||||
"""
|
||||
# Prepare individual views
|
||||
raw_view = self._prepare_raw_view(mask_raw)
|
||||
norm_view = self._prepare_normalized_view(silhouette)
|
||||
|
||||
# Convert to grayscale for side-by-side composition
|
||||
if len(raw_view.shape) == 3:
|
||||
raw_gray = cast(ImageArray, cv2.cvtColor(raw_view, cv2.COLOR_BGR2GRAY))
|
||||
# Prepare individual views without mode indicators (will be drawn on combined)
|
||||
# Raw view preparation (without indicator)
|
||||
if mask_raw is None:
|
||||
raw_gray = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH), dtype=np.uint8)
|
||||
else:
|
||||
raw_gray = raw_view
|
||||
|
||||
if len(norm_view.shape) == 3:
|
||||
norm_gray = cast(ImageArray, cv2.cvtColor(norm_view, cv2.COLOR_BGR2GRAY))
|
||||
if len(mask_raw.shape) == 3:
|
||||
mask_gray = cast(ImageArray, cv2.cvtColor(mask_raw, cv2.COLOR_BGR2GRAY))
|
||||
else:
|
||||
norm_gray = norm_view
|
||||
mask_gray = mask_raw
|
||||
raw_gray = cast(
|
||||
ImageArray,
|
||||
cv2.resize(
|
||||
mask_gray,
|
||||
(DISPLAY_WIDTH, DISPLAY_HEIGHT),
|
||||
interpolation=cv2.INTER_NEAREST,
|
||||
),
|
||||
)
|
||||
|
||||
# Normalized view preparation (without indicator)
|
||||
if silhouette is None:
|
||||
norm_gray = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH), dtype=np.uint8)
|
||||
else:
|
||||
upscaled = self._upscale_silhouette(silhouette)
|
||||
norm_gray = upscaled
|
||||
|
||||
# Stack horizontally
|
||||
combined = np.hstack([raw_gray, norm_gray])
|
||||
|
||||
Reference in New Issue
Block a user