Files
OpenGait/opengait-studio/opengait_studio/input.py
T
crosstyan 00fcda4fe3 feat: extract opengait_studio monorepo module
Move demo implementation into opengait_studio, retire Sports2D runtime integration, and align packaging with root-level monorepo dependency management.
2026-03-07 18:14:13 +08:00

220 lines
6.6 KiB
Python

"""
Input adapters for OpenGait demo.
Provides generator-based interfaces for video sources:
- OpenCV (video files, cameras)
- cv-mmap (shared memory streams)
"""
from collections.abc import AsyncIterator, Generator, Iterable
from typing import Protocol, cast
import logging
import numpy as np
logger = logging.getLogger(__name__)
# Type alias for frame stream: (frame_array, metadata_dict)
FrameStream = Iterable[tuple[np.ndarray, dict[str, object]]]
# Protocol for cv-mmap metadata (needed at runtime for nested function annotation)
class _FrameMetadata(Protocol):
frame_count: int
timestamp_ns: int
# Protocol for cv-mmap client (needed at runtime for cast)
class _CvMmapClient(Protocol):
def __aiter__(self) -> AsyncIterator[tuple[np.ndarray, _FrameMetadata]]: ...
def opencv_source(
path: str | int, max_frames: int | None = None
) -> Generator[tuple[np.ndarray, dict[str, object]], None, None]:
"""
Generator that yields frames from an OpenCV video source.
Parameters
----------
path : str | int
Video file path or camera index (e.g., 0 for default camera)
max_frames : int | None, optional
Maximum number of frames to yield. None means unlimited.
Yields
------
tuple[np.ndarray, dict[str, object]]
(frame_array, metadata_dict) where metadata includes:
- frame_count: frame index (0-based)
- timestamp_ns: monotonic timestamp in nanoseconds (if available)
- source: the path/int provided
"""
import time
import cv2
cap = cv2.VideoCapture(path)
if not cap.isOpened():
raise RuntimeError(f"Failed to open video source: {path}")
is_file_source = isinstance(path, str)
source_fps = float(cap.get(cv2.CAP_PROP_FPS)) if is_file_source else 0.0
fps_valid = source_fps > 0.0 and np.isfinite(source_fps)
fallback_fps = source_fps if fps_valid else 30.0
fallback_interval_ns = int(1_000_000_000 / fallback_fps)
start_ns = time.monotonic_ns()
frame_idx = 0
try:
while max_frames is None or frame_idx < max_frames:
ret, frame = cap.read()
if not ret:
# End of stream
break
if is_file_source:
pos_msec = float(cap.get(cv2.CAP_PROP_POS_MSEC))
if np.isfinite(pos_msec) and pos_msec > 0.0:
timestamp_ns = start_ns + int(pos_msec * 1_000_000)
else:
timestamp_ns = start_ns + frame_idx * fallback_interval_ns
else:
timestamp_ns = time.monotonic_ns()
metadata: dict[str, object] = {
"frame_count": frame_idx,
"timestamp_ns": timestamp_ns,
"source": path,
}
if fps_valid:
metadata["source_fps"] = source_fps
yield frame, metadata
frame_idx += 1
finally:
cap.release()
logger.debug(f"OpenCV source closed: {path}")
def cvmmap_source(
name: str, max_frames: int | None = None
) -> Generator[tuple[np.ndarray, dict[str, object]], None, None]:
"""
Generator that yields frames from a cv-mmap shared memory stream.
Bridges async cv-mmap client to synchronous generator using asyncio.run().
Parameters
----------
name : str
Base name of the cv-mmap source (e.g., "default")
max_frames : int | None, optional
Maximum number of frames to yield. None means unlimited.
Yields
------
tuple[np.ndarray, dict[str, object]]
(frame_array, metadata_dict) where metadata includes:
- frame_count: frame index from cv-mmap
- timestamp_ns: timestamp in nanoseconds from cv-mmap
- source: the cv-mmap name
Raises
------
ImportError
If cvmmap package is not available
RuntimeError
If cv-mmap stream disconnects or errors
"""
import asyncio
# Import cvmmap only when function is called
# Use try/except for runtime import check
try:
from cvmmap import CvMmapClient as _CvMmapClientReal # pyright: ignore[reportMissingTypeStubs]
except ImportError as e:
raise ImportError(
"cvmmap package is required for cv-mmap sources. "
+ "Install from: https://github.com/crosstyan/cv-mmap"
) from e
# Cast to protocol type for type checking
client: _CvMmapClient = cast("_CvMmapClient", _CvMmapClientReal(name))
frame_count = 0
async def _async_generator() -> AsyncIterator[tuple[np.ndarray, _FrameMetadata]]:
"""Async generator wrapper."""
async for frame, meta in client:
yield frame, meta
# Bridge async to sync using asyncio.run()
# We process frames one at a time to keep it simple and robust
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
agen = _async_generator().__aiter__()
while max_frames is None or frame_count < max_frames:
try:
frame, meta = loop.run_until_complete(agen.__anext__())
except StopAsyncIteration:
break
metadata: dict[str, object] = {
"frame_count": meta.frame_count,
"timestamp_ns": meta.timestamp_ns,
"source": f"cvmmap://{name}",
}
yield frame, metadata
frame_count += 1
finally:
loop.close()
logger.debug(f"cv-mmap source closed: {name}")
def create_source(source: str, max_frames: int | None = None) -> FrameStream:
"""
Factory function to create a frame source from a string specification.
Parameters
----------
source : str
Source specification:
- '0', '1', etc. -> Camera index (OpenCV)
- 'cvmmap://name' -> cv-mmap shared memory stream
- Any other string -> Video file path (OpenCV)
max_frames : int | None, optional
Maximum number of frames to yield. None means unlimited.
Returns
-------
FrameStream
Generator yielding (frame, metadata) tuples
Examples
--------
>>> for frame, meta in create_source('0'): # Camera 0
... process(frame)
>>> for frame, meta in create_source('cvmmap://default'): # cv-mmap
... process(frame)
>>> for frame, meta in create_source('/path/to/video.mp4'):
... process(frame)
"""
# Check for cv-mmap protocol
if source.startswith("cvmmap://"):
name = source[len("cvmmap://") :]
return cvmmap_source(name, max_frames)
# Check for camera index (single digit string)
if source.isdigit():
return opencv_source(int(source), max_frames)
# Otherwise treat as file path
return opencv_source(source, max_frames)