220 lines
6.6 KiB
Python
220 lines
6.6 KiB
Python
"""
|
|
Input adapters for OpenGait demo.
|
|
|
|
Provides generator-based interfaces for video sources:
|
|
- OpenCV (video files, cameras)
|
|
- cv-mmap (shared memory streams)
|
|
"""
|
|
|
|
from collections.abc import AsyncIterator, Generator, Iterable
|
|
from typing import Protocol, cast
|
|
|
|
import logging
|
|
|
|
import numpy as np
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Type alias for frame stream: (frame_array, metadata_dict)
|
|
FrameStream = Iterable[tuple[np.ndarray, dict[str, object]]]
|
|
|
|
|
|
# Protocol for cv-mmap metadata (needed at runtime for nested function annotation)
|
|
class _FrameMetadata(Protocol):
|
|
frame_count: int
|
|
timestamp_ns: int
|
|
|
|
|
|
# Protocol for cv-mmap client (needed at runtime for cast)
|
|
class _CvMmapClient(Protocol):
|
|
def __aiter__(self) -> AsyncIterator[tuple[np.ndarray, _FrameMetadata]]: ...
|
|
|
|
|
|
def opencv_source(
|
|
path: str | int, max_frames: int | None = None
|
|
) -> Generator[tuple[np.ndarray, dict[str, object]], None, None]:
|
|
"""
|
|
Generator that yields frames from an OpenCV video source.
|
|
|
|
Parameters
|
|
----------
|
|
path : str | int
|
|
Video file path or camera index (e.g., 0 for default camera)
|
|
max_frames : int | None, optional
|
|
Maximum number of frames to yield. None means unlimited.
|
|
|
|
Yields
|
|
------
|
|
tuple[np.ndarray, dict[str, object]]
|
|
(frame_array, metadata_dict) where metadata includes:
|
|
- frame_count: frame index (0-based)
|
|
- timestamp_ns: monotonic timestamp in nanoseconds (if available)
|
|
- source: the path/int provided
|
|
"""
|
|
import time
|
|
|
|
import cv2
|
|
|
|
cap = cv2.VideoCapture(path)
|
|
if not cap.isOpened():
|
|
raise RuntimeError(f"Failed to open video source: {path}")
|
|
|
|
is_file_source = isinstance(path, str)
|
|
source_fps = float(cap.get(cv2.CAP_PROP_FPS)) if is_file_source else 0.0
|
|
fps_valid = source_fps > 0.0 and np.isfinite(source_fps)
|
|
fallback_fps = source_fps if fps_valid else 30.0
|
|
fallback_interval_ns = int(1_000_000_000 / fallback_fps)
|
|
start_ns = time.monotonic_ns()
|
|
|
|
frame_idx = 0
|
|
try:
|
|
while max_frames is None or frame_idx < max_frames:
|
|
ret, frame = cap.read()
|
|
if not ret:
|
|
# End of stream
|
|
break
|
|
|
|
if is_file_source:
|
|
pos_msec = float(cap.get(cv2.CAP_PROP_POS_MSEC))
|
|
if np.isfinite(pos_msec) and pos_msec > 0.0:
|
|
timestamp_ns = start_ns + int(pos_msec * 1_000_000)
|
|
else:
|
|
timestamp_ns = start_ns + frame_idx * fallback_interval_ns
|
|
else:
|
|
timestamp_ns = time.monotonic_ns()
|
|
|
|
metadata: dict[str, object] = {
|
|
"frame_count": frame_idx,
|
|
"timestamp_ns": timestamp_ns,
|
|
"source": path,
|
|
}
|
|
if fps_valid:
|
|
metadata["source_fps"] = source_fps
|
|
|
|
yield frame, metadata
|
|
frame_idx += 1
|
|
|
|
finally:
|
|
cap.release()
|
|
logger.debug(f"OpenCV source closed: {path}")
|
|
|
|
|
|
def cvmmap_source(
|
|
name: str, max_frames: int | None = None
|
|
) -> Generator[tuple[np.ndarray, dict[str, object]], None, None]:
|
|
"""
|
|
Generator that yields frames from a cv-mmap shared memory stream.
|
|
|
|
Bridges async cv-mmap client to synchronous generator using asyncio.run().
|
|
|
|
Parameters
|
|
----------
|
|
name : str
|
|
Base name of the cv-mmap source (e.g., "default")
|
|
max_frames : int | None, optional
|
|
Maximum number of frames to yield. None means unlimited.
|
|
|
|
Yields
|
|
------
|
|
tuple[np.ndarray, dict[str, object]]
|
|
(frame_array, metadata_dict) where metadata includes:
|
|
- frame_count: frame index from cv-mmap
|
|
- timestamp_ns: timestamp in nanoseconds from cv-mmap
|
|
- source: the cv-mmap name
|
|
|
|
Raises
|
|
------
|
|
ImportError
|
|
If cvmmap package is not available
|
|
RuntimeError
|
|
If cv-mmap stream disconnects or errors
|
|
"""
|
|
import asyncio
|
|
|
|
# Import cvmmap only when function is called
|
|
# Use try/except for runtime import check
|
|
try:
|
|
from cvmmap import CvMmapClient as _CvMmapClientReal # pyright: ignore[reportMissingTypeStubs]
|
|
except ImportError as e:
|
|
raise ImportError(
|
|
"cvmmap package is required for cv-mmap sources. "
|
|
+ "Install from: https://github.com/crosstyan/cv-mmap"
|
|
) from e
|
|
|
|
# Cast to protocol type for type checking
|
|
client: _CvMmapClient = cast("_CvMmapClient", _CvMmapClientReal(name))
|
|
frame_count = 0
|
|
|
|
async def _async_generator() -> AsyncIterator[tuple[np.ndarray, _FrameMetadata]]:
|
|
"""Async generator wrapper."""
|
|
async for frame, meta in client:
|
|
yield frame, meta
|
|
|
|
# Bridge async to sync using asyncio.run()
|
|
# We process frames one at a time to keep it simple and robust
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
|
|
try:
|
|
agen = _async_generator().__aiter__()
|
|
|
|
while max_frames is None or frame_count < max_frames:
|
|
try:
|
|
frame, meta = loop.run_until_complete(agen.__anext__())
|
|
except StopAsyncIteration:
|
|
break
|
|
|
|
metadata: dict[str, object] = {
|
|
"frame_count": meta.frame_count,
|
|
"timestamp_ns": meta.timestamp_ns,
|
|
"source": f"cvmmap://{name}",
|
|
}
|
|
|
|
yield frame, metadata
|
|
frame_count += 1
|
|
|
|
finally:
|
|
loop.close()
|
|
logger.debug(f"cv-mmap source closed: {name}")
|
|
|
|
|
|
def create_source(source: str, max_frames: int | None = None) -> FrameStream:
|
|
"""
|
|
Factory function to create a frame source from a string specification.
|
|
|
|
Parameters
|
|
----------
|
|
source : str
|
|
Source specification:
|
|
- '0', '1', etc. -> Camera index (OpenCV)
|
|
- 'cvmmap://name' -> cv-mmap shared memory stream
|
|
- Any other string -> Video file path (OpenCV)
|
|
max_frames : int | None, optional
|
|
Maximum number of frames to yield. None means unlimited.
|
|
|
|
Returns
|
|
-------
|
|
FrameStream
|
|
Generator yielding (frame, metadata) tuples
|
|
|
|
Examples
|
|
--------
|
|
>>> for frame, meta in create_source('0'): # Camera 0
|
|
... process(frame)
|
|
>>> for frame, meta in create_source('cvmmap://default'): # cv-mmap
|
|
... process(frame)
|
|
>>> for frame, meta in create_source('/path/to/video.mp4'):
|
|
... process(frame)
|
|
"""
|
|
# Check for cv-mmap protocol
|
|
if source.startswith("cvmmap://"):
|
|
name = source[len("cvmmap://") :]
|
|
return cvmmap_source(name, max_frames)
|
|
|
|
# Check for camera index (single digit string)
|
|
if source.isdigit():
|
|
return opencv_source(int(source), max_frames)
|
|
|
|
# Otherwise treat as file path
|
|
return opencv_source(source, max_frames)
|