from __future__ import annotations from collections.abc import Callable from contextlib import suppress import logging from pathlib import Path import time from typing import Protocol, cast from beartype import beartype import click import jaxtyping from jaxtyping import Float, UInt8 import numpy as np from numpy import ndarray from numpy.typing import NDArray from ultralytics.models.yolo.model import YOLO from .input import FrameStream, create_source from .output import ResultPublisher, create_publisher, create_result from .preprocess import frame_to_person_mask, mask_to_silhouette from .sconet_demo import ScoNetDemo from .window import SilhouetteWindow, select_person logger = logging.getLogger(__name__) JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]] JaxtypedFactory = Callable[..., JaxtypedDecorator] jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped) class _BoxesLike(Protocol): @property def xyxy(self) -> NDArray[np.float32] | object: ... @property def id(self) -> NDArray[np.int64] | object | None: ... class _MasksLike(Protocol): @property def data(self) -> NDArray[np.float32] | object: ... class _DetectionResultsLike(Protocol): @property def boxes(self) -> _BoxesLike: ... @property def masks(self) -> _MasksLike: ... class _TrackCallable(Protocol): def __call__( self, source: object, *, persist: bool = True, verbose: bool = False, device: str | None = None, classes: list[int] | None = None, ) -> object: ... class ScoliosisPipeline: _detector: object _source: FrameStream _window: SilhouetteWindow _publisher: ResultPublisher _classifier: ScoNetDemo _device: str _closed: bool def __init__( self, *, source: str, checkpoint: str, config: str, device: str, yolo_model: str, window: int, stride: int, nats_url: str | None, nats_subject: str, max_frames: int | None, ) -> None: self._detector = YOLO(yolo_model) self._source = create_source(source, max_frames=max_frames) self._window = SilhouetteWindow(window_size=window, stride=stride) self._publisher = create_publisher(nats_url=nats_url, subject=nats_subject) self._classifier = ScoNetDemo( cfg_path=config, checkpoint_path=checkpoint, device=device, ) self._device = device self._closed = False @staticmethod def _extract_int(meta: dict[str, object], key: str, fallback: int) -> int: value = meta.get(key) if isinstance(value, int): return value return fallback @staticmethod def _extract_timestamp(meta: dict[str, object]) -> int: value = meta.get("timestamp_ns") if isinstance(value, int): return value return time.monotonic_ns() @staticmethod def _to_mask_u8(mask: ndarray) -> UInt8[ndarray, "h w"]: binary = np.where(np.asarray(mask) > 0.5, np.uint8(255), np.uint8(0)).astype( np.uint8 ) return cast(UInt8[ndarray, "h w"], binary) def _first_result(self, detections: object) -> _DetectionResultsLike | None: if isinstance(detections, list): return cast(_DetectionResultsLike, detections[0]) if detections else None if isinstance(detections, tuple): return cast(_DetectionResultsLike, detections[0]) if detections else None return cast(_DetectionResultsLike, detections) def _select_silhouette( self, result: _DetectionResultsLike, ) -> tuple[Float[ndarray, "64 44"], int] | None: selected = select_person(result) if selected is not None: mask_raw, bbox, track_id = selected silhouette = cast( Float[ndarray, "64 44"] | None, mask_to_silhouette(self._to_mask_u8(mask_raw), bbox), ) if silhouette is not None: return silhouette, int(track_id) fallback = cast( tuple[UInt8[ndarray, "h w"], tuple[int, int, int, int]] | None, frame_to_person_mask(result), ) if fallback is None: return None mask_u8, bbox = fallback silhouette = cast( Float[ndarray, "64 44"] | None, mask_to_silhouette(mask_u8, bbox), ) if silhouette is None: return None return silhouette, 0 @jaxtyped(typechecker=beartype) def process_frame( self, frame: UInt8[ndarray, "h w c"], metadata: dict[str, object], ) -> dict[str, object] | None: frame_idx = self._extract_int(metadata, "frame_count", fallback=0) timestamp_ns = self._extract_timestamp(metadata) track_fn_obj = getattr(self._detector, "track", None) if not callable(track_fn_obj): raise RuntimeError("YOLO detector does not expose a callable track()") track_fn = cast(_TrackCallable, track_fn_obj) detections = track_fn( frame, persist=True, verbose=False, device=self._device, classes=[0], ) first = self._first_result(detections) if first is None: return None selected = self._select_silhouette(first) if selected is None: return None silhouette, track_id = selected self._window.push(silhouette, frame_idx=frame_idx, track_id=track_id) if not self._window.should_classify(): return None window_tensor = self._window.get_tensor(device=self._device) label, confidence = cast( tuple[str, float], self._classifier.predict(window_tensor), ) self._window.mark_classified() window_start = frame_idx - self._window.window_size + 1 result = create_result( frame=frame_idx, track_id=track_id, label=label, confidence=float(confidence), window=(max(0, window_start), frame_idx), timestamp_ns=timestamp_ns, ) self._publisher.publish(result) return result def run(self) -> int: frame_count = 0 start_time = time.perf_counter() try: for item in self._source: frame, metadata = item frame_u8 = np.asarray(frame, dtype=np.uint8) frame_idx = self._extract_int(metadata, "frame_count", fallback=0) frame_count += 1 try: _ = self.process_frame(frame_u8, metadata) except Exception as frame_error: logger.warning( "Skipping frame %d due to processing error: %s", frame_idx, frame_error, ) if frame_count % 100 == 0: elapsed = time.perf_counter() - start_time fps = frame_count / elapsed if elapsed > 0 else 0.0 logger.info("Processed %d frames (%.2f FPS)", frame_count, fps) return 0 except KeyboardInterrupt: logger.info("Interrupted by user, shutting down cleanly.") return 130 finally: self.close() def close(self) -> None: if self._closed: return close_fn = getattr(self._publisher, "close", None) if callable(close_fn): with suppress(Exception): _ = close_fn() self._closed = True def validate_runtime_inputs(source: str, checkpoint: str, config: str) -> None: if source.startswith("cvmmap://") or source.isdigit(): pass else: source_path = Path(source) if not source_path.is_file(): raise ValueError(f"Video source not found: {source}") checkpoint_path = Path(checkpoint) if not checkpoint_path.is_file(): raise ValueError(f"Checkpoint not found: {checkpoint}") config_path = Path(config) if not config_path.is_file(): raise ValueError(f"Config not found: {config}") @click.command(context_settings={"help_option_names": ["-h", "--help"]}) @click.option("--source", type=str, required=True) @click.option("--checkpoint", type=str, required=True) @click.option( "--config", type=str, default="configs/sconet/sconet_scoliosis1k.yaml", show_default=True, ) @click.option("--device", type=str, default="cuda:0", show_default=True) @click.option("--yolo-model", type=str, default="yolo11n-seg.pt", show_default=True) @click.option("--window", type=click.IntRange(min=1), default=30, show_default=True) @click.option("--stride", type=click.IntRange(min=1), default=30, show_default=True) @click.option("--nats-url", type=str, default=None) @click.option( "--nats-subject", type=str, default="scoliosis.result", show_default=True, ) @click.option("--max-frames", type=click.IntRange(min=1), default=None) def main( source: str, checkpoint: str, config: str, device: str, yolo_model: str, window: int, stride: int, nats_url: str | None, nats_subject: str, max_frames: int | None, ) -> None: logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s", ) try: validate_runtime_inputs(source=source, checkpoint=checkpoint, config=config) pipeline = ScoliosisPipeline( source=source, checkpoint=checkpoint, config=config, device=device, yolo_model=yolo_model, window=window, stride=stride, nats_url=nats_url, nats_subject=nats_subject, max_frames=max_frames, ) raise SystemExit(pipeline.run()) except ValueError as err: click.echo(f"Error: {err}", err=True) raise SystemExit(2) from err except RuntimeError as err: click.echo(f"Runtime error: {err}", err=True) raise SystemExit(1) from err