feat(demo): implement ScoNet real-time pipeline runtime
Add the full demo runtime stack for single-person scoliosis inference, including input adapters, silhouette preprocessing, temporal windowing, ScoNet wrapper, result publishing, and click-based CLI orchestration. This commit captures the executable pipeline behavior independently from tests and planning artifacts for clearer review and rollback.
This commit is contained in:
@@ -0,0 +1,325 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from contextlib import suppress
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import time
|
||||
from typing import Protocol, cast
|
||||
|
||||
from beartype import beartype
|
||||
import click
|
||||
import jaxtyping
|
||||
from jaxtyping import Float, UInt8
|
||||
import numpy as np
|
||||
from numpy import ndarray
|
||||
from numpy.typing import NDArray
|
||||
from ultralytics.models.yolo.model import YOLO
|
||||
|
||||
from .input import FrameStream, create_source
|
||||
from .output import ResultPublisher, create_publisher, create_result
|
||||
from .preprocess import frame_to_person_mask, mask_to_silhouette
|
||||
from .sconet_demo import ScoNetDemo
|
||||
from .window import SilhouetteWindow, select_person
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
|
||||
JaxtypedFactory = Callable[..., JaxtypedDecorator]
|
||||
jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
|
||||
|
||||
|
||||
class _BoxesLike(Protocol):
|
||||
@property
|
||||
def xyxy(self) -> NDArray[np.float32] | object: ...
|
||||
|
||||
@property
|
||||
def id(self) -> NDArray[np.int64] | object | None: ...
|
||||
|
||||
|
||||
class _MasksLike(Protocol):
|
||||
@property
|
||||
def data(self) -> NDArray[np.float32] | object: ...
|
||||
|
||||
|
||||
class _DetectionResultsLike(Protocol):
|
||||
@property
|
||||
def boxes(self) -> _BoxesLike: ...
|
||||
|
||||
@property
|
||||
def masks(self) -> _MasksLike: ...
|
||||
|
||||
|
||||
class _TrackCallable(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
source: object,
|
||||
*,
|
||||
persist: bool = True,
|
||||
verbose: bool = False,
|
||||
device: str | None = None,
|
||||
classes: list[int] | None = None,
|
||||
) -> object: ...
|
||||
|
||||
|
||||
class ScoliosisPipeline:
|
||||
_detector: object
|
||||
_source: FrameStream
|
||||
_window: SilhouetteWindow
|
||||
_publisher: ResultPublisher
|
||||
_classifier: ScoNetDemo
|
||||
_device: str
|
||||
_closed: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
source: str,
|
||||
checkpoint: str,
|
||||
config: str,
|
||||
device: str,
|
||||
yolo_model: str,
|
||||
window: int,
|
||||
stride: int,
|
||||
nats_url: str | None,
|
||||
nats_subject: str,
|
||||
max_frames: int | None,
|
||||
) -> None:
|
||||
self._detector = YOLO(yolo_model)
|
||||
self._source = create_source(source, max_frames=max_frames)
|
||||
self._window = SilhouetteWindow(window_size=window, stride=stride)
|
||||
self._publisher = create_publisher(nats_url=nats_url, subject=nats_subject)
|
||||
self._classifier = ScoNetDemo(
|
||||
cfg_path=config,
|
||||
checkpoint_path=checkpoint,
|
||||
device=device,
|
||||
)
|
||||
self._device = device
|
||||
self._closed = False
|
||||
|
||||
@staticmethod
|
||||
def _extract_int(meta: dict[str, object], key: str, fallback: int) -> int:
|
||||
value = meta.get(key)
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
return fallback
|
||||
|
||||
@staticmethod
|
||||
def _extract_timestamp(meta: dict[str, object]) -> int:
|
||||
value = meta.get("timestamp_ns")
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
return time.monotonic_ns()
|
||||
|
||||
@staticmethod
|
||||
def _to_mask_u8(mask: ndarray) -> UInt8[ndarray, "h w"]:
|
||||
binary = np.where(np.asarray(mask) > 0.5, np.uint8(255), np.uint8(0)).astype(
|
||||
np.uint8
|
||||
)
|
||||
return cast(UInt8[ndarray, "h w"], binary)
|
||||
|
||||
def _first_result(self, detections: object) -> _DetectionResultsLike | None:
|
||||
if isinstance(detections, list):
|
||||
return cast(_DetectionResultsLike, detections[0]) if detections else None
|
||||
if isinstance(detections, tuple):
|
||||
return cast(_DetectionResultsLike, detections[0]) if detections else None
|
||||
return cast(_DetectionResultsLike, detections)
|
||||
|
||||
def _select_silhouette(
|
||||
self,
|
||||
result: _DetectionResultsLike,
|
||||
) -> tuple[Float[ndarray, "64 44"], int] | None:
|
||||
selected = select_person(result)
|
||||
if selected is not None:
|
||||
mask_raw, bbox, track_id = selected
|
||||
silhouette = cast(
|
||||
Float[ndarray, "64 44"] | None,
|
||||
mask_to_silhouette(self._to_mask_u8(mask_raw), bbox),
|
||||
)
|
||||
if silhouette is not None:
|
||||
return silhouette, int(track_id)
|
||||
|
||||
fallback = cast(
|
||||
tuple[UInt8[ndarray, "h w"], tuple[int, int, int, int]] | None,
|
||||
frame_to_person_mask(result),
|
||||
)
|
||||
if fallback is None:
|
||||
return None
|
||||
|
||||
mask_u8, bbox = fallback
|
||||
silhouette = cast(
|
||||
Float[ndarray, "64 44"] | None,
|
||||
mask_to_silhouette(mask_u8, bbox),
|
||||
)
|
||||
if silhouette is None:
|
||||
return None
|
||||
return silhouette, 0
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def process_frame(
|
||||
self,
|
||||
frame: UInt8[ndarray, "h w c"],
|
||||
metadata: dict[str, object],
|
||||
) -> dict[str, object] | None:
|
||||
frame_idx = self._extract_int(metadata, "frame_count", fallback=0)
|
||||
timestamp_ns = self._extract_timestamp(metadata)
|
||||
|
||||
track_fn_obj = getattr(self._detector, "track", None)
|
||||
if not callable(track_fn_obj):
|
||||
raise RuntimeError("YOLO detector does not expose a callable track()")
|
||||
|
||||
track_fn = cast(_TrackCallable, track_fn_obj)
|
||||
detections = track_fn(
|
||||
frame,
|
||||
persist=True,
|
||||
verbose=False,
|
||||
device=self._device,
|
||||
classes=[0],
|
||||
)
|
||||
first = self._first_result(detections)
|
||||
if first is None:
|
||||
return None
|
||||
|
||||
selected = self._select_silhouette(first)
|
||||
if selected is None:
|
||||
return None
|
||||
|
||||
silhouette, track_id = selected
|
||||
self._window.push(silhouette, frame_idx=frame_idx, track_id=track_id)
|
||||
|
||||
if not self._window.should_classify():
|
||||
return None
|
||||
|
||||
window_tensor = self._window.get_tensor(device=self._device)
|
||||
label, confidence = cast(
|
||||
tuple[str, float],
|
||||
self._classifier.predict(window_tensor),
|
||||
)
|
||||
self._window.mark_classified()
|
||||
|
||||
window_start = frame_idx - self._window.window_size + 1
|
||||
result = create_result(
|
||||
frame=frame_idx,
|
||||
track_id=track_id,
|
||||
label=label,
|
||||
confidence=float(confidence),
|
||||
window=(max(0, window_start), frame_idx),
|
||||
timestamp_ns=timestamp_ns,
|
||||
)
|
||||
self._publisher.publish(result)
|
||||
return result
|
||||
|
||||
def run(self) -> int:
|
||||
frame_count = 0
|
||||
start_time = time.perf_counter()
|
||||
try:
|
||||
for item in self._source:
|
||||
frame, metadata = item
|
||||
frame_u8 = np.asarray(frame, dtype=np.uint8)
|
||||
frame_idx = self._extract_int(metadata, "frame_count", fallback=0)
|
||||
frame_count += 1
|
||||
try:
|
||||
_ = self.process_frame(frame_u8, metadata)
|
||||
except Exception as frame_error:
|
||||
logger.warning(
|
||||
"Skipping frame %d due to processing error: %s",
|
||||
frame_idx,
|
||||
frame_error,
|
||||
)
|
||||
if frame_count % 100 == 0:
|
||||
elapsed = time.perf_counter() - start_time
|
||||
fps = frame_count / elapsed if elapsed > 0 else 0.0
|
||||
logger.info("Processed %d frames (%.2f FPS)", frame_count, fps)
|
||||
return 0
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Interrupted by user, shutting down cleanly.")
|
||||
return 130
|
||||
finally:
|
||||
self.close()
|
||||
|
||||
def close(self) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
close_fn = getattr(self._publisher, "close", None)
|
||||
if callable(close_fn):
|
||||
with suppress(Exception):
|
||||
_ = close_fn()
|
||||
self._closed = True
|
||||
|
||||
|
||||
def validate_runtime_inputs(source: str, checkpoint: str, config: str) -> None:
|
||||
if source.startswith("cvmmap://") or source.isdigit():
|
||||
pass
|
||||
else:
|
||||
source_path = Path(source)
|
||||
if not source_path.is_file():
|
||||
raise ValueError(f"Video source not found: {source}")
|
||||
|
||||
checkpoint_path = Path(checkpoint)
|
||||
if not checkpoint_path.is_file():
|
||||
raise ValueError(f"Checkpoint not found: {checkpoint}")
|
||||
|
||||
config_path = Path(config)
|
||||
if not config_path.is_file():
|
||||
raise ValueError(f"Config not found: {config}")
|
||||
|
||||
|
||||
@click.command(context_settings={"help_option_names": ["-h", "--help"]})
|
||||
@click.option("--source", type=str, required=True)
|
||||
@click.option("--checkpoint", type=str, required=True)
|
||||
@click.option(
|
||||
"--config",
|
||||
type=str,
|
||||
default="configs/sconet/sconet_scoliosis1k.yaml",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option("--device", type=str, default="cuda:0", show_default=True)
|
||||
@click.option("--yolo-model", type=str, default="yolo11n-seg.pt", show_default=True)
|
||||
@click.option("--window", type=click.IntRange(min=1), default=30, show_default=True)
|
||||
@click.option("--stride", type=click.IntRange(min=1), default=30, show_default=True)
|
||||
@click.option("--nats-url", type=str, default=None)
|
||||
@click.option(
|
||||
"--nats-subject",
|
||||
type=str,
|
||||
default="scoliosis.result",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option("--max-frames", type=click.IntRange(min=1), default=None)
|
||||
def main(
|
||||
source: str,
|
||||
checkpoint: str,
|
||||
config: str,
|
||||
device: str,
|
||||
yolo_model: str,
|
||||
window: int,
|
||||
stride: int,
|
||||
nats_url: str | None,
|
||||
nats_subject: str,
|
||||
max_frames: int | None,
|
||||
) -> None:
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
)
|
||||
|
||||
try:
|
||||
validate_runtime_inputs(source=source, checkpoint=checkpoint, config=config)
|
||||
pipeline = ScoliosisPipeline(
|
||||
source=source,
|
||||
checkpoint=checkpoint,
|
||||
config=config,
|
||||
device=device,
|
||||
yolo_model=yolo_model,
|
||||
window=window,
|
||||
stride=stride,
|
||||
nats_url=nats_url,
|
||||
nats_subject=nats_subject,
|
||||
max_frames=max_frames,
|
||||
)
|
||||
raise SystemExit(pipeline.run())
|
||||
except ValueError as err:
|
||||
click.echo(f"Error: {err}", err=True)
|
||||
raise SystemExit(2) from err
|
||||
except RuntimeError as err:
|
||||
click.echo(f"Runtime error: {err}", err=True)
|
||||
raise SystemExit(1) from err
|
||||
Reference in New Issue
Block a user