feat: extract opengait_studio monorepo module
Move demo implementation into opengait_studio, retire Sports2D runtime integration, and align packaging with root-level monorepo dependency management.
This commit is contained in:
@@ -0,0 +1,390 @@
|
||||
"""
|
||||
Output publishers for OpenGait demo results.
|
||||
|
||||
Provides pluggable result publishing:
|
||||
- ConsolePublisher: JSONL to stdout
|
||||
- NatsPublisher: NATS message broker integration
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import nats
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Protocol, TextIO, TypedDict, cast, runtime_checkable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DemoResult(TypedDict):
|
||||
"""Typed result dictionary for demo pipeline output.
|
||||
|
||||
Contains classification result with frame metadata.
|
||||
"""
|
||||
|
||||
frame: int
|
||||
track_id: int
|
||||
label: str
|
||||
confidence: float
|
||||
window: int
|
||||
timestamp_ns: int
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ResultPublisher(Protocol):
|
||||
"""Protocol for result publishers."""
|
||||
|
||||
def publish(self, result: DemoResult) -> None:
|
||||
"""
|
||||
Publish a result dictionary.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
result : DemoResult
|
||||
Result data with keys: frame, track_id, label, confidence, window, timestamp_ns
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class ConsolePublisher:
|
||||
"""Publisher that outputs JSON Lines to stdout."""
|
||||
|
||||
_output: TextIO
|
||||
|
||||
def __init__(self, output: TextIO = sys.stdout) -> None:
|
||||
"""
|
||||
Initialize console publisher.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : TextIO
|
||||
File-like object to write to (default: sys.stdout)
|
||||
"""
|
||||
self._output = output
|
||||
|
||||
def publish(self, result: DemoResult) -> None:
|
||||
"""
|
||||
Publish result as JSON line.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
result : DemoResult
|
||||
Result data with keys: frame, track_id, label, confidence, window, timestamp_ns
|
||||
"""
|
||||
try:
|
||||
json_line = json.dumps(result, ensure_ascii=False, default=str)
|
||||
_ = self._output.write(json_line + "\n")
|
||||
self._output.flush()
|
||||
except (OSError, ValueError, TypeError) as e:
|
||||
logger.warning(f"Failed to publish to console: {e}")
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the publisher (no-op for console)."""
|
||||
pass
|
||||
|
||||
def __enter__(self) -> ConsolePublisher:
|
||||
"""Context manager entry."""
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
"""Context manager exit."""
|
||||
self.close()
|
||||
|
||||
|
||||
class _NatsClient(Protocol):
|
||||
"""Protocol for connected NATS client."""
|
||||
|
||||
async def publish(self, subject: str, payload: bytes) -> object: ...
|
||||
|
||||
async def close(self) -> object: ...
|
||||
|
||||
async def flush(self) -> object: ...
|
||||
|
||||
|
||||
class NatsPublisher:
|
||||
"""
|
||||
Publisher that sends results to NATS message broker.
|
||||
|
||||
This is a sync-friendly wrapper around the async nats-py client.
|
||||
Uses a background thread with dedicated event loop to bridge sync
|
||||
publish calls to async NATS operations, making it safe to use in
|
||||
both sync and async contexts.
|
||||
"""
|
||||
|
||||
_nats_url: str
|
||||
_subject: str
|
||||
_nc: _NatsClient | None
|
||||
_connected: bool
|
||||
_loop: asyncio.AbstractEventLoop | None
|
||||
_thread: threading.Thread | None
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, nats_url: str, subject: str = "scoliosis.result") -> None:
|
||||
"""
|
||||
Initialize NATS publisher.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
nats_url : str
|
||||
NATS server URL (e.g., "nats://localhost:4222")
|
||||
subject : str
|
||||
NATS subject to publish to (default: "scoliosis.result")
|
||||
"""
|
||||
self._nats_url = nats_url
|
||||
self._subject = subject
|
||||
self._nc = None
|
||||
self._connected = False
|
||||
self._loop = None
|
||||
self._thread = None
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def _start_background_loop(self) -> bool:
|
||||
"""
|
||||
Start background thread with event loop for async operations.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if loop is running, False otherwise
|
||||
"""
|
||||
with self._lock:
|
||||
if self._loop is not None and self._loop.is_running():
|
||||
return True
|
||||
|
||||
try:
|
||||
loop = asyncio.new_event_loop()
|
||||
self._loop = loop
|
||||
|
||||
def run_loop() -> None:
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_forever()
|
||||
|
||||
self._thread = threading.Thread(target=run_loop, daemon=True)
|
||||
self._thread.start()
|
||||
return True
|
||||
except (RuntimeError, OSError) as e:
|
||||
logger.warning(f"Failed to start background event loop: {e}")
|
||||
return False
|
||||
|
||||
def _stop_background_loop(self) -> None:
|
||||
"""Stop the background event loop and thread."""
|
||||
with self._lock:
|
||||
if self._loop is not None and self._loop.is_running():
|
||||
_ = self._loop.call_soon_threadsafe(self._loop.stop)
|
||||
if self._thread is not None and self._thread.is_alive():
|
||||
self._thread.join(timeout=2.0)
|
||||
self._loop = None
|
||||
self._thread = None
|
||||
|
||||
def _ensure_connected(self) -> bool:
|
||||
"""
|
||||
Ensure connection to NATS server.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if connected, False otherwise
|
||||
"""
|
||||
with self._lock:
|
||||
if self._connected and self._nc is not None:
|
||||
return True
|
||||
|
||||
if not self._start_background_loop():
|
||||
return False
|
||||
|
||||
try:
|
||||
|
||||
async def _connect() -> _NatsClient:
|
||||
nc = await nats.connect(self._nats_url) # pyright: ignore[reportUnknownMemberType]
|
||||
return cast(_NatsClient, nc)
|
||||
|
||||
# Run connection in background loop
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
_connect(),
|
||||
self._loop, # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
self._nc = future.result(timeout=10.0)
|
||||
self._connected = True
|
||||
logger.info(f"Connected to NATS at {self._nats_url}")
|
||||
return True
|
||||
except (RuntimeError, OSError, TimeoutError) as e:
|
||||
logger.warning(f"Failed to connect to NATS at {self._nats_url}: {e}")
|
||||
return False
|
||||
|
||||
def publish(self, result: DemoResult) -> None:
|
||||
"""
|
||||
Publish result to NATS subject.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
result : DemoResult
|
||||
Result data with keys: frame, track_id, label, confidence, window, timestamp_ns
|
||||
"""
|
||||
if not self._ensure_connected():
|
||||
# Graceful degradation: log warning but don't crash
|
||||
logger.debug(
|
||||
f"NATS unavailable, dropping result: {result.get('track_id', 'unknown')}"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
|
||||
async def _publish() -> None:
|
||||
if self._nc is not None:
|
||||
payload = json.dumps(
|
||||
result, ensure_ascii=False, default=str
|
||||
).encode("utf-8")
|
||||
_ = await self._nc.publish(self._subject, payload)
|
||||
_ = await self._nc.flush()
|
||||
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
_publish(),
|
||||
self._loop, # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
|
||||
def _on_done(publish_future: object) -> None:
|
||||
fut = cast("asyncio.Future[None]", publish_future)
|
||||
try:
|
||||
exc = fut.exception()
|
||||
except (RuntimeError, OSError) as callback_error:
|
||||
logger.warning(f"NATS publish callback failed: {callback_error}")
|
||||
self._connected = False
|
||||
return
|
||||
if exc is not None:
|
||||
logger.warning(f"Failed to publish to NATS: {exc}")
|
||||
self._connected = False
|
||||
|
||||
future.add_done_callback(_on_done)
|
||||
except (RuntimeError, OSError, ValueError, TypeError) as e:
|
||||
logger.warning(f"Failed to schedule NATS publish: {e}")
|
||||
self._connected = False
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close NATS connection."""
|
||||
with self._lock:
|
||||
if self._nc is not None and self._connected and self._loop is not None:
|
||||
try:
|
||||
|
||||
async def _close() -> None:
|
||||
if self._nc is not None:
|
||||
_ = await self._nc.close()
|
||||
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
_close(),
|
||||
self._loop,
|
||||
)
|
||||
future.result(timeout=5.0)
|
||||
except (RuntimeError, OSError, TimeoutError) as e:
|
||||
logger.debug(f"Error closing NATS connection: {e}")
|
||||
finally:
|
||||
self._nc = None
|
||||
self._connected = False
|
||||
|
||||
self._stop_background_loop()
|
||||
|
||||
def __enter__(self) -> NatsPublisher:
|
||||
"""Context manager entry."""
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
"""Context manager exit."""
|
||||
self.close()
|
||||
|
||||
|
||||
def create_publisher(
|
||||
nats_url: str | None,
|
||||
subject: str = "scoliosis.result",
|
||||
) -> ResultPublisher:
|
||||
"""
|
||||
Factory function to create appropriate publisher.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
nats_url : str | None
|
||||
NATS server URL. If None or empty, returns ConsolePublisher.
|
||||
subject : str
|
||||
NATS subject to publish to (default: "scoliosis.result")
|
||||
|
||||
Returns
|
||||
-------
|
||||
ResultPublisher
|
||||
NatsPublisher if nats_url provided, otherwise ConsolePublisher
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> # Console output (default)
|
||||
>>> pub = create_publisher(None)
|
||||
>>> pub.publish({"frame": 1, "track_id": 42, "label": "normal", "confidence": 0.95, "window": 30, "timestamp_ns": 1234567890})
|
||||
>>>
|
||||
>>> # NATS output
|
||||
>>> pub = create_publisher("nats://localhost:4222")
|
||||
>>> pub.publish({"frame": 1, "track_id": 42, "label": "normal", "confidence": 0.95, "window": 30, "timestamp_ns": 1234567890})
|
||||
>>>
|
||||
>>> # Context manager usage
|
||||
>>> with create_publisher("nats://localhost:4222") as pub:
|
||||
... pub.publish(result)
|
||||
"""
|
||||
if nats_url:
|
||||
return NatsPublisher(nats_url, subject)
|
||||
return ConsolePublisher()
|
||||
|
||||
|
||||
def create_result(
|
||||
frame: int,
|
||||
track_id: int,
|
||||
label: str,
|
||||
confidence: float,
|
||||
window: int | tuple[int, int],
|
||||
timestamp_ns: int | None = None,
|
||||
) -> DemoResult:
|
||||
"""
|
||||
Create a standardized result dictionary.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
frame : int
|
||||
Frame number
|
||||
track_id : int
|
||||
Track/person identifier
|
||||
label : str
|
||||
Classification label (e.g., "normal", "scoliosis")
|
||||
confidence : float
|
||||
Confidence score (0.0 to 1.0)
|
||||
window : int | tuple[int, int]
|
||||
Frame window as int (end frame) or tuple [start, end] that produced this result
|
||||
Frame window [start, end] that produced this result
|
||||
timestamp_ns : int | None
|
||||
Timestamp in nanoseconds. If None, uses current time.
|
||||
|
||||
Returns
|
||||
-------
|
||||
DemoResult
|
||||
Standardized result dictionary
|
||||
"""
|
||||
return {
|
||||
"frame": frame,
|
||||
"track_id": track_id,
|
||||
"label": label,
|
||||
"confidence": confidence,
|
||||
"window": window if isinstance(window, int) else window[1],
|
||||
"timestamp_ns": timestamp_ns
|
||||
if timestamp_ns is not None
|
||||
else time.monotonic_ns(),
|
||||
}
|
||||
Reference in New Issue
Block a user