from collections.abc import AsyncIterator, Sequence from pathlib import Path import time from typing import cast import anyio import numpy as np import pytest from pose_tracking_exp.detection.config import ( DetectionRunnerConfig, load_detection_runner_config, resolve_instances, ) from pose_tracking_exp.detection.runner import ( PendingFrame, SourceSlot, run_detection_runner, store_latest_frame, take_pending_batch, ) from pose_tracking_exp.detection.protocols import FrameSource from pose_tracking_exp.schema.detection import PoseDetections, SourceDeliveryPolicy, SourceFrame def test_load_detection_runner_config_from_toml_and_env( monkeypatch: pytest.MonkeyPatch, tmp_path: Path, ) -> None: config_path = tmp_path / "runner.toml" config_path.write_text( "\n".join( [ 'instances = ["front_left", "front_right"]', 'device = "cuda:1"', 'nats_host = "nats://localhost:4222"', 'yolo_checkpoint = "checkpoint/yolo/yolo11_mix_epoch10.pt"', 'pose_checkpoint = "checkpoint/dwpose/best_coco-wholebody_AP_epoch_50.pth"', "bbox_area_threshold = 2500", "max_batch_frames = 6", "max_batch_wait_ms = 3", ] ), encoding="utf-8", ) monkeypatch.setenv("POSE_TRACKING_EXP_DETECTION_DEVICE", "cpu") config = load_detection_runner_config(config_path) assert config.instances == ("front_left", "front_right") assert config.device == "cpu" assert config.nats_host == "nats://localhost:4222" assert config.bbox_area_threshold == 2500 assert config.max_batch_frames == 6 assert config.max_batch_wait_ms == 3 def test_resolve_instances_prefers_cli_values() -> None: assert resolve_instances(("cli_a", "cli_b"), ("cfg_a",)) == ("cli_a", "cli_b") def test_resolve_instances_falls_back_to_config_values() -> None: assert resolve_instances((), ("cfg_a", "cfg_b")) == ("cfg_a", "cfg_b") def test_store_latest_frame_overwrites_pending_frame() -> None: slot = SourceSlot(source_name="front_left", delivery_policy="latest_only") first = SourceFrame( source_name="front_left", image_bgr=np.zeros((1, 1, 3), dtype=np.uint8), frame_index=1, timestamp_unix_ns=100, ) second = SourceFrame( source_name="front_left", image_bgr=np.ones((1, 1, 3), dtype=np.uint8), frame_index=2, timestamp_unix_ns=200, ) store_latest_frame(slot, first) store_latest_frame(slot, second) assert slot.received_frames == 2 assert slot.dropped_frames == 1 assert slot.pending_frame is not None assert slot.pending_frame.frame is second def test_take_pending_batch_collects_at_most_one_frame_per_source() -> None: slots = { "front_left": SourceSlot( source_name="front_left", delivery_policy="latest_only", pending_frame=PendingFrame( source_name="front_left", frame=SourceFrame( source_name="front_left", image_bgr=np.zeros((1, 1, 3), dtype=np.uint8), frame_index=11, timestamp_unix_ns=110, ), ), ), "front_right": SourceSlot( source_name="front_right", delivery_policy="latest_only", pending_frame=PendingFrame( source_name="front_right", frame=SourceFrame( source_name="front_right", image_bgr=np.zeros((1, 1, 3), dtype=np.uint8), frame_index=22, timestamp_unix_ns=220, ), ), ), "rear": SourceSlot( source_name="rear", delivery_policy="latest_only", pending_frame=PendingFrame( source_name="rear", frame=SourceFrame( source_name="rear", image_bgr=np.zeros((1, 1, 3), dtype=np.uint8), frame_index=33, timestamp_unix_ns=330, ), ), ), } batch = take_pending_batch(slots, max_batch_frames=2) assert [frame.source_name for frame in batch] == ["front_left", "front_right"] assert slots["front_left"].pending_frame is None assert slots["front_right"].pending_frame is None assert slots["rear"].pending_frame is not None class StubSource: def __init__( self, source_name: str, frames: tuple[SourceFrame, ...], *, delivery_policy: SourceDeliveryPolicy = "latest_only", ) -> None: self.source_name = source_name self._frames = frames self.delivery_policy = delivery_policy async def frames(self) -> AsyncIterator[SourceFrame]: for frame in self._frames: yield frame class StubPoseShim: def __init__(self, delay_seconds: float = 0.0) -> None: self._delay_seconds = delay_seconds def process_many(self, frames: Sequence[SourceFrame]) -> list[PoseDetections]: if self._delay_seconds > 0.0: time.sleep(self._delay_seconds) detections: list[PoseDetections] = [] for frame in frames: detections.append( PoseDetections( source_name=frame.source_name, frame_index=frame.frame_index, source_size=(frame.image_bgr.shape[1], frame.image_bgr.shape[0]), boxes_xyxy=np.asarray([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32), box_scores=np.asarray([1.0], dtype=np.float32), keypoints_xy=np.zeros((1, 133, 2), dtype=np.float32), keypoint_scores=np.ones((1, 133), dtype=np.float32), timestamp_unix_ns=frame.timestamp_unix_ns, keypoint_schema="coco_wholebody133", ) ) return detections class StubSink: def __init__(self) -> None: self.messages: list[PoseDetections] = [] self.closed = False async def publish_pose(self, detections: PoseDetections) -> None: self.messages.append(detections) async def aclose(self) -> None: self.closed = True def test_run_detection_runner_publishes_payloads() -> None: sink = StubSink() sources = ( StubSource( "cam0", ( SourceFrame( source_name="cam0", image_bgr=np.zeros((2, 3, 3), dtype=np.uint8), frame_index=1, timestamp_unix_ns=100, ), ), delivery_policy="block", ), StubSource( "cam1", ( SourceFrame( source_name="cam1", image_bgr=np.zeros((2, 3, 3), dtype=np.uint8), frame_index=2, timestamp_unix_ns=200, ), ), delivery_policy="block", ), ) config = DetectionRunnerConfig( instances=("cam0", "cam1"), pose_config_path=Path(__file__), yolo_checkpoint=Path(__file__), pose_checkpoint=Path(__file__), max_batch_frames=2, ) anyio.run( run_detection_runner, cast(tuple[FrameSource, ...], sources), StubPoseShim(), sink, config, ) assert sink.closed is True assert [(item.source_name, item.frame_index, item.timestamp_unix_ns) for item in sink.messages] == [ ("cam0", 1, 100), ("cam1", 2, 200), ] def test_run_detection_runner_blocks_to_preserve_offline_frames() -> None: sink = StubSink() source = StubSource( "cam0", tuple( SourceFrame( source_name="cam0", image_bgr=np.zeros((2, 3, 3), dtype=np.uint8), frame_index=frame_index, timestamp_unix_ns=frame_index * 100, ) for frame_index in range(3) ), delivery_policy="block", ) config = DetectionRunnerConfig( instances=("cam0",), pose_config_path=Path(__file__), yolo_checkpoint=Path(__file__), pose_checkpoint=Path(__file__), max_batch_frames=1, max_batch_wait_ms=0, ) anyio.run( run_detection_runner, cast(tuple[FrameSource, ...], (source,)), StubPoseShim(delay_seconds=0.01), sink, config, ) assert [item.frame_index for item in sink.messages] == [0, 1, 2] def test_run_detection_runner_drops_intermediate_latest_only_frames() -> None: sink = StubSink() source = StubSource( "cam0", tuple( SourceFrame( source_name="cam0", image_bgr=np.zeros((2, 3, 3), dtype=np.uint8), frame_index=frame_index, timestamp_unix_ns=frame_index * 100, ) for frame_index in range(3) ), delivery_policy="latest_only", ) config = DetectionRunnerConfig( instances=("cam0",), pose_config_path=Path(__file__), yolo_checkpoint=Path(__file__), pose_checkpoint=Path(__file__), max_batch_frames=1, max_batch_wait_ms=0, ) anyio.run( run_detection_runner, cast(tuple[FrameSource, ...], (source,)), StubPoseShim(delay_seconds=0.01), sink, config, ) processed = [item.frame_index for item in sink.messages] assert processed[-1] == 2 assert processed != [0, 1, 2] assert 0 not in processed