481f6160ce
Add an explicit source delivery policy to the detection pipeline so offline and realtime sources can be handled differently without splitting the runner. Blocking sources now backpressure ingestion until their pending frame is drained, which preserves every offline video frame even when inference is slower than decode. Latest-only sources keep the previous overwrite behavior for realtime feeds such as cvmmap. The tests now cover both policies: offline sources retain ordered frame delivery under slow inference, while latest-only sources still drop intermediate frames as intended.
317 lines
9.6 KiB
Python
317 lines
9.6 KiB
Python
from collections.abc import AsyncIterator, Sequence
|
|
from pathlib import Path
|
|
import time
|
|
from typing import cast
|
|
|
|
import anyio
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from pose_tracking_exp.detection.config import (
|
|
DetectionRunnerConfig,
|
|
load_detection_runner_config,
|
|
resolve_instances,
|
|
)
|
|
from pose_tracking_exp.detection.runner import (
|
|
PendingFrame,
|
|
SourceSlot,
|
|
run_detection_runner,
|
|
store_latest_frame,
|
|
take_pending_batch,
|
|
)
|
|
from pose_tracking_exp.detection.protocols import FrameSource
|
|
from pose_tracking_exp.schema.detection import PoseDetections, SourceDeliveryPolicy, SourceFrame
|
|
|
|
|
|
def test_load_detection_runner_config_from_toml_and_env(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
tmp_path: Path,
|
|
) -> None:
|
|
config_path = tmp_path / "runner.toml"
|
|
config_path.write_text(
|
|
"\n".join(
|
|
[
|
|
'instances = ["front_left", "front_right"]',
|
|
'device = "cuda:1"',
|
|
'nats_host = "nats://localhost:4222"',
|
|
'yolo_checkpoint = "checkpoint/yolo/yolo11_mix_epoch10.pt"',
|
|
'pose_checkpoint = "checkpoint/dwpose/best_coco-wholebody_AP_epoch_50.pth"',
|
|
"bbox_area_threshold = 2500",
|
|
"max_batch_frames = 6",
|
|
"max_batch_wait_ms = 3",
|
|
]
|
|
),
|
|
encoding="utf-8",
|
|
)
|
|
|
|
monkeypatch.setenv("POSE_TRACKING_EXP_DETECTION_DEVICE", "cpu")
|
|
config = load_detection_runner_config(config_path)
|
|
|
|
assert config.instances == ("front_left", "front_right")
|
|
assert config.device == "cpu"
|
|
assert config.nats_host == "nats://localhost:4222"
|
|
assert config.bbox_area_threshold == 2500
|
|
assert config.max_batch_frames == 6
|
|
assert config.max_batch_wait_ms == 3
|
|
|
|
|
|
def test_resolve_instances_prefers_cli_values() -> None:
|
|
assert resolve_instances(("cli_a", "cli_b"), ("cfg_a",)) == ("cli_a", "cli_b")
|
|
|
|
|
|
def test_resolve_instances_falls_back_to_config_values() -> None:
|
|
assert resolve_instances((), ("cfg_a", "cfg_b")) == ("cfg_a", "cfg_b")
|
|
|
|
|
|
def test_store_latest_frame_overwrites_pending_frame() -> None:
|
|
slot = SourceSlot(source_name="front_left", delivery_policy="latest_only")
|
|
first = SourceFrame(
|
|
source_name="front_left",
|
|
image_bgr=np.zeros((1, 1, 3), dtype=np.uint8),
|
|
frame_index=1,
|
|
timestamp_unix_ns=100,
|
|
)
|
|
second = SourceFrame(
|
|
source_name="front_left",
|
|
image_bgr=np.ones((1, 1, 3), dtype=np.uint8),
|
|
frame_index=2,
|
|
timestamp_unix_ns=200,
|
|
)
|
|
|
|
store_latest_frame(slot, first)
|
|
store_latest_frame(slot, second)
|
|
|
|
assert slot.received_frames == 2
|
|
assert slot.dropped_frames == 1
|
|
assert slot.pending_frame is not None
|
|
assert slot.pending_frame.frame is second
|
|
|
|
|
|
def test_take_pending_batch_collects_at_most_one_frame_per_source() -> None:
|
|
slots = {
|
|
"front_left": SourceSlot(
|
|
source_name="front_left",
|
|
delivery_policy="latest_only",
|
|
pending_frame=PendingFrame(
|
|
source_name="front_left",
|
|
frame=SourceFrame(
|
|
source_name="front_left",
|
|
image_bgr=np.zeros((1, 1, 3), dtype=np.uint8),
|
|
frame_index=11,
|
|
timestamp_unix_ns=110,
|
|
),
|
|
),
|
|
),
|
|
"front_right": SourceSlot(
|
|
source_name="front_right",
|
|
delivery_policy="latest_only",
|
|
pending_frame=PendingFrame(
|
|
source_name="front_right",
|
|
frame=SourceFrame(
|
|
source_name="front_right",
|
|
image_bgr=np.zeros((1, 1, 3), dtype=np.uint8),
|
|
frame_index=22,
|
|
timestamp_unix_ns=220,
|
|
),
|
|
),
|
|
),
|
|
"rear": SourceSlot(
|
|
source_name="rear",
|
|
delivery_policy="latest_only",
|
|
pending_frame=PendingFrame(
|
|
source_name="rear",
|
|
frame=SourceFrame(
|
|
source_name="rear",
|
|
image_bgr=np.zeros((1, 1, 3), dtype=np.uint8),
|
|
frame_index=33,
|
|
timestamp_unix_ns=330,
|
|
),
|
|
),
|
|
),
|
|
}
|
|
|
|
batch = take_pending_batch(slots, max_batch_frames=2)
|
|
|
|
assert [frame.source_name for frame in batch] == ["front_left", "front_right"]
|
|
assert slots["front_left"].pending_frame is None
|
|
assert slots["front_right"].pending_frame is None
|
|
assert slots["rear"].pending_frame is not None
|
|
|
|
|
|
class StubSource:
|
|
def __init__(
|
|
self,
|
|
source_name: str,
|
|
frames: tuple[SourceFrame, ...],
|
|
*,
|
|
delivery_policy: SourceDeliveryPolicy = "latest_only",
|
|
) -> None:
|
|
self.source_name = source_name
|
|
self._frames = frames
|
|
self.delivery_policy = delivery_policy
|
|
|
|
async def frames(self) -> AsyncIterator[SourceFrame]:
|
|
for frame in self._frames:
|
|
yield frame
|
|
|
|
|
|
class StubPoseShim:
|
|
def __init__(self, delay_seconds: float = 0.0) -> None:
|
|
self._delay_seconds = delay_seconds
|
|
|
|
def process_many(self, frames: Sequence[SourceFrame]) -> list[PoseDetections]:
|
|
if self._delay_seconds > 0.0:
|
|
time.sleep(self._delay_seconds)
|
|
detections: list[PoseDetections] = []
|
|
for frame in frames:
|
|
detections.append(
|
|
PoseDetections(
|
|
source_name=frame.source_name,
|
|
frame_index=frame.frame_index,
|
|
source_size=(frame.image_bgr.shape[1], frame.image_bgr.shape[0]),
|
|
boxes_xyxy=np.asarray([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32),
|
|
box_scores=np.asarray([1.0], dtype=np.float32),
|
|
keypoints_xy=np.zeros((1, 133, 2), dtype=np.float32),
|
|
keypoint_scores=np.ones((1, 133), dtype=np.float32),
|
|
timestamp_unix_ns=frame.timestamp_unix_ns,
|
|
keypoint_schema="coco_wholebody133",
|
|
)
|
|
)
|
|
return detections
|
|
|
|
|
|
class StubSink:
|
|
def __init__(self) -> None:
|
|
self.messages: list[PoseDetections] = []
|
|
self.closed = False
|
|
|
|
async def publish_pose(self, detections: PoseDetections) -> None:
|
|
self.messages.append(detections)
|
|
|
|
async def aclose(self) -> None:
|
|
self.closed = True
|
|
|
|
|
|
def test_run_detection_runner_publishes_payloads() -> None:
|
|
sink = StubSink()
|
|
sources = (
|
|
StubSource(
|
|
"cam0",
|
|
(
|
|
SourceFrame(
|
|
source_name="cam0",
|
|
image_bgr=np.zeros((2, 3, 3), dtype=np.uint8),
|
|
frame_index=1,
|
|
timestamp_unix_ns=100,
|
|
),
|
|
),
|
|
delivery_policy="block",
|
|
),
|
|
StubSource(
|
|
"cam1",
|
|
(
|
|
SourceFrame(
|
|
source_name="cam1",
|
|
image_bgr=np.zeros((2, 3, 3), dtype=np.uint8),
|
|
frame_index=2,
|
|
timestamp_unix_ns=200,
|
|
),
|
|
),
|
|
delivery_policy="block",
|
|
),
|
|
)
|
|
config = DetectionRunnerConfig(
|
|
instances=("cam0", "cam1"),
|
|
pose_config_path=Path(__file__),
|
|
yolo_checkpoint=Path(__file__),
|
|
pose_checkpoint=Path(__file__),
|
|
max_batch_frames=2,
|
|
)
|
|
|
|
anyio.run(
|
|
run_detection_runner,
|
|
cast(tuple[FrameSource, ...], sources),
|
|
StubPoseShim(),
|
|
sink,
|
|
config,
|
|
)
|
|
|
|
assert sink.closed is True
|
|
assert [(item.source_name, item.frame_index, item.timestamp_unix_ns) for item in sink.messages] == [
|
|
("cam0", 1, 100),
|
|
("cam1", 2, 200),
|
|
]
|
|
|
|
|
|
def test_run_detection_runner_blocks_to_preserve_offline_frames() -> None:
|
|
sink = StubSink()
|
|
source = StubSource(
|
|
"cam0",
|
|
tuple(
|
|
SourceFrame(
|
|
source_name="cam0",
|
|
image_bgr=np.zeros((2, 3, 3), dtype=np.uint8),
|
|
frame_index=frame_index,
|
|
timestamp_unix_ns=frame_index * 100,
|
|
)
|
|
for frame_index in range(3)
|
|
),
|
|
delivery_policy="block",
|
|
)
|
|
config = DetectionRunnerConfig(
|
|
instances=("cam0",),
|
|
pose_config_path=Path(__file__),
|
|
yolo_checkpoint=Path(__file__),
|
|
pose_checkpoint=Path(__file__),
|
|
max_batch_frames=1,
|
|
max_batch_wait_ms=0,
|
|
)
|
|
|
|
anyio.run(
|
|
run_detection_runner,
|
|
cast(tuple[FrameSource, ...], (source,)),
|
|
StubPoseShim(delay_seconds=0.01),
|
|
sink,
|
|
config,
|
|
)
|
|
|
|
assert [item.frame_index for item in sink.messages] == [0, 1, 2]
|
|
|
|
|
|
def test_run_detection_runner_drops_intermediate_latest_only_frames() -> None:
|
|
sink = StubSink()
|
|
source = StubSource(
|
|
"cam0",
|
|
tuple(
|
|
SourceFrame(
|
|
source_name="cam0",
|
|
image_bgr=np.zeros((2, 3, 3), dtype=np.uint8),
|
|
frame_index=frame_index,
|
|
timestamp_unix_ns=frame_index * 100,
|
|
)
|
|
for frame_index in range(3)
|
|
),
|
|
delivery_policy="latest_only",
|
|
)
|
|
config = DetectionRunnerConfig(
|
|
instances=("cam0",),
|
|
pose_config_path=Path(__file__),
|
|
yolo_checkpoint=Path(__file__),
|
|
pose_checkpoint=Path(__file__),
|
|
max_batch_frames=1,
|
|
max_batch_wait_ms=0,
|
|
)
|
|
|
|
anyio.run(
|
|
run_detection_runner,
|
|
cast(tuple[FrameSource, ...], (source,)),
|
|
StubPoseShim(delay_seconds=0.01),
|
|
sink,
|
|
config,
|
|
)
|
|
|
|
processed = [item.frame_index for item in sink.messages]
|
|
assert processed[-1] == 2
|
|
assert processed != [0, 1, 2]
|
|
assert 0 not in processed
|