967a10c10e
Stabilize studio publish/visualization flow and tighten export behavior while aligning project dependencies with the monorepo runtime expectations.
978 lines
34 KiB
Python
978 lines
34 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import Callable
|
|
import copy
|
|
from contextlib import suppress
|
|
import inspect
|
|
import logging
|
|
from pathlib import Path
|
|
import time
|
|
from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, TypedDict, cast
|
|
|
|
from beartype import beartype
|
|
import click
|
|
import jaxtyping
|
|
from jaxtyping import Float, UInt8
|
|
import numpy as np
|
|
from numpy import ndarray
|
|
from numpy.typing import NDArray
|
|
from ultralytics.models.yolo.model import YOLO
|
|
|
|
from .input import FrameStream, create_source
|
|
from .output import DemoResult, ResultPublisher, create_publisher, create_result
|
|
from .preprocess import BBoxXYXY, frame_to_person_mask, mask_to_silhouette
|
|
from .sconet_demo import ScoNetDemo
|
|
from .window import SilhouetteWindow, select_person
|
|
|
|
if TYPE_CHECKING:
|
|
from .visualizer import OpenCVVisualizer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
|
|
JaxtypedFactory = Callable[..., JaxtypedDecorator]
|
|
jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
|
|
|
|
WindowMode: TypeAlias = Literal["manual", "sliding", "chunked"]
|
|
|
|
|
|
def resolve_stride(window: int, stride: int, window_mode: WindowMode) -> int:
|
|
if window_mode == "manual":
|
|
return stride
|
|
if window_mode == "sliding":
|
|
return 1
|
|
return window
|
|
|
|
|
|
class _BoxesLike(Protocol):
|
|
@property
|
|
def xyxy(self) -> NDArray[np.float32] | object: ...
|
|
|
|
@property
|
|
def id(self) -> NDArray[np.int64] | object | None: ...
|
|
|
|
|
|
class _MasksLike(Protocol):
|
|
@property
|
|
def data(self) -> NDArray[np.float32] | object: ...
|
|
|
|
|
|
class _DetectionResultsLike(Protocol):
|
|
@property
|
|
def boxes(self) -> _BoxesLike: ...
|
|
|
|
@property
|
|
def masks(self) -> _MasksLike: ...
|
|
|
|
|
|
class _TrackCallable(Protocol):
|
|
def __call__(
|
|
self,
|
|
source: object,
|
|
*,
|
|
persist: bool = True,
|
|
verbose: bool = False,
|
|
device: str | None = None,
|
|
classes: list[int] | None = None,
|
|
) -> object: ...
|
|
|
|
|
|
class _SelectedSilhouette(TypedDict):
|
|
"""Selected silhouette payload produced from detector outputs.
|
|
|
|
Fields:
|
|
silhouette: Normalized silhouette tensor fed into ScoNet `(64, 44)`.
|
|
mask_raw: Full-resolution binary person mask in mask/image space.
|
|
bbox_frame: Person bbox in frame coordinates `(x1, y1, x2, y2)` for visualization.
|
|
bbox_mask: Person bbox in mask coordinates `(x1, y1, x2, y2)` for cropping.
|
|
track_id: Tracking ID from detector, or `0` for fallback path.
|
|
"""
|
|
|
|
silhouette: Float[ndarray, "64 44"]
|
|
mask_raw: UInt8[ndarray, "h w"]
|
|
bbox_frame: BBoxXYXY
|
|
bbox_mask: BBoxXYXY
|
|
track_id: int
|
|
|
|
|
|
class _VizPayload(TypedDict, total=False):
|
|
result: DemoResult
|
|
mask_raw: UInt8[ndarray, "h w"] | None
|
|
bbox: BBoxXYXY | None
|
|
bbox_mask: BBoxXYXY | None
|
|
silhouette: Float[ndarray, "64 44"] | None
|
|
segmentation_input: NDArray[np.float32] | None
|
|
track_id: int
|
|
label: str | None
|
|
confidence: float | None
|
|
pose: dict[str, object] | None
|
|
|
|
|
|
class _FramePacer:
|
|
_interval_ns: int
|
|
_next_emit_ns: int | None
|
|
|
|
def __init__(self, target_fps: float) -> None:
|
|
if target_fps <= 0:
|
|
raise ValueError(f"target_fps must be positive, got {target_fps}")
|
|
self._interval_ns = int(1_000_000_000 / target_fps)
|
|
self._next_emit_ns = None
|
|
|
|
def should_emit(self, timestamp_ns: int) -> bool:
|
|
if self._next_emit_ns is None:
|
|
self._next_emit_ns = timestamp_ns + self._interval_ns
|
|
return True
|
|
if timestamp_ns >= self._next_emit_ns:
|
|
while self._next_emit_ns <= timestamp_ns:
|
|
self._next_emit_ns += self._interval_ns
|
|
return True
|
|
return False
|
|
|
|
|
|
class ScoliosisPipeline:
|
|
_detector: object
|
|
_source: FrameStream
|
|
_window: SilhouetteWindow
|
|
_publisher: ResultPublisher
|
|
_classifier: ScoNetDemo
|
|
_device: str
|
|
_closed: bool
|
|
_preprocess_only: bool
|
|
_silhouette_export_path: Path | None
|
|
_silhouette_export_format: str
|
|
_silhouette_buffer: list[dict[str, object]]
|
|
_silhouette_visualize_dir: Path | None
|
|
_result_export_path: Path | None
|
|
_result_export_format: str
|
|
_result_buffer: list[DemoResult]
|
|
_visualizer: OpenCVVisualizer | None
|
|
_last_viz_payload: _VizPayload | None
|
|
_frame_pacer: _FramePacer | None
|
|
_visualizer_accepts_pose_data: bool | None
|
|
_visualizer_signature_owner: object | None
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
source: str,
|
|
checkpoint: str,
|
|
config: str,
|
|
device: str,
|
|
yolo_model: str,
|
|
window: int,
|
|
stride: int,
|
|
nats_url: str | None,
|
|
nats_subject: str,
|
|
max_frames: int | None,
|
|
preprocess_only: bool = False,
|
|
silhouette_export_path: str | None = None,
|
|
silhouette_export_format: str = "pickle",
|
|
silhouette_visualize_dir: str | None = None,
|
|
result_export_path: str | None = None,
|
|
result_export_format: str = "json",
|
|
visualize: bool = False,
|
|
target_fps: float | None = 15.0,
|
|
) -> None:
|
|
self._detector = YOLO(yolo_model)
|
|
self._source = create_source(source, max_frames=max_frames)
|
|
self._window = SilhouetteWindow(window_size=window, stride=stride)
|
|
self._publisher = create_publisher(nats_url=nats_url, subject=nats_subject)
|
|
self._classifier = ScoNetDemo(
|
|
cfg_path=config,
|
|
checkpoint_path=checkpoint,
|
|
device=device,
|
|
)
|
|
self._device = device
|
|
self._closed = False
|
|
self._preprocess_only = preprocess_only
|
|
self._silhouette_export_path = (
|
|
Path(silhouette_export_path) if silhouette_export_path else None
|
|
)
|
|
self._silhouette_export_format = silhouette_export_format
|
|
# Normalize format alias: pkl -> pickle
|
|
if self._silhouette_export_format == "pkl":
|
|
self._silhouette_export_format = "pickle"
|
|
self._silhouette_buffer = []
|
|
self._silhouette_visualize_dir = (
|
|
Path(silhouette_visualize_dir) if silhouette_visualize_dir else None
|
|
)
|
|
self._result_export_path = (
|
|
Path(result_export_path) if result_export_path else None
|
|
)
|
|
self._result_export_format = result_export_format
|
|
self._result_buffer = []
|
|
if visualize:
|
|
from .visualizer import OpenCVVisualizer
|
|
|
|
self._visualizer = OpenCVVisualizer()
|
|
else:
|
|
self._visualizer = None
|
|
self._last_viz_payload = None
|
|
self._frame_pacer = _FramePacer(target_fps) if target_fps is not None else None
|
|
self._visualizer_accepts_pose_data = None
|
|
self._visualizer_signature_owner = None
|
|
|
|
def _detect_visualizer_pose_kwarg(self) -> bool:
|
|
visualizer = self._visualizer
|
|
if visualizer is None:
|
|
return False
|
|
if (
|
|
self._visualizer_signature_owner is visualizer
|
|
and self._visualizer_accepts_pose_data is not None
|
|
):
|
|
return self._visualizer_accepts_pose_data
|
|
update_fn = getattr(visualizer, "update", None)
|
|
if update_fn is None or not callable(update_fn):
|
|
self._visualizer_signature_owner = visualizer
|
|
self._visualizer_accepts_pose_data = False
|
|
return False
|
|
try:
|
|
signature = inspect.signature(update_fn)
|
|
accepts_pose_data = "pose_data" in signature.parameters
|
|
except (ValueError, TypeError):
|
|
accepts_pose_data = False
|
|
self._visualizer_signature_owner = visualizer
|
|
self._visualizer_accepts_pose_data = accepts_pose_data
|
|
return accepts_pose_data
|
|
|
|
@staticmethod
|
|
def _extract_int(meta: dict[str, object], key: str, fallback: int) -> int:
|
|
value = meta.get(key)
|
|
if isinstance(value, int):
|
|
return value
|
|
return fallback
|
|
|
|
@staticmethod
|
|
def _extract_timestamp(meta: dict[str, object]) -> int:
|
|
value = meta.get("timestamp_ns")
|
|
if isinstance(value, int):
|
|
return value
|
|
return time.monotonic_ns()
|
|
|
|
@staticmethod
|
|
def _to_mask_u8(mask: NDArray[np.generic]) -> UInt8[ndarray, "h w"]:
|
|
mask_arr: NDArray[np.floating] = np.asarray(mask, dtype=np.float32) # type: ignore[reportAssignmentType]
|
|
binary = np.where(mask_arr > 0.5, np.uint8(255), np.uint8(0)).astype(np.uint8)
|
|
return cast(UInt8[ndarray, "h w"], binary)
|
|
|
|
def _first_result(self, detections: object) -> _DetectionResultsLike | None:
|
|
if isinstance(detections, list):
|
|
return cast(_DetectionResultsLike, detections[0]) if detections else None
|
|
if isinstance(detections, tuple):
|
|
return cast(_DetectionResultsLike, detections[0]) if detections else None
|
|
return cast(_DetectionResultsLike, detections)
|
|
|
|
def _select_silhouette(
|
|
self,
|
|
result: _DetectionResultsLike,
|
|
) -> _SelectedSilhouette | None:
|
|
selected = select_person(result)
|
|
if selected is not None:
|
|
mask_raw, bbox_mask, bbox_frame, track_id = selected
|
|
silhouette = cast(
|
|
Float[ndarray, "64 44"] | None,
|
|
mask_to_silhouette(self._to_mask_u8(mask_raw), bbox_mask),
|
|
)
|
|
if silhouette is not None:
|
|
return {
|
|
"silhouette": silhouette,
|
|
"mask_raw": mask_raw,
|
|
"bbox_frame": bbox_frame,
|
|
"bbox_mask": bbox_mask,
|
|
"track_id": int(track_id),
|
|
}
|
|
|
|
fallback = cast(
|
|
tuple[UInt8[ndarray, "h w"], BBoxXYXY] | None,
|
|
frame_to_person_mask(result),
|
|
)
|
|
if fallback is None:
|
|
return None
|
|
|
|
mask_u8, bbox_mask = fallback
|
|
silhouette = cast(
|
|
Float[ndarray, "64 44"] | None,
|
|
mask_to_silhouette(mask_u8, bbox_mask),
|
|
)
|
|
if silhouette is None:
|
|
return None
|
|
# Convert mask-space bbox to frame-space for visualization
|
|
# Use result.orig_shape to get frame dimensions safely
|
|
orig_shape = getattr(result, "orig_shape", None)
|
|
if (
|
|
orig_shape is not None
|
|
and isinstance(orig_shape, (tuple, list))
|
|
and len(orig_shape) >= 2
|
|
):
|
|
frame_h, frame_w = int(orig_shape[0]), int(orig_shape[1])
|
|
mask_h, mask_w = mask_u8.shape[0], mask_u8.shape[1]
|
|
if mask_w > 0 and mask_h > 0 and frame_w > 0 and frame_h > 0:
|
|
scale_x = frame_w / mask_w
|
|
scale_y = frame_h / mask_h
|
|
bbox_frame = (
|
|
int(bbox_mask[0] * scale_x),
|
|
int(bbox_mask[1] * scale_y),
|
|
int(bbox_mask[2] * scale_x),
|
|
int(bbox_mask[3] * scale_y),
|
|
)
|
|
else:
|
|
# Fallback: use mask-space bbox if dimensions invalid
|
|
bbox_frame = bbox_mask
|
|
else:
|
|
# Fallback: use mask-space bbox if orig_shape unavailable
|
|
bbox_frame = bbox_mask
|
|
# For fallback case, mask_raw is the same as mask_u8
|
|
return {
|
|
"silhouette": silhouette,
|
|
"mask_raw": mask_u8,
|
|
"bbox_frame": bbox_frame,
|
|
"bbox_mask": bbox_mask,
|
|
"track_id": 0,
|
|
}
|
|
|
|
@jaxtyped(typechecker=beartype)
|
|
def process_frame(
|
|
self,
|
|
frame: UInt8[ndarray, "h w c"],
|
|
metadata: dict[str, object],
|
|
) -> _VizPayload | None:
|
|
frame_idx = self._extract_int(metadata, "frame_count", fallback=0)
|
|
timestamp_ns = self._extract_timestamp(metadata)
|
|
|
|
track_fn_obj = getattr(self._detector, "track", None)
|
|
if not callable(track_fn_obj):
|
|
raise RuntimeError("YOLO detector does not expose a callable track()")
|
|
|
|
track_fn = cast(_TrackCallable, track_fn_obj)
|
|
detections = track_fn(
|
|
frame,
|
|
persist=True,
|
|
verbose=False,
|
|
device=self._device,
|
|
classes=[0],
|
|
)
|
|
first = self._first_result(detections)
|
|
if first is None:
|
|
return None
|
|
|
|
selected = self._select_silhouette(first)
|
|
if selected is None:
|
|
return None
|
|
|
|
silhouette = selected["silhouette"]
|
|
mask_raw = selected["mask_raw"]
|
|
bbox = selected["bbox_frame"]
|
|
bbox_mask = selected["bbox_mask"]
|
|
track_id = selected["track_id"]
|
|
pose_data = None
|
|
|
|
if self._silhouette_export_path is not None or self._preprocess_only:
|
|
self._silhouette_buffer.append(
|
|
{
|
|
"frame": frame_idx,
|
|
"track_id": track_id,
|
|
"timestamp_ns": timestamp_ns,
|
|
"silhouette": silhouette.copy(),
|
|
}
|
|
)
|
|
|
|
# Visualize silhouette if requested
|
|
if self._silhouette_visualize_dir is not None:
|
|
self._visualize_silhouette(silhouette, frame_idx, track_id)
|
|
|
|
if self._preprocess_only:
|
|
# Return visualization payload for display even in preprocess-only mode
|
|
return {
|
|
"mask_raw": mask_raw,
|
|
"bbox": bbox,
|
|
"bbox_mask": bbox_mask,
|
|
"silhouette": silhouette,
|
|
"segmentation_input": None,
|
|
"track_id": track_id,
|
|
"label": None,
|
|
"confidence": None,
|
|
"pose": pose_data,
|
|
}
|
|
self._window.push(silhouette, frame_idx=frame_idx, track_id=track_id)
|
|
|
|
if self._frame_pacer is not None and not self._frame_pacer.should_emit(
|
|
timestamp_ns
|
|
):
|
|
return {
|
|
"mask_raw": mask_raw,
|
|
"bbox": bbox,
|
|
"bbox_mask": bbox_mask,
|
|
"silhouette": silhouette,
|
|
"segmentation_input": self._window.buffered_silhouettes,
|
|
"track_id": track_id,
|
|
"label": None,
|
|
"confidence": None,
|
|
"pose": pose_data,
|
|
}
|
|
segmentation_input = self._window.buffered_silhouettes
|
|
|
|
if not self._window.should_classify():
|
|
# Return visualization payload even when not classifying yet
|
|
return {
|
|
"mask_raw": mask_raw,
|
|
"bbox": bbox,
|
|
"bbox_mask": bbox_mask,
|
|
"silhouette": silhouette,
|
|
"segmentation_input": segmentation_input,
|
|
"track_id": track_id,
|
|
"label": None,
|
|
"confidence": None,
|
|
"pose": pose_data,
|
|
}
|
|
window_tensor = self._window.get_tensor(device=self._device)
|
|
label, confidence = cast(
|
|
tuple[str, float],
|
|
self._classifier.predict(window_tensor),
|
|
)
|
|
self._window.mark_classified()
|
|
|
|
window_start = self._window.window_start_frame
|
|
result = create_result(
|
|
frame=frame_idx,
|
|
track_id=track_id,
|
|
label=label,
|
|
confidence=float(confidence),
|
|
window=(max(0, window_start), frame_idx),
|
|
timestamp_ns=timestamp_ns,
|
|
)
|
|
|
|
# Store result for export if export path specified
|
|
if self._result_export_path is not None:
|
|
self._result_buffer.append(result)
|
|
|
|
self._publisher.publish(result)
|
|
# Return result with visualization payload
|
|
return {
|
|
"result": result,
|
|
"mask_raw": mask_raw,
|
|
"bbox": bbox,
|
|
"bbox_mask": bbox_mask,
|
|
"silhouette": silhouette,
|
|
"segmentation_input": segmentation_input,
|
|
"track_id": track_id,
|
|
"label": label,
|
|
"confidence": confidence,
|
|
"pose": pose_data,
|
|
}
|
|
|
|
def run(self) -> int:
|
|
frame_count = 0
|
|
start_time = time.perf_counter()
|
|
# EMA FPS state (alpha=0.1 for smoothing)
|
|
ema_fps = 0.0
|
|
alpha = 0.1
|
|
prev_time = start_time
|
|
try:
|
|
for item in self._source:
|
|
frame, metadata = item
|
|
frame_u8 = np.asarray(frame, dtype=np.uint8)
|
|
frame_idx = self._extract_int(metadata, "frame_count", fallback=0)
|
|
frame_count += 1
|
|
|
|
# Compute per-frame EMA FPS
|
|
curr_time = time.perf_counter()
|
|
delta = curr_time - prev_time
|
|
prev_time = curr_time
|
|
if delta > 0:
|
|
instant_fps = 1.0 / delta
|
|
if ema_fps == 0.0:
|
|
ema_fps = instant_fps
|
|
else:
|
|
ema_fps = alpha * instant_fps + (1 - alpha) * ema_fps
|
|
|
|
viz_payload = None
|
|
try:
|
|
viz_payload = self.process_frame(frame_u8, metadata)
|
|
except (RuntimeError, ValueError, OSError) as frame_error:
|
|
logger.warning(
|
|
"Skipping frame %d due to processing error: %s",
|
|
frame_idx,
|
|
frame_error,
|
|
)
|
|
|
|
# Update visualizer if enabled
|
|
if self._visualizer is not None:
|
|
# Cache valid payload for no-detection frames
|
|
if viz_payload is not None:
|
|
# Cache a copy to prevent mutation of original data
|
|
viz_payload_dict = cast(_VizPayload, viz_payload)
|
|
cached: _VizPayload = {}
|
|
for k, v in viz_payload_dict.items():
|
|
if k == "pose" and isinstance(v, dict):
|
|
cached[k] = cast(dict[str, object], copy.deepcopy(v))
|
|
continue
|
|
copy_method = cast(
|
|
Callable[[], object] | None, getattr(v, "copy", None)
|
|
)
|
|
if copy_method is not None and callable(copy_method):
|
|
cached[k] = copy_method()
|
|
else:
|
|
cached[k] = v
|
|
self._last_viz_payload = cached
|
|
|
|
if viz_payload is not None:
|
|
viz_data = viz_payload
|
|
elif self._last_viz_payload is not None:
|
|
viz_data = dict(self._last_viz_payload)
|
|
viz_data["bbox"] = None
|
|
viz_data["bbox_mask"] = None
|
|
viz_data["label"] = None
|
|
viz_data["confidence"] = None
|
|
viz_data["pose"] = None
|
|
else:
|
|
viz_data = None
|
|
if viz_data is not None:
|
|
# Cast viz_payload to dict for type checking
|
|
viz_dict = cast(_VizPayload, viz_data)
|
|
mask_raw_obj = viz_dict.get("mask_raw")
|
|
bbox_obj = viz_dict.get("bbox")
|
|
bbox_mask_obj = viz_dict.get("bbox_mask")
|
|
silhouette_obj = viz_dict.get("silhouette")
|
|
segmentation_input_obj = viz_dict.get("segmentation_input")
|
|
track_id_val = viz_dict.get("track_id", 0)
|
|
track_id = track_id_val if isinstance(track_id_val, int) else 0
|
|
label_obj = viz_dict.get("label")
|
|
confidence_obj = viz_dict.get("confidence")
|
|
pose_obj = viz_dict.get("pose")
|
|
|
|
mask_raw = cast(NDArray[np.uint8] | None, mask_raw_obj)
|
|
bbox = cast(BBoxXYXY | None, bbox_obj)
|
|
bbox_mask = cast(BBoxXYXY | None, bbox_mask_obj)
|
|
silhouette = cast(NDArray[np.float32] | None, silhouette_obj)
|
|
segmentation_input = cast(
|
|
NDArray[np.float32] | None,
|
|
segmentation_input_obj,
|
|
)
|
|
label = cast(str | None, label_obj)
|
|
confidence = cast(float | None, confidence_obj)
|
|
pose_data = cast(dict[str, object] | None, pose_obj)
|
|
else:
|
|
# No detection and no cache - use default values
|
|
mask_raw = None
|
|
bbox = None
|
|
bbox_mask = None
|
|
track_id = 0
|
|
silhouette = None
|
|
segmentation_input = None
|
|
label = None
|
|
confidence = None
|
|
pose_data = None
|
|
|
|
if self._detect_visualizer_pose_kwarg():
|
|
keep_running = self._visualizer.update(
|
|
frame_u8,
|
|
bbox,
|
|
bbox_mask,
|
|
track_id,
|
|
mask_raw,
|
|
silhouette,
|
|
segmentation_input,
|
|
label,
|
|
confidence,
|
|
ema_fps,
|
|
pose_data=pose_data,
|
|
)
|
|
else:
|
|
keep_running = self._visualizer.update(
|
|
frame_u8,
|
|
bbox,
|
|
bbox_mask,
|
|
track_id,
|
|
mask_raw,
|
|
silhouette,
|
|
segmentation_input,
|
|
label,
|
|
confidence,
|
|
ema_fps,
|
|
)
|
|
if not keep_running:
|
|
logger.info("Visualization closed by user.")
|
|
break
|
|
|
|
if frame_count % 100 == 0:
|
|
elapsed = time.perf_counter() - start_time
|
|
fps = frame_count / elapsed if elapsed > 0 else 0.0
|
|
logger.info("Processed %d frames (%.2f FPS)", frame_count, fps)
|
|
return 0
|
|
except KeyboardInterrupt:
|
|
logger.info("Interrupted by user, shutting down cleanly.")
|
|
return 130
|
|
finally:
|
|
self.close()
|
|
|
|
def close(self) -> None:
|
|
if self._closed:
|
|
return
|
|
|
|
# Close visualizer if enabled
|
|
if self._visualizer is not None:
|
|
with suppress(Exception):
|
|
self._visualizer.close()
|
|
|
|
# Export silhouettes if requested
|
|
if self._silhouette_export_path is not None and self._silhouette_buffer:
|
|
self._export_silhouettes()
|
|
|
|
# Export results if requested
|
|
if self._result_export_path is not None and self._result_buffer:
|
|
self._export_results()
|
|
|
|
close_fn = getattr(self._publisher, "close", None)
|
|
if callable(close_fn):
|
|
with suppress(Exception):
|
|
_ = close_fn()
|
|
self._closed = True
|
|
|
|
def _export_silhouettes(self) -> None:
|
|
"""Export silhouettes to file in specified format."""
|
|
if self._silhouette_export_path is None:
|
|
return
|
|
|
|
self._silhouette_export_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
if self._silhouette_export_format == "pickle":
|
|
import pickle
|
|
|
|
with open(self._silhouette_export_path, "wb") as f:
|
|
pickle.dump(self._silhouette_buffer, f)
|
|
logger.info(
|
|
"Exported %d silhouettes to %s",
|
|
len(self._silhouette_buffer),
|
|
self._silhouette_export_path,
|
|
)
|
|
elif self._silhouette_export_format == "parquet":
|
|
self._export_parquet_silhouettes()
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported silhouette export format: {self._silhouette_export_format}"
|
|
)
|
|
|
|
def _visualize_silhouette(
|
|
self,
|
|
silhouette: Float[ndarray, "64 44"],
|
|
frame_idx: int,
|
|
track_id: int,
|
|
) -> None:
|
|
"""Save silhouette as PNG image."""
|
|
if self._silhouette_visualize_dir is None:
|
|
return
|
|
|
|
self._silhouette_visualize_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Convert float silhouette to uint8 (0-255)
|
|
silhouette_u8 = (silhouette * 255).astype(np.uint8)
|
|
|
|
# Create deterministic filename
|
|
filename = f"silhouette_frame{frame_idx:06d}_track{track_id:04d}.png"
|
|
output_path = self._silhouette_visualize_dir / filename
|
|
|
|
# Save using PIL
|
|
from PIL import Image
|
|
|
|
Image.fromarray(silhouette_u8).save(output_path)
|
|
|
|
def _export_parquet_silhouettes(self) -> None:
|
|
"""Export silhouettes to parquet format."""
|
|
import importlib
|
|
|
|
try:
|
|
pa = importlib.import_module("pyarrow")
|
|
pq = importlib.import_module("pyarrow.parquet")
|
|
except ImportError as e:
|
|
raise RuntimeError(
|
|
"Parquet export requires pyarrow. Install with: pip install pyarrow"
|
|
) from e
|
|
|
|
# Convert silhouettes to columnar format
|
|
frames = []
|
|
track_ids = []
|
|
timestamps = []
|
|
silhouettes = []
|
|
|
|
for item in self._silhouette_buffer:
|
|
frames.append(item["frame"])
|
|
track_ids.append(item["track_id"])
|
|
timestamps.append(item["timestamp_ns"])
|
|
silhouette_array = cast(NDArray[np.float32], item["silhouette"])
|
|
silhouettes.append(silhouette_array.flatten().tolist())
|
|
|
|
table = pa.table(
|
|
{
|
|
"frame": pa.array(frames, type=pa.int64()),
|
|
"track_id": pa.array(track_ids, type=pa.int64()),
|
|
"timestamp_ns": pa.array(timestamps, type=pa.int64()),
|
|
"silhouette": pa.array(silhouettes, type=pa.list_(pa.float32())),
|
|
}
|
|
)
|
|
|
|
pq.write_table(table, self._silhouette_export_path)
|
|
logger.info(
|
|
"Exported %d silhouettes to parquet: %s",
|
|
len(self._silhouette_buffer),
|
|
self._silhouette_export_path,
|
|
)
|
|
|
|
def _export_results(self) -> None:
|
|
"""Export results to file in specified format."""
|
|
if self._result_export_path is None:
|
|
return
|
|
|
|
self._result_export_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
if self._result_export_format == "json":
|
|
import json
|
|
|
|
with open(self._result_export_path, "w", encoding="utf-8") as f:
|
|
for result in self._result_buffer:
|
|
f.write(json.dumps(result, ensure_ascii=False, default=str) + "\n")
|
|
logger.info(
|
|
"Exported %d results to JSON: %s",
|
|
len(self._result_buffer),
|
|
self._result_export_path,
|
|
)
|
|
elif self._result_export_format == "pickle":
|
|
import pickle
|
|
|
|
with open(self._result_export_path, "wb") as f:
|
|
pickle.dump(self._result_buffer, f)
|
|
logger.info(
|
|
"Exported %d results to pickle: %s",
|
|
len(self._result_buffer),
|
|
self._result_export_path,
|
|
)
|
|
elif self._result_export_format == "parquet":
|
|
self._export_parquet_results()
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported result export format: {self._result_export_format}"
|
|
)
|
|
|
|
def _export_parquet_results(self) -> None:
|
|
"""Export results to parquet format."""
|
|
import importlib
|
|
|
|
try:
|
|
pa = importlib.import_module("pyarrow")
|
|
pq = importlib.import_module("pyarrow.parquet")
|
|
except ImportError as e:
|
|
raise RuntimeError(
|
|
"Parquet export requires pyarrow. Install with: pip install pyarrow"
|
|
) from e
|
|
|
|
frames = []
|
|
track_ids = []
|
|
labels = []
|
|
confidences = []
|
|
windows = []
|
|
timestamps = []
|
|
|
|
for result in self._result_buffer:
|
|
frames.append(result["frame"])
|
|
track_ids.append(result["track_id"])
|
|
labels.append(result["label"])
|
|
confidences.append(result["confidence"])
|
|
windows.append(int(result["window"]))
|
|
timestamps.append(result["timestamp_ns"])
|
|
|
|
table = pa.table(
|
|
{
|
|
"frame": pa.array(frames, type=pa.int64()),
|
|
"track_id": pa.array(track_ids, type=pa.int64()),
|
|
"label": pa.array(labels, type=pa.string()),
|
|
"confidence": pa.array(confidences, type=pa.float64()),
|
|
"window": pa.array(windows, type=pa.int64()),
|
|
"timestamp_ns": pa.array(timestamps, type=pa.int64()),
|
|
}
|
|
)
|
|
|
|
pq.write_table(table, self._result_export_path)
|
|
logger.info(
|
|
"Exported %d results to parquet: %s",
|
|
len(self._result_buffer),
|
|
self._result_export_path,
|
|
)
|
|
|
|
|
|
def validate_runtime_inputs(source: str, checkpoint: str, config: str) -> None:
|
|
if source.startswith("cvmmap://") or source.isdigit():
|
|
pass
|
|
else:
|
|
source_path = Path(source)
|
|
if not source_path.is_file():
|
|
raise ValueError(f"Video source not found: {source}")
|
|
|
|
checkpoint_path = Path(checkpoint)
|
|
if not checkpoint_path.is_file():
|
|
raise ValueError(f"Checkpoint not found: {checkpoint}")
|
|
|
|
config_path = Path(config)
|
|
if not config_path.is_file():
|
|
raise ValueError(f"Config not found: {config}")
|
|
|
|
|
|
@click.command(context_settings={"help_option_names": ["-h", "--help"]})
|
|
@click.option("--source", type=str, required=True)
|
|
@click.option("--checkpoint", type=str, required=True)
|
|
@click.option(
|
|
"--config",
|
|
type=str,
|
|
default="configs/sconet/sconet_scoliosis1k.yaml",
|
|
show_default=True,
|
|
)
|
|
@click.option("--device", type=str, default="cuda:0", show_default=True)
|
|
@click.option(
|
|
"--yolo-model", type=str, default="ckpt/yolo11n-seg.pt", show_default=True
|
|
)
|
|
@click.option("--window", type=click.IntRange(min=1), default=30, show_default=True)
|
|
@click.option("--stride", type=click.IntRange(min=1), default=30, show_default=True)
|
|
@click.option(
|
|
"--window-mode",
|
|
type=click.Choice(["manual", "sliding", "chunked"], case_sensitive=False),
|
|
default="manual",
|
|
show_default=True,
|
|
help=(
|
|
"Window scheduling mode: manual uses --stride; "
|
|
"sliding forces stride=1; chunked forces stride=window"
|
|
),
|
|
)
|
|
@click.option(
|
|
"--target-fps",
|
|
type=click.FloatRange(min=0.1),
|
|
default=15.0,
|
|
show_default=True,
|
|
)
|
|
@click.option("--no-target-fps", is_flag=True, default=False)
|
|
@click.option("--nats-url", type=str, default=None)
|
|
@click.option(
|
|
"--nats-subject",
|
|
type=str,
|
|
default="scoliosis.result",
|
|
show_default=True,
|
|
)
|
|
@click.option("--max-frames", type=click.IntRange(min=1), default=None)
|
|
@click.option(
|
|
"--preprocess-only",
|
|
is_flag=True,
|
|
default=False,
|
|
help="Only preprocess silhouettes, skip classification.",
|
|
)
|
|
@click.option(
|
|
"--silhouette-export-path",
|
|
type=str,
|
|
default=None,
|
|
help="Path to export silhouettes (required for preprocess-only mode).",
|
|
)
|
|
@click.option(
|
|
"--silhouette-export-format",
|
|
type=click.Choice(["pickle", "parquet"]),
|
|
default="pickle",
|
|
show_default=True,
|
|
help="Format for silhouette export.",
|
|
)
|
|
@click.option(
|
|
"--result-export-path",
|
|
type=str,
|
|
default=None,
|
|
help="Path to export inference results.",
|
|
)
|
|
@click.option(
|
|
"--result-export-format",
|
|
type=click.Choice(["json", "pickle", "parquet"]),
|
|
default="json",
|
|
show_default=True,
|
|
help="Format for result export.",
|
|
)
|
|
@click.option(
|
|
"--silhouette-visualize-dir",
|
|
type=str,
|
|
default=None,
|
|
help="Directory to save silhouette PNG visualizations.",
|
|
)
|
|
@click.option(
|
|
"--visualize",
|
|
is_flag=True,
|
|
default=False,
|
|
help="Enable real-time visualization.",
|
|
)
|
|
def main(
|
|
source: str,
|
|
checkpoint: str,
|
|
config: str,
|
|
device: str,
|
|
yolo_model: str,
|
|
window: int,
|
|
stride: int,
|
|
window_mode: str,
|
|
target_fps: float,
|
|
no_target_fps: bool,
|
|
nats_url: str | None,
|
|
nats_subject: str,
|
|
max_frames: int | None,
|
|
preprocess_only: bool,
|
|
silhouette_export_path: str | None,
|
|
silhouette_export_format: str,
|
|
result_export_path: str | None,
|
|
result_export_format: str,
|
|
silhouette_visualize_dir: str | None,
|
|
visualize: bool,
|
|
) -> None:
|
|
# Resolve effective target_fps: respect --no-target_fps to disable pacing
|
|
effective_target_fps = None if no_target_fps else target_fps
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
|
)
|
|
|
|
# Validate preprocess-only mode requirements
|
|
if preprocess_only and not silhouette_export_path:
|
|
raise click.UsageError(
|
|
"--silhouette-export-path is required when using --preprocess-only"
|
|
)
|
|
|
|
try:
|
|
validate_runtime_inputs(source=source, checkpoint=checkpoint, config=config)
|
|
effective_stride = resolve_stride(
|
|
window=window,
|
|
stride=stride,
|
|
window_mode=cast(WindowMode, window_mode.lower()),
|
|
)
|
|
if effective_stride != stride:
|
|
logger.info(
|
|
"window_mode=%s overrides stride=%d -> effective_stride=%d",
|
|
window_mode,
|
|
stride,
|
|
effective_stride,
|
|
)
|
|
pipeline = ScoliosisPipeline(
|
|
source=source,
|
|
checkpoint=checkpoint,
|
|
config=config,
|
|
device=device,
|
|
yolo_model=yolo_model,
|
|
window=window,
|
|
stride=effective_stride,
|
|
nats_url=nats_url,
|
|
nats_subject=nats_subject,
|
|
max_frames=max_frames,
|
|
preprocess_only=preprocess_only,
|
|
silhouette_export_path=silhouette_export_path,
|
|
silhouette_export_format=silhouette_export_format,
|
|
silhouette_visualize_dir=silhouette_visualize_dir,
|
|
result_export_path=result_export_path,
|
|
result_export_format=result_export_format,
|
|
visualize=visualize,
|
|
target_fps=effective_target_fps,
|
|
)
|
|
raise SystemExit(pipeline.run())
|
|
except ValueError as err:
|
|
click.echo(f"Error: {err}", err=True)
|
|
raise SystemExit(2) from err
|
|
except RuntimeError as err:
|
|
click.echo(f"Runtime error: {err}", err=True)
|
|
raise SystemExit(1) from err
|