fix(demo): correct window start metadata and test unpacking
Use buffered frame indices for emitted window bounds to stay accurate across detection gaps, and align select_person tests with the 4-field return contract introduced for frame-space bbox support.
This commit is contained in:
@@ -53,7 +53,6 @@ class _DetectionResultsLike(Protocol):
|
|||||||
def masks(self) -> _MasksLike: ...
|
def masks(self) -> _MasksLike: ...
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class _TrackCallable(Protocol):
|
class _TrackCallable(Protocol):
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@@ -209,7 +208,11 @@ class ScoliosisPipeline:
|
|||||||
# Convert mask-space bbox to frame-space for visualization
|
# Convert mask-space bbox to frame-space for visualization
|
||||||
# Use result.orig_shape to get frame dimensions safely
|
# Use result.orig_shape to get frame dimensions safely
|
||||||
orig_shape = getattr(result, "orig_shape", None)
|
orig_shape = getattr(result, "orig_shape", None)
|
||||||
if orig_shape is not None and isinstance(orig_shape, (tuple, list)) and len(orig_shape) >= 2:
|
if (
|
||||||
|
orig_shape is not None
|
||||||
|
and isinstance(orig_shape, (tuple, list))
|
||||||
|
and len(orig_shape) >= 2
|
||||||
|
):
|
||||||
frame_h, frame_w = int(orig_shape[0]), int(orig_shape[1])
|
frame_h, frame_w = int(orig_shape[0]), int(orig_shape[1])
|
||||||
mask_h, mask_w = mask_u8.shape[0], mask_u8.shape[1]
|
mask_h, mask_w = mask_u8.shape[0], mask_u8.shape[1]
|
||||||
if mask_w > 0 and mask_h > 0 and frame_w > 0 and frame_h > 0:
|
if mask_w > 0 and mask_h > 0 and frame_w > 0 and frame_h > 0:
|
||||||
@@ -230,7 +233,6 @@ class ScoliosisPipeline:
|
|||||||
# For fallback case, mask_raw is the same as mask_u8
|
# For fallback case, mask_raw is the same as mask_u8
|
||||||
return silhouette, mask_u8, bbox_frame, 0
|
return silhouette, mask_u8, bbox_frame, 0
|
||||||
|
|
||||||
|
|
||||||
@jaxtyped(typechecker=beartype)
|
@jaxtyped(typechecker=beartype)
|
||||||
def process_frame(
|
def process_frame(
|
||||||
self,
|
self,
|
||||||
@@ -308,7 +310,7 @@ class ScoliosisPipeline:
|
|||||||
)
|
)
|
||||||
self._window.mark_classified()
|
self._window.mark_classified()
|
||||||
|
|
||||||
window_start = frame_idx - self._window.window_size + 1
|
window_start = self._window.window_start_frame
|
||||||
result = create_result(
|
result = create_result(
|
||||||
frame=frame_idx,
|
frame=frame_idx,
|
||||||
track_id=track_id,
|
track_id=track_id,
|
||||||
@@ -377,16 +379,22 @@ class ScoliosisPipeline:
|
|||||||
viz_payload_dict = cast(dict[str, object], viz_payload)
|
viz_payload_dict = cast(dict[str, object], viz_payload)
|
||||||
cached: dict[str, object] = {}
|
cached: dict[str, object] = {}
|
||||||
for k, v in viz_payload_dict.items():
|
for k, v in viz_payload_dict.items():
|
||||||
copy_method = cast(Callable[[], object] | None, getattr(v, "copy", None))
|
copy_method = cast(
|
||||||
|
Callable[[], object] | None, getattr(v, "copy", None)
|
||||||
|
)
|
||||||
if copy_method is not None:
|
if copy_method is not None:
|
||||||
cached[k] = copy_method()
|
cached[k] = copy_method()
|
||||||
else:
|
else:
|
||||||
cached[k] = v
|
cached[k] = v
|
||||||
self._last_viz_payload = cached
|
self._last_viz_payload = cached
|
||||||
|
|
||||||
# Use cached payload if current is None
|
# Use cached payload if current is None
|
||||||
viz_data = viz_payload if viz_payload is not None else self._last_viz_payload
|
viz_data = (
|
||||||
|
viz_payload
|
||||||
|
if viz_payload is not None
|
||||||
|
else self._last_viz_payload
|
||||||
|
)
|
||||||
|
|
||||||
if viz_data is not None:
|
if viz_data is not None:
|
||||||
# Cast viz_payload to dict for type checking
|
# Cast viz_payload to dict for type checking
|
||||||
viz_dict = cast(dict[str, object], viz_data)
|
viz_dict = cast(dict[str, object], viz_data)
|
||||||
@@ -658,7 +666,9 @@ def validate_runtime_inputs(source: str, checkpoint: str, config: str) -> None:
|
|||||||
show_default=True,
|
show_default=True,
|
||||||
)
|
)
|
||||||
@click.option("--device", type=str, default="cuda:0", show_default=True)
|
@click.option("--device", type=str, default="cuda:0", show_default=True)
|
||||||
@click.option("--yolo-model", type=str, default="ckpt/yolo11n-seg.pt", show_default=True)
|
@click.option(
|
||||||
|
"--yolo-model", type=str, default="ckpt/yolo11n-seg.pt", show_default=True
|
||||||
|
)
|
||||||
@click.option("--window", type=click.IntRange(min=1), default=30, show_default=True)
|
@click.option("--window", type=click.IntRange(min=1), default=30, show_default=True)
|
||||||
@click.option("--stride", type=click.IntRange(min=1), default=30, show_default=True)
|
@click.option("--stride", type=click.IntRange(min=1), default=30, show_default=True)
|
||||||
@click.option("--nats-url", type=str, default=None)
|
@click.option("--nats-url", type=str, default=None)
|
||||||
|
|||||||
@@ -210,6 +210,12 @@ class SilhouetteWindow:
|
|||||||
"""Fill ratio of the buffer (0.0 to 1.0)."""
|
"""Fill ratio of the buffer (0.0 to 1.0)."""
|
||||||
return len(self._buffer) / self.window_size
|
return len(self._buffer) / self.window_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def window_start_frame(self) -> int:
|
||||||
|
if not self._frame_indices:
|
||||||
|
raise ValueError("Window is empty")
|
||||||
|
return int(self._frame_indices[0])
|
||||||
|
|
||||||
|
|
||||||
def _to_numpy(obj: _ArrayLike) -> ndarray:
|
def _to_numpy(obj: _ArrayLike) -> ndarray:
|
||||||
"""Safely convert array-like object to numpy array.
|
"""Safely convert array-like object to numpy array.
|
||||||
|
|||||||
@@ -234,7 +234,7 @@ class TestSelectPerson:
|
|||||||
result = select_person(results)
|
result = select_person(results)
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
mask, bbox, tid = result
|
mask, bbox, _, tid = result
|
||||||
assert mask.shape == (100, 100)
|
assert mask.shape == (100, 100)
|
||||||
assert bbox == (10, 10, 50, 90)
|
assert bbox == (10, 10, 50, 90)
|
||||||
assert tid == 42
|
assert tid == 42
|
||||||
@@ -257,7 +257,7 @@ class TestSelectPerson:
|
|||||||
result = select_person(results)
|
result = select_person(results)
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
mask, bbox, tid = result
|
mask, bbox, _, tid = result
|
||||||
assert bbox == (0, 0, 30, 30) # Largest box
|
assert bbox == (0, 0, 30, 30) # Largest box
|
||||||
assert tid == 2 # Corresponding track ID
|
assert tid == 2 # Corresponding track ID
|
||||||
|
|
||||||
@@ -327,7 +327,7 @@ class TestSelectPerson:
|
|||||||
result = select_person(results)
|
result = select_person(results)
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
_, bbox, tid = result
|
_, bbox, _, tid = result
|
||||||
assert bbox == (10, 10, 50, 90)
|
assert bbox == (10, 10, 50, 90)
|
||||||
assert tid == 1
|
assert tid == 1
|
||||||
|
|
||||||
@@ -341,11 +341,10 @@ class TestSelectPerson:
|
|||||||
result = select_person(results)
|
result = select_person(results)
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
mask, _, _ = result
|
mask, _, _, _ = result
|
||||||
# Should be 2D (extracted from expanded 3D)
|
# Should be 2D (extracted from expanded 3D)
|
||||||
assert mask.shape == (100, 100)
|
assert mask.shape == (100, 100)
|
||||||
|
|
||||||
|
|
||||||
def test_select_person_tensor_cpu_inputs(self) -> None:
|
def test_select_person_tensor_cpu_inputs(self) -> None:
|
||||||
"""Tensor-backed inputs (CPU) should work correctly."""
|
"""Tensor-backed inputs (CPU) should work correctly."""
|
||||||
boxes = torch.tensor([[10.0, 10.0, 50.0, 90.0]], dtype=torch.float32)
|
boxes = torch.tensor([[10.0, 10.0, 50.0, 90.0]], dtype=torch.float32)
|
||||||
@@ -356,7 +355,7 @@ class TestSelectPerson:
|
|||||||
result = select_person(results)
|
result = select_person(results)
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
mask, bbox, tid = result
|
mask, bbox, _, tid = result
|
||||||
assert mask.shape == (100, 100)
|
assert mask.shape == (100, 100)
|
||||||
assert bbox == (10, 10, 50, 90)
|
assert bbox == (10, 10, 50, 90)
|
||||||
assert tid == 42
|
assert tid == 42
|
||||||
@@ -372,7 +371,7 @@ class TestSelectPerson:
|
|||||||
result = select_person(results)
|
result = select_person(results)
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
mask, bbox, tid = result
|
mask, bbox, _, tid = result
|
||||||
assert mask.shape == (100, 100)
|
assert mask.shape == (100, 100)
|
||||||
assert bbox == (10, 10, 50, 90)
|
assert bbox == (10, 10, 50, 90)
|
||||||
assert tid == 42
|
assert tid == 42
|
||||||
@@ -394,6 +393,6 @@ class TestSelectPerson:
|
|||||||
result = select_person(results)
|
result = select_person(results)
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
_, bbox, tid = result
|
_, bbox, _, tid = result
|
||||||
assert bbox == (0, 0, 30, 30) # Largest box
|
assert bbox == (0, 0, 30, 30) # Largest box
|
||||||
assert tid == 2 # Corresponding track ID
|
assert tid == 2 # Corresponding track ID
|
||||||
|
|||||||
Reference in New Issue
Block a user