fix(detection): preserve offline frames in the runner

Add an explicit source delivery policy to the detection pipeline so offline and realtime sources can be handled differently without splitting the runner.

Blocking sources now backpressure ingestion until their pending frame is drained, which preserves every offline video frame even when inference is slower than decode. Latest-only sources keep the previous overwrite behavior for realtime feeds such as cvmmap.

The tests now cover both policies: offline sources retain ordered frame delivery under slow inference, while latest-only sources still drop intermediate frames as intended.
This commit is contained in:
2026-03-27 12:02:27 +08:00
parent 061d5b4592
commit 481f6160ce
7 changed files with 137 additions and 11 deletions
+8 -1
View File
@@ -3,11 +3,18 @@ from typing import Protocol
import numpy as np
from pose_tracking_exp.schema.detection import BoxDetections, PoseBatchRequest, PoseDetections, SourceFrame
from pose_tracking_exp.schema.detection import (
BoxDetections,
PoseBatchRequest,
PoseDetections,
SourceDeliveryPolicy,
SourceFrame,
)
class FrameSource(Protocol):
source_name: str
delivery_policy: SourceDeliveryPolicy
def frames(self) -> AsyncIterator[SourceFrame]:
...
+20 -3
View File
@@ -7,7 +7,7 @@ from loguru import logger
from pose_tracking_exp.detection.config import DetectionRunnerConfig
from pose_tracking_exp.detection.protocols import FrameSource, PoseShim, PoseSink
from pose_tracking_exp.schema.detection import SourceFrame
from pose_tracking_exp.schema.detection import SourceDeliveryPolicy, SourceFrame
PERFORMANCE_WINDOW = 60
@@ -21,6 +21,7 @@ class PendingFrame:
@dataclass(slots=True)
class SourceSlot:
source_name: str
delivery_policy: SourceDeliveryPolicy
pending_frame: PendingFrame | None = None
last_seen_frame_index: int | None = None
received_frames: int = 0
@@ -31,6 +32,8 @@ class SourceSlot:
def store_latest_frame(slot: SourceSlot, frame: SourceFrame) -> None:
if slot.delivery_policy != "latest_only":
raise ValueError("store_latest_frame is only valid for latest_only sources.")
slot.received_frames += 1
if slot.pending_frame is not None:
slot.dropped_frames += 1
@@ -92,7 +95,11 @@ async def run_detection_runner(
batch_size_sma = SimpleMovingAverage(PERFORMANCE_WINDOW)
scheduler_condition = anyio.Condition()
slots = {
source.source_name: SourceSlot(source_name=source.source_name) for source in sources
source.source_name: SourceSlot(
source_name=source.source_name,
delivery_policy=source.delivery_policy,
)
for source in sources
}
inference_limiter = anyio.CapacityLimiter(1)
@@ -111,7 +118,16 @@ async def run_detection_runner(
previous_frame_index = slot.last_seen_frame_index
should_log_init = previous_frame_index is None
slot.last_seen_frame_index = frame.frame_index
store_latest_frame(slot, frame)
if slot.delivery_policy == "block":
while slot.pending_frame is not None:
await scheduler_condition.wait()
slot.received_frames += 1
slot.pending_frame = PendingFrame(
source_name=slot.source_name,
frame=frame,
)
else:
store_latest_frame(slot, frame)
scheduler_condition.notify_all()
if should_log_init:
@@ -156,6 +172,7 @@ async def run_detection_runner(
await scheduler_condition.wait()
batch = take_pending_batch(slots, config.max_batch_frames)
scheduler_condition.notify_all()
start = perf_counter()
pose_infos = await to_thread_run_sync(
@@ -3,11 +3,12 @@ from typing import Protocol
from anyio.to_thread import run_sync as to_thread_run_sync
from pose_tracking_exp.schema.detection import SourceFrame
from pose_tracking_exp.schema.detection import SourceDeliveryPolicy, SourceFrame
class BlockingFrameProducer(Protocol):
source_name: str
delivery_policy: SourceDeliveryPolicy
def iter_frames(self) -> Iterator[SourceFrame]:
...
@@ -21,9 +22,11 @@ class IteratorFrameSource:
def __init__(
self,
source_name: str,
delivery_policy: SourceDeliveryPolicy,
iterator_factory: Callable[[], Iterator[SourceFrame]],
) -> None:
self.source_name = source_name
self.delivery_policy = delivery_policy
self._iterator_factory = iterator_factory
async def frames(self) -> AsyncIterator[SourceFrame]:
@@ -43,5 +46,6 @@ class IteratorFrameSource:
def wrap_blocking_source(producer: BlockingFrameProducer) -> IteratorFrameSource:
return IteratorFrameSource(
source_name=producer.source_name,
delivery_policy=producer.delivery_policy,
iterator_factory=producer.iter_frames,
)
@@ -2,10 +2,12 @@ from collections.abc import AsyncIterator
import numpy as np
from pose_tracking_exp.schema.detection import SourceFrame
from pose_tracking_exp.schema.detection import SourceDeliveryPolicy, SourceFrame
class CvmmapFrameSource:
delivery_policy: SourceDeliveryPolicy = "latest_only"
def __init__(self, source_name: str) -> None:
self.source_name = source_name
@@ -6,7 +6,7 @@ import cv2
import numpy as np
from pose_tracking_exp.detection.sources.adapters import wrap_blocking_source
from pose_tracking_exp.schema.detection import SourceFrame
from pose_tracking_exp.schema.detection import SourceDeliveryPolicy, SourceFrame
_DEFAULT_VIDEO_FPS = 30.0
@@ -33,6 +33,8 @@ def parse_video_input_specs(specs: Sequence[str]) -> tuple[tuple[str, Path], ...
class VideoFrameSource:
delivery_policy: SourceDeliveryPolicy = "block"
def __init__(
self,
video_path: Path,
@@ -15,6 +15,7 @@ from typing import Literal
import numpy as np
CocoKeypointSchema = Literal["coco17", "coco_wholebody133"]
SourceDeliveryPolicy = Literal["block", "latest_only"]
def expected_keypoint_count(schema: CocoKeypointSchema) -> int:
+97 -4
View File
@@ -1,5 +1,7 @@
from collections.abc import AsyncIterator, Sequence
from pathlib import Path
import time
from typing import cast
import anyio
import numpy as np
@@ -17,7 +19,8 @@ from pose_tracking_exp.detection.runner import (
store_latest_frame,
take_pending_batch,
)
from pose_tracking_exp.schema.detection import PoseDetections, SourceFrame
from pose_tracking_exp.detection.protocols import FrameSource
from pose_tracking_exp.schema.detection import PoseDetections, SourceDeliveryPolicy, SourceFrame
def test_load_detection_runner_config_from_toml_and_env(
@@ -61,7 +64,7 @@ def test_resolve_instances_falls_back_to_config_values() -> None:
def test_store_latest_frame_overwrites_pending_frame() -> None:
slot = SourceSlot(source_name="front_left")
slot = SourceSlot(source_name="front_left", delivery_policy="latest_only")
first = SourceFrame(
source_name="front_left",
image_bgr=np.zeros((1, 1, 3), dtype=np.uint8),
@@ -88,6 +91,7 @@ def test_take_pending_batch_collects_at_most_one_frame_per_source() -> None:
slots = {
"front_left": SourceSlot(
source_name="front_left",
delivery_policy="latest_only",
pending_frame=PendingFrame(
source_name="front_left",
frame=SourceFrame(
@@ -100,6 +104,7 @@ def test_take_pending_batch_collects_at_most_one_frame_per_source() -> None:
),
"front_right": SourceSlot(
source_name="front_right",
delivery_policy="latest_only",
pending_frame=PendingFrame(
source_name="front_right",
frame=SourceFrame(
@@ -112,6 +117,7 @@ def test_take_pending_batch_collects_at_most_one_frame_per_source() -> None:
),
"rear": SourceSlot(
source_name="rear",
delivery_policy="latest_only",
pending_frame=PendingFrame(
source_name="rear",
frame=SourceFrame(
@@ -133,9 +139,16 @@ def test_take_pending_batch_collects_at_most_one_frame_per_source() -> None:
class StubSource:
def __init__(self, source_name: str, frames: tuple[SourceFrame, ...]) -> None:
def __init__(
self,
source_name: str,
frames: tuple[SourceFrame, ...],
*,
delivery_policy: SourceDeliveryPolicy = "latest_only",
) -> None:
self.source_name = source_name
self._frames = frames
self.delivery_policy = delivery_policy
async def frames(self) -> AsyncIterator[SourceFrame]:
for frame in self._frames:
@@ -143,7 +156,12 @@ class StubSource:
class StubPoseShim:
def __init__(self, delay_seconds: float = 0.0) -> None:
self._delay_seconds = delay_seconds
def process_many(self, frames: Sequence[SourceFrame]) -> list[PoseDetections]:
if self._delay_seconds > 0.0:
time.sleep(self._delay_seconds)
detections: list[PoseDetections] = []
for frame in frames:
detections.append(
@@ -187,6 +205,7 @@ def test_run_detection_runner_publishes_payloads() -> None:
timestamp_unix_ns=100,
),
),
delivery_policy="block",
),
StubSource(
"cam1",
@@ -198,6 +217,7 @@ def test_run_detection_runner_publishes_payloads() -> None:
timestamp_unix_ns=200,
),
),
delivery_policy="block",
),
)
config = DetectionRunnerConfig(
@@ -210,7 +230,7 @@ def test_run_detection_runner_publishes_payloads() -> None:
anyio.run(
run_detection_runner,
sources,
cast(tuple[FrameSource, ...], sources),
StubPoseShim(),
sink,
config,
@@ -221,3 +241,76 @@ def test_run_detection_runner_publishes_payloads() -> None:
("cam0", 1, 100),
("cam1", 2, 200),
]
def test_run_detection_runner_blocks_to_preserve_offline_frames() -> None:
sink = StubSink()
source = StubSource(
"cam0",
tuple(
SourceFrame(
source_name="cam0",
image_bgr=np.zeros((2, 3, 3), dtype=np.uint8),
frame_index=frame_index,
timestamp_unix_ns=frame_index * 100,
)
for frame_index in range(3)
),
delivery_policy="block",
)
config = DetectionRunnerConfig(
instances=("cam0",),
pose_config_path=Path(__file__),
yolo_checkpoint=Path(__file__),
pose_checkpoint=Path(__file__),
max_batch_frames=1,
max_batch_wait_ms=0,
)
anyio.run(
run_detection_runner,
cast(tuple[FrameSource, ...], (source,)),
StubPoseShim(delay_seconds=0.01),
sink,
config,
)
assert [item.frame_index for item in sink.messages] == [0, 1, 2]
def test_run_detection_runner_drops_intermediate_latest_only_frames() -> None:
sink = StubSink()
source = StubSource(
"cam0",
tuple(
SourceFrame(
source_name="cam0",
image_bgr=np.zeros((2, 3, 3), dtype=np.uint8),
frame_index=frame_index,
timestamp_unix_ns=frame_index * 100,
)
for frame_index in range(3)
),
delivery_policy="latest_only",
)
config = DetectionRunnerConfig(
instances=("cam0",),
pose_config_path=Path(__file__),
yolo_checkpoint=Path(__file__),
pose_checkpoint=Path(__file__),
max_batch_frames=1,
max_batch_wait_ms=0,
)
anyio.run(
run_detection_runner,
cast(tuple[FrameSource, ...], (source,)),
StubPoseShim(delay_seconds=0.01),
sink,
config,
)
processed = [item.frame_index for item in sink.messages]
assert processed[-1] == 2
assert processed != [0, 1, 2]
assert 0 not in processed