From 1f8f959ad7fbd60dc4b67f93c10b8606875e9256 Mon Sep 17 00:00:00 2001 From: crosstyan Date: Sat, 28 Feb 2026 22:13:36 +0800 Subject: [PATCH] fix(demo): correct window start metadata and test unpacking Use buffered frame indices for emitted window bounds to stay accurate across detection gaps, and align select_person tests with the 4-field return contract introduced for frame-space bbox support. --- opengait/demo/pipeline.py | 28 +++++++++++++++++++--------- opengait/demo/window.py | 6 ++++++ tests/demo/test_window.py | 15 +++++++-------- 3 files changed, 32 insertions(+), 17 deletions(-) diff --git a/opengait/demo/pipeline.py b/opengait/demo/pipeline.py index f1c5908..54029a1 100644 --- a/opengait/demo/pipeline.py +++ b/opengait/demo/pipeline.py @@ -53,7 +53,6 @@ class _DetectionResultsLike(Protocol): def masks(self) -> _MasksLike: ... - class _TrackCallable(Protocol): def __call__( self, @@ -209,7 +208,11 @@ class ScoliosisPipeline: # Convert mask-space bbox to frame-space for visualization # Use result.orig_shape to get frame dimensions safely orig_shape = getattr(result, "orig_shape", None) - if orig_shape is not None and isinstance(orig_shape, (tuple, list)) and len(orig_shape) >= 2: + if ( + orig_shape is not None + and isinstance(orig_shape, (tuple, list)) + and len(orig_shape) >= 2 + ): frame_h, frame_w = int(orig_shape[0]), int(orig_shape[1]) mask_h, mask_w = mask_u8.shape[0], mask_u8.shape[1] if mask_w > 0 and mask_h > 0 and frame_w > 0 and frame_h > 0: @@ -230,7 +233,6 @@ class ScoliosisPipeline: # For fallback case, mask_raw is the same as mask_u8 return silhouette, mask_u8, bbox_frame, 0 - @jaxtyped(typechecker=beartype) def process_frame( self, @@ -308,7 +310,7 @@ class ScoliosisPipeline: ) self._window.mark_classified() - window_start = frame_idx - self._window.window_size + 1 + window_start = self._window.window_start_frame result = create_result( frame=frame_idx, track_id=track_id, @@ -377,16 +379,22 @@ class ScoliosisPipeline: viz_payload_dict = cast(dict[str, object], viz_payload) cached: dict[str, object] = {} for k, v in viz_payload_dict.items(): - copy_method = cast(Callable[[], object] | None, getattr(v, "copy", None)) + copy_method = cast( + Callable[[], object] | None, getattr(v, "copy", None) + ) if copy_method is not None: cached[k] = copy_method() else: cached[k] = v self._last_viz_payload = cached - + # Use cached payload if current is None - viz_data = viz_payload if viz_payload is not None else self._last_viz_payload - + viz_data = ( + viz_payload + if viz_payload is not None + else self._last_viz_payload + ) + if viz_data is not None: # Cast viz_payload to dict for type checking viz_dict = cast(dict[str, object], viz_data) @@ -658,7 +666,9 @@ def validate_runtime_inputs(source: str, checkpoint: str, config: str) -> None: show_default=True, ) @click.option("--device", type=str, default="cuda:0", show_default=True) -@click.option("--yolo-model", type=str, default="ckpt/yolo11n-seg.pt", show_default=True) +@click.option( + "--yolo-model", type=str, default="ckpt/yolo11n-seg.pt", show_default=True +) @click.option("--window", type=click.IntRange(min=1), default=30, show_default=True) @click.option("--stride", type=click.IntRange(min=1), default=30, show_default=True) @click.option("--nats-url", type=str, default=None) diff --git a/opengait/demo/window.py b/opengait/demo/window.py index 6443271..34f44db 100644 --- a/opengait/demo/window.py +++ b/opengait/demo/window.py @@ -210,6 +210,12 @@ class SilhouetteWindow: """Fill ratio of the buffer (0.0 to 1.0).""" return len(self._buffer) / self.window_size + @property + def window_start_frame(self) -> int: + if not self._frame_indices: + raise ValueError("Window is empty") + return int(self._frame_indices[0]) + def _to_numpy(obj: _ArrayLike) -> ndarray: """Safely convert array-like object to numpy array. diff --git a/tests/demo/test_window.py b/tests/demo/test_window.py index a69a89a..0b67268 100644 --- a/tests/demo/test_window.py +++ b/tests/demo/test_window.py @@ -234,7 +234,7 @@ class TestSelectPerson: result = select_person(results) assert result is not None - mask, bbox, tid = result + mask, bbox, _, tid = result assert mask.shape == (100, 100) assert bbox == (10, 10, 50, 90) assert tid == 42 @@ -257,7 +257,7 @@ class TestSelectPerson: result = select_person(results) assert result is not None - mask, bbox, tid = result + mask, bbox, _, tid = result assert bbox == (0, 0, 30, 30) # Largest box assert tid == 2 # Corresponding track ID @@ -327,7 +327,7 @@ class TestSelectPerson: result = select_person(results) assert result is not None - _, bbox, tid = result + _, bbox, _, tid = result assert bbox == (10, 10, 50, 90) assert tid == 1 @@ -341,11 +341,10 @@ class TestSelectPerson: result = select_person(results) assert result is not None - mask, _, _ = result + mask, _, _, _ = result # Should be 2D (extracted from expanded 3D) assert mask.shape == (100, 100) - def test_select_person_tensor_cpu_inputs(self) -> None: """Tensor-backed inputs (CPU) should work correctly.""" boxes = torch.tensor([[10.0, 10.0, 50.0, 90.0]], dtype=torch.float32) @@ -356,7 +355,7 @@ class TestSelectPerson: result = select_person(results) assert result is not None - mask, bbox, tid = result + mask, bbox, _, tid = result assert mask.shape == (100, 100) assert bbox == (10, 10, 50, 90) assert tid == 42 @@ -372,7 +371,7 @@ class TestSelectPerson: result = select_person(results) assert result is not None - mask, bbox, tid = result + mask, bbox, _, tid = result assert mask.shape == (100, 100) assert bbox == (10, 10, 50, 90) assert tid == 42 @@ -394,6 +393,6 @@ class TestSelectPerson: result = select_person(results) assert result is not None - _, bbox, tid = result + _, bbox, _, tid = result assert bbox == (0, 0, 30, 30) # Largest box assert tid == 2 # Corresponding track ID