7f073179d7
Align bbox coordinate handling across primary and fallback paths, normalize Both-mode raw mask rendering, and tighten demo result typing to reduce runtime/display inconsistencies.
763 lines
26 KiB
Python
763 lines
26 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import Callable
|
|
from contextlib import suppress
|
|
import logging
|
|
from pathlib import Path
|
|
import time
|
|
from typing import TYPE_CHECKING, Protocol, cast
|
|
|
|
from beartype import beartype
|
|
import click
|
|
import jaxtyping
|
|
from jaxtyping import Float, UInt8
|
|
import numpy as np
|
|
from numpy import ndarray
|
|
from numpy.typing import NDArray
|
|
from ultralytics.models.yolo.model import YOLO
|
|
|
|
from .input import FrameStream, create_source
|
|
from .output import DemoResult, ResultPublisher, create_publisher, create_result
|
|
from .preprocess import BBoxXYXY, frame_to_person_mask, mask_to_silhouette
|
|
from .sconet_demo import ScoNetDemo
|
|
from .window import SilhouetteWindow, select_person
|
|
|
|
if TYPE_CHECKING:
|
|
from .visualizer import OpenCVVisualizer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
|
|
JaxtypedFactory = Callable[..., JaxtypedDecorator]
|
|
jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
|
|
|
|
|
|
class _BoxesLike(Protocol):
|
|
@property
|
|
def xyxy(self) -> NDArray[np.float32] | object: ...
|
|
|
|
@property
|
|
def id(self) -> NDArray[np.int64] | object | None: ...
|
|
|
|
|
|
class _MasksLike(Protocol):
|
|
@property
|
|
def data(self) -> NDArray[np.float32] | object: ...
|
|
|
|
|
|
class _DetectionResultsLike(Protocol):
|
|
@property
|
|
def boxes(self) -> _BoxesLike: ...
|
|
|
|
@property
|
|
def masks(self) -> _MasksLike: ...
|
|
|
|
|
|
|
|
class _TrackCallable(Protocol):
|
|
def __call__(
|
|
self,
|
|
source: object,
|
|
*,
|
|
persist: bool = True,
|
|
verbose: bool = False,
|
|
device: str | None = None,
|
|
classes: list[int] | None = None,
|
|
) -> object: ...
|
|
|
|
|
|
class ScoliosisPipeline:
|
|
_detector: object
|
|
_source: FrameStream
|
|
_window: SilhouetteWindow
|
|
_publisher: ResultPublisher
|
|
_classifier: ScoNetDemo
|
|
_device: str
|
|
_closed: bool
|
|
_preprocess_only: bool
|
|
_silhouette_export_path: Path | None
|
|
_silhouette_export_format: str
|
|
_silhouette_buffer: list[dict[str, object]]
|
|
_silhouette_visualize_dir: Path | None
|
|
_result_export_path: Path | None
|
|
_result_export_format: str
|
|
_result_buffer: list[DemoResult]
|
|
_visualizer: OpenCVVisualizer | None
|
|
_last_viz_payload: dict[str, object] | None
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
source: str,
|
|
checkpoint: str,
|
|
config: str,
|
|
device: str,
|
|
yolo_model: str,
|
|
window: int,
|
|
stride: int,
|
|
nats_url: str | None,
|
|
nats_subject: str,
|
|
max_frames: int | None,
|
|
preprocess_only: bool = False,
|
|
silhouette_export_path: str | None = None,
|
|
silhouette_export_format: str = "pickle",
|
|
silhouette_visualize_dir: str | None = None,
|
|
result_export_path: str | None = None,
|
|
result_export_format: str = "json",
|
|
visualize: bool = False,
|
|
) -> None:
|
|
self._detector = YOLO(yolo_model)
|
|
self._source = create_source(source, max_frames=max_frames)
|
|
self._window = SilhouetteWindow(window_size=window, stride=stride)
|
|
self._publisher = create_publisher(nats_url=nats_url, subject=nats_subject)
|
|
self._classifier = ScoNetDemo(
|
|
cfg_path=config,
|
|
checkpoint_path=checkpoint,
|
|
device=device,
|
|
)
|
|
self._device = device
|
|
self._closed = False
|
|
self._preprocess_only = preprocess_only
|
|
self._silhouette_export_path = (
|
|
Path(silhouette_export_path) if silhouette_export_path else None
|
|
)
|
|
self._silhouette_export_format = silhouette_export_format
|
|
self._silhouette_buffer = []
|
|
self._silhouette_visualize_dir = (
|
|
Path(silhouette_visualize_dir) if silhouette_visualize_dir else None
|
|
)
|
|
self._result_export_path = (
|
|
Path(result_export_path) if result_export_path else None
|
|
)
|
|
self._result_export_format = result_export_format
|
|
self._result_buffer = []
|
|
if visualize:
|
|
from .visualizer import OpenCVVisualizer
|
|
|
|
self._visualizer = OpenCVVisualizer()
|
|
else:
|
|
self._visualizer = None
|
|
self._last_viz_payload = None
|
|
|
|
@staticmethod
|
|
def _extract_int(meta: dict[str, object], key: str, fallback: int) -> int:
|
|
value = meta.get(key)
|
|
if isinstance(value, int):
|
|
return value
|
|
return fallback
|
|
|
|
@staticmethod
|
|
def _extract_timestamp(meta: dict[str, object]) -> int:
|
|
value = meta.get("timestamp_ns")
|
|
if isinstance(value, int):
|
|
return value
|
|
return time.monotonic_ns()
|
|
|
|
@staticmethod
|
|
def _to_mask_u8(mask: ndarray) -> UInt8[ndarray, "h w"]:
|
|
binary = np.where(np.asarray(mask) > 0.5, np.uint8(255), np.uint8(0)).astype(
|
|
np.uint8
|
|
)
|
|
return cast(UInt8[ndarray, "h w"], binary)
|
|
|
|
def _first_result(self, detections: object) -> _DetectionResultsLike | None:
|
|
if isinstance(detections, list):
|
|
return cast(_DetectionResultsLike, detections[0]) if detections else None
|
|
if isinstance(detections, tuple):
|
|
return cast(_DetectionResultsLike, detections[0]) if detections else None
|
|
return cast(_DetectionResultsLike, detections)
|
|
|
|
def _select_silhouette(
|
|
self,
|
|
result: _DetectionResultsLike,
|
|
) -> (
|
|
tuple[
|
|
Float[ndarray, "64 44"],
|
|
UInt8[ndarray, "h w"],
|
|
BBoxXYXY,
|
|
int,
|
|
]
|
|
| None
|
|
):
|
|
selected = select_person(result)
|
|
if selected is not None:
|
|
mask_raw, bbox_mask, bbox_frame, track_id = selected
|
|
silhouette = cast(
|
|
Float[ndarray, "64 44"] | None,
|
|
mask_to_silhouette(self._to_mask_u8(mask_raw), bbox_mask),
|
|
)
|
|
if silhouette is not None:
|
|
return silhouette, mask_raw, bbox_frame, int(track_id)
|
|
|
|
fallback = cast(
|
|
tuple[UInt8[ndarray, "h w"], BBoxXYXY] | None,
|
|
frame_to_person_mask(result),
|
|
)
|
|
if fallback is None:
|
|
return None
|
|
|
|
mask_u8, bbox_mask = fallback
|
|
silhouette = cast(
|
|
Float[ndarray, "64 44"] | None,
|
|
mask_to_silhouette(mask_u8, bbox_mask),
|
|
)
|
|
if silhouette is None:
|
|
return None
|
|
# Convert mask-space bbox to frame-space for visualization
|
|
# Use result.orig_shape to get frame dimensions safely
|
|
orig_shape = getattr(result, "orig_shape", None)
|
|
if orig_shape is not None and isinstance(orig_shape, (tuple, list)) and len(orig_shape) >= 2:
|
|
frame_h, frame_w = int(orig_shape[0]), int(orig_shape[1])
|
|
mask_h, mask_w = mask_u8.shape[0], mask_u8.shape[1]
|
|
if mask_w > 0 and mask_h > 0 and frame_w > 0 and frame_h > 0:
|
|
scale_x = frame_w / mask_w
|
|
scale_y = frame_h / mask_h
|
|
bbox_frame = (
|
|
int(bbox_mask[0] * scale_x),
|
|
int(bbox_mask[1] * scale_y),
|
|
int(bbox_mask[2] * scale_x),
|
|
int(bbox_mask[3] * scale_y),
|
|
)
|
|
else:
|
|
# Fallback: use mask-space bbox if dimensions invalid
|
|
bbox_frame = bbox_mask
|
|
else:
|
|
# Fallback: use mask-space bbox if orig_shape unavailable
|
|
bbox_frame = bbox_mask
|
|
# For fallback case, mask_raw is the same as mask_u8
|
|
return silhouette, mask_u8, bbox_frame, 0
|
|
|
|
|
|
@jaxtyped(typechecker=beartype)
|
|
def process_frame(
|
|
self,
|
|
frame: UInt8[ndarray, "h w c"],
|
|
metadata: dict[str, object],
|
|
) -> dict[str, object] | None:
|
|
frame_idx = self._extract_int(metadata, "frame_count", fallback=0)
|
|
timestamp_ns = self._extract_timestamp(metadata)
|
|
|
|
track_fn_obj = getattr(self._detector, "track", None)
|
|
if not callable(track_fn_obj):
|
|
raise RuntimeError("YOLO detector does not expose a callable track()")
|
|
|
|
track_fn = cast(_TrackCallable, track_fn_obj)
|
|
detections = track_fn(
|
|
frame,
|
|
persist=True,
|
|
verbose=False,
|
|
device=self._device,
|
|
classes=[0],
|
|
)
|
|
first = self._first_result(detections)
|
|
if first is None:
|
|
return None
|
|
|
|
selected = self._select_silhouette(first)
|
|
if selected is None:
|
|
return None
|
|
|
|
silhouette, mask_raw, bbox, track_id = selected
|
|
|
|
# Store silhouette for export if in preprocess-only mode or if export requested
|
|
if self._silhouette_export_path is not None or self._preprocess_only:
|
|
self._silhouette_buffer.append(
|
|
{
|
|
"frame": frame_idx,
|
|
"track_id": track_id,
|
|
"timestamp_ns": timestamp_ns,
|
|
"silhouette": silhouette.copy(),
|
|
}
|
|
)
|
|
|
|
# Visualize silhouette if requested
|
|
if self._silhouette_visualize_dir is not None:
|
|
self._visualize_silhouette(silhouette, frame_idx, track_id)
|
|
|
|
if self._preprocess_only:
|
|
# Return visualization payload for display even in preprocess-only mode
|
|
return {
|
|
"mask_raw": mask_raw,
|
|
"bbox": bbox,
|
|
"silhouette": silhouette,
|
|
"track_id": track_id,
|
|
"label": None,
|
|
"confidence": None,
|
|
}
|
|
|
|
self._window.push(silhouette, frame_idx=frame_idx, track_id=track_id)
|
|
|
|
if not self._window.should_classify():
|
|
# Return visualization payload even when not classifying yet
|
|
return {
|
|
"mask_raw": mask_raw,
|
|
"bbox": bbox,
|
|
"silhouette": silhouette,
|
|
"track_id": track_id,
|
|
"label": None,
|
|
"confidence": None,
|
|
}
|
|
|
|
window_tensor = self._window.get_tensor(device=self._device)
|
|
label, confidence = cast(
|
|
tuple[str, float],
|
|
self._classifier.predict(window_tensor),
|
|
)
|
|
self._window.mark_classified()
|
|
|
|
window_start = frame_idx - self._window.window_size + 1
|
|
result = create_result(
|
|
frame=frame_idx,
|
|
track_id=track_id,
|
|
label=label,
|
|
confidence=float(confidence),
|
|
window=(max(0, window_start), frame_idx),
|
|
timestamp_ns=timestamp_ns,
|
|
)
|
|
|
|
# Store result for export if export path specified
|
|
if self._result_export_path is not None:
|
|
self._result_buffer.append(result)
|
|
|
|
self._publisher.publish(result)
|
|
# Return result with visualization payload
|
|
return {
|
|
"result": result,
|
|
"mask_raw": mask_raw,
|
|
"bbox": bbox,
|
|
"silhouette": silhouette,
|
|
"track_id": track_id,
|
|
"label": label,
|
|
"confidence": confidence,
|
|
}
|
|
|
|
def run(self) -> int:
|
|
frame_count = 0
|
|
start_time = time.perf_counter()
|
|
# EMA FPS state (alpha=0.1 for smoothing)
|
|
ema_fps = 0.0
|
|
alpha = 0.1
|
|
prev_time = start_time
|
|
try:
|
|
for item in self._source:
|
|
frame, metadata = item
|
|
frame_u8 = np.asarray(frame, dtype=np.uint8)
|
|
frame_idx = self._extract_int(metadata, "frame_count", fallback=0)
|
|
frame_count += 1
|
|
|
|
# Compute per-frame EMA FPS
|
|
curr_time = time.perf_counter()
|
|
delta = curr_time - prev_time
|
|
prev_time = curr_time
|
|
if delta > 0:
|
|
instant_fps = 1.0 / delta
|
|
if ema_fps == 0.0:
|
|
ema_fps = instant_fps
|
|
else:
|
|
ema_fps = alpha * instant_fps + (1 - alpha) * ema_fps
|
|
|
|
viz_payload = None
|
|
try:
|
|
viz_payload = self.process_frame(frame_u8, metadata)
|
|
except Exception as frame_error:
|
|
logger.warning(
|
|
"Skipping frame %d due to processing error: %s",
|
|
frame_idx,
|
|
frame_error,
|
|
)
|
|
|
|
# Update visualizer if enabled
|
|
if self._visualizer is not None:
|
|
# Cache valid payload for no-detection frames
|
|
if viz_payload is not None:
|
|
# Cache a copy to prevent mutation of original data
|
|
viz_payload_dict = cast(dict[str, object], viz_payload)
|
|
cached: dict[str, object] = {}
|
|
for k, v in viz_payload_dict.items():
|
|
copy_method = cast(Callable[[], object] | None, getattr(v, "copy", None))
|
|
if copy_method is not None:
|
|
cached[k] = copy_method()
|
|
else:
|
|
cached[k] = v
|
|
self._last_viz_payload = cached
|
|
|
|
# Use cached payload if current is None
|
|
viz_data = viz_payload if viz_payload is not None else self._last_viz_payload
|
|
|
|
if viz_data is not None:
|
|
# Cast viz_payload to dict for type checking
|
|
viz_dict = cast(dict[str, object], viz_data)
|
|
mask_raw_obj = viz_dict.get("mask_raw")
|
|
bbox_obj = viz_dict.get("bbox")
|
|
silhouette_obj = viz_dict.get("silhouette")
|
|
track_id_val = viz_dict.get("track_id", 0)
|
|
track_id = track_id_val if isinstance(track_id_val, int) else 0
|
|
label_obj = viz_dict.get("label")
|
|
confidence_obj = viz_dict.get("confidence")
|
|
|
|
# Cast extracted values to expected types
|
|
mask_raw = cast(NDArray[np.uint8] | None, mask_raw_obj)
|
|
bbox = cast(BBoxXYXY | None, bbox_obj)
|
|
silhouette = cast(NDArray[np.float32] | None, silhouette_obj)
|
|
label = cast(str | None, label_obj)
|
|
confidence = cast(float | None, confidence_obj)
|
|
else:
|
|
# No detection and no cache - use default values
|
|
mask_raw = None
|
|
bbox = None
|
|
track_id = 0
|
|
silhouette = None
|
|
label = None
|
|
confidence = None
|
|
|
|
keep_running = self._visualizer.update(
|
|
frame_u8,
|
|
bbox,
|
|
track_id,
|
|
mask_raw,
|
|
silhouette,
|
|
label,
|
|
confidence,
|
|
ema_fps,
|
|
)
|
|
if not keep_running:
|
|
logger.info("Visualization closed by user.")
|
|
break
|
|
|
|
if frame_count % 100 == 0:
|
|
elapsed = time.perf_counter() - start_time
|
|
fps = frame_count / elapsed if elapsed > 0 else 0.0
|
|
logger.info("Processed %d frames (%.2f FPS)", frame_count, fps)
|
|
return 0
|
|
except KeyboardInterrupt:
|
|
logger.info("Interrupted by user, shutting down cleanly.")
|
|
return 130
|
|
finally:
|
|
self.close()
|
|
|
|
def close(self) -> None:
|
|
if self._closed:
|
|
return
|
|
|
|
# Close visualizer if enabled
|
|
if self._visualizer is not None:
|
|
with suppress(Exception):
|
|
self._visualizer.close()
|
|
|
|
# Export silhouettes if requested
|
|
if self._silhouette_export_path is not None and self._silhouette_buffer:
|
|
self._export_silhouettes()
|
|
|
|
# Export results if requested
|
|
if self._result_export_path is not None and self._result_buffer:
|
|
self._export_results()
|
|
|
|
close_fn = getattr(self._publisher, "close", None)
|
|
if callable(close_fn):
|
|
with suppress(Exception):
|
|
_ = close_fn()
|
|
self._closed = True
|
|
|
|
def _export_silhouettes(self) -> None:
|
|
"""Export silhouettes to file in specified format."""
|
|
if self._silhouette_export_path is None:
|
|
return
|
|
|
|
self._silhouette_export_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
if self._silhouette_export_format == "pickle":
|
|
import pickle
|
|
|
|
with open(self._silhouette_export_path, "wb") as f:
|
|
pickle.dump(self._silhouette_buffer, f)
|
|
logger.info(
|
|
"Exported %d silhouettes to %s",
|
|
len(self._silhouette_buffer),
|
|
self._silhouette_export_path,
|
|
)
|
|
elif self._silhouette_export_format == "parquet":
|
|
self._export_parquet_silhouettes()
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported silhouette export format: {self._silhouette_export_format}"
|
|
)
|
|
|
|
def _visualize_silhouette(
|
|
self,
|
|
silhouette: Float[ndarray, "64 44"],
|
|
frame_idx: int,
|
|
track_id: int,
|
|
) -> None:
|
|
"""Save silhouette as PNG image."""
|
|
if self._silhouette_visualize_dir is None:
|
|
return
|
|
|
|
self._silhouette_visualize_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Convert float silhouette to uint8 (0-255)
|
|
silhouette_u8 = (silhouette * 255).astype(np.uint8)
|
|
|
|
# Create deterministic filename
|
|
filename = f"silhouette_frame{frame_idx:06d}_track{track_id:04d}.png"
|
|
output_path = self._silhouette_visualize_dir / filename
|
|
|
|
# Save using PIL
|
|
from PIL import Image
|
|
|
|
Image.fromarray(silhouette_u8).save(output_path)
|
|
|
|
def _export_parquet_silhouettes(self) -> None:
|
|
"""Export silhouettes to parquet format."""
|
|
import importlib
|
|
|
|
try:
|
|
pa = importlib.import_module("pyarrow")
|
|
pq = importlib.import_module("pyarrow.parquet")
|
|
except ImportError as e:
|
|
raise RuntimeError(
|
|
"Parquet export requires pyarrow. Install with: pip install pyarrow"
|
|
) from e
|
|
|
|
# Convert silhouettes to columnar format
|
|
frames = []
|
|
track_ids = []
|
|
timestamps = []
|
|
silhouettes = []
|
|
|
|
for item in self._silhouette_buffer:
|
|
frames.append(item["frame"])
|
|
track_ids.append(item["track_id"])
|
|
timestamps.append(item["timestamp_ns"])
|
|
silhouette_array = cast(ndarray, item["silhouette"])
|
|
silhouettes.append(silhouette_array.flatten().tolist())
|
|
|
|
table = pa.table(
|
|
{
|
|
"frame": pa.array(frames, type=pa.int64()),
|
|
"track_id": pa.array(track_ids, type=pa.int64()),
|
|
"timestamp_ns": pa.array(timestamps, type=pa.int64()),
|
|
"silhouette": pa.array(silhouettes, type=pa.list_(pa.float64())),
|
|
}
|
|
)
|
|
|
|
pq.write_table(table, self._silhouette_export_path)
|
|
logger.info(
|
|
"Exported %d silhouettes to parquet: %s",
|
|
len(self._silhouette_buffer),
|
|
self._silhouette_export_path,
|
|
)
|
|
|
|
def _export_results(self) -> None:
|
|
"""Export results to file in specified format."""
|
|
if self._result_export_path is None:
|
|
return
|
|
|
|
self._result_export_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
if self._result_export_format == "json":
|
|
import json
|
|
|
|
with open(self._result_export_path, "w", encoding="utf-8") as f:
|
|
for result in self._result_buffer:
|
|
f.write(json.dumps(result, ensure_ascii=False, default=str) + "\n")
|
|
logger.info(
|
|
"Exported %d results to JSON: %s",
|
|
len(self._result_buffer),
|
|
self._result_export_path,
|
|
)
|
|
elif self._result_export_format == "pickle":
|
|
import pickle
|
|
|
|
with open(self._result_export_path, "wb") as f:
|
|
pickle.dump(self._result_buffer, f)
|
|
logger.info(
|
|
"Exported %d results to pickle: %s",
|
|
len(self._result_buffer),
|
|
self._result_export_path,
|
|
)
|
|
elif self._result_export_format == "parquet":
|
|
self._export_parquet_results()
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported result export format: {self._result_export_format}"
|
|
)
|
|
|
|
def _export_parquet_results(self) -> None:
|
|
"""Export results to parquet format."""
|
|
import importlib
|
|
|
|
try:
|
|
pa = importlib.import_module("pyarrow")
|
|
pq = importlib.import_module("pyarrow.parquet")
|
|
except ImportError as e:
|
|
raise RuntimeError(
|
|
"Parquet export requires pyarrow. Install with: pip install pyarrow"
|
|
) from e
|
|
|
|
frames = []
|
|
track_ids = []
|
|
labels = []
|
|
confidences = []
|
|
windows = []
|
|
timestamps = []
|
|
|
|
for result in self._result_buffer:
|
|
frames.append(result["frame"])
|
|
track_ids.append(result["track_id"])
|
|
labels.append(result["label"])
|
|
confidences.append(result["confidence"])
|
|
windows.append(result["window"])
|
|
timestamps.append(result["timestamp_ns"])
|
|
|
|
table = pa.table(
|
|
{
|
|
"frame": pa.array(frames, type=pa.int64()),
|
|
"track_id": pa.array(track_ids, type=pa.int64()),
|
|
"label": pa.array(labels, type=pa.string()),
|
|
"confidence": pa.array(confidences, type=pa.float64()),
|
|
"window": pa.array(windows, type=pa.int64()),
|
|
"timestamp_ns": pa.array(timestamps, type=pa.int64()),
|
|
}
|
|
)
|
|
|
|
pq.write_table(table, self._result_export_path)
|
|
logger.info(
|
|
"Exported %d results to parquet: %s",
|
|
len(self._result_buffer),
|
|
self._result_export_path,
|
|
)
|
|
|
|
|
|
def validate_runtime_inputs(source: str, checkpoint: str, config: str) -> None:
|
|
if source.startswith("cvmmap://") or source.isdigit():
|
|
pass
|
|
else:
|
|
source_path = Path(source)
|
|
if not source_path.is_file():
|
|
raise ValueError(f"Video source not found: {source}")
|
|
|
|
checkpoint_path = Path(checkpoint)
|
|
if not checkpoint_path.is_file():
|
|
raise ValueError(f"Checkpoint not found: {checkpoint}")
|
|
|
|
config_path = Path(config)
|
|
if not config_path.is_file():
|
|
raise ValueError(f"Config not found: {config}")
|
|
|
|
|
|
@click.command(context_settings={"help_option_names": ["-h", "--help"]})
|
|
@click.option("--source", type=str, required=True)
|
|
@click.option("--checkpoint", type=str, required=True)
|
|
@click.option(
|
|
"--config",
|
|
type=str,
|
|
default="configs/sconet/sconet_scoliosis1k.yaml",
|
|
show_default=True,
|
|
)
|
|
@click.option("--device", type=str, default="cuda:0", show_default=True)
|
|
@click.option("--yolo-model", type=str, default="ckpt/yolo11n-seg.pt", show_default=True)
|
|
@click.option("--window", type=click.IntRange(min=1), default=30, show_default=True)
|
|
@click.option("--stride", type=click.IntRange(min=1), default=30, show_default=True)
|
|
@click.option("--nats-url", type=str, default=None)
|
|
@click.option(
|
|
"--nats-subject",
|
|
type=str,
|
|
default="scoliosis.result",
|
|
show_default=True,
|
|
)
|
|
@click.option("--max-frames", type=click.IntRange(min=1), default=None)
|
|
@click.option(
|
|
"--preprocess-only",
|
|
is_flag=True,
|
|
default=False,
|
|
help="Only preprocess silhouettes, skip classification.",
|
|
)
|
|
@click.option(
|
|
"--silhouette-export-path",
|
|
type=str,
|
|
default=None,
|
|
help="Path to export silhouettes (required for preprocess-only mode).",
|
|
)
|
|
@click.option(
|
|
"--silhouette-export-format",
|
|
type=click.Choice(["pickle", "parquet"]),
|
|
default="pickle",
|
|
show_default=True,
|
|
help="Format for silhouette export.",
|
|
)
|
|
@click.option(
|
|
"--result-export-path",
|
|
type=str,
|
|
default=None,
|
|
help="Path to export inference results.",
|
|
)
|
|
@click.option(
|
|
"--result-export-format",
|
|
type=click.Choice(["json", "pickle", "parquet"]),
|
|
default="json",
|
|
show_default=True,
|
|
help="Format for result export.",
|
|
)
|
|
@click.option(
|
|
"--silhouette-visualize-dir",
|
|
type=str,
|
|
default=None,
|
|
help="Directory to save silhouette PNG visualizations.",
|
|
)
|
|
def main(
|
|
source: str,
|
|
checkpoint: str,
|
|
config: str,
|
|
device: str,
|
|
yolo_model: str,
|
|
window: int,
|
|
stride: int,
|
|
nats_url: str | None,
|
|
nats_subject: str,
|
|
max_frames: int | None,
|
|
preprocess_only: bool,
|
|
silhouette_export_path: str | None,
|
|
silhouette_export_format: str,
|
|
result_export_path: str | None,
|
|
result_export_format: str,
|
|
silhouette_visualize_dir: str | None,
|
|
) -> None:
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
|
)
|
|
|
|
# Validate preprocess-only mode requirements
|
|
if preprocess_only and not silhouette_export_path:
|
|
raise click.UsageError(
|
|
"--silhouette-export-path is required when using --preprocess-only"
|
|
)
|
|
|
|
try:
|
|
validate_runtime_inputs(source=source, checkpoint=checkpoint, config=config)
|
|
pipeline = ScoliosisPipeline(
|
|
source=source,
|
|
checkpoint=checkpoint,
|
|
config=config,
|
|
device=device,
|
|
yolo_model=yolo_model,
|
|
window=window,
|
|
stride=stride,
|
|
nats_url=nats_url,
|
|
nats_subject=nats_subject,
|
|
max_frames=max_frames,
|
|
preprocess_only=preprocess_only,
|
|
silhouette_export_path=silhouette_export_path,
|
|
silhouette_export_format=silhouette_export_format,
|
|
silhouette_visualize_dir=silhouette_visualize_dir,
|
|
result_export_path=result_export_path,
|
|
result_export_format=result_export_format,
|
|
)
|
|
raise SystemExit(pipeline.run())
|
|
except ValueError as err:
|
|
click.echo(f"Error: {err}", err=True)
|
|
raise SystemExit(2) from err
|
|
except RuntimeError as err:
|
|
click.echo(f"Runtime error: {err}", err=True)
|
|
raise SystemExit(1) from err
|