refactor(demo): simplify visualizer wiring and typing
Apply Oracle-guided cleanup to make the demo pipeline contract explicit and remove defensive runtime indirection while preserving existing visualization behavior.
This commit is contained in:
@@ -1,7 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import inspect
|
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
@@ -90,7 +89,6 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Build kwargs based on what ScoliosisPipeline accepts
|
# Build kwargs based on what ScoliosisPipeline accepts
|
||||||
sig = inspect.signature(ScoliosisPipeline.__init__)
|
|
||||||
pipeline_kwargs = {
|
pipeline_kwargs = {
|
||||||
"source": args.source,
|
"source": args.source,
|
||||||
"checkpoint": args.checkpoint,
|
"checkpoint": args.checkpoint,
|
||||||
@@ -108,10 +106,8 @@ if __name__ == "__main__":
|
|||||||
"silhouette_visualize_dir": args.silhouette_visualize_dir,
|
"silhouette_visualize_dir": args.silhouette_visualize_dir,
|
||||||
"result_export_path": args.result_export_path,
|
"result_export_path": args.result_export_path,
|
||||||
"result_export_format": args.result_export_format,
|
"result_export_format": args.result_export_format,
|
||||||
|
"visualize": args.visualize,
|
||||||
}
|
}
|
||||||
if "visualize" in sig.parameters:
|
|
||||||
pipeline_kwargs["visualize"] = args.visualize
|
|
||||||
|
|
||||||
pipeline = ScoliosisPipeline(**pipeline_kwargs)
|
pipeline = ScoliosisPipeline(**pipeline_kwargs)
|
||||||
raise SystemExit(pipeline.run())
|
raise SystemExit(pipeline.run())
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
|
|||||||
+32
-29
@@ -5,7 +5,7 @@ from contextlib import suppress
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import time
|
import time
|
||||||
from typing import Protocol, cast
|
from typing import TYPE_CHECKING, Protocol, cast
|
||||||
|
|
||||||
from beartype import beartype
|
from beartype import beartype
|
||||||
import click
|
import click
|
||||||
@@ -22,6 +22,9 @@ from .preprocess import frame_to_person_mask, mask_to_silhouette
|
|||||||
from .sconet_demo import ScoNetDemo
|
from .sconet_demo import ScoNetDemo
|
||||||
from .window import SilhouetteWindow, select_person
|
from .window import SilhouetteWindow, select_person
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .visualizer import OpenCVVisualizer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
|
JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
|
||||||
@@ -78,7 +81,7 @@ class ScoliosisPipeline:
|
|||||||
_result_export_path: Path | None
|
_result_export_path: Path | None
|
||||||
_result_export_format: str
|
_result_export_format: str
|
||||||
_result_buffer: list[dict[str, object]]
|
_result_buffer: list[dict[str, object]]
|
||||||
_visualizer: object | None
|
_visualizer: OpenCVVisualizer | None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -342,31 +345,34 @@ class ScoliosisPipeline:
|
|||||||
if self._visualizer is not None and viz_payload is not None:
|
if self._visualizer is not None and viz_payload is not None:
|
||||||
# Cast viz_payload to dict for type checking
|
# Cast viz_payload to dict for type checking
|
||||||
viz_dict = cast(dict[str, object], viz_payload)
|
viz_dict = cast(dict[str, object], viz_payload)
|
||||||
mask_raw = viz_dict.get("mask_raw")
|
mask_raw_obj = viz_dict.get("mask_raw")
|
||||||
bbox = viz_dict.get("bbox")
|
bbox_obj = viz_dict.get("bbox")
|
||||||
silhouette = viz_dict.get("silhouette")
|
silhouette_obj = viz_dict.get("silhouette")
|
||||||
track_id_val = viz_dict.get("track_id", 0)
|
track_id_val = viz_dict.get("track_id", 0)
|
||||||
track_id = track_id_val if isinstance(track_id_val, int) else 0
|
track_id = track_id_val if isinstance(track_id_val, int) else 0
|
||||||
label = viz_dict.get("label")
|
label_obj = viz_dict.get("label")
|
||||||
confidence = viz_dict.get("confidence")
|
confidence_obj = viz_dict.get("confidence")
|
||||||
|
|
||||||
# Cast _visualizer to object with update method
|
# Cast extracted values to expected types
|
||||||
visualizer = cast(object, self._visualizer)
|
mask_raw = cast(NDArray[np.uint8] | None, mask_raw_obj)
|
||||||
update_fn = getattr(visualizer, "update", None)
|
bbox = cast(tuple[int, int, int, int] | None, bbox_obj)
|
||||||
if callable(update_fn):
|
silhouette = cast(NDArray[np.float32] | None, silhouette_obj)
|
||||||
keep_running = update_fn(
|
label = cast(str | None, label_obj)
|
||||||
frame_u8,
|
confidence = cast(float | None, confidence_obj)
|
||||||
bbox,
|
|
||||||
track_id,
|
keep_running = self._visualizer.update(
|
||||||
mask_raw,
|
frame_u8,
|
||||||
silhouette,
|
bbox,
|
||||||
label,
|
track_id,
|
||||||
confidence,
|
mask_raw,
|
||||||
ema_fps,
|
silhouette,
|
||||||
)
|
label,
|
||||||
if not keep_running:
|
confidence,
|
||||||
logger.info("Visualization closed by user.")
|
ema_fps,
|
||||||
break
|
)
|
||||||
|
if not keep_running:
|
||||||
|
logger.info("Visualization closed by user.")
|
||||||
|
break
|
||||||
|
|
||||||
if frame_count % 100 == 0:
|
if frame_count % 100 == 0:
|
||||||
elapsed = time.perf_counter() - start_time
|
elapsed = time.perf_counter() - start_time
|
||||||
@@ -385,11 +391,8 @@ class ScoliosisPipeline:
|
|||||||
|
|
||||||
# Close visualizer if enabled
|
# Close visualizer if enabled
|
||||||
if self._visualizer is not None:
|
if self._visualizer is not None:
|
||||||
visualizer = cast(object, self._visualizer)
|
with suppress(Exception):
|
||||||
close_viz = getattr(visualizer, "close", None)
|
self._visualizer.close()
|
||||||
if callable(close_viz):
|
|
||||||
with suppress(Exception):
|
|
||||||
_ = close_viz()
|
|
||||||
|
|
||||||
# Export silhouettes if requested
|
# Export silhouettes if requested
|
||||||
if self._silhouette_export_path is not None and self._silhouette_buffer:
|
if self._silhouette_export_path is not None and self._silhouette_buffer:
|
||||||
|
|||||||
+21
-11
@@ -315,20 +315,30 @@ class OpenCVVisualizer:
|
|||||||
Returns:
|
Returns:
|
||||||
Displayable side-by-side image
|
Displayable side-by-side image
|
||||||
"""
|
"""
|
||||||
# Prepare individual views
|
# Prepare individual views without mode indicators (will be drawn on combined)
|
||||||
raw_view = self._prepare_raw_view(mask_raw)
|
# Raw view preparation (without indicator)
|
||||||
norm_view = self._prepare_normalized_view(silhouette)
|
if mask_raw is None:
|
||||||
|
raw_gray = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH), dtype=np.uint8)
|
||||||
# Convert to grayscale for side-by-side composition
|
|
||||||
if len(raw_view.shape) == 3:
|
|
||||||
raw_gray = cast(ImageArray, cv2.cvtColor(raw_view, cv2.COLOR_BGR2GRAY))
|
|
||||||
else:
|
else:
|
||||||
raw_gray = raw_view
|
if len(mask_raw.shape) == 3:
|
||||||
|
mask_gray = cast(ImageArray, cv2.cvtColor(mask_raw, cv2.COLOR_BGR2GRAY))
|
||||||
|
else:
|
||||||
|
mask_gray = mask_raw
|
||||||
|
raw_gray = cast(
|
||||||
|
ImageArray,
|
||||||
|
cv2.resize(
|
||||||
|
mask_gray,
|
||||||
|
(DISPLAY_WIDTH, DISPLAY_HEIGHT),
|
||||||
|
interpolation=cv2.INTER_NEAREST,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
if len(norm_view.shape) == 3:
|
# Normalized view preparation (without indicator)
|
||||||
norm_gray = cast(ImageArray, cv2.cvtColor(norm_view, cv2.COLOR_BGR2GRAY))
|
if silhouette is None:
|
||||||
|
norm_gray = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH), dtype=np.uint8)
|
||||||
else:
|
else:
|
||||||
norm_gray = norm_view
|
upscaled = self._upscale_silhouette(silhouette)
|
||||||
|
norm_gray = upscaled
|
||||||
|
|
||||||
# Stack horizontally
|
# Stack horizontally
|
||||||
combined = np.hstack([raw_gray, norm_gray])
|
combined = np.hstack([raw_gray, norm_gray])
|
||||||
|
|||||||
Reference in New Issue
Block a user