chore: update demo runtime, tests, and agent docs

This commit is contained in:
2026-03-02 12:33:17 +08:00
parent 1f8f959ad7
commit cbb3284c13
14 changed files with 1491 additions and 236 deletions
+105 -5
View File
@@ -5,7 +5,7 @@ from contextlib import suppress
import logging
from pathlib import Path
import time
from typing import TYPE_CHECKING, Protocol, cast
from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, cast
from beartype import beartype
import click
@@ -31,6 +31,16 @@ JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
JaxtypedFactory = Callable[..., JaxtypedDecorator]
jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
WindowMode: TypeAlias = Literal["manual", "sliding", "chunked"]
def resolve_stride(window: int, stride: int, window_mode: WindowMode) -> int:
if window_mode == "manual":
return stride
if window_mode == "sliding":
return 1
return window
class _BoxesLike(Protocol):
@property
@@ -65,6 +75,27 @@ class _TrackCallable(Protocol):
) -> object: ...
class _FramePacer:
_interval_ns: int
_next_emit_ns: int | None
def __init__(self, target_fps: float) -> None:
if target_fps <= 0:
raise ValueError(f"target_fps must be positive, got {target_fps}")
self._interval_ns = int(1_000_000_000 / target_fps)
self._next_emit_ns = None
def should_emit(self, timestamp_ns: int) -> bool:
if self._next_emit_ns is None:
self._next_emit_ns = timestamp_ns + self._interval_ns
return True
if timestamp_ns >= self._next_emit_ns:
while self._next_emit_ns <= timestamp_ns:
self._next_emit_ns += self._interval_ns
return True
return False
class ScoliosisPipeline:
_detector: object
_source: FrameStream
@@ -83,6 +114,7 @@ class ScoliosisPipeline:
_result_buffer: list[DemoResult]
_visualizer: OpenCVVisualizer | None
_last_viz_payload: dict[str, object] | None
_frame_pacer: _FramePacer | None
def __init__(
self,
@@ -104,6 +136,7 @@ class ScoliosisPipeline:
result_export_path: str | None = None,
result_export_format: str = "json",
visualize: bool = False,
target_fps: float | None = 15.0,
) -> None:
self._detector = YOLO(yolo_model)
self._source = create_source(source, max_frames=max_frames)
@@ -140,6 +173,7 @@ class ScoliosisPipeline:
else:
self._visualizer = None
self._last_viz_payload = None
self._frame_pacer = _FramePacer(target_fps) if target_fps is not None else None
@staticmethod
def _extract_int(meta: dict[str, object], key: str, fallback: int) -> int:
@@ -177,6 +211,7 @@ class ScoliosisPipeline:
Float[ndarray, "64 44"],
UInt8[ndarray, "h w"],
BBoxXYXY,
BBoxXYXY,
int,
]
| None
@@ -189,7 +224,7 @@ class ScoliosisPipeline:
mask_to_silhouette(self._to_mask_u8(mask_raw), bbox_mask),
)
if silhouette is not None:
return silhouette, mask_raw, bbox_frame, int(track_id)
return silhouette, mask_raw, bbox_frame, bbox_mask, int(track_id)
fallback = cast(
tuple[UInt8[ndarray, "h w"], BBoxXYXY] | None,
@@ -231,7 +266,7 @@ class ScoliosisPipeline:
# Fallback: use mask-space bbox if orig_shape unavailable
bbox_frame = bbox_mask
# For fallback case, mask_raw is the same as mask_u8
return silhouette, mask_u8, bbox_frame, 0
return silhouette, mask_u8, bbox_frame, bbox_mask, 0
@jaxtyped(typechecker=beartype)
def process_frame(
@@ -262,7 +297,7 @@ class ScoliosisPipeline:
if selected is None:
return None
silhouette, mask_raw, bbox, track_id = selected
silhouette, mask_raw, bbox, bbox_mask, track_id = selected
# Store silhouette for export if in preprocess-only mode or if export requested
if self._silhouette_export_path is not None or self._preprocess_only:
@@ -284,20 +319,39 @@ class ScoliosisPipeline:
return {
"mask_raw": mask_raw,
"bbox": bbox,
"bbox_mask": bbox_mask,
"silhouette": silhouette,
"segmentation_input": None,
"track_id": track_id,
"label": None,
"confidence": None,
}
if self._frame_pacer is not None and not self._frame_pacer.should_emit(
timestamp_ns
):
return {
"mask_raw": mask_raw,
"bbox": bbox,
"bbox_mask": bbox_mask,
"silhouette": silhouette,
"segmentation_input": self._window.buffered_silhouettes,
"track_id": track_id,
"label": None,
"confidence": None,
}
self._window.push(silhouette, frame_idx=frame_idx, track_id=track_id)
segmentation_input = self._window.buffered_silhouettes
if not self._window.should_classify():
# Return visualization payload even when not classifying yet
return {
"mask_raw": mask_raw,
"bbox": bbox,
"bbox_mask": bbox_mask,
"silhouette": silhouette,
"segmentation_input": segmentation_input,
"track_id": track_id,
"label": None,
"confidence": None,
@@ -330,7 +384,9 @@ class ScoliosisPipeline:
"result": result,
"mask_raw": mask_raw,
"bbox": bbox,
"bbox_mask": bbox_mask,
"silhouette": silhouette,
"segmentation_input": segmentation_input,
"track_id": track_id,
"label": label,
"confidence": confidence,
@@ -400,7 +456,9 @@ class ScoliosisPipeline:
viz_dict = cast(dict[str, object], viz_data)
mask_raw_obj = viz_dict.get("mask_raw")
bbox_obj = viz_dict.get("bbox")
bbox_mask_obj = viz_dict.get("bbox_mask")
silhouette_obj = viz_dict.get("silhouette")
segmentation_input_obj = viz_dict.get("segmentation_input")
track_id_val = viz_dict.get("track_id", 0)
track_id = track_id_val if isinstance(track_id_val, int) else 0
label_obj = viz_dict.get("label")
@@ -409,24 +467,33 @@ class ScoliosisPipeline:
# Cast extracted values to expected types
mask_raw = cast(NDArray[np.uint8] | None, mask_raw_obj)
bbox = cast(BBoxXYXY | None, bbox_obj)
bbox_mask = cast(BBoxXYXY | None, bbox_mask_obj)
silhouette = cast(NDArray[np.float32] | None, silhouette_obj)
segmentation_input = cast(
NDArray[np.float32] | None,
segmentation_input_obj,
)
label = cast(str | None, label_obj)
confidence = cast(float | None, confidence_obj)
else:
# No detection and no cache - use default values
mask_raw = None
bbox = None
bbox_mask = None
track_id = 0
silhouette = None
segmentation_input = None
label = None
confidence = None
keep_running = self._visualizer.update(
frame_u8,
bbox,
bbox_mask,
track_id,
mask_raw,
silhouette,
segmentation_input,
label,
confidence,
ema_fps,
@@ -671,6 +738,23 @@ def validate_runtime_inputs(source: str, checkpoint: str, config: str) -> None:
)
@click.option("--window", type=click.IntRange(min=1), default=30, show_default=True)
@click.option("--stride", type=click.IntRange(min=1), default=30, show_default=True)
@click.option(
"--window-mode",
type=click.Choice(["manual", "sliding", "chunked"], case_sensitive=False),
default="manual",
show_default=True,
help=(
"Window scheduling mode: manual uses --stride; "
"sliding forces stride=1; chunked forces stride=window"
),
)
@click.option(
"--target-fps",
type=click.FloatRange(min=0.1),
default=15.0,
show_default=True,
)
@click.option("--no-target-fps", is_flag=True, default=False)
@click.option("--nats-url", type=str, default=None)
@click.option(
"--nats-subject",
@@ -725,6 +809,9 @@ def main(
yolo_model: str,
window: int,
stride: int,
window_mode: str,
target_fps: float | None,
no_target_fps: bool,
nats_url: str | None,
nats_subject: str,
max_frames: int | None,
@@ -748,6 +835,18 @@ def main(
try:
validate_runtime_inputs(source=source, checkpoint=checkpoint, config=config)
effective_stride = resolve_stride(
window=window,
stride=stride,
window_mode=cast(WindowMode, window_mode.lower()),
)
if effective_stride != stride:
logger.info(
"window_mode=%s overrides stride=%d -> effective_stride=%d",
window_mode,
stride,
effective_stride,
)
pipeline = ScoliosisPipeline(
source=source,
checkpoint=checkpoint,
@@ -755,7 +854,8 @@ def main(
device=device,
yolo_model=yolo_model,
window=window,
stride=stride,
stride=effective_stride,
target_fps=None if no_target_fps else target_fps,
nats_url=nats_url,
nats_subject=nats_subject,
max_frames=max_frames,