b24644f16e
Add the full demo runtime stack for single-person scoliosis inference, including input adapters, silhouette preprocessing, temporal windowing, ScoNet wrapper, result publishing, and click-based CLI orchestration. This commit captures the executable pipeline behavior independently from tests and planning artifacts for clearer review and rollback.
369 lines
11 KiB
Python
369 lines
11 KiB
Python
"""
|
|
Output publishers for OpenGait demo results.
|
|
|
|
Provides pluggable result publishing:
|
|
- ConsolePublisher: JSONL to stdout
|
|
- NatsPublisher: NATS message broker integration
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import sys
|
|
import threading
|
|
import time
|
|
from typing import TYPE_CHECKING, Protocol, TextIO, cast, runtime_checkable
|
|
|
|
if TYPE_CHECKING:
|
|
from types import TracebackType
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@runtime_checkable
|
|
class ResultPublisher(Protocol):
|
|
"""Protocol for result publishers."""
|
|
|
|
def publish(self, result: dict[str, object]) -> None:
|
|
"""
|
|
Publish a result dictionary.
|
|
|
|
Parameters
|
|
----------
|
|
result : dict[str, object]
|
|
Result data with keys: frame, track_id, label, confidence, window, timestamp_ns
|
|
"""
|
|
...
|
|
|
|
|
|
class ConsolePublisher:
|
|
"""Publisher that outputs JSON Lines to stdout."""
|
|
|
|
_output: TextIO
|
|
|
|
def __init__(self, output: TextIO = sys.stdout) -> None:
|
|
"""
|
|
Initialize console publisher.
|
|
|
|
Parameters
|
|
----------
|
|
output : TextIO
|
|
File-like object to write to (default: sys.stdout)
|
|
"""
|
|
self._output = output
|
|
|
|
def publish(self, result: dict[str, object]) -> None:
|
|
"""
|
|
Publish result as JSON line.
|
|
|
|
Parameters
|
|
----------
|
|
result : dict[str, object]
|
|
Result data with keys: frame, track_id, label, confidence, window, timestamp_ns
|
|
"""
|
|
try:
|
|
json_line = json.dumps(result, ensure_ascii=False, default=str)
|
|
_ = self._output.write(json_line + "\n")
|
|
self._output.flush()
|
|
except Exception as e:
|
|
logger.warning(f"Failed to publish to console: {e}")
|
|
|
|
def close(self) -> None:
|
|
"""Close the publisher (no-op for console)."""
|
|
pass
|
|
|
|
def __enter__(self) -> ConsolePublisher:
|
|
"""Context manager entry."""
|
|
return self
|
|
|
|
def __exit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc_val: BaseException | None,
|
|
exc_tb: TracebackType | None,
|
|
) -> None:
|
|
"""Context manager exit."""
|
|
self.close()
|
|
|
|
|
|
class _NatsClient(Protocol):
|
|
"""Protocol for connected NATS client."""
|
|
|
|
async def publish(self, subject: str, payload: bytes) -> object: ...
|
|
|
|
async def close(self) -> object: ...
|
|
|
|
async def flush(self) -> object: ...
|
|
|
|
|
|
class NatsPublisher:
|
|
"""
|
|
Publisher that sends results to NATS message broker.
|
|
|
|
This is a sync-friendly wrapper around the async nats-py client.
|
|
Uses a background thread with dedicated event loop to bridge sync
|
|
publish calls to async NATS operations, making it safe to use in
|
|
both sync and async contexts.
|
|
"""
|
|
|
|
_nats_url: str
|
|
_subject: str
|
|
_nc: _NatsClient | None
|
|
_connected: bool
|
|
_loop: asyncio.AbstractEventLoop | None
|
|
_thread: threading.Thread | None
|
|
_lock: threading.Lock
|
|
|
|
def __init__(self, nats_url: str, subject: str = "scoliosis.result") -> None:
|
|
"""
|
|
Initialize NATS publisher.
|
|
|
|
Parameters
|
|
----------
|
|
nats_url : str
|
|
NATS server URL (e.g., "nats://localhost:4222")
|
|
subject : str
|
|
NATS subject to publish to (default: "scoliosis.result")
|
|
"""
|
|
self._nats_url = nats_url
|
|
self._subject = subject
|
|
self._nc = None
|
|
self._connected = False
|
|
self._loop = None
|
|
self._thread = None
|
|
self._lock = threading.Lock()
|
|
|
|
def _start_background_loop(self) -> bool:
|
|
"""
|
|
Start background thread with event loop for async operations.
|
|
|
|
Returns
|
|
-------
|
|
bool
|
|
True if loop is running, False otherwise
|
|
"""
|
|
with self._lock:
|
|
if self._loop is not None and self._loop.is_running():
|
|
return True
|
|
|
|
try:
|
|
loop = asyncio.new_event_loop()
|
|
self._loop = loop
|
|
|
|
def run_loop() -> None:
|
|
asyncio.set_event_loop(loop)
|
|
loop.run_forever()
|
|
|
|
self._thread = threading.Thread(target=run_loop, daemon=True)
|
|
self._thread.start()
|
|
return True
|
|
except Exception as e:
|
|
logger.warning(f"Failed to start background event loop: {e}")
|
|
return False
|
|
|
|
def _stop_background_loop(self) -> None:
|
|
"""Stop the background event loop and thread."""
|
|
with self._lock:
|
|
if self._loop is not None and self._loop.is_running():
|
|
_ = self._loop.call_soon_threadsafe(self._loop.stop)
|
|
if self._thread is not None and self._thread.is_alive():
|
|
self._thread.join(timeout=2.0)
|
|
self._loop = None
|
|
self._thread = None
|
|
|
|
def _ensure_connected(self) -> bool:
|
|
"""
|
|
Ensure connection to NATS server.
|
|
|
|
Returns
|
|
-------
|
|
bool
|
|
True if connected, False otherwise
|
|
"""
|
|
with self._lock:
|
|
if self._connected and self._nc is not None:
|
|
return True
|
|
|
|
if not self._start_background_loop():
|
|
return False
|
|
|
|
try:
|
|
import nats
|
|
|
|
async def _connect() -> _NatsClient:
|
|
nc = await nats.connect(self._nats_url) # pyright: ignore[reportUnknownMemberType]
|
|
return cast(_NatsClient, nc)
|
|
|
|
# Run connection in background loop
|
|
future = asyncio.run_coroutine_threadsafe(
|
|
_connect(),
|
|
self._loop, # pyright: ignore[reportArgumentType]
|
|
)
|
|
self._nc = future.result(timeout=10.0)
|
|
self._connected = True
|
|
logger.info(f"Connected to NATS at {self._nats_url}")
|
|
return True
|
|
except ImportError:
|
|
logger.warning(
|
|
"nats-py package not installed. Install with: pip install nats-py"
|
|
)
|
|
return False
|
|
except Exception as e:
|
|
logger.warning(f"Failed to connect to NATS at {self._nats_url}: {e}")
|
|
return False
|
|
|
|
def publish(self, result: dict[str, object]) -> None:
|
|
"""
|
|
Publish result to NATS subject.
|
|
|
|
Parameters
|
|
----------
|
|
result : dict[str, object]
|
|
Result data with keys: frame, track_id, label, confidence, window, timestamp_ns
|
|
"""
|
|
if not self._ensure_connected():
|
|
# Graceful degradation: log warning but don't crash
|
|
logger.debug(
|
|
f"NATS unavailable, dropping result: {result.get('track_id', 'unknown')}"
|
|
)
|
|
return
|
|
|
|
try:
|
|
|
|
async def _publish() -> None:
|
|
if self._nc is not None:
|
|
payload = json.dumps(
|
|
result, ensure_ascii=False, default=str
|
|
).encode("utf-8")
|
|
_ = await self._nc.publish(self._subject, payload)
|
|
_ = await self._nc.flush()
|
|
# Run publish in background loop
|
|
future = asyncio.run_coroutine_threadsafe(
|
|
_publish(),
|
|
self._loop, # pyright: ignore[reportArgumentType]
|
|
)
|
|
future.result(timeout=5.0) # Wait for publish to complete
|
|
except Exception as e:
|
|
logger.warning(f"Failed to publish to NATS: {e}")
|
|
self._connected = False # Mark for reconnection on next publish
|
|
|
|
def close(self) -> None:
|
|
"""Close NATS connection."""
|
|
with self._lock:
|
|
if self._nc is not None and self._connected and self._loop is not None:
|
|
try:
|
|
|
|
async def _close() -> None:
|
|
if self._nc is not None:
|
|
_ = await self._nc.close()
|
|
|
|
future = asyncio.run_coroutine_threadsafe(
|
|
_close(),
|
|
self._loop,
|
|
)
|
|
future.result(timeout=5.0)
|
|
except Exception as e:
|
|
logger.debug(f"Error closing NATS connection: {e}")
|
|
finally:
|
|
self._nc = None
|
|
self._connected = False
|
|
|
|
self._stop_background_loop()
|
|
|
|
def __enter__(self) -> NatsPublisher:
|
|
"""Context manager entry."""
|
|
return self
|
|
|
|
def __exit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc_val: BaseException | None,
|
|
exc_tb: TracebackType | None,
|
|
) -> None:
|
|
"""Context manager exit."""
|
|
self.close()
|
|
|
|
|
|
def create_publisher(
|
|
nats_url: str | None,
|
|
subject: str = "scoliosis.result",
|
|
) -> ResultPublisher:
|
|
"""
|
|
Factory function to create appropriate publisher.
|
|
|
|
Parameters
|
|
----------
|
|
nats_url : str | None
|
|
NATS server URL. If None or empty, returns ConsolePublisher.
|
|
subject : str
|
|
NATS subject to publish to (default: "scoliosis.result")
|
|
|
|
Returns
|
|
-------
|
|
ResultPublisher
|
|
NatsPublisher if nats_url provided, otherwise ConsolePublisher
|
|
|
|
Examples
|
|
--------
|
|
>>> # Console output (default)
|
|
>>> pub = create_publisher(None)
|
|
>>> pub.publish({"frame": 1, "track_id": 42, "label": "normal", "confidence": 0.95, "window": 30, "timestamp_ns": 1234567890})
|
|
>>>
|
|
>>> # NATS output
|
|
>>> pub = create_publisher("nats://localhost:4222")
|
|
>>> pub.publish({"frame": 1, "track_id": 42, "label": "normal", "confidence": 0.95, "window": 30, "timestamp_ns": 1234567890})
|
|
>>>
|
|
>>> # Context manager usage
|
|
>>> with create_publisher("nats://localhost:4222") as pub:
|
|
... pub.publish(result)
|
|
"""
|
|
if nats_url:
|
|
return NatsPublisher(nats_url, subject)
|
|
return ConsolePublisher()
|
|
|
|
|
|
def create_result(
|
|
frame: int,
|
|
track_id: int,
|
|
label: str,
|
|
confidence: float,
|
|
window: int | tuple[int, int],
|
|
timestamp_ns: int | None = None,
|
|
) -> dict[str, object]:
|
|
"""
|
|
Create a standardized result dictionary.
|
|
|
|
Parameters
|
|
----------
|
|
frame : int
|
|
Frame number
|
|
track_id : int
|
|
Track/person identifier
|
|
label : str
|
|
Classification label (e.g., "normal", "scoliosis")
|
|
confidence : float
|
|
Confidence score (0.0 to 1.0)
|
|
window : int | tuple[int, int]
|
|
Frame window as int (end frame) or tuple [start, end] that produced this result
|
|
Frame window [start, end] that produced this result
|
|
timestamp_ns : int | None
|
|
Timestamp in nanoseconds. If None, uses current time.
|
|
|
|
Returns
|
|
-------
|
|
dict[str, object]
|
|
Standardized result dictionary
|
|
"""
|
|
return {
|
|
"frame": frame,
|
|
"track_id": track_id,
|
|
"label": label,
|
|
"confidence": confidence,
|
|
"window": window if isinstance(window, int) else window[1],
|
|
"timestamp_ns": timestamp_ns
|
|
if timestamp_ns is not None
|
|
else time.monotonic_ns(),
|
|
}
|