150 lines
5.2 KiB
Python
150 lines
5.2 KiB
Python
from __future__ import annotations
|
|
|
|
import argparse
|
|
import logging
|
|
import sys
|
|
from typing import cast
|
|
|
|
from .pipeline import ScoliosisPipeline, WindowMode, resolve_stride
|
|
|
|
|
|
def _positive_float(value: str) -> float:
|
|
parsed = float(value)
|
|
if parsed <= 0:
|
|
raise argparse.ArgumentTypeError("target-fps must be positive")
|
|
return parsed
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Scoliosis Detection Pipeline")
|
|
parser.add_argument(
|
|
"--source", type=str, required=True, help="Video source path or camera ID"
|
|
)
|
|
parser.add_argument(
|
|
"--checkpoint", type=str, required=True, help="Model checkpoint path"
|
|
)
|
|
parser.add_argument(
|
|
"--config",
|
|
type=str,
|
|
default="configs/sconet/sconet_scoliosis1k.yaml",
|
|
help="Config file path",
|
|
)
|
|
parser.add_argument("--device", type=str, default="cuda:0", help="Device to run on")
|
|
parser.add_argument(
|
|
"--yolo-model", type=str, default="ckpt/yolo11n-seg.pt", help="YOLO model name"
|
|
)
|
|
parser.add_argument(
|
|
"--window", type=int, default=30, help="Window size for classification"
|
|
)
|
|
parser.add_argument("--stride", type=int, default=30, help="Stride for window")
|
|
parser.add_argument(
|
|
"--target-fps",
|
|
type=_positive_float,
|
|
default=15.0,
|
|
help="Target FPS for temporal downsampling before windowing",
|
|
)
|
|
parser.add_argument(
|
|
"--window-mode",
|
|
type=str,
|
|
choices=["manual", "sliding", "chunked"],
|
|
default="manual",
|
|
help="Window scheduling mode: manual uses --stride; sliding uses stride=1; chunked uses stride=window",
|
|
)
|
|
parser.add_argument(
|
|
"--no-target-fps",
|
|
action="store_true",
|
|
help="Disable temporal downsampling and use all frames",
|
|
)
|
|
parser.add_argument(
|
|
"--nats-url", type=str, default=None, help="NATS URL for result publishing"
|
|
)
|
|
parser.add_argument(
|
|
"--nats-subject", type=str, default="scoliosis.result", help="NATS subject"
|
|
)
|
|
parser.add_argument(
|
|
"--max-frames", type=int, default=None, help="Maximum frames to process"
|
|
)
|
|
parser.add_argument(
|
|
"--preprocess-only", action="store_true", help="Only preprocess silhouettes"
|
|
)
|
|
parser.add_argument(
|
|
"--silhouette-export-path",
|
|
type=str,
|
|
default=None,
|
|
help="Path to export silhouettes",
|
|
)
|
|
parser.add_argument(
|
|
"--silhouette-export-format", type=str, default="pickle", help="Export format"
|
|
)
|
|
parser.add_argument(
|
|
"--silhouette-visualize-dir",
|
|
type=str,
|
|
default=None,
|
|
help="Directory for silhouette visualizations",
|
|
)
|
|
parser.add_argument(
|
|
"--result-export-path", type=str, default=None, help="Path to export results"
|
|
)
|
|
parser.add_argument(
|
|
"--result-export-format", type=str, default="json", help="Result export format"
|
|
)
|
|
parser.add_argument(
|
|
"--visualize", action="store_true", help="Enable real-time visualization"
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
|
)
|
|
|
|
# Validate preprocess-only mode requires silhouette export path
|
|
if args.preprocess_only and not args.silhouette_export_path:
|
|
print(
|
|
"Error: --silhouette-export-path is required when using --preprocess-only",
|
|
file=sys.stderr,
|
|
)
|
|
raise SystemExit(2)
|
|
|
|
try:
|
|
# Import here to avoid circular imports
|
|
from .pipeline import validate_runtime_inputs
|
|
|
|
validate_runtime_inputs(
|
|
source=args.source, checkpoint=args.checkpoint, config=args.config
|
|
)
|
|
|
|
effective_stride = resolve_stride(
|
|
window=cast(int, args.window),
|
|
stride=cast(int, args.stride),
|
|
window_mode=cast(WindowMode, args.window_mode),
|
|
)
|
|
|
|
pipeline = ScoliosisPipeline(
|
|
source=cast(str, args.source),
|
|
checkpoint=cast(str, args.checkpoint),
|
|
config=cast(str, args.config),
|
|
device=cast(str, args.device),
|
|
yolo_model=cast(str, args.yolo_model),
|
|
window=cast(int, args.window),
|
|
stride=effective_stride,
|
|
target_fps=(None if args.no_target_fps else cast(float, args.target_fps)),
|
|
nats_url=cast(str | None, args.nats_url),
|
|
nats_subject=cast(str, args.nats_subject),
|
|
max_frames=cast(int | None, args.max_frames),
|
|
preprocess_only=cast(bool, args.preprocess_only),
|
|
silhouette_export_path=cast(str | None, args.silhouette_export_path),
|
|
silhouette_export_format=cast(str, args.silhouette_export_format),
|
|
silhouette_visualize_dir=cast(str | None, args.silhouette_visualize_dir),
|
|
result_export_path=cast(str | None, args.result_export_path),
|
|
result_export_format=cast(str, args.result_export_format),
|
|
visualize=cast(bool, args.visualize),
|
|
)
|
|
raise SystemExit(pipeline.run())
|
|
except ValueError as err:
|
|
print(f"Error: {err}", file=sys.stderr)
|
|
raise SystemExit(2) from err
|
|
except RuntimeError as err:
|
|
print(f"Runtime error: {err}", file=sys.stderr)
|
|
raise SystemExit(1) from err
|