fix(detection): preserve offline frames in the runner

Add an explicit source delivery policy to the detection pipeline so offline and realtime sources can be handled differently without splitting the runner.

Blocking sources now backpressure ingestion until their pending frame is drained, which preserves every offline video frame even when inference is slower than decode. Latest-only sources keep the previous overwrite behavior for realtime feeds such as cvmmap.

The tests now cover both policies: offline sources retain ordered frame delivery under slow inference, while latest-only sources still drop intermediate frames as intended.
This commit is contained in:
2026-03-27 12:02:27 +08:00
parent 061d5b4592
commit 481f6160ce
7 changed files with 137 additions and 11 deletions
+8 -1
View File
@@ -3,11 +3,18 @@ from typing import Protocol
import numpy as np import numpy as np
from pose_tracking_exp.schema.detection import BoxDetections, PoseBatchRequest, PoseDetections, SourceFrame from pose_tracking_exp.schema.detection import (
BoxDetections,
PoseBatchRequest,
PoseDetections,
SourceDeliveryPolicy,
SourceFrame,
)
class FrameSource(Protocol): class FrameSource(Protocol):
source_name: str source_name: str
delivery_policy: SourceDeliveryPolicy
def frames(self) -> AsyncIterator[SourceFrame]: def frames(self) -> AsyncIterator[SourceFrame]:
... ...
+19 -2
View File
@@ -7,7 +7,7 @@ from loguru import logger
from pose_tracking_exp.detection.config import DetectionRunnerConfig from pose_tracking_exp.detection.config import DetectionRunnerConfig
from pose_tracking_exp.detection.protocols import FrameSource, PoseShim, PoseSink from pose_tracking_exp.detection.protocols import FrameSource, PoseShim, PoseSink
from pose_tracking_exp.schema.detection import SourceFrame from pose_tracking_exp.schema.detection import SourceDeliveryPolicy, SourceFrame
PERFORMANCE_WINDOW = 60 PERFORMANCE_WINDOW = 60
@@ -21,6 +21,7 @@ class PendingFrame:
@dataclass(slots=True) @dataclass(slots=True)
class SourceSlot: class SourceSlot:
source_name: str source_name: str
delivery_policy: SourceDeliveryPolicy
pending_frame: PendingFrame | None = None pending_frame: PendingFrame | None = None
last_seen_frame_index: int | None = None last_seen_frame_index: int | None = None
received_frames: int = 0 received_frames: int = 0
@@ -31,6 +32,8 @@ class SourceSlot:
def store_latest_frame(slot: SourceSlot, frame: SourceFrame) -> None: def store_latest_frame(slot: SourceSlot, frame: SourceFrame) -> None:
if slot.delivery_policy != "latest_only":
raise ValueError("store_latest_frame is only valid for latest_only sources.")
slot.received_frames += 1 slot.received_frames += 1
if slot.pending_frame is not None: if slot.pending_frame is not None:
slot.dropped_frames += 1 slot.dropped_frames += 1
@@ -92,7 +95,11 @@ async def run_detection_runner(
batch_size_sma = SimpleMovingAverage(PERFORMANCE_WINDOW) batch_size_sma = SimpleMovingAverage(PERFORMANCE_WINDOW)
scheduler_condition = anyio.Condition() scheduler_condition = anyio.Condition()
slots = { slots = {
source.source_name: SourceSlot(source_name=source.source_name) for source in sources source.source_name: SourceSlot(
source_name=source.source_name,
delivery_policy=source.delivery_policy,
)
for source in sources
} }
inference_limiter = anyio.CapacityLimiter(1) inference_limiter = anyio.CapacityLimiter(1)
@@ -111,6 +118,15 @@ async def run_detection_runner(
previous_frame_index = slot.last_seen_frame_index previous_frame_index = slot.last_seen_frame_index
should_log_init = previous_frame_index is None should_log_init = previous_frame_index is None
slot.last_seen_frame_index = frame.frame_index slot.last_seen_frame_index = frame.frame_index
if slot.delivery_policy == "block":
while slot.pending_frame is not None:
await scheduler_condition.wait()
slot.received_frames += 1
slot.pending_frame = PendingFrame(
source_name=slot.source_name,
frame=frame,
)
else:
store_latest_frame(slot, frame) store_latest_frame(slot, frame)
scheduler_condition.notify_all() scheduler_condition.notify_all()
@@ -156,6 +172,7 @@ async def run_detection_runner(
await scheduler_condition.wait() await scheduler_condition.wait()
batch = take_pending_batch(slots, config.max_batch_frames) batch = take_pending_batch(slots, config.max_batch_frames)
scheduler_condition.notify_all()
start = perf_counter() start = perf_counter()
pose_infos = await to_thread_run_sync( pose_infos = await to_thread_run_sync(
@@ -3,11 +3,12 @@ from typing import Protocol
from anyio.to_thread import run_sync as to_thread_run_sync from anyio.to_thread import run_sync as to_thread_run_sync
from pose_tracking_exp.schema.detection import SourceFrame from pose_tracking_exp.schema.detection import SourceDeliveryPolicy, SourceFrame
class BlockingFrameProducer(Protocol): class BlockingFrameProducer(Protocol):
source_name: str source_name: str
delivery_policy: SourceDeliveryPolicy
def iter_frames(self) -> Iterator[SourceFrame]: def iter_frames(self) -> Iterator[SourceFrame]:
... ...
@@ -21,9 +22,11 @@ class IteratorFrameSource:
def __init__( def __init__(
self, self,
source_name: str, source_name: str,
delivery_policy: SourceDeliveryPolicy,
iterator_factory: Callable[[], Iterator[SourceFrame]], iterator_factory: Callable[[], Iterator[SourceFrame]],
) -> None: ) -> None:
self.source_name = source_name self.source_name = source_name
self.delivery_policy = delivery_policy
self._iterator_factory = iterator_factory self._iterator_factory = iterator_factory
async def frames(self) -> AsyncIterator[SourceFrame]: async def frames(self) -> AsyncIterator[SourceFrame]:
@@ -43,5 +46,6 @@ class IteratorFrameSource:
def wrap_blocking_source(producer: BlockingFrameProducer) -> IteratorFrameSource: def wrap_blocking_source(producer: BlockingFrameProducer) -> IteratorFrameSource:
return IteratorFrameSource( return IteratorFrameSource(
source_name=producer.source_name, source_name=producer.source_name,
delivery_policy=producer.delivery_policy,
iterator_factory=producer.iter_frames, iterator_factory=producer.iter_frames,
) )
@@ -2,10 +2,12 @@ from collections.abc import AsyncIterator
import numpy as np import numpy as np
from pose_tracking_exp.schema.detection import SourceFrame from pose_tracking_exp.schema.detection import SourceDeliveryPolicy, SourceFrame
class CvmmapFrameSource: class CvmmapFrameSource:
delivery_policy: SourceDeliveryPolicy = "latest_only"
def __init__(self, source_name: str) -> None: def __init__(self, source_name: str) -> None:
self.source_name = source_name self.source_name = source_name
@@ -6,7 +6,7 @@ import cv2
import numpy as np import numpy as np
from pose_tracking_exp.detection.sources.adapters import wrap_blocking_source from pose_tracking_exp.detection.sources.adapters import wrap_blocking_source
from pose_tracking_exp.schema.detection import SourceFrame from pose_tracking_exp.schema.detection import SourceDeliveryPolicy, SourceFrame
_DEFAULT_VIDEO_FPS = 30.0 _DEFAULT_VIDEO_FPS = 30.0
@@ -33,6 +33,8 @@ def parse_video_input_specs(specs: Sequence[str]) -> tuple[tuple[str, Path], ...
class VideoFrameSource: class VideoFrameSource:
delivery_policy: SourceDeliveryPolicy = "block"
def __init__( def __init__(
self, self,
video_path: Path, video_path: Path,
@@ -15,6 +15,7 @@ from typing import Literal
import numpy as np import numpy as np
CocoKeypointSchema = Literal["coco17", "coco_wholebody133"] CocoKeypointSchema = Literal["coco17", "coco_wholebody133"]
SourceDeliveryPolicy = Literal["block", "latest_only"]
def expected_keypoint_count(schema: CocoKeypointSchema) -> int: def expected_keypoint_count(schema: CocoKeypointSchema) -> int:
+97 -4
View File
@@ -1,5 +1,7 @@
from collections.abc import AsyncIterator, Sequence from collections.abc import AsyncIterator, Sequence
from pathlib import Path from pathlib import Path
import time
from typing import cast
import anyio import anyio
import numpy as np import numpy as np
@@ -17,7 +19,8 @@ from pose_tracking_exp.detection.runner import (
store_latest_frame, store_latest_frame,
take_pending_batch, take_pending_batch,
) )
from pose_tracking_exp.schema.detection import PoseDetections, SourceFrame from pose_tracking_exp.detection.protocols import FrameSource
from pose_tracking_exp.schema.detection import PoseDetections, SourceDeliveryPolicy, SourceFrame
def test_load_detection_runner_config_from_toml_and_env( def test_load_detection_runner_config_from_toml_and_env(
@@ -61,7 +64,7 @@ def test_resolve_instances_falls_back_to_config_values() -> None:
def test_store_latest_frame_overwrites_pending_frame() -> None: def test_store_latest_frame_overwrites_pending_frame() -> None:
slot = SourceSlot(source_name="front_left") slot = SourceSlot(source_name="front_left", delivery_policy="latest_only")
first = SourceFrame( first = SourceFrame(
source_name="front_left", source_name="front_left",
image_bgr=np.zeros((1, 1, 3), dtype=np.uint8), image_bgr=np.zeros((1, 1, 3), dtype=np.uint8),
@@ -88,6 +91,7 @@ def test_take_pending_batch_collects_at_most_one_frame_per_source() -> None:
slots = { slots = {
"front_left": SourceSlot( "front_left": SourceSlot(
source_name="front_left", source_name="front_left",
delivery_policy="latest_only",
pending_frame=PendingFrame( pending_frame=PendingFrame(
source_name="front_left", source_name="front_left",
frame=SourceFrame( frame=SourceFrame(
@@ -100,6 +104,7 @@ def test_take_pending_batch_collects_at_most_one_frame_per_source() -> None:
), ),
"front_right": SourceSlot( "front_right": SourceSlot(
source_name="front_right", source_name="front_right",
delivery_policy="latest_only",
pending_frame=PendingFrame( pending_frame=PendingFrame(
source_name="front_right", source_name="front_right",
frame=SourceFrame( frame=SourceFrame(
@@ -112,6 +117,7 @@ def test_take_pending_batch_collects_at_most_one_frame_per_source() -> None:
), ),
"rear": SourceSlot( "rear": SourceSlot(
source_name="rear", source_name="rear",
delivery_policy="latest_only",
pending_frame=PendingFrame( pending_frame=PendingFrame(
source_name="rear", source_name="rear",
frame=SourceFrame( frame=SourceFrame(
@@ -133,9 +139,16 @@ def test_take_pending_batch_collects_at_most_one_frame_per_source() -> None:
class StubSource: class StubSource:
def __init__(self, source_name: str, frames: tuple[SourceFrame, ...]) -> None: def __init__(
self,
source_name: str,
frames: tuple[SourceFrame, ...],
*,
delivery_policy: SourceDeliveryPolicy = "latest_only",
) -> None:
self.source_name = source_name self.source_name = source_name
self._frames = frames self._frames = frames
self.delivery_policy = delivery_policy
async def frames(self) -> AsyncIterator[SourceFrame]: async def frames(self) -> AsyncIterator[SourceFrame]:
for frame in self._frames: for frame in self._frames:
@@ -143,7 +156,12 @@ class StubSource:
class StubPoseShim: class StubPoseShim:
def __init__(self, delay_seconds: float = 0.0) -> None:
self._delay_seconds = delay_seconds
def process_many(self, frames: Sequence[SourceFrame]) -> list[PoseDetections]: def process_many(self, frames: Sequence[SourceFrame]) -> list[PoseDetections]:
if self._delay_seconds > 0.0:
time.sleep(self._delay_seconds)
detections: list[PoseDetections] = [] detections: list[PoseDetections] = []
for frame in frames: for frame in frames:
detections.append( detections.append(
@@ -187,6 +205,7 @@ def test_run_detection_runner_publishes_payloads() -> None:
timestamp_unix_ns=100, timestamp_unix_ns=100,
), ),
), ),
delivery_policy="block",
), ),
StubSource( StubSource(
"cam1", "cam1",
@@ -198,6 +217,7 @@ def test_run_detection_runner_publishes_payloads() -> None:
timestamp_unix_ns=200, timestamp_unix_ns=200,
), ),
), ),
delivery_policy="block",
), ),
) )
config = DetectionRunnerConfig( config = DetectionRunnerConfig(
@@ -210,7 +230,7 @@ def test_run_detection_runner_publishes_payloads() -> None:
anyio.run( anyio.run(
run_detection_runner, run_detection_runner,
sources, cast(tuple[FrameSource, ...], sources),
StubPoseShim(), StubPoseShim(),
sink, sink,
config, config,
@@ -221,3 +241,76 @@ def test_run_detection_runner_publishes_payloads() -> None:
("cam0", 1, 100), ("cam0", 1, 100),
("cam1", 2, 200), ("cam1", 2, 200),
] ]
def test_run_detection_runner_blocks_to_preserve_offline_frames() -> None:
sink = StubSink()
source = StubSource(
"cam0",
tuple(
SourceFrame(
source_name="cam0",
image_bgr=np.zeros((2, 3, 3), dtype=np.uint8),
frame_index=frame_index,
timestamp_unix_ns=frame_index * 100,
)
for frame_index in range(3)
),
delivery_policy="block",
)
config = DetectionRunnerConfig(
instances=("cam0",),
pose_config_path=Path(__file__),
yolo_checkpoint=Path(__file__),
pose_checkpoint=Path(__file__),
max_batch_frames=1,
max_batch_wait_ms=0,
)
anyio.run(
run_detection_runner,
cast(tuple[FrameSource, ...], (source,)),
StubPoseShim(delay_seconds=0.01),
sink,
config,
)
assert [item.frame_index for item in sink.messages] == [0, 1, 2]
def test_run_detection_runner_drops_intermediate_latest_only_frames() -> None:
sink = StubSink()
source = StubSource(
"cam0",
tuple(
SourceFrame(
source_name="cam0",
image_bgr=np.zeros((2, 3, 3), dtype=np.uint8),
frame_index=frame_index,
timestamp_unix_ns=frame_index * 100,
)
for frame_index in range(3)
),
delivery_policy="latest_only",
)
config = DetectionRunnerConfig(
instances=("cam0",),
pose_config_path=Path(__file__),
yolo_checkpoint=Path(__file__),
pose_checkpoint=Path(__file__),
max_batch_frames=1,
max_batch_wait_ms=0,
)
anyio.run(
run_detection_runner,
cast(tuple[FrameSource, ...], (source,)),
StubPoseShim(delay_seconds=0.01),
sink,
config,
)
processed = [item.frame_index for item in sink.messages]
assert processed[-1] == 2
assert processed != [0, 1, 2]
assert 0 not in processed