from __future__ import annotations from collections.abc import Callable from contextlib import suppress import logging from pathlib import Path import time from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, TypedDict, cast from beartype import beartype import click import jaxtyping from jaxtyping import Float, UInt8 import numpy as np from numpy import ndarray from numpy.typing import NDArray from ultralytics.models.yolo.model import YOLO from .input import FrameStream, create_source from .output import DemoResult, ResultPublisher, create_publisher, create_result from .preprocess import BBoxXYXY, frame_to_person_mask, mask_to_silhouette from .sconet_demo import ScoNetDemo from .window import SilhouetteWindow, select_person if TYPE_CHECKING: from .visualizer import OpenCVVisualizer logger = logging.getLogger(__name__) JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]] JaxtypedFactory = Callable[..., JaxtypedDecorator] jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped) WindowMode: TypeAlias = Literal["manual", "sliding", "chunked"] def resolve_stride(window: int, stride: int, window_mode: WindowMode) -> int: if window_mode == "manual": return stride if window_mode == "sliding": return 1 return window class _BoxesLike(Protocol): @property def xyxy(self) -> NDArray[np.float32] | object: ... @property def id(self) -> NDArray[np.int64] | object | None: ... class _MasksLike(Protocol): @property def data(self) -> NDArray[np.float32] | object: ... class _DetectionResultsLike(Protocol): @property def boxes(self) -> _BoxesLike: ... @property def masks(self) -> _MasksLike: ... class _TrackCallable(Protocol): def __call__( self, source: object, *, persist: bool = True, verbose: bool = False, device: str | None = None, classes: list[int] | None = None, ) -> object: ... class _SelectedSilhouette(TypedDict): """Selected silhouette payload produced from detector outputs. Fields: silhouette: Normalized silhouette tensor fed into ScoNet `(64, 44)`. mask_raw: Full-resolution binary person mask in mask/image space. bbox_frame: Person bbox in frame coordinates `(x1, y1, x2, y2)` for visualization. bbox_mask: Person bbox in mask coordinates `(x1, y1, x2, y2)` for cropping. track_id: Tracking ID from detector, or `0` for fallback path. """ silhouette: Float[ndarray, "64 44"] mask_raw: UInt8[ndarray, "h w"] bbox_frame: BBoxXYXY bbox_mask: BBoxXYXY track_id: int class _FramePacer: _interval_ns: int _next_emit_ns: int | None def __init__(self, target_fps: float) -> None: if target_fps <= 0: raise ValueError(f"target_fps must be positive, got {target_fps}") self._interval_ns = int(1_000_000_000 / target_fps) self._next_emit_ns = None def should_emit(self, timestamp_ns: int) -> bool: if self._next_emit_ns is None: self._next_emit_ns = timestamp_ns + self._interval_ns return True if timestamp_ns >= self._next_emit_ns: while self._next_emit_ns <= timestamp_ns: self._next_emit_ns += self._interval_ns return True return False class ScoliosisPipeline: _detector: object _source: FrameStream _window: SilhouetteWindow _publisher: ResultPublisher _classifier: ScoNetDemo _device: str _closed: bool _preprocess_only: bool _silhouette_export_path: Path | None _silhouette_export_format: str _silhouette_buffer: list[dict[str, object]] _silhouette_visualize_dir: Path | None _result_export_path: Path | None _result_export_format: str _result_buffer: list[DemoResult] _visualizer: OpenCVVisualizer | None _last_viz_payload: dict[str, object] | None _frame_pacer: _FramePacer | None def __init__( self, *, source: str, checkpoint: str, config: str, device: str, yolo_model: str, window: int, stride: int, nats_url: str | None, nats_subject: str, max_frames: int | None, preprocess_only: bool = False, silhouette_export_path: str | None = None, silhouette_export_format: str = "pickle", silhouette_visualize_dir: str | None = None, result_export_path: str | None = None, result_export_format: str = "json", visualize: bool = False, target_fps: float | None = 15.0, ) -> None: self._detector = YOLO(yolo_model) self._source = create_source(source, max_frames=max_frames) self._window = SilhouetteWindow(window_size=window, stride=stride) self._publisher = create_publisher(nats_url=nats_url, subject=nats_subject) self._classifier = ScoNetDemo( cfg_path=config, checkpoint_path=checkpoint, device=device, ) self._device = device self._closed = False self._preprocess_only = preprocess_only self._silhouette_export_path = ( Path(silhouette_export_path) if silhouette_export_path else None ) self._silhouette_export_format = silhouette_export_format # Normalize format alias: pkl -> pickle if self._silhouette_export_format == "pkl": self._silhouette_export_format = "pickle" self._silhouette_buffer = [] self._silhouette_visualize_dir = ( Path(silhouette_visualize_dir) if silhouette_visualize_dir else None ) self._result_export_path = ( Path(result_export_path) if result_export_path else None ) self._result_export_format = result_export_format self._result_buffer = [] if visualize: from .visualizer import OpenCVVisualizer self._visualizer = OpenCVVisualizer() else: self._visualizer = None self._last_viz_payload = None self._frame_pacer = _FramePacer(target_fps) if target_fps is not None else None @staticmethod def _extract_int(meta: dict[str, object], key: str, fallback: int) -> int: value = meta.get(key) if isinstance(value, int): return value return fallback @staticmethod def _extract_timestamp(meta: dict[str, object]) -> int: value = meta.get("timestamp_ns") if isinstance(value, int): return value return time.monotonic_ns() @staticmethod def _to_mask_u8(mask: ndarray) -> UInt8[ndarray, "h w"]: binary = np.where(np.asarray(mask) > 0.5, np.uint8(255), np.uint8(0)).astype( np.uint8 ) return cast(UInt8[ndarray, "h w"], binary) def _first_result(self, detections: object) -> _DetectionResultsLike | None: if isinstance(detections, list): return cast(_DetectionResultsLike, detections[0]) if detections else None if isinstance(detections, tuple): return cast(_DetectionResultsLike, detections[0]) if detections else None return cast(_DetectionResultsLike, detections) def _select_silhouette( self, result: _DetectionResultsLike, ) -> _SelectedSilhouette | None: selected = select_person(result) if selected is not None: mask_raw, bbox_mask, bbox_frame, track_id = selected silhouette = cast( Float[ndarray, "64 44"] | None, mask_to_silhouette(self._to_mask_u8(mask_raw), bbox_mask), ) if silhouette is not None: return { "silhouette": silhouette, "mask_raw": mask_raw, "bbox_frame": bbox_frame, "bbox_mask": bbox_mask, "track_id": int(track_id), } fallback = cast( tuple[UInt8[ndarray, "h w"], BBoxXYXY] | None, frame_to_person_mask(result), ) if fallback is None: return None mask_u8, bbox_mask = fallback silhouette = cast( Float[ndarray, "64 44"] | None, mask_to_silhouette(mask_u8, bbox_mask), ) if silhouette is None: return None # Convert mask-space bbox to frame-space for visualization # Use result.orig_shape to get frame dimensions safely orig_shape = getattr(result, "orig_shape", None) if ( orig_shape is not None and isinstance(orig_shape, (tuple, list)) and len(orig_shape) >= 2 ): frame_h, frame_w = int(orig_shape[0]), int(orig_shape[1]) mask_h, mask_w = mask_u8.shape[0], mask_u8.shape[1] if mask_w > 0 and mask_h > 0 and frame_w > 0 and frame_h > 0: scale_x = frame_w / mask_w scale_y = frame_h / mask_h bbox_frame = ( int(bbox_mask[0] * scale_x), int(bbox_mask[1] * scale_y), int(bbox_mask[2] * scale_x), int(bbox_mask[3] * scale_y), ) else: # Fallback: use mask-space bbox if dimensions invalid bbox_frame = bbox_mask else: # Fallback: use mask-space bbox if orig_shape unavailable bbox_frame = bbox_mask # For fallback case, mask_raw is the same as mask_u8 return { "silhouette": silhouette, "mask_raw": mask_u8, "bbox_frame": bbox_frame, "bbox_mask": bbox_mask, "track_id": 0, } @jaxtyped(typechecker=beartype) def process_frame( self, frame: UInt8[ndarray, "h w c"], metadata: dict[str, object], ) -> dict[str, object] | None: frame_idx = self._extract_int(metadata, "frame_count", fallback=0) timestamp_ns = self._extract_timestamp(metadata) track_fn_obj = getattr(self._detector, "track", None) if not callable(track_fn_obj): raise RuntimeError("YOLO detector does not expose a callable track()") track_fn = cast(_TrackCallable, track_fn_obj) detections = track_fn( frame, persist=True, verbose=False, device=self._device, classes=[0], ) first = self._first_result(detections) if first is None: return None selected = self._select_silhouette(first) if selected is None: return None silhouette = selected["silhouette"] mask_raw = selected["mask_raw"] bbox = selected["bbox_frame"] bbox_mask = selected["bbox_mask"] track_id = selected["track_id"] # Store silhouette for export if in preprocess-only mode or if export requested if self._silhouette_export_path is not None or self._preprocess_only: self._silhouette_buffer.append( { "frame": frame_idx, "track_id": track_id, "timestamp_ns": timestamp_ns, "silhouette": silhouette.copy(), } ) # Visualize silhouette if requested if self._silhouette_visualize_dir is not None: self._visualize_silhouette(silhouette, frame_idx, track_id) if self._preprocess_only: # Return visualization payload for display even in preprocess-only mode return { "mask_raw": mask_raw, "bbox": bbox, "bbox_mask": bbox_mask, "silhouette": silhouette, "segmentation_input": None, "track_id": track_id, "label": None, "confidence": None, } if self._frame_pacer is not None and not self._frame_pacer.should_emit( timestamp_ns ): return { "mask_raw": mask_raw, "bbox": bbox, "bbox_mask": bbox_mask, "silhouette": silhouette, "segmentation_input": self._window.buffered_silhouettes, "track_id": track_id, "label": None, "confidence": None, } self._window.push(silhouette, frame_idx=frame_idx, track_id=track_id) segmentation_input = self._window.buffered_silhouettes if not self._window.should_classify(): # Return visualization payload even when not classifying yet return { "mask_raw": mask_raw, "bbox": bbox, "bbox_mask": bbox_mask, "silhouette": silhouette, "segmentation_input": segmentation_input, "track_id": track_id, "label": None, "confidence": None, } window_tensor = self._window.get_tensor(device=self._device) label, confidence = cast( tuple[str, float], self._classifier.predict(window_tensor), ) self._window.mark_classified() window_start = self._window.window_start_frame result = create_result( frame=frame_idx, track_id=track_id, label=label, confidence=float(confidence), window=(max(0, window_start), frame_idx), timestamp_ns=timestamp_ns, ) # Store result for export if export path specified if self._result_export_path is not None: self._result_buffer.append(result) self._publisher.publish(result) # Return result with visualization payload return { "result": result, "mask_raw": mask_raw, "bbox": bbox, "bbox_mask": bbox_mask, "silhouette": silhouette, "segmentation_input": segmentation_input, "track_id": track_id, "label": label, "confidence": confidence, } def run(self) -> int: frame_count = 0 start_time = time.perf_counter() # EMA FPS state (alpha=0.1 for smoothing) ema_fps = 0.0 alpha = 0.1 prev_time = start_time try: for item in self._source: frame, metadata = item frame_u8 = np.asarray(frame, dtype=np.uint8) frame_idx = self._extract_int(metadata, "frame_count", fallback=0) frame_count += 1 # Compute per-frame EMA FPS curr_time = time.perf_counter() delta = curr_time - prev_time prev_time = curr_time if delta > 0: instant_fps = 1.0 / delta if ema_fps == 0.0: ema_fps = instant_fps else: ema_fps = alpha * instant_fps + (1 - alpha) * ema_fps viz_payload = None try: viz_payload = self.process_frame(frame_u8, metadata) except Exception as frame_error: logger.warning( "Skipping frame %d due to processing error: %s", frame_idx, frame_error, ) # Update visualizer if enabled if self._visualizer is not None: # Cache valid payload for no-detection frames if viz_payload is not None: # Cache a copy to prevent mutation of original data viz_payload_dict = cast(dict[str, object], viz_payload) cached: dict[str, object] = {} for k, v in viz_payload_dict.items(): copy_method = cast( Callable[[], object] | None, getattr(v, "copy", None) ) if copy_method is not None and callable(copy_method): cached[k] = copy_method() else: cached[k] = v self._last_viz_payload = cached if viz_payload is not None: viz_data = viz_payload elif self._last_viz_payload is not None: viz_data = dict(self._last_viz_payload) viz_data["bbox"] = None viz_data["bbox_mask"] = None viz_data["label"] = None viz_data["confidence"] = None else: viz_data = None if viz_data is not None: # Cast viz_payload to dict for type checking viz_dict = cast(dict[str, object], viz_data) mask_raw_obj = viz_dict.get("mask_raw") bbox_obj = viz_dict.get("bbox") bbox_mask_obj = viz_dict.get("bbox_mask") silhouette_obj = viz_dict.get("silhouette") segmentation_input_obj = viz_dict.get("segmentation_input") track_id_val = viz_dict.get("track_id", 0) track_id = track_id_val if isinstance(track_id_val, int) else 0 label_obj = viz_dict.get("label") confidence_obj = viz_dict.get("confidence") # Cast extracted values to expected types mask_raw = cast(NDArray[np.uint8] | None, mask_raw_obj) bbox = cast(BBoxXYXY | None, bbox_obj) bbox_mask = cast(BBoxXYXY | None, bbox_mask_obj) silhouette = cast(NDArray[np.float32] | None, silhouette_obj) segmentation_input = cast( NDArray[np.float32] | None, segmentation_input_obj, ) label = cast(str | None, label_obj) confidence = cast(float | None, confidence_obj) else: # No detection and no cache - use default values mask_raw = None bbox = None bbox_mask = None track_id = 0 silhouette = None segmentation_input = None label = None confidence = None keep_running = self._visualizer.update( frame_u8, bbox, bbox_mask, track_id, mask_raw, silhouette, segmentation_input, label, confidence, ema_fps, ) if not keep_running: logger.info("Visualization closed by user.") break if frame_count % 100 == 0: elapsed = time.perf_counter() - start_time fps = frame_count / elapsed if elapsed > 0 else 0.0 logger.info("Processed %d frames (%.2f FPS)", frame_count, fps) return 0 except KeyboardInterrupt: logger.info("Interrupted by user, shutting down cleanly.") return 130 finally: self.close() def close(self) -> None: if self._closed: return # Close visualizer if enabled if self._visualizer is not None: with suppress(Exception): self._visualizer.close() # Export silhouettes if requested if self._silhouette_export_path is not None and self._silhouette_buffer: self._export_silhouettes() # Export results if requested if self._result_export_path is not None and self._result_buffer: self._export_results() close_fn = getattr(self._publisher, "close", None) if callable(close_fn): with suppress(Exception): _ = close_fn() self._closed = True def _export_silhouettes(self) -> None: """Export silhouettes to file in specified format.""" if self._silhouette_export_path is None: return self._silhouette_export_path.parent.mkdir(parents=True, exist_ok=True) if self._silhouette_export_format == "pickle": import pickle with open(self._silhouette_export_path, "wb") as f: pickle.dump(self._silhouette_buffer, f) logger.info( "Exported %d silhouettes to %s", len(self._silhouette_buffer), self._silhouette_export_path, ) elif self._silhouette_export_format == "parquet": self._export_parquet_silhouettes() else: raise ValueError( f"Unsupported silhouette export format: {self._silhouette_export_format}" ) def _visualize_silhouette( self, silhouette: Float[ndarray, "64 44"], frame_idx: int, track_id: int, ) -> None: """Save silhouette as PNG image.""" if self._silhouette_visualize_dir is None: return self._silhouette_visualize_dir.mkdir(parents=True, exist_ok=True) # Convert float silhouette to uint8 (0-255) silhouette_u8 = (silhouette * 255).astype(np.uint8) # Create deterministic filename filename = f"silhouette_frame{frame_idx:06d}_track{track_id:04d}.png" output_path = self._silhouette_visualize_dir / filename # Save using PIL from PIL import Image Image.fromarray(silhouette_u8).save(output_path) def _export_parquet_silhouettes(self) -> None: """Export silhouettes to parquet format.""" import importlib try: pa = importlib.import_module("pyarrow") pq = importlib.import_module("pyarrow.parquet") except ImportError as e: raise RuntimeError( "Parquet export requires pyarrow. Install with: pip install pyarrow" ) from e # Convert silhouettes to columnar format frames = [] track_ids = [] timestamps = [] silhouettes = [] for item in self._silhouette_buffer: frames.append(item["frame"]) track_ids.append(item["track_id"]) timestamps.append(item["timestamp_ns"]) silhouette_array = cast(ndarray, item["silhouette"]) silhouettes.append(silhouette_array.flatten().tolist()) table = pa.table( { "frame": pa.array(frames, type=pa.int64()), "track_id": pa.array(track_ids, type=pa.int64()), "timestamp_ns": pa.array(timestamps, type=pa.int64()), "silhouette": pa.array(silhouettes, type=pa.list_(pa.float64())), } ) pq.write_table(table, self._silhouette_export_path) logger.info( "Exported %d silhouettes to parquet: %s", len(self._silhouette_buffer), self._silhouette_export_path, ) def _export_results(self) -> None: """Export results to file in specified format.""" if self._result_export_path is None: return self._result_export_path.parent.mkdir(parents=True, exist_ok=True) if self._result_export_format == "json": import json with open(self._result_export_path, "w", encoding="utf-8") as f: for result in self._result_buffer: f.write(json.dumps(result, ensure_ascii=False, default=str) + "\n") logger.info( "Exported %d results to JSON: %s", len(self._result_buffer), self._result_export_path, ) elif self._result_export_format == "pickle": import pickle with open(self._result_export_path, "wb") as f: pickle.dump(self._result_buffer, f) logger.info( "Exported %d results to pickle: %s", len(self._result_buffer), self._result_export_path, ) elif self._result_export_format == "parquet": self._export_parquet_results() else: raise ValueError( f"Unsupported result export format: {self._result_export_format}" ) def _export_parquet_results(self) -> None: """Export results to parquet format.""" import importlib try: pa = importlib.import_module("pyarrow") pq = importlib.import_module("pyarrow.parquet") except ImportError as e: raise RuntimeError( "Parquet export requires pyarrow. Install with: pip install pyarrow" ) from e frames = [] track_ids = [] labels = [] confidences = [] windows = [] timestamps = [] for result in self._result_buffer: frames.append(result["frame"]) track_ids.append(result["track_id"]) labels.append(result["label"]) confidences.append(result["confidence"]) windows.append(result["window"]) timestamps.append(result["timestamp_ns"]) table = pa.table( { "frame": pa.array(frames, type=pa.int64()), "track_id": pa.array(track_ids, type=pa.int64()), "label": pa.array(labels, type=pa.string()), "confidence": pa.array(confidences, type=pa.float64()), "window": pa.array(windows, type=pa.int64()), "timestamp_ns": pa.array(timestamps, type=pa.int64()), } ) pq.write_table(table, self._result_export_path) logger.info( "Exported %d results to parquet: %s", len(self._result_buffer), self._result_export_path, ) def validate_runtime_inputs(source: str, checkpoint: str, config: str) -> None: if source.startswith("cvmmap://") or source.isdigit(): pass else: source_path = Path(source) if not source_path.is_file(): raise ValueError(f"Video source not found: {source}") checkpoint_path = Path(checkpoint) if not checkpoint_path.is_file(): raise ValueError(f"Checkpoint not found: {checkpoint}") config_path = Path(config) if not config_path.is_file(): raise ValueError(f"Config not found: {config}") @click.command(context_settings={"help_option_names": ["-h", "--help"]}) @click.option("--source", type=str, required=True) @click.option("--checkpoint", type=str, required=True) @click.option( "--config", type=str, default="configs/sconet/sconet_scoliosis1k.yaml", show_default=True, ) @click.option("--device", type=str, default="cuda:0", show_default=True) @click.option( "--yolo-model", type=str, default="ckpt/yolo11n-seg.pt", show_default=True ) @click.option("--window", type=click.IntRange(min=1), default=30, show_default=True) @click.option("--stride", type=click.IntRange(min=1), default=30, show_default=True) @click.option( "--window-mode", type=click.Choice(["manual", "sliding", "chunked"], case_sensitive=False), default="manual", show_default=True, help=( "Window scheduling mode: manual uses --stride; " "sliding forces stride=1; chunked forces stride=window" ), ) @click.option( "--target-fps", type=click.FloatRange(min=0.1), default=15.0, show_default=True, ) @click.option("--no-target-fps", is_flag=True, default=False) @click.option("--nats-url", type=str, default=None) @click.option( "--nats-subject", type=str, default="scoliosis.result", show_default=True, ) @click.option("--max-frames", type=click.IntRange(min=1), default=None) @click.option( "--preprocess-only", is_flag=True, default=False, help="Only preprocess silhouettes, skip classification.", ) @click.option( "--silhouette-export-path", type=str, default=None, help="Path to export silhouettes (required for preprocess-only mode).", ) @click.option( "--silhouette-export-format", type=click.Choice(["pickle", "parquet"]), default="pickle", show_default=True, help="Format for silhouette export.", ) @click.option( "--result-export-path", type=str, default=None, help="Path to export inference results.", ) @click.option( "--result-export-format", type=click.Choice(["json", "pickle", "parquet"]), default="json", show_default=True, help="Format for result export.", ) @click.option( "--silhouette-visualize-dir", type=str, default=None, help="Directory to save silhouette PNG visualizations.", ) def main( source: str, checkpoint: str, config: str, device: str, yolo_model: str, window: int, stride: int, window_mode: str, target_fps: float | None, no_target_fps: bool, nats_url: str | None, nats_subject: str, max_frames: int | None, preprocess_only: bool, silhouette_export_path: str | None, silhouette_export_format: str, result_export_path: str | None, result_export_format: str, silhouette_visualize_dir: str | None, ) -> None: logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s", ) # Validate preprocess-only mode requirements if preprocess_only and not silhouette_export_path: raise click.UsageError( "--silhouette-export-path is required when using --preprocess-only" ) try: validate_runtime_inputs(source=source, checkpoint=checkpoint, config=config) effective_stride = resolve_stride( window=window, stride=stride, window_mode=cast(WindowMode, window_mode.lower()), ) if effective_stride != stride: logger.info( "window_mode=%s overrides stride=%d -> effective_stride=%d", window_mode, stride, effective_stride, ) pipeline = ScoliosisPipeline( source=source, checkpoint=checkpoint, config=config, device=device, yolo_model=yolo_model, window=window, stride=effective_stride, target_fps=None if no_target_fps else target_fps, nats_url=nats_url, nats_subject=nats_subject, max_frames=max_frames, preprocess_only=preprocess_only, silhouette_export_path=silhouette_export_path, silhouette_export_format=silhouette_export_format, silhouette_visualize_dir=silhouette_visualize_dir, result_export_path=result_export_path, result_export_format=result_export_format, ) raise SystemExit(pipeline.run()) except ValueError as err: click.echo(f"Error: {err}", err=True) raise SystemExit(2) from err except RuntimeError as err: click.echo(f"Runtime error: {err}", err=True) raise SystemExit(1) from err