diff --git a/src/pose_tracking_exp/detection/protocols.py b/src/pose_tracking_exp/detection/protocols.py index c9a3eac..1078b39 100644 --- a/src/pose_tracking_exp/detection/protocols.py +++ b/src/pose_tracking_exp/detection/protocols.py @@ -3,11 +3,18 @@ from typing import Protocol import numpy as np -from pose_tracking_exp.schema.detection import BoxDetections, PoseBatchRequest, PoseDetections, SourceFrame +from pose_tracking_exp.schema.detection import ( + BoxDetections, + PoseBatchRequest, + PoseDetections, + SourceDeliveryPolicy, + SourceFrame, +) class FrameSource(Protocol): source_name: str + delivery_policy: SourceDeliveryPolicy def frames(self) -> AsyncIterator[SourceFrame]: ... diff --git a/src/pose_tracking_exp/detection/runner.py b/src/pose_tracking_exp/detection/runner.py index 66752dd..9548f82 100644 --- a/src/pose_tracking_exp/detection/runner.py +++ b/src/pose_tracking_exp/detection/runner.py @@ -7,7 +7,7 @@ from loguru import logger from pose_tracking_exp.detection.config import DetectionRunnerConfig from pose_tracking_exp.detection.protocols import FrameSource, PoseShim, PoseSink -from pose_tracking_exp.schema.detection import SourceFrame +from pose_tracking_exp.schema.detection import SourceDeliveryPolicy, SourceFrame PERFORMANCE_WINDOW = 60 @@ -21,6 +21,7 @@ class PendingFrame: @dataclass(slots=True) class SourceSlot: source_name: str + delivery_policy: SourceDeliveryPolicy pending_frame: PendingFrame | None = None last_seen_frame_index: int | None = None received_frames: int = 0 @@ -31,6 +32,8 @@ class SourceSlot: def store_latest_frame(slot: SourceSlot, frame: SourceFrame) -> None: + if slot.delivery_policy != "latest_only": + raise ValueError("store_latest_frame is only valid for latest_only sources.") slot.received_frames += 1 if slot.pending_frame is not None: slot.dropped_frames += 1 @@ -92,7 +95,11 @@ async def run_detection_runner( batch_size_sma = SimpleMovingAverage(PERFORMANCE_WINDOW) scheduler_condition = anyio.Condition() slots = { - source.source_name: SourceSlot(source_name=source.source_name) for source in sources + source.source_name: SourceSlot( + source_name=source.source_name, + delivery_policy=source.delivery_policy, + ) + for source in sources } inference_limiter = anyio.CapacityLimiter(1) @@ -111,7 +118,16 @@ async def run_detection_runner( previous_frame_index = slot.last_seen_frame_index should_log_init = previous_frame_index is None slot.last_seen_frame_index = frame.frame_index - store_latest_frame(slot, frame) + if slot.delivery_policy == "block": + while slot.pending_frame is not None: + await scheduler_condition.wait() + slot.received_frames += 1 + slot.pending_frame = PendingFrame( + source_name=slot.source_name, + frame=frame, + ) + else: + store_latest_frame(slot, frame) scheduler_condition.notify_all() if should_log_init: @@ -156,6 +172,7 @@ async def run_detection_runner( await scheduler_condition.wait() batch = take_pending_batch(slots, config.max_batch_frames) + scheduler_condition.notify_all() start = perf_counter() pose_infos = await to_thread_run_sync( diff --git a/src/pose_tracking_exp/detection/sources/adapters.py b/src/pose_tracking_exp/detection/sources/adapters.py index 9c5c746..0caa2d4 100644 --- a/src/pose_tracking_exp/detection/sources/adapters.py +++ b/src/pose_tracking_exp/detection/sources/adapters.py @@ -3,11 +3,12 @@ from typing import Protocol from anyio.to_thread import run_sync as to_thread_run_sync -from pose_tracking_exp.schema.detection import SourceFrame +from pose_tracking_exp.schema.detection import SourceDeliveryPolicy, SourceFrame class BlockingFrameProducer(Protocol): source_name: str + delivery_policy: SourceDeliveryPolicy def iter_frames(self) -> Iterator[SourceFrame]: ... @@ -21,9 +22,11 @@ class IteratorFrameSource: def __init__( self, source_name: str, + delivery_policy: SourceDeliveryPolicy, iterator_factory: Callable[[], Iterator[SourceFrame]], ) -> None: self.source_name = source_name + self.delivery_policy = delivery_policy self._iterator_factory = iterator_factory async def frames(self) -> AsyncIterator[SourceFrame]: @@ -43,5 +46,6 @@ class IteratorFrameSource: def wrap_blocking_source(producer: BlockingFrameProducer) -> IteratorFrameSource: return IteratorFrameSource( source_name=producer.source_name, + delivery_policy=producer.delivery_policy, iterator_factory=producer.iter_frames, ) diff --git a/src/pose_tracking_exp/detection/sources/cvmmap.py b/src/pose_tracking_exp/detection/sources/cvmmap.py index 2fa81c7..7b67b11 100644 --- a/src/pose_tracking_exp/detection/sources/cvmmap.py +++ b/src/pose_tracking_exp/detection/sources/cvmmap.py @@ -2,10 +2,12 @@ from collections.abc import AsyncIterator import numpy as np -from pose_tracking_exp.schema.detection import SourceFrame +from pose_tracking_exp.schema.detection import SourceDeliveryPolicy, SourceFrame class CvmmapFrameSource: + delivery_policy: SourceDeliveryPolicy = "latest_only" + def __init__(self, source_name: str) -> None: self.source_name = source_name diff --git a/src/pose_tracking_exp/detection/sources/video.py b/src/pose_tracking_exp/detection/sources/video.py index 53d8753..eadac17 100644 --- a/src/pose_tracking_exp/detection/sources/video.py +++ b/src/pose_tracking_exp/detection/sources/video.py @@ -6,7 +6,7 @@ import cv2 import numpy as np from pose_tracking_exp.detection.sources.adapters import wrap_blocking_source -from pose_tracking_exp.schema.detection import SourceFrame +from pose_tracking_exp.schema.detection import SourceDeliveryPolicy, SourceFrame _DEFAULT_VIDEO_FPS = 30.0 @@ -33,6 +33,8 @@ def parse_video_input_specs(specs: Sequence[str]) -> tuple[tuple[str, Path], ... class VideoFrameSource: + delivery_policy: SourceDeliveryPolicy = "block" + def __init__( self, video_path: Path, diff --git a/src/pose_tracking_exp/schema/detection.py b/src/pose_tracking_exp/schema/detection.py index 7a8bee2..a230343 100644 --- a/src/pose_tracking_exp/schema/detection.py +++ b/src/pose_tracking_exp/schema/detection.py @@ -15,6 +15,7 @@ from typing import Literal import numpy as np CocoKeypointSchema = Literal["coco17", "coco_wholebody133"] +SourceDeliveryPolicy = Literal["block", "latest_only"] def expected_keypoint_count(schema: CocoKeypointSchema) -> int: diff --git a/tests/test_detection_runner.py b/tests/test_detection_runner.py index 37d1e26..9582a6e 100644 --- a/tests/test_detection_runner.py +++ b/tests/test_detection_runner.py @@ -1,5 +1,7 @@ from collections.abc import AsyncIterator, Sequence from pathlib import Path +import time +from typing import cast import anyio import numpy as np @@ -17,7 +19,8 @@ from pose_tracking_exp.detection.runner import ( store_latest_frame, take_pending_batch, ) -from pose_tracking_exp.schema.detection import PoseDetections, SourceFrame +from pose_tracking_exp.detection.protocols import FrameSource +from pose_tracking_exp.schema.detection import PoseDetections, SourceDeliveryPolicy, SourceFrame def test_load_detection_runner_config_from_toml_and_env( @@ -61,7 +64,7 @@ def test_resolve_instances_falls_back_to_config_values() -> None: def test_store_latest_frame_overwrites_pending_frame() -> None: - slot = SourceSlot(source_name="front_left") + slot = SourceSlot(source_name="front_left", delivery_policy="latest_only") first = SourceFrame( source_name="front_left", image_bgr=np.zeros((1, 1, 3), dtype=np.uint8), @@ -88,6 +91,7 @@ def test_take_pending_batch_collects_at_most_one_frame_per_source() -> None: slots = { "front_left": SourceSlot( source_name="front_left", + delivery_policy="latest_only", pending_frame=PendingFrame( source_name="front_left", frame=SourceFrame( @@ -100,6 +104,7 @@ def test_take_pending_batch_collects_at_most_one_frame_per_source() -> None: ), "front_right": SourceSlot( source_name="front_right", + delivery_policy="latest_only", pending_frame=PendingFrame( source_name="front_right", frame=SourceFrame( @@ -112,6 +117,7 @@ def test_take_pending_batch_collects_at_most_one_frame_per_source() -> None: ), "rear": SourceSlot( source_name="rear", + delivery_policy="latest_only", pending_frame=PendingFrame( source_name="rear", frame=SourceFrame( @@ -133,9 +139,16 @@ def test_take_pending_batch_collects_at_most_one_frame_per_source() -> None: class StubSource: - def __init__(self, source_name: str, frames: tuple[SourceFrame, ...]) -> None: + def __init__( + self, + source_name: str, + frames: tuple[SourceFrame, ...], + *, + delivery_policy: SourceDeliveryPolicy = "latest_only", + ) -> None: self.source_name = source_name self._frames = frames + self.delivery_policy = delivery_policy async def frames(self) -> AsyncIterator[SourceFrame]: for frame in self._frames: @@ -143,7 +156,12 @@ class StubSource: class StubPoseShim: + def __init__(self, delay_seconds: float = 0.0) -> None: + self._delay_seconds = delay_seconds + def process_many(self, frames: Sequence[SourceFrame]) -> list[PoseDetections]: + if self._delay_seconds > 0.0: + time.sleep(self._delay_seconds) detections: list[PoseDetections] = [] for frame in frames: detections.append( @@ -187,6 +205,7 @@ def test_run_detection_runner_publishes_payloads() -> None: timestamp_unix_ns=100, ), ), + delivery_policy="block", ), StubSource( "cam1", @@ -198,6 +217,7 @@ def test_run_detection_runner_publishes_payloads() -> None: timestamp_unix_ns=200, ), ), + delivery_policy="block", ), ) config = DetectionRunnerConfig( @@ -210,7 +230,7 @@ def test_run_detection_runner_publishes_payloads() -> None: anyio.run( run_detection_runner, - sources, + cast(tuple[FrameSource, ...], sources), StubPoseShim(), sink, config, @@ -221,3 +241,76 @@ def test_run_detection_runner_publishes_payloads() -> None: ("cam0", 1, 100), ("cam1", 2, 200), ] + + +def test_run_detection_runner_blocks_to_preserve_offline_frames() -> None: + sink = StubSink() + source = StubSource( + "cam0", + tuple( + SourceFrame( + source_name="cam0", + image_bgr=np.zeros((2, 3, 3), dtype=np.uint8), + frame_index=frame_index, + timestamp_unix_ns=frame_index * 100, + ) + for frame_index in range(3) + ), + delivery_policy="block", + ) + config = DetectionRunnerConfig( + instances=("cam0",), + pose_config_path=Path(__file__), + yolo_checkpoint=Path(__file__), + pose_checkpoint=Path(__file__), + max_batch_frames=1, + max_batch_wait_ms=0, + ) + + anyio.run( + run_detection_runner, + cast(tuple[FrameSource, ...], (source,)), + StubPoseShim(delay_seconds=0.01), + sink, + config, + ) + + assert [item.frame_index for item in sink.messages] == [0, 1, 2] + + +def test_run_detection_runner_drops_intermediate_latest_only_frames() -> None: + sink = StubSink() + source = StubSource( + "cam0", + tuple( + SourceFrame( + source_name="cam0", + image_bgr=np.zeros((2, 3, 3), dtype=np.uint8), + frame_index=frame_index, + timestamp_unix_ns=frame_index * 100, + ) + for frame_index in range(3) + ), + delivery_policy="latest_only", + ) + config = DetectionRunnerConfig( + instances=("cam0",), + pose_config_path=Path(__file__), + yolo_checkpoint=Path(__file__), + pose_checkpoint=Path(__file__), + max_batch_frames=1, + max_batch_wait_ms=0, + ) + + anyio.run( + run_detection_runner, + cast(tuple[FrameSource, ...], (source,)), + StubPoseShim(delay_seconds=0.01), + sink, + config, + ) + + processed = [item.frame_index for item in sink.messages] + assert processed[-1] == 2 + assert processed != [0, 1, 2] + assert 0 not in processed