Files
OpenGait/opengait-studio/opengait_studio/pipeline.py
T
crosstyan 00fcda4fe3 feat: extract opengait_studio monorepo module
Move demo implementation into opengait_studio, retire Sports2D runtime integration, and align packaging with root-level monorepo dependency management.
2026-03-07 18:14:13 +08:00

948 lines
33 KiB
Python

from __future__ import annotations
from collections.abc import Callable
from contextlib import suppress
import logging
from pathlib import Path
import time
from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, TypedDict, cast
from beartype import beartype
import click
import jaxtyping
from jaxtyping import Float, UInt8
import numpy as np
from numpy import ndarray
from numpy.typing import NDArray
from ultralytics.models.yolo.model import YOLO
from .input import FrameStream, create_source
from .output import DemoResult, ResultPublisher, create_publisher, create_result
from .preprocess import BBoxXYXY, frame_to_person_mask, mask_to_silhouette
from .sconet_demo import ScoNetDemo
from .window import SilhouetteWindow, select_person
if TYPE_CHECKING:
from .visualizer import OpenCVVisualizer
logger = logging.getLogger(__name__)
JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
JaxtypedFactory = Callable[..., JaxtypedDecorator]
jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
WindowMode: TypeAlias = Literal["manual", "sliding", "chunked"]
def resolve_stride(window: int, stride: int, window_mode: WindowMode) -> int:
if window_mode == "manual":
return stride
if window_mode == "sliding":
return 1
return window
class _BoxesLike(Protocol):
@property
def xyxy(self) -> NDArray[np.float32] | object: ...
@property
def id(self) -> NDArray[np.int64] | object | None: ...
class _MasksLike(Protocol):
@property
def data(self) -> NDArray[np.float32] | object: ...
class _DetectionResultsLike(Protocol):
@property
def boxes(self) -> _BoxesLike: ...
@property
def masks(self) -> _MasksLike: ...
class _TrackCallable(Protocol):
def __call__(
self,
source: object,
*,
persist: bool = True,
verbose: bool = False,
device: str | None = None,
classes: list[int] | None = None,
) -> object: ...
class _SelectedSilhouette(TypedDict):
"""Selected silhouette payload produced from detector outputs.
Fields:
silhouette: Normalized silhouette tensor fed into ScoNet `(64, 44)`.
mask_raw: Full-resolution binary person mask in mask/image space.
bbox_frame: Person bbox in frame coordinates `(x1, y1, x2, y2)` for visualization.
bbox_mask: Person bbox in mask coordinates `(x1, y1, x2, y2)` for cropping.
track_id: Tracking ID from detector, or `0` for fallback path.
"""
silhouette: Float[ndarray, "64 44"]
mask_raw: UInt8[ndarray, "h w"]
bbox_frame: BBoxXYXY
bbox_mask: BBoxXYXY
track_id: int
class _VizPayload(TypedDict, total=False):
result: DemoResult
mask_raw: UInt8[ndarray, "h w"] | None
bbox: BBoxXYXY | None
bbox_mask: BBoxXYXY | None
silhouette: Float[ndarray, "64 44"] | None
segmentation_input: NDArray[np.float32] | None
track_id: int
label: str | None
confidence: float | None
pose: dict[str, object] | None
class _FramePacer:
_interval_ns: int
_next_emit_ns: int | None
def __init__(self, target_fps: float) -> None:
if target_fps <= 0:
raise ValueError(f"target_fps must be positive, got {target_fps}")
self._interval_ns = int(1_000_000_000 / target_fps)
self._next_emit_ns = None
def should_emit(self, timestamp_ns: int) -> bool:
if self._next_emit_ns is None:
self._next_emit_ns = timestamp_ns + self._interval_ns
return True
if timestamp_ns >= self._next_emit_ns:
while self._next_emit_ns <= timestamp_ns:
self._next_emit_ns += self._interval_ns
return True
return False
class ScoliosisPipeline:
_detector: object
_source: FrameStream
_window: SilhouetteWindow
_publisher: ResultPublisher
_classifier: ScoNetDemo
_device: str
_closed: bool
_preprocess_only: bool
_silhouette_export_path: Path | None
_silhouette_export_format: str
_silhouette_buffer: list[dict[str, object]]
_silhouette_visualize_dir: Path | None
_result_export_path: Path | None
_result_export_format: str
_result_buffer: list[DemoResult]
_visualizer: OpenCVVisualizer | None
_last_viz_payload: _VizPayload | None
_frame_pacer: _FramePacer | None
def __init__(
self,
*,
source: str,
checkpoint: str,
config: str,
device: str,
yolo_model: str,
window: int,
stride: int,
nats_url: str | None,
nats_subject: str,
max_frames: int | None,
preprocess_only: bool = False,
silhouette_export_path: str | None = None,
silhouette_export_format: str = "pickle",
silhouette_visualize_dir: str | None = None,
result_export_path: str | None = None,
result_export_format: str = "json",
visualize: bool = False,
target_fps: float | None = 15.0,
) -> None:
self._detector = YOLO(yolo_model)
self._source = create_source(source, max_frames=max_frames)
self._window = SilhouetteWindow(window_size=window, stride=stride)
self._publisher = create_publisher(nats_url=nats_url, subject=nats_subject)
self._classifier = ScoNetDemo(
cfg_path=config,
checkpoint_path=checkpoint,
device=device,
)
self._device = device
self._closed = False
self._preprocess_only = preprocess_only
self._silhouette_export_path = (
Path(silhouette_export_path) if silhouette_export_path else None
)
self._silhouette_export_format = silhouette_export_format
# Normalize format alias: pkl -> pickle
if self._silhouette_export_format == "pkl":
self._silhouette_export_format = "pickle"
self._silhouette_buffer = []
self._silhouette_visualize_dir = (
Path(silhouette_visualize_dir) if silhouette_visualize_dir else None
)
self._result_export_path = (
Path(result_export_path) if result_export_path else None
)
self._result_export_format = result_export_format
self._result_buffer = []
if visualize:
from .visualizer import OpenCVVisualizer
self._visualizer = OpenCVVisualizer()
else:
self._visualizer = None
self._last_viz_payload = None
self._frame_pacer = _FramePacer(target_fps) if target_fps is not None else None
@staticmethod
def _extract_int(meta: dict[str, object], key: str, fallback: int) -> int:
value = meta.get(key)
if isinstance(value, int):
return value
return fallback
@staticmethod
def _extract_timestamp(meta: dict[str, object]) -> int:
value = meta.get("timestamp_ns")
if isinstance(value, int):
return value
return time.monotonic_ns()
@staticmethod
def _to_mask_u8(mask: NDArray[np.generic]) -> UInt8[ndarray, "h w"]:
mask_arr: NDArray[np.floating] = np.asarray(mask, dtype=np.float32) # type: ignore[reportAssignmentType]
binary = np.where(mask_arr > 0.5, np.uint8(255), np.uint8(0)).astype(np.uint8)
return cast(UInt8[ndarray, "h w"], binary)
def _first_result(self, detections: object) -> _DetectionResultsLike | None:
if isinstance(detections, list):
return cast(_DetectionResultsLike, detections[0]) if detections else None
if isinstance(detections, tuple):
return cast(_DetectionResultsLike, detections[0]) if detections else None
return cast(_DetectionResultsLike, detections)
def _select_silhouette(
self,
result: _DetectionResultsLike,
) -> _SelectedSilhouette | None:
selected = select_person(result)
if selected is not None:
mask_raw, bbox_mask, bbox_frame, track_id = selected
silhouette = cast(
Float[ndarray, "64 44"] | None,
mask_to_silhouette(self._to_mask_u8(mask_raw), bbox_mask),
)
if silhouette is not None:
return {
"silhouette": silhouette,
"mask_raw": mask_raw,
"bbox_frame": bbox_frame,
"bbox_mask": bbox_mask,
"track_id": int(track_id),
}
fallback = cast(
tuple[UInt8[ndarray, "h w"], BBoxXYXY] | None,
frame_to_person_mask(result),
)
if fallback is None:
return None
mask_u8, bbox_mask = fallback
silhouette = cast(
Float[ndarray, "64 44"] | None,
mask_to_silhouette(mask_u8, bbox_mask),
)
if silhouette is None:
return None
# Convert mask-space bbox to frame-space for visualization
# Use result.orig_shape to get frame dimensions safely
orig_shape = getattr(result, "orig_shape", None)
if (
orig_shape is not None
and isinstance(orig_shape, (tuple, list))
and len(orig_shape) >= 2
):
frame_h, frame_w = int(orig_shape[0]), int(orig_shape[1])
mask_h, mask_w = mask_u8.shape[0], mask_u8.shape[1]
if mask_w > 0 and mask_h > 0 and frame_w > 0 and frame_h > 0:
scale_x = frame_w / mask_w
scale_y = frame_h / mask_h
bbox_frame = (
int(bbox_mask[0] * scale_x),
int(bbox_mask[1] * scale_y),
int(bbox_mask[2] * scale_x),
int(bbox_mask[3] * scale_y),
)
else:
# Fallback: use mask-space bbox if dimensions invalid
bbox_frame = bbox_mask
else:
# Fallback: use mask-space bbox if orig_shape unavailable
bbox_frame = bbox_mask
# For fallback case, mask_raw is the same as mask_u8
return {
"silhouette": silhouette,
"mask_raw": mask_u8,
"bbox_frame": bbox_frame,
"bbox_mask": bbox_mask,
"track_id": 0,
}
@jaxtyped(typechecker=beartype)
def process_frame(
self,
frame: UInt8[ndarray, "h w c"],
metadata: dict[str, object],
) -> _VizPayload | None:
frame_idx = self._extract_int(metadata, "frame_count", fallback=0)
timestamp_ns = self._extract_timestamp(metadata)
track_fn_obj = getattr(self._detector, "track", None)
if not callable(track_fn_obj):
raise RuntimeError("YOLO detector does not expose a callable track()")
track_fn = cast(_TrackCallable, track_fn_obj)
detections = track_fn(
frame,
persist=True,
verbose=False,
device=self._device,
classes=[0],
)
first = self._first_result(detections)
if first is None:
return None
selected = self._select_silhouette(first)
if selected is None:
return None
silhouette = selected["silhouette"]
mask_raw = selected["mask_raw"]
bbox = selected["bbox_frame"]
bbox_mask = selected["bbox_mask"]
track_id = selected["track_id"]
pose_data = None
if self._silhouette_export_path is not None or self._preprocess_only:
self._silhouette_buffer.append(
{
"frame": frame_idx,
"track_id": track_id,
"timestamp_ns": timestamp_ns,
"silhouette": silhouette.copy(),
}
)
# Visualize silhouette if requested
if self._silhouette_visualize_dir is not None:
self._visualize_silhouette(silhouette, frame_idx, track_id)
if self._preprocess_only:
# Return visualization payload for display even in preprocess-only mode
return {
"mask_raw": mask_raw,
"bbox": bbox,
"bbox_mask": bbox_mask,
"silhouette": silhouette,
"segmentation_input": None,
"track_id": track_id,
"label": None,
"confidence": None,
"pose": pose_data,
}
self._window.push(silhouette, frame_idx=frame_idx, track_id=track_id)
if self._frame_pacer is not None and not self._frame_pacer.should_emit(
timestamp_ns
):
return {
"mask_raw": mask_raw,
"bbox": bbox,
"bbox_mask": bbox_mask,
"silhouette": silhouette,
"segmentation_input": self._window.buffered_silhouettes,
"track_id": track_id,
"label": None,
"confidence": None,
"pose": pose_data,
}
segmentation_input = self._window.buffered_silhouettes
if not self._window.should_classify():
# Return visualization payload even when not classifying yet
return {
"mask_raw": mask_raw,
"bbox": bbox,
"bbox_mask": bbox_mask,
"silhouette": silhouette,
"segmentation_input": segmentation_input,
"track_id": track_id,
"label": None,
"confidence": None,
"pose": pose_data,
}
window_tensor = self._window.get_tensor(device=self._device)
label, confidence = cast(
tuple[str, float],
self._classifier.predict(window_tensor),
)
self._window.mark_classified()
window_start = self._window.window_start_frame
result = create_result(
frame=frame_idx,
track_id=track_id,
label=label,
confidence=float(confidence),
window=(max(0, window_start), frame_idx),
timestamp_ns=timestamp_ns,
)
# Store result for export if export path specified
if self._result_export_path is not None:
self._result_buffer.append(result)
self._publisher.publish(result)
# Return result with visualization payload
return {
"result": result,
"mask_raw": mask_raw,
"bbox": bbox,
"bbox_mask": bbox_mask,
"silhouette": silhouette,
"segmentation_input": segmentation_input,
"track_id": track_id,
"label": label,
"confidence": confidence,
"pose": pose_data,
}
def run(self) -> int:
frame_count = 0
start_time = time.perf_counter()
# EMA FPS state (alpha=0.1 for smoothing)
ema_fps = 0.0
alpha = 0.1
prev_time = start_time
try:
for item in self._source:
frame, metadata = item
frame_u8 = np.asarray(frame, dtype=np.uint8)
frame_idx = self._extract_int(metadata, "frame_count", fallback=0)
frame_count += 1
# Compute per-frame EMA FPS
curr_time = time.perf_counter()
delta = curr_time - prev_time
prev_time = curr_time
if delta > 0:
instant_fps = 1.0 / delta
if ema_fps == 0.0:
ema_fps = instant_fps
else:
ema_fps = alpha * instant_fps + (1 - alpha) * ema_fps
viz_payload = None
try:
viz_payload = self.process_frame(frame_u8, metadata)
except (RuntimeError, ValueError, TypeError, OSError) as frame_error:
logger.warning(
"Skipping frame %d due to processing error: %s",
frame_idx,
frame_error,
)
# Update visualizer if enabled
if self._visualizer is not None:
# Cache valid payload for no-detection frames
if viz_payload is not None:
# Cache a copy to prevent mutation of original data
viz_payload_dict = cast(_VizPayload, viz_payload)
cached: _VizPayload = {}
for k, v in viz_payload_dict.items():
copy_method = cast(
Callable[[], object] | None, getattr(v, "copy", None)
)
if copy_method is not None and callable(copy_method):
cached[k] = copy_method()
else:
cached[k] = v
self._last_viz_payload = cached
if viz_payload is not None:
viz_data = viz_payload
elif self._last_viz_payload is not None:
viz_data = dict(self._last_viz_payload)
viz_data["bbox"] = None
viz_data["bbox_mask"] = None
viz_data["label"] = None
viz_data["confidence"] = None
viz_data["pose"] = None
else:
viz_data = None
if viz_data is not None:
# Cast viz_payload to dict for type checking
viz_dict = cast(_VizPayload, viz_data)
mask_raw_obj = viz_dict.get("mask_raw")
bbox_obj = viz_dict.get("bbox")
bbox_mask_obj = viz_dict.get("bbox_mask")
silhouette_obj = viz_dict.get("silhouette")
segmentation_input_obj = viz_dict.get("segmentation_input")
track_id_val = viz_dict.get("track_id", 0)
track_id = track_id_val if isinstance(track_id_val, int) else 0
label_obj = viz_dict.get("label")
confidence_obj = viz_dict.get("confidence")
pose_obj = viz_dict.get("pose")
mask_raw = cast(NDArray[np.uint8] | None, mask_raw_obj)
bbox = cast(BBoxXYXY | None, bbox_obj)
bbox_mask = cast(BBoxXYXY | None, bbox_mask_obj)
silhouette = cast(NDArray[np.float32] | None, silhouette_obj)
segmentation_input = cast(
NDArray[np.float32] | None,
segmentation_input_obj,
)
label = cast(str | None, label_obj)
confidence = cast(float | None, confidence_obj)
pose_data = cast(dict[str, object] | None, pose_obj)
else:
# No detection and no cache - use default values
mask_raw = None
bbox = None
bbox_mask = None
track_id = 0
silhouette = None
segmentation_input = None
label = None
confidence = None
pose_data = None
# Try keyword arg for pose_data (backward compatible with old signatures)
try:
keep_running = self._visualizer.update(
frame_u8,
bbox,
bbox_mask,
track_id,
mask_raw,
silhouette,
segmentation_input,
label,
confidence,
ema_fps,
pose_data=pose_data,
)
except TypeError:
# Fallback for legacy visualizers that don't accept pose_data
keep_running = self._visualizer.update(
frame_u8,
bbox,
bbox_mask,
track_id,
mask_raw,
silhouette,
segmentation_input,
label,
confidence,
ema_fps,
)
if not keep_running:
logger.info("Visualization closed by user.")
break
if frame_count % 100 == 0:
elapsed = time.perf_counter() - start_time
fps = frame_count / elapsed if elapsed > 0 else 0.0
logger.info("Processed %d frames (%.2f FPS)", frame_count, fps)
return 0
except KeyboardInterrupt:
logger.info("Interrupted by user, shutting down cleanly.")
return 130
finally:
self.close()
def close(self) -> None:
if self._closed:
return
# Close visualizer if enabled
if self._visualizer is not None:
with suppress(Exception):
self._visualizer.close()
# Export silhouettes if requested
if self._silhouette_export_path is not None and self._silhouette_buffer:
self._export_silhouettes()
# Export results if requested
if self._result_export_path is not None and self._result_buffer:
self._export_results()
close_fn = getattr(self._publisher, "close", None)
if callable(close_fn):
with suppress(Exception):
_ = close_fn()
self._closed = True
def _export_silhouettes(self) -> None:
"""Export silhouettes to file in specified format."""
if self._silhouette_export_path is None:
return
self._silhouette_export_path.parent.mkdir(parents=True, exist_ok=True)
if self._silhouette_export_format == "pickle":
import pickle
with open(self._silhouette_export_path, "wb") as f:
pickle.dump(self._silhouette_buffer, f)
logger.info(
"Exported %d silhouettes to %s",
len(self._silhouette_buffer),
self._silhouette_export_path,
)
elif self._silhouette_export_format == "parquet":
self._export_parquet_silhouettes()
else:
raise ValueError(
f"Unsupported silhouette export format: {self._silhouette_export_format}"
)
def _visualize_silhouette(
self,
silhouette: Float[ndarray, "64 44"],
frame_idx: int,
track_id: int,
) -> None:
"""Save silhouette as PNG image."""
if self._silhouette_visualize_dir is None:
return
self._silhouette_visualize_dir.mkdir(parents=True, exist_ok=True)
# Convert float silhouette to uint8 (0-255)
silhouette_u8 = (silhouette * 255).astype(np.uint8)
# Create deterministic filename
filename = f"silhouette_frame{frame_idx:06d}_track{track_id:04d}.png"
output_path = self._silhouette_visualize_dir / filename
# Save using PIL
from PIL import Image
Image.fromarray(silhouette_u8).save(output_path)
def _export_parquet_silhouettes(self) -> None:
"""Export silhouettes to parquet format."""
import importlib
try:
pa = importlib.import_module("pyarrow")
pq = importlib.import_module("pyarrow.parquet")
except ImportError as e:
raise RuntimeError(
"Parquet export requires pyarrow. Install with: pip install pyarrow"
) from e
# Convert silhouettes to columnar format
frames = []
track_ids = []
timestamps = []
silhouettes = []
for item in self._silhouette_buffer:
frames.append(item["frame"])
track_ids.append(item["track_id"])
timestamps.append(item["timestamp_ns"])
silhouette_array = cast(NDArray[np.float32], item["silhouette"])
silhouettes.append(silhouette_array.flatten().tolist())
table = pa.table(
{
"frame": pa.array(frames, type=pa.int64()),
"track_id": pa.array(track_ids, type=pa.int64()),
"timestamp_ns": pa.array(timestamps, type=pa.int64()),
"silhouette": pa.array(silhouettes, type=pa.list_(pa.float64())),
}
)
pq.write_table(table, self._silhouette_export_path)
logger.info(
"Exported %d silhouettes to parquet: %s",
len(self._silhouette_buffer),
self._silhouette_export_path,
)
def _export_results(self) -> None:
"""Export results to file in specified format."""
if self._result_export_path is None:
return
self._result_export_path.parent.mkdir(parents=True, exist_ok=True)
if self._result_export_format == "json":
import json
with open(self._result_export_path, "w", encoding="utf-8") as f:
for result in self._result_buffer:
f.write(json.dumps(result, ensure_ascii=False, default=str) + "\n")
logger.info(
"Exported %d results to JSON: %s",
len(self._result_buffer),
self._result_export_path,
)
elif self._result_export_format == "pickle":
import pickle
with open(self._result_export_path, "wb") as f:
pickle.dump(self._result_buffer, f)
logger.info(
"Exported %d results to pickle: %s",
len(self._result_buffer),
self._result_export_path,
)
elif self._result_export_format == "parquet":
self._export_parquet_results()
else:
raise ValueError(
f"Unsupported result export format: {self._result_export_format}"
)
def _export_parquet_results(self) -> None:
"""Export results to parquet format."""
import importlib
try:
pa = importlib.import_module("pyarrow")
pq = importlib.import_module("pyarrow.parquet")
except ImportError as e:
raise RuntimeError(
"Parquet export requires pyarrow. Install with: pip install pyarrow"
) from e
frames = []
track_ids = []
labels = []
confidences = []
windows = []
timestamps = []
for result in self._result_buffer:
frames.append(result["frame"])
track_ids.append(result["track_id"])
labels.append(result["label"])
confidences.append(result["confidence"])
windows.append(result["window"])
timestamps.append(result["timestamp_ns"])
table = pa.table(
{
"frame": pa.array(frames, type=pa.int64()),
"track_id": pa.array(track_ids, type=pa.int64()),
"label": pa.array(labels, type=pa.string()),
"confidence": pa.array(confidences, type=pa.float64()),
"window": pa.array(windows, type=pa.int64()),
"timestamp_ns": pa.array(timestamps, type=pa.int64()),
}
)
pq.write_table(table, self._result_export_path)
logger.info(
"Exported %d results to parquet: %s",
len(self._result_buffer),
self._result_export_path,
)
def validate_runtime_inputs(source: str, checkpoint: str, config: str) -> None:
if source.startswith("cvmmap://") or source.isdigit():
pass
else:
source_path = Path(source)
if not source_path.is_file():
raise ValueError(f"Video source not found: {source}")
checkpoint_path = Path(checkpoint)
if not checkpoint_path.is_file():
raise ValueError(f"Checkpoint not found: {checkpoint}")
config_path = Path(config)
if not config_path.is_file():
raise ValueError(f"Config not found: {config}")
@click.command(context_settings={"help_option_names": ["-h", "--help"]})
@click.option("--source", type=str, required=True)
@click.option("--checkpoint", type=str, required=True)
@click.option(
"--config",
type=str,
default="configs/sconet/sconet_scoliosis1k.yaml",
show_default=True,
)
@click.option("--device", type=str, default="cuda:0", show_default=True)
@click.option(
"--yolo-model", type=str, default="ckpt/yolo11n-seg.pt", show_default=True
)
@click.option("--window", type=click.IntRange(min=1), default=30, show_default=True)
@click.option("--stride", type=click.IntRange(min=1), default=30, show_default=True)
@click.option(
"--window-mode",
type=click.Choice(["manual", "sliding", "chunked"], case_sensitive=False),
default="manual",
show_default=True,
help=(
"Window scheduling mode: manual uses --stride; "
"sliding forces stride=1; chunked forces stride=window"
),
)
@click.option(
"--target-fps",
type=click.FloatRange(min=0.1),
default=15.0,
show_default=True,
)
@click.option("--no-target-fps", is_flag=True, default=False)
@click.option("--nats-url", type=str, default=None)
@click.option(
"--nats-subject",
type=str,
default="scoliosis.result",
show_default=True,
)
@click.option("--max-frames", type=click.IntRange(min=1), default=None)
@click.option(
"--preprocess-only",
is_flag=True,
default=False,
help="Only preprocess silhouettes, skip classification.",
)
@click.option(
"--silhouette-export-path",
type=str,
default=None,
help="Path to export silhouettes (required for preprocess-only mode).",
)
@click.option(
"--silhouette-export-format",
type=click.Choice(["pickle", "parquet"]),
default="pickle",
show_default=True,
help="Format for silhouette export.",
)
@click.option(
"--result-export-path",
type=str,
default=None,
help="Path to export inference results.",
)
@click.option(
"--result-export-format",
type=click.Choice(["json", "pickle", "parquet"]),
default="json",
show_default=True,
help="Format for result export.",
)
@click.option(
"--silhouette-visualize-dir",
type=str,
default=None,
help="Directory to save silhouette PNG visualizations.",
)
@click.option(
"--visualize",
is_flag=True,
default=False,
help="Enable real-time visualization.",
)
def main(
source: str,
checkpoint: str,
config: str,
device: str,
yolo_model: str,
window: int,
stride: int,
window_mode: str,
target_fps: float,
no_target_fps: bool,
nats_url: str | None,
nats_subject: str,
max_frames: int | None,
preprocess_only: bool,
silhouette_export_path: str | None,
silhouette_export_format: str,
result_export_path: str | None,
result_export_format: str,
silhouette_visualize_dir: str | None,
visualize: bool,
) -> None:
# Resolve effective target_fps: respect --no-target_fps to disable pacing
effective_target_fps = None if no_target_fps else target_fps
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
# Validate preprocess-only mode requirements
if preprocess_only and not silhouette_export_path:
raise click.UsageError(
"--silhouette-export-path is required when using --preprocess-only"
)
try:
validate_runtime_inputs(source=source, checkpoint=checkpoint, config=config)
effective_stride = resolve_stride(
window=window,
stride=stride,
window_mode=cast(WindowMode, window_mode.lower()),
)
if effective_stride != stride:
logger.info(
"window_mode=%s overrides stride=%d -> effective_stride=%d",
window_mode,
stride,
effective_stride,
)
pipeline = ScoliosisPipeline(
source=source,
checkpoint=checkpoint,
config=config,
device=device,
yolo_model=yolo_model,
window=window,
stride=effective_stride,
nats_url=nats_url,
nats_subject=nats_subject,
max_frames=max_frames,
preprocess_only=preprocess_only,
silhouette_export_path=silhouette_export_path,
silhouette_export_format=silhouette_export_format,
silhouette_visualize_dir=silhouette_visualize_dir,
result_export_path=result_export_path,
result_export_format=result_export_format,
visualize=visualize,
target_fps=effective_target_fps,
)
raise SystemExit(pipeline.run())
except ValueError as err:
click.echo(f"Error: {err}", err=True)
raise SystemExit(2) from err
except RuntimeError as err:
click.echo(f"Runtime error: {err}", err=True)
raise SystemExit(1) from err