From 481f6160ce2315dffc48c828162ea99ba5d2e45b Mon Sep 17 00:00:00 2001 From: crosstyan Date: Fri, 27 Mar 2026 12:02:27 +0800 Subject: [PATCH] fix(detection): preserve offline frames in the runner Add an explicit source delivery policy to the detection pipeline so offline and realtime sources can be handled differently without splitting the runner. Blocking sources now backpressure ingestion until their pending frame is drained, which preserves every offline video frame even when inference is slower than decode. Latest-only sources keep the previous overwrite behavior for realtime feeds such as cvmmap. The tests now cover both policies: offline sources retain ordered frame delivery under slow inference, while latest-only sources still drop intermediate frames as intended. --- src/pose_tracking_exp/detection/protocols.py | 9 +- src/pose_tracking_exp/detection/runner.py | 23 +++- .../detection/sources/adapters.py | 6 +- .../detection/sources/cvmmap.py | 4 +- .../detection/sources/video.py | 4 +- src/pose_tracking_exp/schema/detection.py | 1 + tests/test_detection_runner.py | 101 +++++++++++++++++- 7 files changed, 137 insertions(+), 11 deletions(-) diff --git a/src/pose_tracking_exp/detection/protocols.py b/src/pose_tracking_exp/detection/protocols.py index c9a3eac..1078b39 100644 --- a/src/pose_tracking_exp/detection/protocols.py +++ b/src/pose_tracking_exp/detection/protocols.py @@ -3,11 +3,18 @@ from typing import Protocol import numpy as np -from pose_tracking_exp.schema.detection import BoxDetections, PoseBatchRequest, PoseDetections, SourceFrame +from pose_tracking_exp.schema.detection import ( + BoxDetections, + PoseBatchRequest, + PoseDetections, + SourceDeliveryPolicy, + SourceFrame, +) class FrameSource(Protocol): source_name: str + delivery_policy: SourceDeliveryPolicy def frames(self) -> AsyncIterator[SourceFrame]: ... diff --git a/src/pose_tracking_exp/detection/runner.py b/src/pose_tracking_exp/detection/runner.py index 66752dd..9548f82 100644 --- a/src/pose_tracking_exp/detection/runner.py +++ b/src/pose_tracking_exp/detection/runner.py @@ -7,7 +7,7 @@ from loguru import logger from pose_tracking_exp.detection.config import DetectionRunnerConfig from pose_tracking_exp.detection.protocols import FrameSource, PoseShim, PoseSink -from pose_tracking_exp.schema.detection import SourceFrame +from pose_tracking_exp.schema.detection import SourceDeliveryPolicy, SourceFrame PERFORMANCE_WINDOW = 60 @@ -21,6 +21,7 @@ class PendingFrame: @dataclass(slots=True) class SourceSlot: source_name: str + delivery_policy: SourceDeliveryPolicy pending_frame: PendingFrame | None = None last_seen_frame_index: int | None = None received_frames: int = 0 @@ -31,6 +32,8 @@ class SourceSlot: def store_latest_frame(slot: SourceSlot, frame: SourceFrame) -> None: + if slot.delivery_policy != "latest_only": + raise ValueError("store_latest_frame is only valid for latest_only sources.") slot.received_frames += 1 if slot.pending_frame is not None: slot.dropped_frames += 1 @@ -92,7 +95,11 @@ async def run_detection_runner( batch_size_sma = SimpleMovingAverage(PERFORMANCE_WINDOW) scheduler_condition = anyio.Condition() slots = { - source.source_name: SourceSlot(source_name=source.source_name) for source in sources + source.source_name: SourceSlot( + source_name=source.source_name, + delivery_policy=source.delivery_policy, + ) + for source in sources } inference_limiter = anyio.CapacityLimiter(1) @@ -111,7 +118,16 @@ async def run_detection_runner( previous_frame_index = slot.last_seen_frame_index should_log_init = previous_frame_index is None slot.last_seen_frame_index = frame.frame_index - store_latest_frame(slot, frame) + if slot.delivery_policy == "block": + while slot.pending_frame is not None: + await scheduler_condition.wait() + slot.received_frames += 1 + slot.pending_frame = PendingFrame( + source_name=slot.source_name, + frame=frame, + ) + else: + store_latest_frame(slot, frame) scheduler_condition.notify_all() if should_log_init: @@ -156,6 +172,7 @@ async def run_detection_runner( await scheduler_condition.wait() batch = take_pending_batch(slots, config.max_batch_frames) + scheduler_condition.notify_all() start = perf_counter() pose_infos = await to_thread_run_sync( diff --git a/src/pose_tracking_exp/detection/sources/adapters.py b/src/pose_tracking_exp/detection/sources/adapters.py index 9c5c746..0caa2d4 100644 --- a/src/pose_tracking_exp/detection/sources/adapters.py +++ b/src/pose_tracking_exp/detection/sources/adapters.py @@ -3,11 +3,12 @@ from typing import Protocol from anyio.to_thread import run_sync as to_thread_run_sync -from pose_tracking_exp.schema.detection import SourceFrame +from pose_tracking_exp.schema.detection import SourceDeliveryPolicy, SourceFrame class BlockingFrameProducer(Protocol): source_name: str + delivery_policy: SourceDeliveryPolicy def iter_frames(self) -> Iterator[SourceFrame]: ... @@ -21,9 +22,11 @@ class IteratorFrameSource: def __init__( self, source_name: str, + delivery_policy: SourceDeliveryPolicy, iterator_factory: Callable[[], Iterator[SourceFrame]], ) -> None: self.source_name = source_name + self.delivery_policy = delivery_policy self._iterator_factory = iterator_factory async def frames(self) -> AsyncIterator[SourceFrame]: @@ -43,5 +46,6 @@ class IteratorFrameSource: def wrap_blocking_source(producer: BlockingFrameProducer) -> IteratorFrameSource: return IteratorFrameSource( source_name=producer.source_name, + delivery_policy=producer.delivery_policy, iterator_factory=producer.iter_frames, ) diff --git a/src/pose_tracking_exp/detection/sources/cvmmap.py b/src/pose_tracking_exp/detection/sources/cvmmap.py index 2fa81c7..7b67b11 100644 --- a/src/pose_tracking_exp/detection/sources/cvmmap.py +++ b/src/pose_tracking_exp/detection/sources/cvmmap.py @@ -2,10 +2,12 @@ from collections.abc import AsyncIterator import numpy as np -from pose_tracking_exp.schema.detection import SourceFrame +from pose_tracking_exp.schema.detection import SourceDeliveryPolicy, SourceFrame class CvmmapFrameSource: + delivery_policy: SourceDeliveryPolicy = "latest_only" + def __init__(self, source_name: str) -> None: self.source_name = source_name diff --git a/src/pose_tracking_exp/detection/sources/video.py b/src/pose_tracking_exp/detection/sources/video.py index 53d8753..eadac17 100644 --- a/src/pose_tracking_exp/detection/sources/video.py +++ b/src/pose_tracking_exp/detection/sources/video.py @@ -6,7 +6,7 @@ import cv2 import numpy as np from pose_tracking_exp.detection.sources.adapters import wrap_blocking_source -from pose_tracking_exp.schema.detection import SourceFrame +from pose_tracking_exp.schema.detection import SourceDeliveryPolicy, SourceFrame _DEFAULT_VIDEO_FPS = 30.0 @@ -33,6 +33,8 @@ def parse_video_input_specs(specs: Sequence[str]) -> tuple[tuple[str, Path], ... class VideoFrameSource: + delivery_policy: SourceDeliveryPolicy = "block" + def __init__( self, video_path: Path, diff --git a/src/pose_tracking_exp/schema/detection.py b/src/pose_tracking_exp/schema/detection.py index 7a8bee2..a230343 100644 --- a/src/pose_tracking_exp/schema/detection.py +++ b/src/pose_tracking_exp/schema/detection.py @@ -15,6 +15,7 @@ from typing import Literal import numpy as np CocoKeypointSchema = Literal["coco17", "coco_wholebody133"] +SourceDeliveryPolicy = Literal["block", "latest_only"] def expected_keypoint_count(schema: CocoKeypointSchema) -> int: diff --git a/tests/test_detection_runner.py b/tests/test_detection_runner.py index 37d1e26..9582a6e 100644 --- a/tests/test_detection_runner.py +++ b/tests/test_detection_runner.py @@ -1,5 +1,7 @@ from collections.abc import AsyncIterator, Sequence from pathlib import Path +import time +from typing import cast import anyio import numpy as np @@ -17,7 +19,8 @@ from pose_tracking_exp.detection.runner import ( store_latest_frame, take_pending_batch, ) -from pose_tracking_exp.schema.detection import PoseDetections, SourceFrame +from pose_tracking_exp.detection.protocols import FrameSource +from pose_tracking_exp.schema.detection import PoseDetections, SourceDeliveryPolicy, SourceFrame def test_load_detection_runner_config_from_toml_and_env( @@ -61,7 +64,7 @@ def test_resolve_instances_falls_back_to_config_values() -> None: def test_store_latest_frame_overwrites_pending_frame() -> None: - slot = SourceSlot(source_name="front_left") + slot = SourceSlot(source_name="front_left", delivery_policy="latest_only") first = SourceFrame( source_name="front_left", image_bgr=np.zeros((1, 1, 3), dtype=np.uint8), @@ -88,6 +91,7 @@ def test_take_pending_batch_collects_at_most_one_frame_per_source() -> None: slots = { "front_left": SourceSlot( source_name="front_left", + delivery_policy="latest_only", pending_frame=PendingFrame( source_name="front_left", frame=SourceFrame( @@ -100,6 +104,7 @@ def test_take_pending_batch_collects_at_most_one_frame_per_source() -> None: ), "front_right": SourceSlot( source_name="front_right", + delivery_policy="latest_only", pending_frame=PendingFrame( source_name="front_right", frame=SourceFrame( @@ -112,6 +117,7 @@ def test_take_pending_batch_collects_at_most_one_frame_per_source() -> None: ), "rear": SourceSlot( source_name="rear", + delivery_policy="latest_only", pending_frame=PendingFrame( source_name="rear", frame=SourceFrame( @@ -133,9 +139,16 @@ def test_take_pending_batch_collects_at_most_one_frame_per_source() -> None: class StubSource: - def __init__(self, source_name: str, frames: tuple[SourceFrame, ...]) -> None: + def __init__( + self, + source_name: str, + frames: tuple[SourceFrame, ...], + *, + delivery_policy: SourceDeliveryPolicy = "latest_only", + ) -> None: self.source_name = source_name self._frames = frames + self.delivery_policy = delivery_policy async def frames(self) -> AsyncIterator[SourceFrame]: for frame in self._frames: @@ -143,7 +156,12 @@ class StubSource: class StubPoseShim: + def __init__(self, delay_seconds: float = 0.0) -> None: + self._delay_seconds = delay_seconds + def process_many(self, frames: Sequence[SourceFrame]) -> list[PoseDetections]: + if self._delay_seconds > 0.0: + time.sleep(self._delay_seconds) detections: list[PoseDetections] = [] for frame in frames: detections.append( @@ -187,6 +205,7 @@ def test_run_detection_runner_publishes_payloads() -> None: timestamp_unix_ns=100, ), ), + delivery_policy="block", ), StubSource( "cam1", @@ -198,6 +217,7 @@ def test_run_detection_runner_publishes_payloads() -> None: timestamp_unix_ns=200, ), ), + delivery_policy="block", ), ) config = DetectionRunnerConfig( @@ -210,7 +230,7 @@ def test_run_detection_runner_publishes_payloads() -> None: anyio.run( run_detection_runner, - sources, + cast(tuple[FrameSource, ...], sources), StubPoseShim(), sink, config, @@ -221,3 +241,76 @@ def test_run_detection_runner_publishes_payloads() -> None: ("cam0", 1, 100), ("cam1", 2, 200), ] + + +def test_run_detection_runner_blocks_to_preserve_offline_frames() -> None: + sink = StubSink() + source = StubSource( + "cam0", + tuple( + SourceFrame( + source_name="cam0", + image_bgr=np.zeros((2, 3, 3), dtype=np.uint8), + frame_index=frame_index, + timestamp_unix_ns=frame_index * 100, + ) + for frame_index in range(3) + ), + delivery_policy="block", + ) + config = DetectionRunnerConfig( + instances=("cam0",), + pose_config_path=Path(__file__), + yolo_checkpoint=Path(__file__), + pose_checkpoint=Path(__file__), + max_batch_frames=1, + max_batch_wait_ms=0, + ) + + anyio.run( + run_detection_runner, + cast(tuple[FrameSource, ...], (source,)), + StubPoseShim(delay_seconds=0.01), + sink, + config, + ) + + assert [item.frame_index for item in sink.messages] == [0, 1, 2] + + +def test_run_detection_runner_drops_intermediate_latest_only_frames() -> None: + sink = StubSink() + source = StubSource( + "cam0", + tuple( + SourceFrame( + source_name="cam0", + image_bgr=np.zeros((2, 3, 3), dtype=np.uint8), + frame_index=frame_index, + timestamp_unix_ns=frame_index * 100, + ) + for frame_index in range(3) + ), + delivery_policy="latest_only", + ) + config = DetectionRunnerConfig( + instances=("cam0",), + pose_config_path=Path(__file__), + yolo_checkpoint=Path(__file__), + pose_checkpoint=Path(__file__), + max_batch_frames=1, + max_batch_wait_ms=0, + ) + + anyio.run( + run_detection_runner, + cast(tuple[FrameSource, ...], (source,)), + StubPoseShim(delay_seconds=0.01), + sink, + config, + ) + + processed = [item.frame_index for item in sink.messages] + assert processed[-1] == 2 + assert processed != [0, 1, 2] + assert 0 not in processed