fix(demo): pace gait windows before buffering
Make the OpenGait-studio demo drop unpaced frames before they grow the silhouette window. Separate source-frame gap tracking from paced-frame stride tracking so runtime scheduling matches the documented demo-window-and-stride behavior. Add regressions for paced window growth and schedule-frame stride semantics.
This commit is contained in:
@@ -148,6 +148,7 @@ class ScoliosisPipeline:
|
|||||||
_visualizer: OpenCVVisualizer | None
|
_visualizer: OpenCVVisualizer | None
|
||||||
_last_viz_payload: _VizPayload | None
|
_last_viz_payload: _VizPayload | None
|
||||||
_frame_pacer: _FramePacer | None
|
_frame_pacer: _FramePacer | None
|
||||||
|
_paced_frame_idx: int
|
||||||
_visualizer_accepts_pose_data: bool | None
|
_visualizer_accepts_pose_data: bool | None
|
||||||
_visualizer_signature_owner: object | None
|
_visualizer_signature_owner: object | None
|
||||||
|
|
||||||
@@ -209,6 +210,7 @@ class ScoliosisPipeline:
|
|||||||
self._visualizer = None
|
self._visualizer = None
|
||||||
self._last_viz_payload = None
|
self._last_viz_payload = None
|
||||||
self._frame_pacer = _FramePacer(target_fps) if target_fps is not None else None
|
self._frame_pacer = _FramePacer(target_fps) if target_fps is not None else None
|
||||||
|
self._paced_frame_idx = -1
|
||||||
self._visualizer_accepts_pose_data = None
|
self._visualizer_accepts_pose_data = None
|
||||||
self._visualizer_signature_owner = None
|
self._visualizer_signature_owner = None
|
||||||
|
|
||||||
@@ -393,8 +395,6 @@ class ScoliosisPipeline:
|
|||||||
"confidence": None,
|
"confidence": None,
|
||||||
"pose": pose_data,
|
"pose": pose_data,
|
||||||
}
|
}
|
||||||
self._window.push(silhouette, frame_idx=frame_idx, track_id=track_id)
|
|
||||||
|
|
||||||
if self._frame_pacer is not None and not self._frame_pacer.should_emit(
|
if self._frame_pacer is not None and not self._frame_pacer.should_emit(
|
||||||
timestamp_ns
|
timestamp_ns
|
||||||
):
|
):
|
||||||
@@ -409,6 +409,13 @@ class ScoliosisPipeline:
|
|||||||
"confidence": None,
|
"confidence": None,
|
||||||
"pose": pose_data,
|
"pose": pose_data,
|
||||||
}
|
}
|
||||||
|
self._paced_frame_idx += 1
|
||||||
|
self._window.push(
|
||||||
|
silhouette,
|
||||||
|
frame_idx=frame_idx,
|
||||||
|
track_id=track_id,
|
||||||
|
schedule_frame_idx=self._paced_frame_idx,
|
||||||
|
)
|
||||||
segmentation_input = self._window.buffered_silhouettes
|
segmentation_input = self._window.buffered_silhouettes
|
||||||
|
|
||||||
if not self._window.should_classify():
|
if not self._window.should_classify():
|
||||||
|
|||||||
@@ -67,7 +67,8 @@ class SilhouetteWindow:
|
|||||||
stride: int
|
stride: int
|
||||||
gap_threshold: int
|
gap_threshold: int
|
||||||
_buffer: deque[Float[ndarray, "64 44"]]
|
_buffer: deque[Float[ndarray, "64 44"]]
|
||||||
_frame_indices: deque[int]
|
_source_frame_indices: deque[int]
|
||||||
|
_schedule_frame_indices: deque[int]
|
||||||
_track_id: int | None
|
_track_id: int | None
|
||||||
_last_classified_frame: int
|
_last_classified_frame: int
|
||||||
_frame_count: int
|
_frame_count: int
|
||||||
@@ -91,12 +92,20 @@ class SilhouetteWindow:
|
|||||||
|
|
||||||
# Bounded storage via deque
|
# Bounded storage via deque
|
||||||
self._buffer = deque(maxlen=window_size)
|
self._buffer = deque(maxlen=window_size)
|
||||||
self._frame_indices = deque(maxlen=window_size)
|
self._source_frame_indices = deque(maxlen=window_size)
|
||||||
|
self._schedule_frame_indices = deque(maxlen=window_size)
|
||||||
self._track_id = None
|
self._track_id = None
|
||||||
self._last_classified_frame = -1
|
self._last_classified_frame = -1
|
||||||
self._frame_count = 0
|
self._frame_count = 0
|
||||||
|
|
||||||
def push(self, sil: np.ndarray, frame_idx: int, track_id: int) -> None:
|
def push(
|
||||||
|
self,
|
||||||
|
sil: np.ndarray,
|
||||||
|
frame_idx: int,
|
||||||
|
track_id: int,
|
||||||
|
*,
|
||||||
|
schedule_frame_idx: int | None = None,
|
||||||
|
) -> None:
|
||||||
"""Push a new silhouette into the window.
|
"""Push a new silhouette into the window.
|
||||||
|
|
||||||
Automatically resets buffer on track ID change or frame gap
|
Automatically resets buffer on track ID change or frame gap
|
||||||
@@ -112,8 +121,8 @@ class SilhouetteWindow:
|
|||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
# Check for frame gap
|
# Check for frame gap
|
||||||
if self._frame_indices:
|
if self._source_frame_indices:
|
||||||
last_frame = self._frame_indices[-1]
|
last_frame = self._source_frame_indices[-1]
|
||||||
gap = frame_idx - last_frame
|
gap = frame_idx - last_frame
|
||||||
if gap > self.gap_threshold:
|
if gap > self.gap_threshold:
|
||||||
self.reset()
|
self.reset()
|
||||||
@@ -129,7 +138,10 @@ class SilhouetteWindow:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self._buffer.append(sil_array)
|
self._buffer.append(sil_array)
|
||||||
self._frame_indices.append(frame_idx)
|
self._source_frame_indices.append(frame_idx)
|
||||||
|
self._schedule_frame_indices.append(
|
||||||
|
frame_idx if schedule_frame_idx is None else schedule_frame_idx
|
||||||
|
)
|
||||||
self._frame_count += 1
|
self._frame_count += 1
|
||||||
|
|
||||||
def is_ready(self) -> bool:
|
def is_ready(self) -> bool:
|
||||||
@@ -152,7 +164,7 @@ class SilhouetteWindow:
|
|||||||
if self._last_classified_frame < 0:
|
if self._last_classified_frame < 0:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
current_frame = self._frame_indices[-1]
|
current_frame = self._schedule_frame_indices[-1]
|
||||||
frames_since = current_frame - self._last_classified_frame
|
frames_since = current_frame - self._last_classified_frame
|
||||||
return frames_since >= self.stride
|
return frames_since >= self.stride
|
||||||
|
|
||||||
@@ -185,15 +197,16 @@ class SilhouetteWindow:
|
|||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
"""Reset the window, clearing all buffers and counters."""
|
"""Reset the window, clearing all buffers and counters."""
|
||||||
self._buffer.clear()
|
self._buffer.clear()
|
||||||
self._frame_indices.clear()
|
self._source_frame_indices.clear()
|
||||||
|
self._schedule_frame_indices.clear()
|
||||||
self._track_id = None
|
self._track_id = None
|
||||||
self._last_classified_frame = -1
|
self._last_classified_frame = -1
|
||||||
self._frame_count = 0
|
self._frame_count = 0
|
||||||
|
|
||||||
def mark_classified(self) -> None:
|
def mark_classified(self) -> None:
|
||||||
"""Mark current frame as classified, updating stride tracking."""
|
"""Mark current frame as classified, updating stride tracking."""
|
||||||
if self._frame_indices:
|
if self._schedule_frame_indices:
|
||||||
self._last_classified_frame = self._frame_indices[-1]
|
self._last_classified_frame = self._schedule_frame_indices[-1]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_track_id(self) -> int | None:
|
def current_track_id(self) -> int | None:
|
||||||
@@ -212,9 +225,9 @@ class SilhouetteWindow:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def window_start_frame(self) -> int:
|
def window_start_frame(self) -> int:
|
||||||
if not self._frame_indices:
|
if not self._source_frame_indices:
|
||||||
raise ValueError("Window is empty")
|
raise ValueError("Window is empty")
|
||||||
return int(self._frame_indices[0])
|
return int(self._source_frame_indices[0])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def buffered_silhouettes(self) -> Float[ndarray, "n 64 44"]:
|
def buffered_silhouettes(self) -> Float[ndarray, "n 64 44"]:
|
||||||
|
|||||||
@@ -927,6 +927,88 @@ def test_frame_pacer_emission_count_24_to_15() -> None:
|
|||||||
assert 60 <= emitted <= 65
|
assert 60 <= emitted <= 65
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipeline_pacing_skips_window_growth_until_emitted() -> None:
|
||||||
|
from opengait_studio.pipeline import ScoliosisPipeline
|
||||||
|
|
||||||
|
with (
|
||||||
|
mock.patch("opengait_studio.pipeline.YOLO") as mock_yolo,
|
||||||
|
mock.patch("opengait_studio.pipeline.create_source") as mock_source,
|
||||||
|
mock.patch("opengait_studio.pipeline.create_publisher") as mock_publisher,
|
||||||
|
mock.patch("opengait_studio.pipeline.ScoNetDemo") as mock_classifier,
|
||||||
|
mock.patch("opengait_studio.pipeline.select_person") as mock_select_person,
|
||||||
|
mock.patch("opengait_studio.pipeline.mask_to_silhouette") as mock_mask_to_sil,
|
||||||
|
):
|
||||||
|
mock_detector = mock.MagicMock()
|
||||||
|
mock_box = mock.MagicMock()
|
||||||
|
mock_box.xyxy = np.array([[100, 100, 200, 300]], dtype=np.float32)
|
||||||
|
mock_box.id = np.array([1], dtype=np.int64)
|
||||||
|
mock_mask = mock.MagicMock()
|
||||||
|
mock_mask.data = np.random.rand(1, 480, 640).astype(np.float32)
|
||||||
|
mock_result = mock.MagicMock()
|
||||||
|
mock_result.boxes = mock_box
|
||||||
|
mock_result.masks = mock_mask
|
||||||
|
mock_detector.track.return_value = [mock_result]
|
||||||
|
mock_yolo.return_value = mock_detector
|
||||||
|
mock_source.return_value = []
|
||||||
|
mock_publisher.return_value = mock.MagicMock()
|
||||||
|
|
||||||
|
mock_model = mock.MagicMock()
|
||||||
|
mock_model.predict.return_value = ("neutral", 0.7)
|
||||||
|
mock_classifier.return_value = mock_model
|
||||||
|
|
||||||
|
dummy_mask = np.random.randint(0, 256, (480, 640), dtype=np.uint8)
|
||||||
|
dummy_bbox_mask = (100, 100, 200, 300)
|
||||||
|
dummy_bbox_frame = (100, 100, 200, 300)
|
||||||
|
dummy_silhouette = np.random.rand(64, 44).astype(np.float32)
|
||||||
|
mock_select_person.return_value = (
|
||||||
|
dummy_mask,
|
||||||
|
dummy_bbox_mask,
|
||||||
|
dummy_bbox_frame,
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
mock_mask_to_sil.return_value = dummy_silhouette
|
||||||
|
|
||||||
|
pipeline = ScoliosisPipeline(
|
||||||
|
source="dummy.mp4",
|
||||||
|
checkpoint="dummy.pt",
|
||||||
|
config=str(CONFIG_PATH) if CONFIG_PATH.exists() else "dummy.yaml",
|
||||||
|
device="cpu",
|
||||||
|
yolo_model="dummy.pt",
|
||||||
|
window=2,
|
||||||
|
stride=1,
|
||||||
|
nats_url=None,
|
||||||
|
nats_subject="test",
|
||||||
|
max_frames=None,
|
||||||
|
target_fps=15.0,
|
||||||
|
)
|
||||||
|
frame = np.zeros((480, 640, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
first = pipeline.process_frame(
|
||||||
|
frame,
|
||||||
|
{"frame_count": 0, "timestamp_ns": 1_000_000_000},
|
||||||
|
)
|
||||||
|
second = pipeline.process_frame(
|
||||||
|
frame,
|
||||||
|
{"frame_count": 1, "timestamp_ns": 1_033_000_000},
|
||||||
|
)
|
||||||
|
third = pipeline.process_frame(
|
||||||
|
frame,
|
||||||
|
{"frame_count": 2, "timestamp_ns": 1_067_000_000},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert first is not None
|
||||||
|
assert second is not None
|
||||||
|
assert third is not None
|
||||||
|
assert first["segmentation_input"] is not None
|
||||||
|
assert second["segmentation_input"] is not None
|
||||||
|
assert third["segmentation_input"] is not None
|
||||||
|
assert first["segmentation_input"].shape[0] == 1
|
||||||
|
assert second["segmentation_input"].shape[0] == 1
|
||||||
|
assert second["label"] is None
|
||||||
|
assert third["segmentation_input"].shape[0] == 2
|
||||||
|
assert third["label"] == "neutral"
|
||||||
|
|
||||||
|
|
||||||
def test_frame_pacer_requires_positive_target_fps() -> None:
|
def test_frame_pacer_requires_positive_target_fps() -> None:
|
||||||
from opengait_studio.pipeline import _FramePacer
|
from opengait_studio.pipeline import _FramePacer
|
||||||
|
|
||||||
|
|||||||
@@ -144,6 +144,21 @@ class TestSilhouetteWindow:
|
|||||||
window.push(sil, frame_idx=7, track_id=1)
|
window.push(sil, frame_idx=7, track_id=1)
|
||||||
assert window.should_classify()
|
assert window.should_classify()
|
||||||
|
|
||||||
|
def test_should_classify_uses_schedule_frame_idx(self) -> None:
|
||||||
|
"""Stride should be measured in paced/scheduled frames, not source frames."""
|
||||||
|
window = SilhouetteWindow(window_size=2, stride=2, gap_threshold=50)
|
||||||
|
sil = np.ones((64, 44), dtype=np.float32)
|
||||||
|
|
||||||
|
window.push(sil, frame_idx=0, track_id=1, schedule_frame_idx=0)
|
||||||
|
window.push(sil, frame_idx=10, track_id=1, schedule_frame_idx=1)
|
||||||
|
|
||||||
|
assert window.should_classify()
|
||||||
|
|
||||||
|
window.mark_classified()
|
||||||
|
window.push(sil, frame_idx=20, track_id=1, schedule_frame_idx=2)
|
||||||
|
|
||||||
|
assert not window.should_classify()
|
||||||
|
|
||||||
def test_should_classify_not_ready(self) -> None:
|
def test_should_classify_not_ready(self) -> None:
|
||||||
"""should_classify should return False when window not ready."""
|
"""should_classify should return False when window not ready."""
|
||||||
window = SilhouetteWindow(window_size=5, stride=1, gap_threshold=10)
|
window = SilhouetteWindow(window_size=5, stride=1, gap_threshold=10)
|
||||||
|
|||||||
Reference in New Issue
Block a user