b24644f16e
Add the full demo runtime stack for single-person scoliosis inference, including input adapters, silhouette preprocessing, temporal windowing, ScoNet wrapper, result publishing, and click-based CLI orchestration. This commit captures the executable pipeline behavior independently from tests and planning artifacts for clearer review and rollback.
326 lines
10 KiB
Python
326 lines
10 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import Callable
|
|
from contextlib import suppress
|
|
import logging
|
|
from pathlib import Path
|
|
import time
|
|
from typing import Protocol, cast
|
|
|
|
from beartype import beartype
|
|
import click
|
|
import jaxtyping
|
|
from jaxtyping import Float, UInt8
|
|
import numpy as np
|
|
from numpy import ndarray
|
|
from numpy.typing import NDArray
|
|
from ultralytics.models.yolo.model import YOLO
|
|
|
|
from .input import FrameStream, create_source
|
|
from .output import ResultPublisher, create_publisher, create_result
|
|
from .preprocess import frame_to_person_mask, mask_to_silhouette
|
|
from .sconet_demo import ScoNetDemo
|
|
from .window import SilhouetteWindow, select_person
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
|
|
JaxtypedFactory = Callable[..., JaxtypedDecorator]
|
|
jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
|
|
|
|
|
|
class _BoxesLike(Protocol):
|
|
@property
|
|
def xyxy(self) -> NDArray[np.float32] | object: ...
|
|
|
|
@property
|
|
def id(self) -> NDArray[np.int64] | object | None: ...
|
|
|
|
|
|
class _MasksLike(Protocol):
|
|
@property
|
|
def data(self) -> NDArray[np.float32] | object: ...
|
|
|
|
|
|
class _DetectionResultsLike(Protocol):
|
|
@property
|
|
def boxes(self) -> _BoxesLike: ...
|
|
|
|
@property
|
|
def masks(self) -> _MasksLike: ...
|
|
|
|
|
|
class _TrackCallable(Protocol):
|
|
def __call__(
|
|
self,
|
|
source: object,
|
|
*,
|
|
persist: bool = True,
|
|
verbose: bool = False,
|
|
device: str | None = None,
|
|
classes: list[int] | None = None,
|
|
) -> object: ...
|
|
|
|
|
|
class ScoliosisPipeline:
|
|
_detector: object
|
|
_source: FrameStream
|
|
_window: SilhouetteWindow
|
|
_publisher: ResultPublisher
|
|
_classifier: ScoNetDemo
|
|
_device: str
|
|
_closed: bool
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
source: str,
|
|
checkpoint: str,
|
|
config: str,
|
|
device: str,
|
|
yolo_model: str,
|
|
window: int,
|
|
stride: int,
|
|
nats_url: str | None,
|
|
nats_subject: str,
|
|
max_frames: int | None,
|
|
) -> None:
|
|
self._detector = YOLO(yolo_model)
|
|
self._source = create_source(source, max_frames=max_frames)
|
|
self._window = SilhouetteWindow(window_size=window, stride=stride)
|
|
self._publisher = create_publisher(nats_url=nats_url, subject=nats_subject)
|
|
self._classifier = ScoNetDemo(
|
|
cfg_path=config,
|
|
checkpoint_path=checkpoint,
|
|
device=device,
|
|
)
|
|
self._device = device
|
|
self._closed = False
|
|
|
|
@staticmethod
|
|
def _extract_int(meta: dict[str, object], key: str, fallback: int) -> int:
|
|
value = meta.get(key)
|
|
if isinstance(value, int):
|
|
return value
|
|
return fallback
|
|
|
|
@staticmethod
|
|
def _extract_timestamp(meta: dict[str, object]) -> int:
|
|
value = meta.get("timestamp_ns")
|
|
if isinstance(value, int):
|
|
return value
|
|
return time.monotonic_ns()
|
|
|
|
@staticmethod
|
|
def _to_mask_u8(mask: ndarray) -> UInt8[ndarray, "h w"]:
|
|
binary = np.where(np.asarray(mask) > 0.5, np.uint8(255), np.uint8(0)).astype(
|
|
np.uint8
|
|
)
|
|
return cast(UInt8[ndarray, "h w"], binary)
|
|
|
|
def _first_result(self, detections: object) -> _DetectionResultsLike | None:
|
|
if isinstance(detections, list):
|
|
return cast(_DetectionResultsLike, detections[0]) if detections else None
|
|
if isinstance(detections, tuple):
|
|
return cast(_DetectionResultsLike, detections[0]) if detections else None
|
|
return cast(_DetectionResultsLike, detections)
|
|
|
|
def _select_silhouette(
|
|
self,
|
|
result: _DetectionResultsLike,
|
|
) -> tuple[Float[ndarray, "64 44"], int] | None:
|
|
selected = select_person(result)
|
|
if selected is not None:
|
|
mask_raw, bbox, track_id = selected
|
|
silhouette = cast(
|
|
Float[ndarray, "64 44"] | None,
|
|
mask_to_silhouette(self._to_mask_u8(mask_raw), bbox),
|
|
)
|
|
if silhouette is not None:
|
|
return silhouette, int(track_id)
|
|
|
|
fallback = cast(
|
|
tuple[UInt8[ndarray, "h w"], tuple[int, int, int, int]] | None,
|
|
frame_to_person_mask(result),
|
|
)
|
|
if fallback is None:
|
|
return None
|
|
|
|
mask_u8, bbox = fallback
|
|
silhouette = cast(
|
|
Float[ndarray, "64 44"] | None,
|
|
mask_to_silhouette(mask_u8, bbox),
|
|
)
|
|
if silhouette is None:
|
|
return None
|
|
return silhouette, 0
|
|
|
|
@jaxtyped(typechecker=beartype)
|
|
def process_frame(
|
|
self,
|
|
frame: UInt8[ndarray, "h w c"],
|
|
metadata: dict[str, object],
|
|
) -> dict[str, object] | None:
|
|
frame_idx = self._extract_int(metadata, "frame_count", fallback=0)
|
|
timestamp_ns = self._extract_timestamp(metadata)
|
|
|
|
track_fn_obj = getattr(self._detector, "track", None)
|
|
if not callable(track_fn_obj):
|
|
raise RuntimeError("YOLO detector does not expose a callable track()")
|
|
|
|
track_fn = cast(_TrackCallable, track_fn_obj)
|
|
detections = track_fn(
|
|
frame,
|
|
persist=True,
|
|
verbose=False,
|
|
device=self._device,
|
|
classes=[0],
|
|
)
|
|
first = self._first_result(detections)
|
|
if first is None:
|
|
return None
|
|
|
|
selected = self._select_silhouette(first)
|
|
if selected is None:
|
|
return None
|
|
|
|
silhouette, track_id = selected
|
|
self._window.push(silhouette, frame_idx=frame_idx, track_id=track_id)
|
|
|
|
if not self._window.should_classify():
|
|
return None
|
|
|
|
window_tensor = self._window.get_tensor(device=self._device)
|
|
label, confidence = cast(
|
|
tuple[str, float],
|
|
self._classifier.predict(window_tensor),
|
|
)
|
|
self._window.mark_classified()
|
|
|
|
window_start = frame_idx - self._window.window_size + 1
|
|
result = create_result(
|
|
frame=frame_idx,
|
|
track_id=track_id,
|
|
label=label,
|
|
confidence=float(confidence),
|
|
window=(max(0, window_start), frame_idx),
|
|
timestamp_ns=timestamp_ns,
|
|
)
|
|
self._publisher.publish(result)
|
|
return result
|
|
|
|
def run(self) -> int:
|
|
frame_count = 0
|
|
start_time = time.perf_counter()
|
|
try:
|
|
for item in self._source:
|
|
frame, metadata = item
|
|
frame_u8 = np.asarray(frame, dtype=np.uint8)
|
|
frame_idx = self._extract_int(metadata, "frame_count", fallback=0)
|
|
frame_count += 1
|
|
try:
|
|
_ = self.process_frame(frame_u8, metadata)
|
|
except Exception as frame_error:
|
|
logger.warning(
|
|
"Skipping frame %d due to processing error: %s",
|
|
frame_idx,
|
|
frame_error,
|
|
)
|
|
if frame_count % 100 == 0:
|
|
elapsed = time.perf_counter() - start_time
|
|
fps = frame_count / elapsed if elapsed > 0 else 0.0
|
|
logger.info("Processed %d frames (%.2f FPS)", frame_count, fps)
|
|
return 0
|
|
except KeyboardInterrupt:
|
|
logger.info("Interrupted by user, shutting down cleanly.")
|
|
return 130
|
|
finally:
|
|
self.close()
|
|
|
|
def close(self) -> None:
|
|
if self._closed:
|
|
return
|
|
close_fn = getattr(self._publisher, "close", None)
|
|
if callable(close_fn):
|
|
with suppress(Exception):
|
|
_ = close_fn()
|
|
self._closed = True
|
|
|
|
|
|
def validate_runtime_inputs(source: str, checkpoint: str, config: str) -> None:
|
|
if source.startswith("cvmmap://") or source.isdigit():
|
|
pass
|
|
else:
|
|
source_path = Path(source)
|
|
if not source_path.is_file():
|
|
raise ValueError(f"Video source not found: {source}")
|
|
|
|
checkpoint_path = Path(checkpoint)
|
|
if not checkpoint_path.is_file():
|
|
raise ValueError(f"Checkpoint not found: {checkpoint}")
|
|
|
|
config_path = Path(config)
|
|
if not config_path.is_file():
|
|
raise ValueError(f"Config not found: {config}")
|
|
|
|
|
|
@click.command(context_settings={"help_option_names": ["-h", "--help"]})
|
|
@click.option("--source", type=str, required=True)
|
|
@click.option("--checkpoint", type=str, required=True)
|
|
@click.option(
|
|
"--config",
|
|
type=str,
|
|
default="configs/sconet/sconet_scoliosis1k.yaml",
|
|
show_default=True,
|
|
)
|
|
@click.option("--device", type=str, default="cuda:0", show_default=True)
|
|
@click.option("--yolo-model", type=str, default="yolo11n-seg.pt", show_default=True)
|
|
@click.option("--window", type=click.IntRange(min=1), default=30, show_default=True)
|
|
@click.option("--stride", type=click.IntRange(min=1), default=30, show_default=True)
|
|
@click.option("--nats-url", type=str, default=None)
|
|
@click.option(
|
|
"--nats-subject",
|
|
type=str,
|
|
default="scoliosis.result",
|
|
show_default=True,
|
|
)
|
|
@click.option("--max-frames", type=click.IntRange(min=1), default=None)
|
|
def main(
|
|
source: str,
|
|
checkpoint: str,
|
|
config: str,
|
|
device: str,
|
|
yolo_model: str,
|
|
window: int,
|
|
stride: int,
|
|
nats_url: str | None,
|
|
nats_subject: str,
|
|
max_frames: int | None,
|
|
) -> None:
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
|
)
|
|
|
|
try:
|
|
validate_runtime_inputs(source=source, checkpoint=checkpoint, config=config)
|
|
pipeline = ScoliosisPipeline(
|
|
source=source,
|
|
checkpoint=checkpoint,
|
|
config=config,
|
|
device=device,
|
|
yolo_model=yolo_model,
|
|
window=window,
|
|
stride=stride,
|
|
nats_url=nats_url,
|
|
nats_subject=nats_subject,
|
|
max_frames=max_frames,
|
|
)
|
|
raise SystemExit(pipeline.run())
|
|
except ValueError as err:
|
|
click.echo(f"Error: {err}", err=True)
|
|
raise SystemExit(2) from err
|
|
except RuntimeError as err:
|
|
click.echo(f"Runtime error: {err}", err=True)
|
|
raise SystemExit(1) from err
|