chore: update demo runtime, tests, and agent docs
This commit is contained in:
+105
-5
@@ -5,7 +5,7 @@ from contextlib import suppress
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Protocol, cast
|
||||
from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, cast
|
||||
|
||||
from beartype import beartype
|
||||
import click
|
||||
@@ -31,6 +31,16 @@ JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
|
||||
JaxtypedFactory = Callable[..., JaxtypedDecorator]
|
||||
jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
|
||||
|
||||
WindowMode: TypeAlias = Literal["manual", "sliding", "chunked"]
|
||||
|
||||
|
||||
def resolve_stride(window: int, stride: int, window_mode: WindowMode) -> int:
|
||||
if window_mode == "manual":
|
||||
return stride
|
||||
if window_mode == "sliding":
|
||||
return 1
|
||||
return window
|
||||
|
||||
|
||||
class _BoxesLike(Protocol):
|
||||
@property
|
||||
@@ -65,6 +75,27 @@ class _TrackCallable(Protocol):
|
||||
) -> object: ...
|
||||
|
||||
|
||||
class _FramePacer:
|
||||
_interval_ns: int
|
||||
_next_emit_ns: int | None
|
||||
|
||||
def __init__(self, target_fps: float) -> None:
|
||||
if target_fps <= 0:
|
||||
raise ValueError(f"target_fps must be positive, got {target_fps}")
|
||||
self._interval_ns = int(1_000_000_000 / target_fps)
|
||||
self._next_emit_ns = None
|
||||
|
||||
def should_emit(self, timestamp_ns: int) -> bool:
|
||||
if self._next_emit_ns is None:
|
||||
self._next_emit_ns = timestamp_ns + self._interval_ns
|
||||
return True
|
||||
if timestamp_ns >= self._next_emit_ns:
|
||||
while self._next_emit_ns <= timestamp_ns:
|
||||
self._next_emit_ns += self._interval_ns
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class ScoliosisPipeline:
|
||||
_detector: object
|
||||
_source: FrameStream
|
||||
@@ -83,6 +114,7 @@ class ScoliosisPipeline:
|
||||
_result_buffer: list[DemoResult]
|
||||
_visualizer: OpenCVVisualizer | None
|
||||
_last_viz_payload: dict[str, object] | None
|
||||
_frame_pacer: _FramePacer | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -104,6 +136,7 @@ class ScoliosisPipeline:
|
||||
result_export_path: str | None = None,
|
||||
result_export_format: str = "json",
|
||||
visualize: bool = False,
|
||||
target_fps: float | None = 15.0,
|
||||
) -> None:
|
||||
self._detector = YOLO(yolo_model)
|
||||
self._source = create_source(source, max_frames=max_frames)
|
||||
@@ -140,6 +173,7 @@ class ScoliosisPipeline:
|
||||
else:
|
||||
self._visualizer = None
|
||||
self._last_viz_payload = None
|
||||
self._frame_pacer = _FramePacer(target_fps) if target_fps is not None else None
|
||||
|
||||
@staticmethod
|
||||
def _extract_int(meta: dict[str, object], key: str, fallback: int) -> int:
|
||||
@@ -177,6 +211,7 @@ class ScoliosisPipeline:
|
||||
Float[ndarray, "64 44"],
|
||||
UInt8[ndarray, "h w"],
|
||||
BBoxXYXY,
|
||||
BBoxXYXY,
|
||||
int,
|
||||
]
|
||||
| None
|
||||
@@ -189,7 +224,7 @@ class ScoliosisPipeline:
|
||||
mask_to_silhouette(self._to_mask_u8(mask_raw), bbox_mask),
|
||||
)
|
||||
if silhouette is not None:
|
||||
return silhouette, mask_raw, bbox_frame, int(track_id)
|
||||
return silhouette, mask_raw, bbox_frame, bbox_mask, int(track_id)
|
||||
|
||||
fallback = cast(
|
||||
tuple[UInt8[ndarray, "h w"], BBoxXYXY] | None,
|
||||
@@ -231,7 +266,7 @@ class ScoliosisPipeline:
|
||||
# Fallback: use mask-space bbox if orig_shape unavailable
|
||||
bbox_frame = bbox_mask
|
||||
# For fallback case, mask_raw is the same as mask_u8
|
||||
return silhouette, mask_u8, bbox_frame, 0
|
||||
return silhouette, mask_u8, bbox_frame, bbox_mask, 0
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def process_frame(
|
||||
@@ -262,7 +297,7 @@ class ScoliosisPipeline:
|
||||
if selected is None:
|
||||
return None
|
||||
|
||||
silhouette, mask_raw, bbox, track_id = selected
|
||||
silhouette, mask_raw, bbox, bbox_mask, track_id = selected
|
||||
|
||||
# Store silhouette for export if in preprocess-only mode or if export requested
|
||||
if self._silhouette_export_path is not None or self._preprocess_only:
|
||||
@@ -284,20 +319,39 @@ class ScoliosisPipeline:
|
||||
return {
|
||||
"mask_raw": mask_raw,
|
||||
"bbox": bbox,
|
||||
"bbox_mask": bbox_mask,
|
||||
"silhouette": silhouette,
|
||||
"segmentation_input": None,
|
||||
"track_id": track_id,
|
||||
"label": None,
|
||||
"confidence": None,
|
||||
}
|
||||
|
||||
if self._frame_pacer is not None and not self._frame_pacer.should_emit(
|
||||
timestamp_ns
|
||||
):
|
||||
return {
|
||||
"mask_raw": mask_raw,
|
||||
"bbox": bbox,
|
||||
"bbox_mask": bbox_mask,
|
||||
"silhouette": silhouette,
|
||||
"segmentation_input": self._window.buffered_silhouettes,
|
||||
"track_id": track_id,
|
||||
"label": None,
|
||||
"confidence": None,
|
||||
}
|
||||
|
||||
self._window.push(silhouette, frame_idx=frame_idx, track_id=track_id)
|
||||
segmentation_input = self._window.buffered_silhouettes
|
||||
|
||||
if not self._window.should_classify():
|
||||
# Return visualization payload even when not classifying yet
|
||||
return {
|
||||
"mask_raw": mask_raw,
|
||||
"bbox": bbox,
|
||||
"bbox_mask": bbox_mask,
|
||||
"silhouette": silhouette,
|
||||
"segmentation_input": segmentation_input,
|
||||
"track_id": track_id,
|
||||
"label": None,
|
||||
"confidence": None,
|
||||
@@ -330,7 +384,9 @@ class ScoliosisPipeline:
|
||||
"result": result,
|
||||
"mask_raw": mask_raw,
|
||||
"bbox": bbox,
|
||||
"bbox_mask": bbox_mask,
|
||||
"silhouette": silhouette,
|
||||
"segmentation_input": segmentation_input,
|
||||
"track_id": track_id,
|
||||
"label": label,
|
||||
"confidence": confidence,
|
||||
@@ -400,7 +456,9 @@ class ScoliosisPipeline:
|
||||
viz_dict = cast(dict[str, object], viz_data)
|
||||
mask_raw_obj = viz_dict.get("mask_raw")
|
||||
bbox_obj = viz_dict.get("bbox")
|
||||
bbox_mask_obj = viz_dict.get("bbox_mask")
|
||||
silhouette_obj = viz_dict.get("silhouette")
|
||||
segmentation_input_obj = viz_dict.get("segmentation_input")
|
||||
track_id_val = viz_dict.get("track_id", 0)
|
||||
track_id = track_id_val if isinstance(track_id_val, int) else 0
|
||||
label_obj = viz_dict.get("label")
|
||||
@@ -409,24 +467,33 @@ class ScoliosisPipeline:
|
||||
# Cast extracted values to expected types
|
||||
mask_raw = cast(NDArray[np.uint8] | None, mask_raw_obj)
|
||||
bbox = cast(BBoxXYXY | None, bbox_obj)
|
||||
bbox_mask = cast(BBoxXYXY | None, bbox_mask_obj)
|
||||
silhouette = cast(NDArray[np.float32] | None, silhouette_obj)
|
||||
segmentation_input = cast(
|
||||
NDArray[np.float32] | None,
|
||||
segmentation_input_obj,
|
||||
)
|
||||
label = cast(str | None, label_obj)
|
||||
confidence = cast(float | None, confidence_obj)
|
||||
else:
|
||||
# No detection and no cache - use default values
|
||||
mask_raw = None
|
||||
bbox = None
|
||||
bbox_mask = None
|
||||
track_id = 0
|
||||
silhouette = None
|
||||
segmentation_input = None
|
||||
label = None
|
||||
confidence = None
|
||||
|
||||
keep_running = self._visualizer.update(
|
||||
frame_u8,
|
||||
bbox,
|
||||
bbox_mask,
|
||||
track_id,
|
||||
mask_raw,
|
||||
silhouette,
|
||||
segmentation_input,
|
||||
label,
|
||||
confidence,
|
||||
ema_fps,
|
||||
@@ -671,6 +738,23 @@ def validate_runtime_inputs(source: str, checkpoint: str, config: str) -> None:
|
||||
)
|
||||
@click.option("--window", type=click.IntRange(min=1), default=30, show_default=True)
|
||||
@click.option("--stride", type=click.IntRange(min=1), default=30, show_default=True)
|
||||
@click.option(
|
||||
"--window-mode",
|
||||
type=click.Choice(["manual", "sliding", "chunked"], case_sensitive=False),
|
||||
default="manual",
|
||||
show_default=True,
|
||||
help=(
|
||||
"Window scheduling mode: manual uses --stride; "
|
||||
"sliding forces stride=1; chunked forces stride=window"
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--target-fps",
|
||||
type=click.FloatRange(min=0.1),
|
||||
default=15.0,
|
||||
show_default=True,
|
||||
)
|
||||
@click.option("--no-target-fps", is_flag=True, default=False)
|
||||
@click.option("--nats-url", type=str, default=None)
|
||||
@click.option(
|
||||
"--nats-subject",
|
||||
@@ -725,6 +809,9 @@ def main(
|
||||
yolo_model: str,
|
||||
window: int,
|
||||
stride: int,
|
||||
window_mode: str,
|
||||
target_fps: float | None,
|
||||
no_target_fps: bool,
|
||||
nats_url: str | None,
|
||||
nats_subject: str,
|
||||
max_frames: int | None,
|
||||
@@ -748,6 +835,18 @@ def main(
|
||||
|
||||
try:
|
||||
validate_runtime_inputs(source=source, checkpoint=checkpoint, config=config)
|
||||
effective_stride = resolve_stride(
|
||||
window=window,
|
||||
stride=stride,
|
||||
window_mode=cast(WindowMode, window_mode.lower()),
|
||||
)
|
||||
if effective_stride != stride:
|
||||
logger.info(
|
||||
"window_mode=%s overrides stride=%d -> effective_stride=%d",
|
||||
window_mode,
|
||||
stride,
|
||||
effective_stride,
|
||||
)
|
||||
pipeline = ScoliosisPipeline(
|
||||
source=source,
|
||||
checkpoint=checkpoint,
|
||||
@@ -755,7 +854,8 @@ def main(
|
||||
device=device,
|
||||
yolo_model=yolo_model,
|
||||
window=window,
|
||||
stride=stride,
|
||||
stride=effective_stride,
|
||||
target_fps=None if no_target_fps else target_fps,
|
||||
nats_url=nats_url,
|
||||
nats_subject=nats_subject,
|
||||
max_frames=max_frames,
|
||||
|
||||
Reference in New Issue
Block a user