feat(demo): implement ScoNet real-time pipeline runtime

Add the full demo runtime stack for single-person scoliosis inference, including input adapters, silhouette preprocessing, temporal windowing, ScoNet wrapper, result publishing, and click-based CLI orchestration. This commit captures the executable pipeline behavior independently from tests and planning artifacts for clearer review and rollback.
This commit is contained in:
2026-02-27 09:59:04 +08:00
parent cd754ffcfb
commit b24644f16e
8 changed files with 1785 additions and 0 deletions
+325
View File
@@ -0,0 +1,325 @@
from __future__ import annotations
from collections.abc import Callable
from contextlib import suppress
import logging
from pathlib import Path
import time
from typing import Protocol, cast
from beartype import beartype
import click
import jaxtyping
from jaxtyping import Float, UInt8
import numpy as np
from numpy import ndarray
from numpy.typing import NDArray
from ultralytics.models.yolo.model import YOLO
from .input import FrameStream, create_source
from .output import ResultPublisher, create_publisher, create_result
from .preprocess import frame_to_person_mask, mask_to_silhouette
from .sconet_demo import ScoNetDemo
from .window import SilhouetteWindow, select_person
logger = logging.getLogger(__name__)
JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
JaxtypedFactory = Callable[..., JaxtypedDecorator]
jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
class _BoxesLike(Protocol):
@property
def xyxy(self) -> NDArray[np.float32] | object: ...
@property
def id(self) -> NDArray[np.int64] | object | None: ...
class _MasksLike(Protocol):
@property
def data(self) -> NDArray[np.float32] | object: ...
class _DetectionResultsLike(Protocol):
@property
def boxes(self) -> _BoxesLike: ...
@property
def masks(self) -> _MasksLike: ...
class _TrackCallable(Protocol):
def __call__(
self,
source: object,
*,
persist: bool = True,
verbose: bool = False,
device: str | None = None,
classes: list[int] | None = None,
) -> object: ...
class ScoliosisPipeline:
_detector: object
_source: FrameStream
_window: SilhouetteWindow
_publisher: ResultPublisher
_classifier: ScoNetDemo
_device: str
_closed: bool
def __init__(
self,
*,
source: str,
checkpoint: str,
config: str,
device: str,
yolo_model: str,
window: int,
stride: int,
nats_url: str | None,
nats_subject: str,
max_frames: int | None,
) -> None:
self._detector = YOLO(yolo_model)
self._source = create_source(source, max_frames=max_frames)
self._window = SilhouetteWindow(window_size=window, stride=stride)
self._publisher = create_publisher(nats_url=nats_url, subject=nats_subject)
self._classifier = ScoNetDemo(
cfg_path=config,
checkpoint_path=checkpoint,
device=device,
)
self._device = device
self._closed = False
@staticmethod
def _extract_int(meta: dict[str, object], key: str, fallback: int) -> int:
value = meta.get(key)
if isinstance(value, int):
return value
return fallback
@staticmethod
def _extract_timestamp(meta: dict[str, object]) -> int:
value = meta.get("timestamp_ns")
if isinstance(value, int):
return value
return time.monotonic_ns()
@staticmethod
def _to_mask_u8(mask: ndarray) -> UInt8[ndarray, "h w"]:
binary = np.where(np.asarray(mask) > 0.5, np.uint8(255), np.uint8(0)).astype(
np.uint8
)
return cast(UInt8[ndarray, "h w"], binary)
def _first_result(self, detections: object) -> _DetectionResultsLike | None:
if isinstance(detections, list):
return cast(_DetectionResultsLike, detections[0]) if detections else None
if isinstance(detections, tuple):
return cast(_DetectionResultsLike, detections[0]) if detections else None
return cast(_DetectionResultsLike, detections)
def _select_silhouette(
self,
result: _DetectionResultsLike,
) -> tuple[Float[ndarray, "64 44"], int] | None:
selected = select_person(result)
if selected is not None:
mask_raw, bbox, track_id = selected
silhouette = cast(
Float[ndarray, "64 44"] | None,
mask_to_silhouette(self._to_mask_u8(mask_raw), bbox),
)
if silhouette is not None:
return silhouette, int(track_id)
fallback = cast(
tuple[UInt8[ndarray, "h w"], tuple[int, int, int, int]] | None,
frame_to_person_mask(result),
)
if fallback is None:
return None
mask_u8, bbox = fallback
silhouette = cast(
Float[ndarray, "64 44"] | None,
mask_to_silhouette(mask_u8, bbox),
)
if silhouette is None:
return None
return silhouette, 0
@jaxtyped(typechecker=beartype)
def process_frame(
self,
frame: UInt8[ndarray, "h w c"],
metadata: dict[str, object],
) -> dict[str, object] | None:
frame_idx = self._extract_int(metadata, "frame_count", fallback=0)
timestamp_ns = self._extract_timestamp(metadata)
track_fn_obj = getattr(self._detector, "track", None)
if not callable(track_fn_obj):
raise RuntimeError("YOLO detector does not expose a callable track()")
track_fn = cast(_TrackCallable, track_fn_obj)
detections = track_fn(
frame,
persist=True,
verbose=False,
device=self._device,
classes=[0],
)
first = self._first_result(detections)
if first is None:
return None
selected = self._select_silhouette(first)
if selected is None:
return None
silhouette, track_id = selected
self._window.push(silhouette, frame_idx=frame_idx, track_id=track_id)
if not self._window.should_classify():
return None
window_tensor = self._window.get_tensor(device=self._device)
label, confidence = cast(
tuple[str, float],
self._classifier.predict(window_tensor),
)
self._window.mark_classified()
window_start = frame_idx - self._window.window_size + 1
result = create_result(
frame=frame_idx,
track_id=track_id,
label=label,
confidence=float(confidence),
window=(max(0, window_start), frame_idx),
timestamp_ns=timestamp_ns,
)
self._publisher.publish(result)
return result
def run(self) -> int:
frame_count = 0
start_time = time.perf_counter()
try:
for item in self._source:
frame, metadata = item
frame_u8 = np.asarray(frame, dtype=np.uint8)
frame_idx = self._extract_int(metadata, "frame_count", fallback=0)
frame_count += 1
try:
_ = self.process_frame(frame_u8, metadata)
except Exception as frame_error:
logger.warning(
"Skipping frame %d due to processing error: %s",
frame_idx,
frame_error,
)
if frame_count % 100 == 0:
elapsed = time.perf_counter() - start_time
fps = frame_count / elapsed if elapsed > 0 else 0.0
logger.info("Processed %d frames (%.2f FPS)", frame_count, fps)
return 0
except KeyboardInterrupt:
logger.info("Interrupted by user, shutting down cleanly.")
return 130
finally:
self.close()
def close(self) -> None:
if self._closed:
return
close_fn = getattr(self._publisher, "close", None)
if callable(close_fn):
with suppress(Exception):
_ = close_fn()
self._closed = True
def validate_runtime_inputs(source: str, checkpoint: str, config: str) -> None:
if source.startswith("cvmmap://") or source.isdigit():
pass
else:
source_path = Path(source)
if not source_path.is_file():
raise ValueError(f"Video source not found: {source}")
checkpoint_path = Path(checkpoint)
if not checkpoint_path.is_file():
raise ValueError(f"Checkpoint not found: {checkpoint}")
config_path = Path(config)
if not config_path.is_file():
raise ValueError(f"Config not found: {config}")
@click.command(context_settings={"help_option_names": ["-h", "--help"]})
@click.option("--source", type=str, required=True)
@click.option("--checkpoint", type=str, required=True)
@click.option(
"--config",
type=str,
default="configs/sconet/sconet_scoliosis1k.yaml",
show_default=True,
)
@click.option("--device", type=str, default="cuda:0", show_default=True)
@click.option("--yolo-model", type=str, default="yolo11n-seg.pt", show_default=True)
@click.option("--window", type=click.IntRange(min=1), default=30, show_default=True)
@click.option("--stride", type=click.IntRange(min=1), default=30, show_default=True)
@click.option("--nats-url", type=str, default=None)
@click.option(
"--nats-subject",
type=str,
default="scoliosis.result",
show_default=True,
)
@click.option("--max-frames", type=click.IntRange(min=1), default=None)
def main(
source: str,
checkpoint: str,
config: str,
device: str,
yolo_model: str,
window: int,
stride: int,
nats_url: str | None,
nats_subject: str,
max_frames: int | None,
) -> None:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
try:
validate_runtime_inputs(source=source, checkpoint=checkpoint, config=config)
pipeline = ScoliosisPipeline(
source=source,
checkpoint=checkpoint,
config=config,
device=device,
yolo_model=yolo_model,
window=window,
stride=stride,
nats_url=nats_url,
nats_subject=nats_subject,
max_frames=max_frames,
)
raise SystemExit(pipeline.run())
except ValueError as err:
click.echo(f"Error: {err}", err=True)
raise SystemExit(2) from err
except RuntimeError as err:
click.echo(f"Runtime error: {err}", err=True)
raise SystemExit(1) from err