55e8155adc
Accepting "pkl" as an alias for "pickle" avoids runtime export failures for common shorthand CLI usage while preserving existing export behavior.
766 lines
26 KiB
Python
766 lines
26 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import Callable
|
|
from contextlib import suppress
|
|
import logging
|
|
from pathlib import Path
|
|
import time
|
|
from typing import TYPE_CHECKING, Protocol, cast
|
|
|
|
from beartype import beartype
|
|
import click
|
|
import jaxtyping
|
|
from jaxtyping import Float, UInt8
|
|
import numpy as np
|
|
from numpy import ndarray
|
|
from numpy.typing import NDArray
|
|
from ultralytics.models.yolo.model import YOLO
|
|
|
|
from .input import FrameStream, create_source
|
|
from .output import DemoResult, ResultPublisher, create_publisher, create_result
|
|
from .preprocess import BBoxXYXY, frame_to_person_mask, mask_to_silhouette
|
|
from .sconet_demo import ScoNetDemo
|
|
from .window import SilhouetteWindow, select_person
|
|
|
|
if TYPE_CHECKING:
|
|
from .visualizer import OpenCVVisualizer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
|
|
JaxtypedFactory = Callable[..., JaxtypedDecorator]
|
|
jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
|
|
|
|
|
|
class _BoxesLike(Protocol):
|
|
@property
|
|
def xyxy(self) -> NDArray[np.float32] | object: ...
|
|
|
|
@property
|
|
def id(self) -> NDArray[np.int64] | object | None: ...
|
|
|
|
|
|
class _MasksLike(Protocol):
|
|
@property
|
|
def data(self) -> NDArray[np.float32] | object: ...
|
|
|
|
|
|
class _DetectionResultsLike(Protocol):
|
|
@property
|
|
def boxes(self) -> _BoxesLike: ...
|
|
|
|
@property
|
|
def masks(self) -> _MasksLike: ...
|
|
|
|
|
|
|
|
class _TrackCallable(Protocol):
|
|
def __call__(
|
|
self,
|
|
source: object,
|
|
*,
|
|
persist: bool = True,
|
|
verbose: bool = False,
|
|
device: str | None = None,
|
|
classes: list[int] | None = None,
|
|
) -> object: ...
|
|
|
|
|
|
class ScoliosisPipeline:
|
|
_detector: object
|
|
_source: FrameStream
|
|
_window: SilhouetteWindow
|
|
_publisher: ResultPublisher
|
|
_classifier: ScoNetDemo
|
|
_device: str
|
|
_closed: bool
|
|
_preprocess_only: bool
|
|
_silhouette_export_path: Path | None
|
|
_silhouette_export_format: str
|
|
_silhouette_buffer: list[dict[str, object]]
|
|
_silhouette_visualize_dir: Path | None
|
|
_result_export_path: Path | None
|
|
_result_export_format: str
|
|
_result_buffer: list[DemoResult]
|
|
_visualizer: OpenCVVisualizer | None
|
|
_last_viz_payload: dict[str, object] | None
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
source: str,
|
|
checkpoint: str,
|
|
config: str,
|
|
device: str,
|
|
yolo_model: str,
|
|
window: int,
|
|
stride: int,
|
|
nats_url: str | None,
|
|
nats_subject: str,
|
|
max_frames: int | None,
|
|
preprocess_only: bool = False,
|
|
silhouette_export_path: str | None = None,
|
|
silhouette_export_format: str = "pickle",
|
|
silhouette_visualize_dir: str | None = None,
|
|
result_export_path: str | None = None,
|
|
result_export_format: str = "json",
|
|
visualize: bool = False,
|
|
) -> None:
|
|
self._detector = YOLO(yolo_model)
|
|
self._source = create_source(source, max_frames=max_frames)
|
|
self._window = SilhouetteWindow(window_size=window, stride=stride)
|
|
self._publisher = create_publisher(nats_url=nats_url, subject=nats_subject)
|
|
self._classifier = ScoNetDemo(
|
|
cfg_path=config,
|
|
checkpoint_path=checkpoint,
|
|
device=device,
|
|
)
|
|
self._device = device
|
|
self._closed = False
|
|
self._preprocess_only = preprocess_only
|
|
self._silhouette_export_path = (
|
|
Path(silhouette_export_path) if silhouette_export_path else None
|
|
)
|
|
self._silhouette_export_format = silhouette_export_format
|
|
# Normalize format alias: pkl -> pickle
|
|
if self._silhouette_export_format == "pkl":
|
|
self._silhouette_export_format = "pickle"
|
|
self._silhouette_buffer = []
|
|
self._silhouette_visualize_dir = (
|
|
Path(silhouette_visualize_dir) if silhouette_visualize_dir else None
|
|
)
|
|
self._result_export_path = (
|
|
Path(result_export_path) if result_export_path else None
|
|
)
|
|
self._result_export_format = result_export_format
|
|
self._result_buffer = []
|
|
if visualize:
|
|
from .visualizer import OpenCVVisualizer
|
|
|
|
self._visualizer = OpenCVVisualizer()
|
|
else:
|
|
self._visualizer = None
|
|
self._last_viz_payload = None
|
|
|
|
@staticmethod
|
|
def _extract_int(meta: dict[str, object], key: str, fallback: int) -> int:
|
|
value = meta.get(key)
|
|
if isinstance(value, int):
|
|
return value
|
|
return fallback
|
|
|
|
@staticmethod
|
|
def _extract_timestamp(meta: dict[str, object]) -> int:
|
|
value = meta.get("timestamp_ns")
|
|
if isinstance(value, int):
|
|
return value
|
|
return time.monotonic_ns()
|
|
|
|
@staticmethod
|
|
def _to_mask_u8(mask: ndarray) -> UInt8[ndarray, "h w"]:
|
|
binary = np.where(np.asarray(mask) > 0.5, np.uint8(255), np.uint8(0)).astype(
|
|
np.uint8
|
|
)
|
|
return cast(UInt8[ndarray, "h w"], binary)
|
|
|
|
def _first_result(self, detections: object) -> _DetectionResultsLike | None:
|
|
if isinstance(detections, list):
|
|
return cast(_DetectionResultsLike, detections[0]) if detections else None
|
|
if isinstance(detections, tuple):
|
|
return cast(_DetectionResultsLike, detections[0]) if detections else None
|
|
return cast(_DetectionResultsLike, detections)
|
|
|
|
def _select_silhouette(
|
|
self,
|
|
result: _DetectionResultsLike,
|
|
) -> (
|
|
tuple[
|
|
Float[ndarray, "64 44"],
|
|
UInt8[ndarray, "h w"],
|
|
BBoxXYXY,
|
|
int,
|
|
]
|
|
| None
|
|
):
|
|
selected = select_person(result)
|
|
if selected is not None:
|
|
mask_raw, bbox_mask, bbox_frame, track_id = selected
|
|
silhouette = cast(
|
|
Float[ndarray, "64 44"] | None,
|
|
mask_to_silhouette(self._to_mask_u8(mask_raw), bbox_mask),
|
|
)
|
|
if silhouette is not None:
|
|
return silhouette, mask_raw, bbox_frame, int(track_id)
|
|
|
|
fallback = cast(
|
|
tuple[UInt8[ndarray, "h w"], BBoxXYXY] | None,
|
|
frame_to_person_mask(result),
|
|
)
|
|
if fallback is None:
|
|
return None
|
|
|
|
mask_u8, bbox_mask = fallback
|
|
silhouette = cast(
|
|
Float[ndarray, "64 44"] | None,
|
|
mask_to_silhouette(mask_u8, bbox_mask),
|
|
)
|
|
if silhouette is None:
|
|
return None
|
|
# Convert mask-space bbox to frame-space for visualization
|
|
# Use result.orig_shape to get frame dimensions safely
|
|
orig_shape = getattr(result, "orig_shape", None)
|
|
if orig_shape is not None and isinstance(orig_shape, (tuple, list)) and len(orig_shape) >= 2:
|
|
frame_h, frame_w = int(orig_shape[0]), int(orig_shape[1])
|
|
mask_h, mask_w = mask_u8.shape[0], mask_u8.shape[1]
|
|
if mask_w > 0 and mask_h > 0 and frame_w > 0 and frame_h > 0:
|
|
scale_x = frame_w / mask_w
|
|
scale_y = frame_h / mask_h
|
|
bbox_frame = (
|
|
int(bbox_mask[0] * scale_x),
|
|
int(bbox_mask[1] * scale_y),
|
|
int(bbox_mask[2] * scale_x),
|
|
int(bbox_mask[3] * scale_y),
|
|
)
|
|
else:
|
|
# Fallback: use mask-space bbox if dimensions invalid
|
|
bbox_frame = bbox_mask
|
|
else:
|
|
# Fallback: use mask-space bbox if orig_shape unavailable
|
|
bbox_frame = bbox_mask
|
|
# For fallback case, mask_raw is the same as mask_u8
|
|
return silhouette, mask_u8, bbox_frame, 0
|
|
|
|
|
|
@jaxtyped(typechecker=beartype)
|
|
def process_frame(
|
|
self,
|
|
frame: UInt8[ndarray, "h w c"],
|
|
metadata: dict[str, object],
|
|
) -> dict[str, object] | None:
|
|
frame_idx = self._extract_int(metadata, "frame_count", fallback=0)
|
|
timestamp_ns = self._extract_timestamp(metadata)
|
|
|
|
track_fn_obj = getattr(self._detector, "track", None)
|
|
if not callable(track_fn_obj):
|
|
raise RuntimeError("YOLO detector does not expose a callable track()")
|
|
|
|
track_fn = cast(_TrackCallable, track_fn_obj)
|
|
detections = track_fn(
|
|
frame,
|
|
persist=True,
|
|
verbose=False,
|
|
device=self._device,
|
|
classes=[0],
|
|
)
|
|
first = self._first_result(detections)
|
|
if first is None:
|
|
return None
|
|
|
|
selected = self._select_silhouette(first)
|
|
if selected is None:
|
|
return None
|
|
|
|
silhouette, mask_raw, bbox, track_id = selected
|
|
|
|
# Store silhouette for export if in preprocess-only mode or if export requested
|
|
if self._silhouette_export_path is not None or self._preprocess_only:
|
|
self._silhouette_buffer.append(
|
|
{
|
|
"frame": frame_idx,
|
|
"track_id": track_id,
|
|
"timestamp_ns": timestamp_ns,
|
|
"silhouette": silhouette.copy(),
|
|
}
|
|
)
|
|
|
|
# Visualize silhouette if requested
|
|
if self._silhouette_visualize_dir is not None:
|
|
self._visualize_silhouette(silhouette, frame_idx, track_id)
|
|
|
|
if self._preprocess_only:
|
|
# Return visualization payload for display even in preprocess-only mode
|
|
return {
|
|
"mask_raw": mask_raw,
|
|
"bbox": bbox,
|
|
"silhouette": silhouette,
|
|
"track_id": track_id,
|
|
"label": None,
|
|
"confidence": None,
|
|
}
|
|
|
|
self._window.push(silhouette, frame_idx=frame_idx, track_id=track_id)
|
|
|
|
if not self._window.should_classify():
|
|
# Return visualization payload even when not classifying yet
|
|
return {
|
|
"mask_raw": mask_raw,
|
|
"bbox": bbox,
|
|
"silhouette": silhouette,
|
|
"track_id": track_id,
|
|
"label": None,
|
|
"confidence": None,
|
|
}
|
|
|
|
window_tensor = self._window.get_tensor(device=self._device)
|
|
label, confidence = cast(
|
|
tuple[str, float],
|
|
self._classifier.predict(window_tensor),
|
|
)
|
|
self._window.mark_classified()
|
|
|
|
window_start = frame_idx - self._window.window_size + 1
|
|
result = create_result(
|
|
frame=frame_idx,
|
|
track_id=track_id,
|
|
label=label,
|
|
confidence=float(confidence),
|
|
window=(max(0, window_start), frame_idx),
|
|
timestamp_ns=timestamp_ns,
|
|
)
|
|
|
|
# Store result for export if export path specified
|
|
if self._result_export_path is not None:
|
|
self._result_buffer.append(result)
|
|
|
|
self._publisher.publish(result)
|
|
# Return result with visualization payload
|
|
return {
|
|
"result": result,
|
|
"mask_raw": mask_raw,
|
|
"bbox": bbox,
|
|
"silhouette": silhouette,
|
|
"track_id": track_id,
|
|
"label": label,
|
|
"confidence": confidence,
|
|
}
|
|
|
|
def run(self) -> int:
|
|
frame_count = 0
|
|
start_time = time.perf_counter()
|
|
# EMA FPS state (alpha=0.1 for smoothing)
|
|
ema_fps = 0.0
|
|
alpha = 0.1
|
|
prev_time = start_time
|
|
try:
|
|
for item in self._source:
|
|
frame, metadata = item
|
|
frame_u8 = np.asarray(frame, dtype=np.uint8)
|
|
frame_idx = self._extract_int(metadata, "frame_count", fallback=0)
|
|
frame_count += 1
|
|
|
|
# Compute per-frame EMA FPS
|
|
curr_time = time.perf_counter()
|
|
delta = curr_time - prev_time
|
|
prev_time = curr_time
|
|
if delta > 0:
|
|
instant_fps = 1.0 / delta
|
|
if ema_fps == 0.0:
|
|
ema_fps = instant_fps
|
|
else:
|
|
ema_fps = alpha * instant_fps + (1 - alpha) * ema_fps
|
|
|
|
viz_payload = None
|
|
try:
|
|
viz_payload = self.process_frame(frame_u8, metadata)
|
|
except Exception as frame_error:
|
|
logger.warning(
|
|
"Skipping frame %d due to processing error: %s",
|
|
frame_idx,
|
|
frame_error,
|
|
)
|
|
|
|
# Update visualizer if enabled
|
|
if self._visualizer is not None:
|
|
# Cache valid payload for no-detection frames
|
|
if viz_payload is not None:
|
|
# Cache a copy to prevent mutation of original data
|
|
viz_payload_dict = cast(dict[str, object], viz_payload)
|
|
cached: dict[str, object] = {}
|
|
for k, v in viz_payload_dict.items():
|
|
copy_method = cast(Callable[[], object] | None, getattr(v, "copy", None))
|
|
if copy_method is not None:
|
|
cached[k] = copy_method()
|
|
else:
|
|
cached[k] = v
|
|
self._last_viz_payload = cached
|
|
|
|
# Use cached payload if current is None
|
|
viz_data = viz_payload if viz_payload is not None else self._last_viz_payload
|
|
|
|
if viz_data is not None:
|
|
# Cast viz_payload to dict for type checking
|
|
viz_dict = cast(dict[str, object], viz_data)
|
|
mask_raw_obj = viz_dict.get("mask_raw")
|
|
bbox_obj = viz_dict.get("bbox")
|
|
silhouette_obj = viz_dict.get("silhouette")
|
|
track_id_val = viz_dict.get("track_id", 0)
|
|
track_id = track_id_val if isinstance(track_id_val, int) else 0
|
|
label_obj = viz_dict.get("label")
|
|
confidence_obj = viz_dict.get("confidence")
|
|
|
|
# Cast extracted values to expected types
|
|
mask_raw = cast(NDArray[np.uint8] | None, mask_raw_obj)
|
|
bbox = cast(BBoxXYXY | None, bbox_obj)
|
|
silhouette = cast(NDArray[np.float32] | None, silhouette_obj)
|
|
label = cast(str | None, label_obj)
|
|
confidence = cast(float | None, confidence_obj)
|
|
else:
|
|
# No detection and no cache - use default values
|
|
mask_raw = None
|
|
bbox = None
|
|
track_id = 0
|
|
silhouette = None
|
|
label = None
|
|
confidence = None
|
|
|
|
keep_running = self._visualizer.update(
|
|
frame_u8,
|
|
bbox,
|
|
track_id,
|
|
mask_raw,
|
|
silhouette,
|
|
label,
|
|
confidence,
|
|
ema_fps,
|
|
)
|
|
if not keep_running:
|
|
logger.info("Visualization closed by user.")
|
|
break
|
|
|
|
if frame_count % 100 == 0:
|
|
elapsed = time.perf_counter() - start_time
|
|
fps = frame_count / elapsed if elapsed > 0 else 0.0
|
|
logger.info("Processed %d frames (%.2f FPS)", frame_count, fps)
|
|
return 0
|
|
except KeyboardInterrupt:
|
|
logger.info("Interrupted by user, shutting down cleanly.")
|
|
return 130
|
|
finally:
|
|
self.close()
|
|
|
|
def close(self) -> None:
|
|
if self._closed:
|
|
return
|
|
|
|
# Close visualizer if enabled
|
|
if self._visualizer is not None:
|
|
with suppress(Exception):
|
|
self._visualizer.close()
|
|
|
|
# Export silhouettes if requested
|
|
if self._silhouette_export_path is not None and self._silhouette_buffer:
|
|
self._export_silhouettes()
|
|
|
|
# Export results if requested
|
|
if self._result_export_path is not None and self._result_buffer:
|
|
self._export_results()
|
|
|
|
close_fn = getattr(self._publisher, "close", None)
|
|
if callable(close_fn):
|
|
with suppress(Exception):
|
|
_ = close_fn()
|
|
self._closed = True
|
|
|
|
def _export_silhouettes(self) -> None:
|
|
"""Export silhouettes to file in specified format."""
|
|
if self._silhouette_export_path is None:
|
|
return
|
|
|
|
self._silhouette_export_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
if self._silhouette_export_format == "pickle":
|
|
import pickle
|
|
|
|
with open(self._silhouette_export_path, "wb") as f:
|
|
pickle.dump(self._silhouette_buffer, f)
|
|
logger.info(
|
|
"Exported %d silhouettes to %s",
|
|
len(self._silhouette_buffer),
|
|
self._silhouette_export_path,
|
|
)
|
|
elif self._silhouette_export_format == "parquet":
|
|
self._export_parquet_silhouettes()
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported silhouette export format: {self._silhouette_export_format}"
|
|
)
|
|
|
|
def _visualize_silhouette(
|
|
self,
|
|
silhouette: Float[ndarray, "64 44"],
|
|
frame_idx: int,
|
|
track_id: int,
|
|
) -> None:
|
|
"""Save silhouette as PNG image."""
|
|
if self._silhouette_visualize_dir is None:
|
|
return
|
|
|
|
self._silhouette_visualize_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Convert float silhouette to uint8 (0-255)
|
|
silhouette_u8 = (silhouette * 255).astype(np.uint8)
|
|
|
|
# Create deterministic filename
|
|
filename = f"silhouette_frame{frame_idx:06d}_track{track_id:04d}.png"
|
|
output_path = self._silhouette_visualize_dir / filename
|
|
|
|
# Save using PIL
|
|
from PIL import Image
|
|
|
|
Image.fromarray(silhouette_u8).save(output_path)
|
|
|
|
def _export_parquet_silhouettes(self) -> None:
|
|
"""Export silhouettes to parquet format."""
|
|
import importlib
|
|
|
|
try:
|
|
pa = importlib.import_module("pyarrow")
|
|
pq = importlib.import_module("pyarrow.parquet")
|
|
except ImportError as e:
|
|
raise RuntimeError(
|
|
"Parquet export requires pyarrow. Install with: pip install pyarrow"
|
|
) from e
|
|
|
|
# Convert silhouettes to columnar format
|
|
frames = []
|
|
track_ids = []
|
|
timestamps = []
|
|
silhouettes = []
|
|
|
|
for item in self._silhouette_buffer:
|
|
frames.append(item["frame"])
|
|
track_ids.append(item["track_id"])
|
|
timestamps.append(item["timestamp_ns"])
|
|
silhouette_array = cast(ndarray, item["silhouette"])
|
|
silhouettes.append(silhouette_array.flatten().tolist())
|
|
|
|
table = pa.table(
|
|
{
|
|
"frame": pa.array(frames, type=pa.int64()),
|
|
"track_id": pa.array(track_ids, type=pa.int64()),
|
|
"timestamp_ns": pa.array(timestamps, type=pa.int64()),
|
|
"silhouette": pa.array(silhouettes, type=pa.list_(pa.float64())),
|
|
}
|
|
)
|
|
|
|
pq.write_table(table, self._silhouette_export_path)
|
|
logger.info(
|
|
"Exported %d silhouettes to parquet: %s",
|
|
len(self._silhouette_buffer),
|
|
self._silhouette_export_path,
|
|
)
|
|
|
|
def _export_results(self) -> None:
|
|
"""Export results to file in specified format."""
|
|
if self._result_export_path is None:
|
|
return
|
|
|
|
self._result_export_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
if self._result_export_format == "json":
|
|
import json
|
|
|
|
with open(self._result_export_path, "w", encoding="utf-8") as f:
|
|
for result in self._result_buffer:
|
|
f.write(json.dumps(result, ensure_ascii=False, default=str) + "\n")
|
|
logger.info(
|
|
"Exported %d results to JSON: %s",
|
|
len(self._result_buffer),
|
|
self._result_export_path,
|
|
)
|
|
elif self._result_export_format == "pickle":
|
|
import pickle
|
|
|
|
with open(self._result_export_path, "wb") as f:
|
|
pickle.dump(self._result_buffer, f)
|
|
logger.info(
|
|
"Exported %d results to pickle: %s",
|
|
len(self._result_buffer),
|
|
self._result_export_path,
|
|
)
|
|
elif self._result_export_format == "parquet":
|
|
self._export_parquet_results()
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported result export format: {self._result_export_format}"
|
|
)
|
|
|
|
def _export_parquet_results(self) -> None:
|
|
"""Export results to parquet format."""
|
|
import importlib
|
|
|
|
try:
|
|
pa = importlib.import_module("pyarrow")
|
|
pq = importlib.import_module("pyarrow.parquet")
|
|
except ImportError as e:
|
|
raise RuntimeError(
|
|
"Parquet export requires pyarrow. Install with: pip install pyarrow"
|
|
) from e
|
|
|
|
frames = []
|
|
track_ids = []
|
|
labels = []
|
|
confidences = []
|
|
windows = []
|
|
timestamps = []
|
|
|
|
for result in self._result_buffer:
|
|
frames.append(result["frame"])
|
|
track_ids.append(result["track_id"])
|
|
labels.append(result["label"])
|
|
confidences.append(result["confidence"])
|
|
windows.append(result["window"])
|
|
timestamps.append(result["timestamp_ns"])
|
|
|
|
table = pa.table(
|
|
{
|
|
"frame": pa.array(frames, type=pa.int64()),
|
|
"track_id": pa.array(track_ids, type=pa.int64()),
|
|
"label": pa.array(labels, type=pa.string()),
|
|
"confidence": pa.array(confidences, type=pa.float64()),
|
|
"window": pa.array(windows, type=pa.int64()),
|
|
"timestamp_ns": pa.array(timestamps, type=pa.int64()),
|
|
}
|
|
)
|
|
|
|
pq.write_table(table, self._result_export_path)
|
|
logger.info(
|
|
"Exported %d results to parquet: %s",
|
|
len(self._result_buffer),
|
|
self._result_export_path,
|
|
)
|
|
|
|
|
|
def validate_runtime_inputs(source: str, checkpoint: str, config: str) -> None:
|
|
if source.startswith("cvmmap://") or source.isdigit():
|
|
pass
|
|
else:
|
|
source_path = Path(source)
|
|
if not source_path.is_file():
|
|
raise ValueError(f"Video source not found: {source}")
|
|
|
|
checkpoint_path = Path(checkpoint)
|
|
if not checkpoint_path.is_file():
|
|
raise ValueError(f"Checkpoint not found: {checkpoint}")
|
|
|
|
config_path = Path(config)
|
|
if not config_path.is_file():
|
|
raise ValueError(f"Config not found: {config}")
|
|
|
|
|
|
@click.command(context_settings={"help_option_names": ["-h", "--help"]})
|
|
@click.option("--source", type=str, required=True)
|
|
@click.option("--checkpoint", type=str, required=True)
|
|
@click.option(
|
|
"--config",
|
|
type=str,
|
|
default="configs/sconet/sconet_scoliosis1k.yaml",
|
|
show_default=True,
|
|
)
|
|
@click.option("--device", type=str, default="cuda:0", show_default=True)
|
|
@click.option("--yolo-model", type=str, default="ckpt/yolo11n-seg.pt", show_default=True)
|
|
@click.option("--window", type=click.IntRange(min=1), default=30, show_default=True)
|
|
@click.option("--stride", type=click.IntRange(min=1), default=30, show_default=True)
|
|
@click.option("--nats-url", type=str, default=None)
|
|
@click.option(
|
|
"--nats-subject",
|
|
type=str,
|
|
default="scoliosis.result",
|
|
show_default=True,
|
|
)
|
|
@click.option("--max-frames", type=click.IntRange(min=1), default=None)
|
|
@click.option(
|
|
"--preprocess-only",
|
|
is_flag=True,
|
|
default=False,
|
|
help="Only preprocess silhouettes, skip classification.",
|
|
)
|
|
@click.option(
|
|
"--silhouette-export-path",
|
|
type=str,
|
|
default=None,
|
|
help="Path to export silhouettes (required for preprocess-only mode).",
|
|
)
|
|
@click.option(
|
|
"--silhouette-export-format",
|
|
type=click.Choice(["pickle", "parquet"]),
|
|
default="pickle",
|
|
show_default=True,
|
|
help="Format for silhouette export.",
|
|
)
|
|
@click.option(
|
|
"--result-export-path",
|
|
type=str,
|
|
default=None,
|
|
help="Path to export inference results.",
|
|
)
|
|
@click.option(
|
|
"--result-export-format",
|
|
type=click.Choice(["json", "pickle", "parquet"]),
|
|
default="json",
|
|
show_default=True,
|
|
help="Format for result export.",
|
|
)
|
|
@click.option(
|
|
"--silhouette-visualize-dir",
|
|
type=str,
|
|
default=None,
|
|
help="Directory to save silhouette PNG visualizations.",
|
|
)
|
|
def main(
|
|
source: str,
|
|
checkpoint: str,
|
|
config: str,
|
|
device: str,
|
|
yolo_model: str,
|
|
window: int,
|
|
stride: int,
|
|
nats_url: str | None,
|
|
nats_subject: str,
|
|
max_frames: int | None,
|
|
preprocess_only: bool,
|
|
silhouette_export_path: str | None,
|
|
silhouette_export_format: str,
|
|
result_export_path: str | None,
|
|
result_export_format: str,
|
|
silhouette_visualize_dir: str | None,
|
|
) -> None:
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
|
)
|
|
|
|
# Validate preprocess-only mode requirements
|
|
if preprocess_only and not silhouette_export_path:
|
|
raise click.UsageError(
|
|
"--silhouette-export-path is required when using --preprocess-only"
|
|
)
|
|
|
|
try:
|
|
validate_runtime_inputs(source=source, checkpoint=checkpoint, config=config)
|
|
pipeline = ScoliosisPipeline(
|
|
source=source,
|
|
checkpoint=checkpoint,
|
|
config=config,
|
|
device=device,
|
|
yolo_model=yolo_model,
|
|
window=window,
|
|
stride=stride,
|
|
nats_url=nats_url,
|
|
nats_subject=nats_subject,
|
|
max_frames=max_frames,
|
|
preprocess_only=preprocess_only,
|
|
silhouette_export_path=silhouette_export_path,
|
|
silhouette_export_format=silhouette_export_format,
|
|
silhouette_visualize_dir=silhouette_visualize_dir,
|
|
result_export_path=result_export_path,
|
|
result_export_format=result_export_format,
|
|
)
|
|
raise SystemExit(pipeline.run())
|
|
except ValueError as err:
|
|
click.echo(f"Error: {err}", err=True)
|
|
raise SystemExit(2) from err
|
|
except RuntimeError as err:
|
|
click.echo(f"Runtime error: {err}", err=True)
|
|
raise SystemExit(1) from err
|