feat(demo): add realtime visualization pipeline flow

Integrate an opt-in OpenCV visualizer into the demo runtime so operators can monitor tracking, segmentation, and inference confidence in real time without changing the default non-visual execution path.
This commit is contained in:
2026-02-27 20:14:24 +08:00
parent 846549498c
commit 4cc2ef7c63
3 changed files with 670 additions and 11 deletions
+107 -9
View File
@@ -78,6 +78,7 @@ class ScoliosisPipeline:
_result_export_path: Path | None
_result_export_format: str
_result_buffer: list[dict[str, object]]
_visualizer: object | None
def __init__(
self,
@@ -98,6 +99,7 @@ class ScoliosisPipeline:
silhouette_visualize_dir: str | None = None,
result_export_path: str | None = None,
result_export_format: str = "json",
visualize: bool = False,
) -> None:
self._detector = YOLO(yolo_model)
self._source = create_source(source, max_frames=max_frames)
@@ -124,6 +126,12 @@ class ScoliosisPipeline:
)
self._result_export_format = result_export_format
self._result_buffer = []
if visualize:
from .visualizer import OpenCVVisualizer
self._visualizer = OpenCVVisualizer()
else:
self._visualizer = None
@staticmethod
def _extract_int(meta: dict[str, object], key: str, fallback: int) -> int:
@@ -156,7 +164,15 @@ class ScoliosisPipeline:
def _select_silhouette(
self,
result: _DetectionResultsLike,
) -> tuple[Float[ndarray, "64 44"], int] | None:
) -> (
tuple[
Float[ndarray, "64 44"],
UInt8[ndarray, "h w"],
tuple[int, int, int, int],
int,
]
| None
):
selected = select_person(result)
if selected is not None:
mask_raw, bbox, track_id = selected
@@ -165,7 +181,7 @@ class ScoliosisPipeline:
mask_to_silhouette(self._to_mask_u8(mask_raw), bbox),
)
if silhouette is not None:
return silhouette, int(track_id)
return silhouette, mask_raw, bbox, int(track_id)
fallback = cast(
tuple[UInt8[ndarray, "h w"], tuple[int, int, int, int]] | None,
@@ -181,7 +197,8 @@ class ScoliosisPipeline:
)
if silhouette is None:
return None
return silhouette, 0
# For fallback case, mask_raw is the same as mask_u8
return silhouette, mask_u8, bbox, 0
@jaxtyped(typechecker=beartype)
def process_frame(
@@ -212,7 +229,7 @@ class ScoliosisPipeline:
if selected is None:
return None
silhouette, track_id = selected
silhouette, mask_raw, bbox, track_id = selected
# Store silhouette for export if in preprocess-only mode or if export requested
if self._silhouette_export_path is not None or self._preprocess_only:
@@ -230,12 +247,28 @@ class ScoliosisPipeline:
self._visualize_silhouette(silhouette, frame_idx, track_id)
if self._preprocess_only:
return None
# Return visualization payload for display even in preprocess-only mode
return {
"mask_raw": mask_raw,
"bbox": bbox,
"silhouette": silhouette,
"track_id": track_id,
"label": None,
"confidence": None,
}
self._window.push(silhouette, frame_idx=frame_idx, track_id=track_id)
if not self._window.should_classify():
return None
# Return visualization payload even when not classifying yet
return {
"mask_raw": mask_raw,
"bbox": bbox,
"silhouette": silhouette,
"track_id": track_id,
"label": None,
"confidence": None,
}
window_tensor = self._window.get_tensor(device=self._device)
label, confidence = cast(
@@ -259,25 +292,82 @@ class ScoliosisPipeline:
self._result_buffer.append(result)
self._publisher.publish(result)
return result
# Return result with visualization payload
return {
"result": result,
"mask_raw": mask_raw,
"bbox": bbox,
"silhouette": silhouette,
"track_id": track_id,
"label": label,
"confidence": confidence,
}
def run(self) -> int:
frame_count = 0
start_time = time.perf_counter()
# EMA FPS state (alpha=0.1 for smoothing)
ema_fps = 0.0
alpha = 0.1
prev_time = start_time
try:
for item in self._source:
frame, metadata = item
frame_u8 = np.asarray(frame, dtype=np.uint8)
frame_idx = self._extract_int(metadata, "frame_count", fallback=0)
frame_count += 1
# Compute per-frame EMA FPS
curr_time = time.perf_counter()
delta = curr_time - prev_time
prev_time = curr_time
if delta > 0:
instant_fps = 1.0 / delta
if ema_fps == 0.0:
ema_fps = instant_fps
else:
ema_fps = alpha * instant_fps + (1 - alpha) * ema_fps
viz_payload = None
try:
_ = self.process_frame(frame_u8, metadata)
viz_payload = self.process_frame(frame_u8, metadata)
except Exception as frame_error:
logger.warning(
"Skipping frame %d due to processing error: %s",
frame_idx,
frame_error,
)
# Update visualizer if enabled
if self._visualizer is not None and viz_payload is not None:
# Cast viz_payload to dict for type checking
viz_dict = cast(dict[str, object], viz_payload)
mask_raw = viz_dict.get("mask_raw")
bbox = viz_dict.get("bbox")
silhouette = viz_dict.get("silhouette")
track_id_val = viz_dict.get("track_id", 0)
track_id = track_id_val if isinstance(track_id_val, int) else 0
label = viz_dict.get("label")
confidence = viz_dict.get("confidence")
# Cast _visualizer to object with update method
visualizer = cast(object, self._visualizer)
update_fn = getattr(visualizer, "update", None)
if callable(update_fn):
keep_running = update_fn(
frame_u8,
bbox,
track_id,
mask_raw,
silhouette,
label,
confidence,
ema_fps,
)
if not keep_running:
logger.info("Visualization closed by user.")
break
if frame_count % 100 == 0:
elapsed = time.perf_counter() - start_time
fps = frame_count / elapsed if elapsed > 0 else 0.0
@@ -293,6 +383,14 @@ class ScoliosisPipeline:
if self._closed:
return
# Close visualizer if enabled
if self._visualizer is not None:
visualizer = cast(object, self._visualizer)
close_viz = getattr(visualizer, "close", None)
if callable(close_viz):
with suppress(Exception):
_ = close_viz()
# Export silhouettes if requested
if self._silhouette_export_path is not None and self._silhouette_buffer:
self._export_silhouettes()
@@ -504,7 +602,7 @@ def validate_runtime_inputs(source: str, checkpoint: str, config: str) -> None:
show_default=True,
)
@click.option("--device", type=str, default="cuda:0", show_default=True)
@click.option("--yolo-model", type=str, default="yolo11n-seg.pt", show_default=True)
@click.option("--yolo-model", type=str, default="ckpt/yolo11n-seg.pt", show_default=True)
@click.option("--window", type=click.IntRange(min=1), default=30, show_default=True)
@click.option("--stride", type=click.IntRange(min=1), default=30, show_default=True)
@click.option("--nats-url", type=str, default=None)