from __future__ import annotations import argparse import logging import sys from typing import cast from .pipeline import ScoliosisPipeline, WindowMode, resolve_stride def _positive_float(value: str) -> float: parsed = float(value) if parsed <= 0: raise argparse.ArgumentTypeError("target-fps must be positive") return parsed if __name__ == "__main__": parser = argparse.ArgumentParser(description="Scoliosis Detection Pipeline") parser.add_argument( "--source", type=str, required=True, help="Video source path or camera ID" ) parser.add_argument( "--checkpoint", type=str, required=True, help="Model checkpoint path" ) parser.add_argument( "--config", type=str, default="configs/sconet/sconet_scoliosis1k.yaml", help="Config file path", ) parser.add_argument("--device", type=str, default="cuda:0", help="Device to run on") parser.add_argument( "--yolo-model", type=str, default="ckpt/yolo11n-seg.pt", help="YOLO model name" ) parser.add_argument( "--window", type=int, default=30, help="Window size for classification" ) parser.add_argument("--stride", type=int, default=30, help="Stride for window") parser.add_argument( "--target-fps", type=_positive_float, default=15.0, help="Target FPS for temporal downsampling before windowing", ) parser.add_argument( "--window-mode", type=str, choices=["manual", "sliding", "chunked"], default="manual", help="Window scheduling mode: manual uses --stride; sliding uses stride=1; chunked uses stride=window", ) parser.add_argument( "--no-target-fps", action="store_true", help="Disable temporal downsampling and use all frames", ) parser.add_argument( "--nats-url", type=str, default=None, help="NATS URL for result publishing" ) parser.add_argument( "--nats-subject", type=str, default="scoliosis.result", help="NATS subject" ) parser.add_argument( "--max-frames", type=int, default=None, help="Maximum frames to process" ) parser.add_argument( "--preprocess-only", action="store_true", help="Only preprocess silhouettes" ) parser.add_argument( "--silhouette-export-path", type=str, default=None, help="Path to export silhouettes", ) parser.add_argument( "--silhouette-export-format", type=str, default="pickle", help="Export format" ) parser.add_argument( "--silhouette-visualize-dir", type=str, default=None, help="Directory for silhouette visualizations", ) parser.add_argument( "--result-export-path", type=str, default=None, help="Path to export results" ) parser.add_argument( "--result-export-format", type=str, default="json", help="Result export format" ) parser.add_argument( "--visualize", action="store_true", help="Enable real-time visualization" ) args = parser.parse_args() logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s", ) # Validate preprocess-only mode requires silhouette export path if args.preprocess_only and not args.silhouette_export_path: print( "Error: --silhouette-export-path is required when using --preprocess-only", file=sys.stderr, ) raise SystemExit(2) try: # Import here to avoid circular imports from .pipeline import validate_runtime_inputs validate_runtime_inputs( source=args.source, checkpoint=args.checkpoint, config=args.config ) effective_stride = resolve_stride( window=cast(int, args.window), stride=cast(int, args.stride), window_mode=cast(WindowMode, args.window_mode), ) pipeline = ScoliosisPipeline( source=cast(str, args.source), checkpoint=cast(str, args.checkpoint), config=cast(str, args.config), device=cast(str, args.device), yolo_model=cast(str, args.yolo_model), window=cast(int, args.window), stride=effective_stride, target_fps=(None if args.no_target_fps else cast(float, args.target_fps)), nats_url=cast(str | None, args.nats_url), nats_subject=cast(str, args.nats_subject), max_frames=cast(int | None, args.max_frames), preprocess_only=cast(bool, args.preprocess_only), silhouette_export_path=cast(str | None, args.silhouette_export_path), silhouette_export_format=cast(str, args.silhouette_export_format), silhouette_visualize_dir=cast(str | None, args.silhouette_visualize_dir), result_export_path=cast(str | None, args.result_export_path), result_export_format=cast(str, args.result_export_format), visualize=cast(bool, args.visualize), ) raise SystemExit(pipeline.run()) except ValueError as err: print(f"Error: {err}", file=sys.stderr) raise SystemExit(2) from err except RuntimeError as err: print(f"Runtime error: {err}", file=sys.stderr) raise SystemExit(1) from err