fix(demo): correct window start metadata and test unpacking

Use buffered frame indices for emitted window bounds to stay accurate across detection gaps, and align select_person tests with the 4-field return contract introduced for frame-space bbox support.
This commit is contained in:
2026-02-28 22:13:36 +08:00
parent ce64b559ec
commit 1f8f959ad7
3 changed files with 32 additions and 17 deletions
+19 -9
View File
@@ -53,7 +53,6 @@ class _DetectionResultsLike(Protocol):
def masks(self) -> _MasksLike: ...
class _TrackCallable(Protocol):
def __call__(
self,
@@ -209,7 +208,11 @@ class ScoliosisPipeline:
# Convert mask-space bbox to frame-space for visualization
# Use result.orig_shape to get frame dimensions safely
orig_shape = getattr(result, "orig_shape", None)
if orig_shape is not None and isinstance(orig_shape, (tuple, list)) and len(orig_shape) >= 2:
if (
orig_shape is not None
and isinstance(orig_shape, (tuple, list))
and len(orig_shape) >= 2
):
frame_h, frame_w = int(orig_shape[0]), int(orig_shape[1])
mask_h, mask_w = mask_u8.shape[0], mask_u8.shape[1]
if mask_w > 0 and mask_h > 0 and frame_w > 0 and frame_h > 0:
@@ -230,7 +233,6 @@ class ScoliosisPipeline:
# For fallback case, mask_raw is the same as mask_u8
return silhouette, mask_u8, bbox_frame, 0
@jaxtyped(typechecker=beartype)
def process_frame(
self,
@@ -308,7 +310,7 @@ class ScoliosisPipeline:
)
self._window.mark_classified()
window_start = frame_idx - self._window.window_size + 1
window_start = self._window.window_start_frame
result = create_result(
frame=frame_idx,
track_id=track_id,
@@ -377,16 +379,22 @@ class ScoliosisPipeline:
viz_payload_dict = cast(dict[str, object], viz_payload)
cached: dict[str, object] = {}
for k, v in viz_payload_dict.items():
copy_method = cast(Callable[[], object] | None, getattr(v, "copy", None))
copy_method = cast(
Callable[[], object] | None, getattr(v, "copy", None)
)
if copy_method is not None:
cached[k] = copy_method()
else:
cached[k] = v
self._last_viz_payload = cached
# Use cached payload if current is None
viz_data = viz_payload if viz_payload is not None else self._last_viz_payload
viz_data = (
viz_payload
if viz_payload is not None
else self._last_viz_payload
)
if viz_data is not None:
# Cast viz_payload to dict for type checking
viz_dict = cast(dict[str, object], viz_data)
@@ -658,7 +666,9 @@ def validate_runtime_inputs(source: str, checkpoint: str, config: str) -> None:
show_default=True,
)
@click.option("--device", type=str, default="cuda:0", show_default=True)
@click.option("--yolo-model", type=str, default="ckpt/yolo11n-seg.pt", show_default=True)
@click.option(
"--yolo-model", type=str, default="ckpt/yolo11n-seg.pt", show_default=True
)
@click.option("--window", type=click.IntRange(min=1), default=30, show_default=True)
@click.option("--stride", type=click.IntRange(min=1), default=30, show_default=True)
@click.option("--nats-url", type=str, default=None)
+6
View File
@@ -210,6 +210,12 @@ class SilhouetteWindow:
"""Fill ratio of the buffer (0.0 to 1.0)."""
return len(self._buffer) / self.window_size
@property
def window_start_frame(self) -> int:
if not self._frame_indices:
raise ValueError("Window is empty")
return int(self._frame_indices[0])
def _to_numpy(obj: _ArrayLike) -> ndarray:
"""Safely convert array-like object to numpy array.