feat(demo): implement ScoNet real-time pipeline runtime
Add the full demo runtime stack for single-person scoliosis inference, including input adapters, silhouette preprocessing, temporal windowing, ScoNet wrapper, result publishing, and click-based CLI orchestration. This commit captures the executable pipeline behavior independently from tests and planning artifacts for clearer review and rollback.
This commit is contained in:
@@ -0,0 +1,7 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from .pipeline import main
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,203 @@
|
|||||||
|
"""
|
||||||
|
Input adapters for OpenGait demo.
|
||||||
|
|
||||||
|
Provides generator-based interfaces for video sources:
|
||||||
|
- OpenCV (video files, cameras)
|
||||||
|
- cv-mmap (shared memory streams)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterator, Generator, Iterable
|
||||||
|
from typing import TYPE_CHECKING, Protocol, cast
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Type alias for frame stream: (frame_array, metadata_dict)
|
||||||
|
FrameStream = Iterable[tuple[np.ndarray, dict[str, object]]]
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
# Protocol for cv-mmap metadata to avoid direct import
|
||||||
|
class _FrameMetadata(Protocol):
|
||||||
|
frame_count: int
|
||||||
|
timestamp_ns: int
|
||||||
|
|
||||||
|
# Protocol for cv-mmap client
|
||||||
|
class _CvMmapClient(Protocol):
|
||||||
|
def __aiter__(self) -> AsyncIterator[tuple[np.ndarray, _FrameMetadata]]: ...
|
||||||
|
|
||||||
|
|
||||||
|
def opencv_source(
|
||||||
|
path: str | int, max_frames: int | None = None
|
||||||
|
) -> Generator[tuple[np.ndarray, dict[str, object]], None, None]:
|
||||||
|
"""
|
||||||
|
Generator that yields frames from an OpenCV video source.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : str | int
|
||||||
|
Video file path or camera index (e.g., 0 for default camera)
|
||||||
|
max_frames : int | None, optional
|
||||||
|
Maximum number of frames to yield. None means unlimited.
|
||||||
|
|
||||||
|
Yields
|
||||||
|
------
|
||||||
|
tuple[np.ndarray, dict[str, object]]
|
||||||
|
(frame_array, metadata_dict) where metadata includes:
|
||||||
|
- frame_count: frame index (0-based)
|
||||||
|
- timestamp_ns: monotonic timestamp in nanoseconds (if available)
|
||||||
|
- source: the path/int provided
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
cap = cv2.VideoCapture(path)
|
||||||
|
if not cap.isOpened():
|
||||||
|
raise RuntimeError(f"Failed to open video source: {path}")
|
||||||
|
|
||||||
|
frame_idx = 0
|
||||||
|
try:
|
||||||
|
while max_frames is None or frame_idx < max_frames:
|
||||||
|
ret, frame = cap.read()
|
||||||
|
if not ret:
|
||||||
|
# End of stream
|
||||||
|
break
|
||||||
|
|
||||||
|
# Get timestamp if available (some backends support this)
|
||||||
|
timestamp_ns = time.monotonic_ns()
|
||||||
|
|
||||||
|
metadata: dict[str, object] = {
|
||||||
|
"frame_count": frame_idx,
|
||||||
|
"timestamp_ns": timestamp_ns,
|
||||||
|
"source": path,
|
||||||
|
}
|
||||||
|
|
||||||
|
yield frame, metadata
|
||||||
|
frame_idx += 1
|
||||||
|
|
||||||
|
finally:
|
||||||
|
cap.release()
|
||||||
|
logger.debug(f"OpenCV source closed: {path}")
|
||||||
|
|
||||||
|
|
||||||
|
def cvmmap_source(
|
||||||
|
name: str, max_frames: int | None = None
|
||||||
|
) -> Generator[tuple[np.ndarray, dict[str, object]], None, None]:
|
||||||
|
"""
|
||||||
|
Generator that yields frames from a cv-mmap shared memory stream.
|
||||||
|
|
||||||
|
Bridges async cv-mmap client to synchronous generator using asyncio.run().
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
name : str
|
||||||
|
Base name of the cv-mmap source (e.g., "default")
|
||||||
|
max_frames : int | None, optional
|
||||||
|
Maximum number of frames to yield. None means unlimited.
|
||||||
|
|
||||||
|
Yields
|
||||||
|
------
|
||||||
|
tuple[np.ndarray, dict[str, object]]
|
||||||
|
(frame_array, metadata_dict) where metadata includes:
|
||||||
|
- frame_count: frame index from cv-mmap
|
||||||
|
- timestamp_ns: timestamp in nanoseconds from cv-mmap
|
||||||
|
- source: the cv-mmap name
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ImportError
|
||||||
|
If cvmmap package is not available
|
||||||
|
RuntimeError
|
||||||
|
If cv-mmap stream disconnects or errors
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Import cvmmap only when function is called
|
||||||
|
# Use try/except for runtime import check
|
||||||
|
try:
|
||||||
|
|
||||||
|
from cvmmap import CvMmapClient as _CvMmapClientReal # pyright: ignore[reportMissingTypeStubs]
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"cvmmap package is required for cv-mmap sources. "
|
||||||
|
+ "Install from: https://github.com/crosstyan/cv-mmap"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
# Cast to protocol type for type checking
|
||||||
|
client: _CvMmapClient = cast("_CvMmapClient", _CvMmapClientReal(name))
|
||||||
|
frame_count = 0
|
||||||
|
|
||||||
|
async def _async_generator() -> AsyncIterator[tuple[np.ndarray, _FrameMetadata]]:
|
||||||
|
"""Async generator wrapper."""
|
||||||
|
async for frame, meta in client:
|
||||||
|
yield frame, meta
|
||||||
|
|
||||||
|
# Bridge async to sync using asyncio.run()
|
||||||
|
# We process frames one at a time to keep it simple and robust
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
try:
|
||||||
|
agen = _async_generator().__aiter__()
|
||||||
|
|
||||||
|
while max_frames is None or frame_count < max_frames:
|
||||||
|
try:
|
||||||
|
frame, meta = loop.run_until_complete(agen.__anext__())
|
||||||
|
except StopAsyncIteration:
|
||||||
|
break
|
||||||
|
|
||||||
|
metadata: dict[str, object] = {
|
||||||
|
"frame_count": meta.frame_count,
|
||||||
|
"timestamp_ns": meta.timestamp_ns,
|
||||||
|
"source": f"cvmmap://{name}",
|
||||||
|
}
|
||||||
|
|
||||||
|
yield frame, metadata
|
||||||
|
frame_count += 1
|
||||||
|
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
logger.debug(f"cv-mmap source closed: {name}")
|
||||||
|
|
||||||
|
|
||||||
|
def create_source(source: str, max_frames: int | None = None) -> FrameStream:
|
||||||
|
"""
|
||||||
|
Factory function to create a frame source from a string specification.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
source : str
|
||||||
|
Source specification:
|
||||||
|
- '0', '1', etc. -> Camera index (OpenCV)
|
||||||
|
- 'cvmmap://name' -> cv-mmap shared memory stream
|
||||||
|
- Any other string -> Video file path (OpenCV)
|
||||||
|
max_frames : int | None, optional
|
||||||
|
Maximum number of frames to yield. None means unlimited.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
FrameStream
|
||||||
|
Generator yielding (frame, metadata) tuples
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> for frame, meta in create_source('0'): # Camera 0
|
||||||
|
... process(frame)
|
||||||
|
>>> for frame, meta in create_source('cvmmap://default'): # cv-mmap
|
||||||
|
... process(frame)
|
||||||
|
>>> for frame, meta in create_source('/path/to/video.mp4'):
|
||||||
|
... process(frame)
|
||||||
|
"""
|
||||||
|
# Check for cv-mmap protocol
|
||||||
|
if source.startswith("cvmmap://"):
|
||||||
|
name = source[len("cvmmap://") :]
|
||||||
|
return cvmmap_source(name, max_frames)
|
||||||
|
|
||||||
|
# Check for camera index (single digit string)
|
||||||
|
if source.isdigit():
|
||||||
|
return opencv_source(int(source), max_frames)
|
||||||
|
|
||||||
|
# Otherwise treat as file path
|
||||||
|
return opencv_source(source, max_frames)
|
||||||
@@ -0,0 +1,368 @@
|
|||||||
|
"""
|
||||||
|
Output publishers for OpenGait demo results.
|
||||||
|
|
||||||
|
Provides pluggable result publishing:
|
||||||
|
- ConsolePublisher: JSONL to stdout
|
||||||
|
- NatsPublisher: NATS message broker integration
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import TYPE_CHECKING, Protocol, TextIO, cast, runtime_checkable
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from types import TracebackType
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class ResultPublisher(Protocol):
|
||||||
|
"""Protocol for result publishers."""
|
||||||
|
|
||||||
|
def publish(self, result: dict[str, object]) -> None:
|
||||||
|
"""
|
||||||
|
Publish a result dictionary.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
result : dict[str, object]
|
||||||
|
Result data with keys: frame, track_id, label, confidence, window, timestamp_ns
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class ConsolePublisher:
|
||||||
|
"""Publisher that outputs JSON Lines to stdout."""
|
||||||
|
|
||||||
|
_output: TextIO
|
||||||
|
|
||||||
|
def __init__(self, output: TextIO = sys.stdout) -> None:
|
||||||
|
"""
|
||||||
|
Initialize console publisher.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
output : TextIO
|
||||||
|
File-like object to write to (default: sys.stdout)
|
||||||
|
"""
|
||||||
|
self._output = output
|
||||||
|
|
||||||
|
def publish(self, result: dict[str, object]) -> None:
|
||||||
|
"""
|
||||||
|
Publish result as JSON line.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
result : dict[str, object]
|
||||||
|
Result data with keys: frame, track_id, label, confidence, window, timestamp_ns
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
json_line = json.dumps(result, ensure_ascii=False, default=str)
|
||||||
|
_ = self._output.write(json_line + "\n")
|
||||||
|
self._output.flush()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to publish to console: {e}")
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the publisher (no-op for console)."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __enter__(self) -> ConsolePublisher:
|
||||||
|
"""Context manager entry."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(
|
||||||
|
self,
|
||||||
|
exc_type: type[BaseException] | None,
|
||||||
|
exc_val: BaseException | None,
|
||||||
|
exc_tb: TracebackType | None,
|
||||||
|
) -> None:
|
||||||
|
"""Context manager exit."""
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
|
||||||
|
class _NatsClient(Protocol):
|
||||||
|
"""Protocol for connected NATS client."""
|
||||||
|
|
||||||
|
async def publish(self, subject: str, payload: bytes) -> object: ...
|
||||||
|
|
||||||
|
async def close(self) -> object: ...
|
||||||
|
|
||||||
|
async def flush(self) -> object: ...
|
||||||
|
|
||||||
|
|
||||||
|
class NatsPublisher:
|
||||||
|
"""
|
||||||
|
Publisher that sends results to NATS message broker.
|
||||||
|
|
||||||
|
This is a sync-friendly wrapper around the async nats-py client.
|
||||||
|
Uses a background thread with dedicated event loop to bridge sync
|
||||||
|
publish calls to async NATS operations, making it safe to use in
|
||||||
|
both sync and async contexts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_nats_url: str
|
||||||
|
_subject: str
|
||||||
|
_nc: _NatsClient | None
|
||||||
|
_connected: bool
|
||||||
|
_loop: asyncio.AbstractEventLoop | None
|
||||||
|
_thread: threading.Thread | None
|
||||||
|
_lock: threading.Lock
|
||||||
|
|
||||||
|
def __init__(self, nats_url: str, subject: str = "scoliosis.result") -> None:
|
||||||
|
"""
|
||||||
|
Initialize NATS publisher.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
nats_url : str
|
||||||
|
NATS server URL (e.g., "nats://localhost:4222")
|
||||||
|
subject : str
|
||||||
|
NATS subject to publish to (default: "scoliosis.result")
|
||||||
|
"""
|
||||||
|
self._nats_url = nats_url
|
||||||
|
self._subject = subject
|
||||||
|
self._nc = None
|
||||||
|
self._connected = False
|
||||||
|
self._loop = None
|
||||||
|
self._thread = None
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
def _start_background_loop(self) -> bool:
|
||||||
|
"""
|
||||||
|
Start background thread with event loop for async operations.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bool
|
||||||
|
True if loop is running, False otherwise
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
if self._loop is not None and self._loop.is_running():
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
self._loop = loop
|
||||||
|
|
||||||
|
def run_loop() -> None:
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
loop.run_forever()
|
||||||
|
|
||||||
|
self._thread = threading.Thread(target=run_loop, daemon=True)
|
||||||
|
self._thread.start()
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to start background event loop: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _stop_background_loop(self) -> None:
|
||||||
|
"""Stop the background event loop and thread."""
|
||||||
|
with self._lock:
|
||||||
|
if self._loop is not None and self._loop.is_running():
|
||||||
|
_ = self._loop.call_soon_threadsafe(self._loop.stop)
|
||||||
|
if self._thread is not None and self._thread.is_alive():
|
||||||
|
self._thread.join(timeout=2.0)
|
||||||
|
self._loop = None
|
||||||
|
self._thread = None
|
||||||
|
|
||||||
|
def _ensure_connected(self) -> bool:
|
||||||
|
"""
|
||||||
|
Ensure connection to NATS server.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bool
|
||||||
|
True if connected, False otherwise
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
if self._connected and self._nc is not None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if not self._start_background_loop():
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import nats
|
||||||
|
|
||||||
|
async def _connect() -> _NatsClient:
|
||||||
|
nc = await nats.connect(self._nats_url) # pyright: ignore[reportUnknownMemberType]
|
||||||
|
return cast(_NatsClient, nc)
|
||||||
|
|
||||||
|
# Run connection in background loop
|
||||||
|
future = asyncio.run_coroutine_threadsafe(
|
||||||
|
_connect(),
|
||||||
|
self._loop, # pyright: ignore[reportArgumentType]
|
||||||
|
)
|
||||||
|
self._nc = future.result(timeout=10.0)
|
||||||
|
self._connected = True
|
||||||
|
logger.info(f"Connected to NATS at {self._nats_url}")
|
||||||
|
return True
|
||||||
|
except ImportError:
|
||||||
|
logger.warning(
|
||||||
|
"nats-py package not installed. Install with: pip install nats-py"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to connect to NATS at {self._nats_url}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def publish(self, result: dict[str, object]) -> None:
|
||||||
|
"""
|
||||||
|
Publish result to NATS subject.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
result : dict[str, object]
|
||||||
|
Result data with keys: frame, track_id, label, confidence, window, timestamp_ns
|
||||||
|
"""
|
||||||
|
if not self._ensure_connected():
|
||||||
|
# Graceful degradation: log warning but don't crash
|
||||||
|
logger.debug(
|
||||||
|
f"NATS unavailable, dropping result: {result.get('track_id', 'unknown')}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
async def _publish() -> None:
|
||||||
|
if self._nc is not None:
|
||||||
|
payload = json.dumps(
|
||||||
|
result, ensure_ascii=False, default=str
|
||||||
|
).encode("utf-8")
|
||||||
|
_ = await self._nc.publish(self._subject, payload)
|
||||||
|
_ = await self._nc.flush()
|
||||||
|
# Run publish in background loop
|
||||||
|
future = asyncio.run_coroutine_threadsafe(
|
||||||
|
_publish(),
|
||||||
|
self._loop, # pyright: ignore[reportArgumentType]
|
||||||
|
)
|
||||||
|
future.result(timeout=5.0) # Wait for publish to complete
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to publish to NATS: {e}")
|
||||||
|
self._connected = False # Mark for reconnection on next publish
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close NATS connection."""
|
||||||
|
with self._lock:
|
||||||
|
if self._nc is not None and self._connected and self._loop is not None:
|
||||||
|
try:
|
||||||
|
|
||||||
|
async def _close() -> None:
|
||||||
|
if self._nc is not None:
|
||||||
|
_ = await self._nc.close()
|
||||||
|
|
||||||
|
future = asyncio.run_coroutine_threadsafe(
|
||||||
|
_close(),
|
||||||
|
self._loop,
|
||||||
|
)
|
||||||
|
future.result(timeout=5.0)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error closing NATS connection: {e}")
|
||||||
|
finally:
|
||||||
|
self._nc = None
|
||||||
|
self._connected = False
|
||||||
|
|
||||||
|
self._stop_background_loop()
|
||||||
|
|
||||||
|
def __enter__(self) -> NatsPublisher:
|
||||||
|
"""Context manager entry."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(
|
||||||
|
self,
|
||||||
|
exc_type: type[BaseException] | None,
|
||||||
|
exc_val: BaseException | None,
|
||||||
|
exc_tb: TracebackType | None,
|
||||||
|
) -> None:
|
||||||
|
"""Context manager exit."""
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
|
||||||
|
def create_publisher(
|
||||||
|
nats_url: str | None,
|
||||||
|
subject: str = "scoliosis.result",
|
||||||
|
) -> ResultPublisher:
|
||||||
|
"""
|
||||||
|
Factory function to create appropriate publisher.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
nats_url : str | None
|
||||||
|
NATS server URL. If None or empty, returns ConsolePublisher.
|
||||||
|
subject : str
|
||||||
|
NATS subject to publish to (default: "scoliosis.result")
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
ResultPublisher
|
||||||
|
NatsPublisher if nats_url provided, otherwise ConsolePublisher
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> # Console output (default)
|
||||||
|
>>> pub = create_publisher(None)
|
||||||
|
>>> pub.publish({"frame": 1, "track_id": 42, "label": "normal", "confidence": 0.95, "window": 30, "timestamp_ns": 1234567890})
|
||||||
|
>>>
|
||||||
|
>>> # NATS output
|
||||||
|
>>> pub = create_publisher("nats://localhost:4222")
|
||||||
|
>>> pub.publish({"frame": 1, "track_id": 42, "label": "normal", "confidence": 0.95, "window": 30, "timestamp_ns": 1234567890})
|
||||||
|
>>>
|
||||||
|
>>> # Context manager usage
|
||||||
|
>>> with create_publisher("nats://localhost:4222") as pub:
|
||||||
|
... pub.publish(result)
|
||||||
|
"""
|
||||||
|
if nats_url:
|
||||||
|
return NatsPublisher(nats_url, subject)
|
||||||
|
return ConsolePublisher()
|
||||||
|
|
||||||
|
|
||||||
|
def create_result(
|
||||||
|
frame: int,
|
||||||
|
track_id: int,
|
||||||
|
label: str,
|
||||||
|
confidence: float,
|
||||||
|
window: int | tuple[int, int],
|
||||||
|
timestamp_ns: int | None = None,
|
||||||
|
) -> dict[str, object]:
|
||||||
|
"""
|
||||||
|
Create a standardized result dictionary.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
frame : int
|
||||||
|
Frame number
|
||||||
|
track_id : int
|
||||||
|
Track/person identifier
|
||||||
|
label : str
|
||||||
|
Classification label (e.g., "normal", "scoliosis")
|
||||||
|
confidence : float
|
||||||
|
Confidence score (0.0 to 1.0)
|
||||||
|
window : int | tuple[int, int]
|
||||||
|
Frame window as int (end frame) or tuple [start, end] that produced this result
|
||||||
|
Frame window [start, end] that produced this result
|
||||||
|
timestamp_ns : int | None
|
||||||
|
Timestamp in nanoseconds. If None, uses current time.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
dict[str, object]
|
||||||
|
Standardized result dictionary
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"frame": frame,
|
||||||
|
"track_id": track_id,
|
||||||
|
"label": label,
|
||||||
|
"confidence": confidence,
|
||||||
|
"window": window if isinstance(window, int) else window[1],
|
||||||
|
"timestamp_ns": timestamp_ns
|
||||||
|
if timestamp_ns is not None
|
||||||
|
else time.monotonic_ns(),
|
||||||
|
}
|
||||||
@@ -0,0 +1,325 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from contextlib import suppress
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
import time
|
||||||
|
from typing import Protocol, cast
|
||||||
|
|
||||||
|
from beartype import beartype
|
||||||
|
import click
|
||||||
|
import jaxtyping
|
||||||
|
from jaxtyping import Float, UInt8
|
||||||
|
import numpy as np
|
||||||
|
from numpy import ndarray
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
from ultralytics.models.yolo.model import YOLO
|
||||||
|
|
||||||
|
from .input import FrameStream, create_source
|
||||||
|
from .output import ResultPublisher, create_publisher, create_result
|
||||||
|
from .preprocess import frame_to_person_mask, mask_to_silhouette
|
||||||
|
from .sconet_demo import ScoNetDemo
|
||||||
|
from .window import SilhouetteWindow, select_person
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
|
||||||
|
JaxtypedFactory = Callable[..., JaxtypedDecorator]
|
||||||
|
jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
|
||||||
|
|
||||||
|
|
||||||
|
class _BoxesLike(Protocol):
|
||||||
|
@property
|
||||||
|
def xyxy(self) -> NDArray[np.float32] | object: ...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def id(self) -> NDArray[np.int64] | object | None: ...
|
||||||
|
|
||||||
|
|
||||||
|
class _MasksLike(Protocol):
|
||||||
|
@property
|
||||||
|
def data(self) -> NDArray[np.float32] | object: ...
|
||||||
|
|
||||||
|
|
||||||
|
class _DetectionResultsLike(Protocol):
|
||||||
|
@property
|
||||||
|
def boxes(self) -> _BoxesLike: ...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def masks(self) -> _MasksLike: ...
|
||||||
|
|
||||||
|
|
||||||
|
class _TrackCallable(Protocol):
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
source: object,
|
||||||
|
*,
|
||||||
|
persist: bool = True,
|
||||||
|
verbose: bool = False,
|
||||||
|
device: str | None = None,
|
||||||
|
classes: list[int] | None = None,
|
||||||
|
) -> object: ...
|
||||||
|
|
||||||
|
|
||||||
|
class ScoliosisPipeline:
|
||||||
|
_detector: object
|
||||||
|
_source: FrameStream
|
||||||
|
_window: SilhouetteWindow
|
||||||
|
_publisher: ResultPublisher
|
||||||
|
_classifier: ScoNetDemo
|
||||||
|
_device: str
|
||||||
|
_closed: bool
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
source: str,
|
||||||
|
checkpoint: str,
|
||||||
|
config: str,
|
||||||
|
device: str,
|
||||||
|
yolo_model: str,
|
||||||
|
window: int,
|
||||||
|
stride: int,
|
||||||
|
nats_url: str | None,
|
||||||
|
nats_subject: str,
|
||||||
|
max_frames: int | None,
|
||||||
|
) -> None:
|
||||||
|
self._detector = YOLO(yolo_model)
|
||||||
|
self._source = create_source(source, max_frames=max_frames)
|
||||||
|
self._window = SilhouetteWindow(window_size=window, stride=stride)
|
||||||
|
self._publisher = create_publisher(nats_url=nats_url, subject=nats_subject)
|
||||||
|
self._classifier = ScoNetDemo(
|
||||||
|
cfg_path=config,
|
||||||
|
checkpoint_path=checkpoint,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self._device = device
|
||||||
|
self._closed = False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_int(meta: dict[str, object], key: str, fallback: int) -> int:
|
||||||
|
value = meta.get(key)
|
||||||
|
if isinstance(value, int):
|
||||||
|
return value
|
||||||
|
return fallback
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_timestamp(meta: dict[str, object]) -> int:
|
||||||
|
value = meta.get("timestamp_ns")
|
||||||
|
if isinstance(value, int):
|
||||||
|
return value
|
||||||
|
return time.monotonic_ns()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _to_mask_u8(mask: ndarray) -> UInt8[ndarray, "h w"]:
|
||||||
|
binary = np.where(np.asarray(mask) > 0.5, np.uint8(255), np.uint8(0)).astype(
|
||||||
|
np.uint8
|
||||||
|
)
|
||||||
|
return cast(UInt8[ndarray, "h w"], binary)
|
||||||
|
|
||||||
|
def _first_result(self, detections: object) -> _DetectionResultsLike | None:
|
||||||
|
if isinstance(detections, list):
|
||||||
|
return cast(_DetectionResultsLike, detections[0]) if detections else None
|
||||||
|
if isinstance(detections, tuple):
|
||||||
|
return cast(_DetectionResultsLike, detections[0]) if detections else None
|
||||||
|
return cast(_DetectionResultsLike, detections)
|
||||||
|
|
||||||
|
def _select_silhouette(
|
||||||
|
self,
|
||||||
|
result: _DetectionResultsLike,
|
||||||
|
) -> tuple[Float[ndarray, "64 44"], int] | None:
|
||||||
|
selected = select_person(result)
|
||||||
|
if selected is not None:
|
||||||
|
mask_raw, bbox, track_id = selected
|
||||||
|
silhouette = cast(
|
||||||
|
Float[ndarray, "64 44"] | None,
|
||||||
|
mask_to_silhouette(self._to_mask_u8(mask_raw), bbox),
|
||||||
|
)
|
||||||
|
if silhouette is not None:
|
||||||
|
return silhouette, int(track_id)
|
||||||
|
|
||||||
|
fallback = cast(
|
||||||
|
tuple[UInt8[ndarray, "h w"], tuple[int, int, int, int]] | None,
|
||||||
|
frame_to_person_mask(result),
|
||||||
|
)
|
||||||
|
if fallback is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
mask_u8, bbox = fallback
|
||||||
|
silhouette = cast(
|
||||||
|
Float[ndarray, "64 44"] | None,
|
||||||
|
mask_to_silhouette(mask_u8, bbox),
|
||||||
|
)
|
||||||
|
if silhouette is None:
|
||||||
|
return None
|
||||||
|
return silhouette, 0
|
||||||
|
|
||||||
|
@jaxtyped(typechecker=beartype)
|
||||||
|
def process_frame(
|
||||||
|
self,
|
||||||
|
frame: UInt8[ndarray, "h w c"],
|
||||||
|
metadata: dict[str, object],
|
||||||
|
) -> dict[str, object] | None:
|
||||||
|
frame_idx = self._extract_int(metadata, "frame_count", fallback=0)
|
||||||
|
timestamp_ns = self._extract_timestamp(metadata)
|
||||||
|
|
||||||
|
track_fn_obj = getattr(self._detector, "track", None)
|
||||||
|
if not callable(track_fn_obj):
|
||||||
|
raise RuntimeError("YOLO detector does not expose a callable track()")
|
||||||
|
|
||||||
|
track_fn = cast(_TrackCallable, track_fn_obj)
|
||||||
|
detections = track_fn(
|
||||||
|
frame,
|
||||||
|
persist=True,
|
||||||
|
verbose=False,
|
||||||
|
device=self._device,
|
||||||
|
classes=[0],
|
||||||
|
)
|
||||||
|
first = self._first_result(detections)
|
||||||
|
if first is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
selected = self._select_silhouette(first)
|
||||||
|
if selected is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
silhouette, track_id = selected
|
||||||
|
self._window.push(silhouette, frame_idx=frame_idx, track_id=track_id)
|
||||||
|
|
||||||
|
if not self._window.should_classify():
|
||||||
|
return None
|
||||||
|
|
||||||
|
window_tensor = self._window.get_tensor(device=self._device)
|
||||||
|
label, confidence = cast(
|
||||||
|
tuple[str, float],
|
||||||
|
self._classifier.predict(window_tensor),
|
||||||
|
)
|
||||||
|
self._window.mark_classified()
|
||||||
|
|
||||||
|
window_start = frame_idx - self._window.window_size + 1
|
||||||
|
result = create_result(
|
||||||
|
frame=frame_idx,
|
||||||
|
track_id=track_id,
|
||||||
|
label=label,
|
||||||
|
confidence=float(confidence),
|
||||||
|
window=(max(0, window_start), frame_idx),
|
||||||
|
timestamp_ns=timestamp_ns,
|
||||||
|
)
|
||||||
|
self._publisher.publish(result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def run(self) -> int:
|
||||||
|
frame_count = 0
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
try:
|
||||||
|
for item in self._source:
|
||||||
|
frame, metadata = item
|
||||||
|
frame_u8 = np.asarray(frame, dtype=np.uint8)
|
||||||
|
frame_idx = self._extract_int(metadata, "frame_count", fallback=0)
|
||||||
|
frame_count += 1
|
||||||
|
try:
|
||||||
|
_ = self.process_frame(frame_u8, metadata)
|
||||||
|
except Exception as frame_error:
|
||||||
|
logger.warning(
|
||||||
|
"Skipping frame %d due to processing error: %s",
|
||||||
|
frame_idx,
|
||||||
|
frame_error,
|
||||||
|
)
|
||||||
|
if frame_count % 100 == 0:
|
||||||
|
elapsed = time.perf_counter() - start_time
|
||||||
|
fps = frame_count / elapsed if elapsed > 0 else 0.0
|
||||||
|
logger.info("Processed %d frames (%.2f FPS)", frame_count, fps)
|
||||||
|
return 0
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Interrupted by user, shutting down cleanly.")
|
||||||
|
return 130
|
||||||
|
finally:
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
if self._closed:
|
||||||
|
return
|
||||||
|
close_fn = getattr(self._publisher, "close", None)
|
||||||
|
if callable(close_fn):
|
||||||
|
with suppress(Exception):
|
||||||
|
_ = close_fn()
|
||||||
|
self._closed = True
|
||||||
|
|
||||||
|
|
||||||
|
def validate_runtime_inputs(source: str, checkpoint: str, config: str) -> None:
|
||||||
|
if source.startswith("cvmmap://") or source.isdigit():
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
source_path = Path(source)
|
||||||
|
if not source_path.is_file():
|
||||||
|
raise ValueError(f"Video source not found: {source}")
|
||||||
|
|
||||||
|
checkpoint_path = Path(checkpoint)
|
||||||
|
if not checkpoint_path.is_file():
|
||||||
|
raise ValueError(f"Checkpoint not found: {checkpoint}")
|
||||||
|
|
||||||
|
config_path = Path(config)
|
||||||
|
if not config_path.is_file():
|
||||||
|
raise ValueError(f"Config not found: {config}")
|
||||||
|
|
||||||
|
|
||||||
|
@click.command(context_settings={"help_option_names": ["-h", "--help"]})
|
||||||
|
@click.option("--source", type=str, required=True)
|
||||||
|
@click.option("--checkpoint", type=str, required=True)
|
||||||
|
@click.option(
|
||||||
|
"--config",
|
||||||
|
type=str,
|
||||||
|
default="configs/sconet/sconet_scoliosis1k.yaml",
|
||||||
|
show_default=True,
|
||||||
|
)
|
||||||
|
@click.option("--device", type=str, default="cuda:0", show_default=True)
|
||||||
|
@click.option("--yolo-model", type=str, default="yolo11n-seg.pt", show_default=True)
|
||||||
|
@click.option("--window", type=click.IntRange(min=1), default=30, show_default=True)
|
||||||
|
@click.option("--stride", type=click.IntRange(min=1), default=30, show_default=True)
|
||||||
|
@click.option("--nats-url", type=str, default=None)
|
||||||
|
@click.option(
|
||||||
|
"--nats-subject",
|
||||||
|
type=str,
|
||||||
|
default="scoliosis.result",
|
||||||
|
show_default=True,
|
||||||
|
)
|
||||||
|
@click.option("--max-frames", type=click.IntRange(min=1), default=None)
|
||||||
|
def main(
|
||||||
|
source: str,
|
||||||
|
checkpoint: str,
|
||||||
|
config: str,
|
||||||
|
device: str,
|
||||||
|
yolo_model: str,
|
||||||
|
window: int,
|
||||||
|
stride: int,
|
||||||
|
nats_url: str | None,
|
||||||
|
nats_subject: str,
|
||||||
|
max_frames: int | None,
|
||||||
|
) -> None:
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
validate_runtime_inputs(source=source, checkpoint=checkpoint, config=config)
|
||||||
|
pipeline = ScoliosisPipeline(
|
||||||
|
source=source,
|
||||||
|
checkpoint=checkpoint,
|
||||||
|
config=config,
|
||||||
|
device=device,
|
||||||
|
yolo_model=yolo_model,
|
||||||
|
window=window,
|
||||||
|
stride=stride,
|
||||||
|
nats_url=nats_url,
|
||||||
|
nats_subject=nats_subject,
|
||||||
|
max_frames=max_frames,
|
||||||
|
)
|
||||||
|
raise SystemExit(pipeline.run())
|
||||||
|
except ValueError as err:
|
||||||
|
click.echo(f"Error: {err}", err=True)
|
||||||
|
raise SystemExit(2) from err
|
||||||
|
except RuntimeError as err:
|
||||||
|
click.echo(f"Runtime error: {err}", err=True)
|
||||||
|
raise SystemExit(1) from err
|
||||||
@@ -0,0 +1,270 @@
|
|||||||
|
from collections.abc import Callable
|
||||||
|
import math
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
from beartype import beartype
|
||||||
|
import jaxtyping
|
||||||
|
from jaxtyping import Float, UInt8
|
||||||
|
import numpy as np
|
||||||
|
from numpy import ndarray
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
|
SIL_HEIGHT = 64
|
||||||
|
SIL_WIDTH = 44
|
||||||
|
SIL_FULL_WIDTH = 64
|
||||||
|
SIDE_CUT = 10
|
||||||
|
MIN_MASK_AREA = 500
|
||||||
|
|
||||||
|
JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
|
||||||
|
JaxtypedFactory = Callable[..., JaxtypedDecorator]
|
||||||
|
jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
|
||||||
|
|
||||||
|
UInt8Array = NDArray[np.uint8]
|
||||||
|
Float32Array = NDArray[np.float32]
|
||||||
|
|
||||||
|
|
||||||
|
def _read_attr(container: object, key: str) -> object | None:
|
||||||
|
if isinstance(container, dict):
|
||||||
|
dict_obj = cast(dict[object, object], container)
|
||||||
|
return dict_obj.get(key)
|
||||||
|
try:
|
||||||
|
return cast(object, object.__getattribute__(container, key))
|
||||||
|
except AttributeError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _to_numpy_array(value: object) -> NDArray[np.generic]:
|
||||||
|
current: object = value
|
||||||
|
if isinstance(current, np.ndarray):
|
||||||
|
return current
|
||||||
|
|
||||||
|
detach_obj = _read_attr(current, "detach")
|
||||||
|
if callable(detach_obj):
|
||||||
|
detach_fn = cast(Callable[[], object], detach_obj)
|
||||||
|
current = detach_fn()
|
||||||
|
|
||||||
|
cpu_obj = _read_attr(current, "cpu")
|
||||||
|
if callable(cpu_obj):
|
||||||
|
cpu_fn = cast(Callable[[], object], cpu_obj)
|
||||||
|
current = cpu_fn()
|
||||||
|
|
||||||
|
numpy_obj = _read_attr(current, "numpy")
|
||||||
|
if callable(numpy_obj):
|
||||||
|
numpy_fn = cast(Callable[[], object], numpy_obj)
|
||||||
|
as_numpy = numpy_fn()
|
||||||
|
if isinstance(as_numpy, np.ndarray):
|
||||||
|
return as_numpy
|
||||||
|
|
||||||
|
return cast(NDArray[np.generic], np.asarray(current))
|
||||||
|
|
||||||
|
|
||||||
|
def _bbox_from_mask(mask: UInt8[ndarray, "h w"]) -> tuple[int, int, int, int] | None:
|
||||||
|
mask_u8 = np.asarray(mask, dtype=np.uint8)
|
||||||
|
coords = np.argwhere(mask_u8 > 0)
|
||||||
|
if int(coords.size) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
ys = coords[:, 0].astype(np.int64)
|
||||||
|
xs = coords[:, 1].astype(np.int64)
|
||||||
|
x1 = int(np.min(xs))
|
||||||
|
x2 = int(np.max(xs)) + 1
|
||||||
|
y1 = int(np.min(ys))
|
||||||
|
y2 = int(np.max(ys)) + 1
|
||||||
|
if x2 <= x1 or y2 <= y1:
|
||||||
|
return None
|
||||||
|
return (x1, y1, x2, y2)
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_bbox(
|
||||||
|
bbox: tuple[int, int, int, int], height: int, width: int
|
||||||
|
) -> tuple[int, int, int, int] | None:
|
||||||
|
x1, y1, x2, y2 = bbox
|
||||||
|
x1c = max(0, min(int(x1), width - 1))
|
||||||
|
y1c = max(0, min(int(y1), height - 1))
|
||||||
|
x2c = max(0, min(int(x2), width))
|
||||||
|
y2c = max(0, min(int(y2), height))
|
||||||
|
if x2c <= x1c or y2c <= y1c:
|
||||||
|
return None
|
||||||
|
return (x1c, y1c, x2c, y2c)
|
||||||
|
|
||||||
|
|
||||||
|
@jaxtyped(typechecker=beartype)
|
||||||
|
def frame_to_person_mask(
|
||||||
|
result: object, min_area: int = MIN_MASK_AREA
|
||||||
|
) -> tuple[UInt8[ndarray, "h w"], tuple[int, int, int, int]] | None:
|
||||||
|
masks_obj = _read_attr(result, "masks")
|
||||||
|
if masks_obj is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
masks_data_obj = _read_attr(masks_obj, "data")
|
||||||
|
if masks_data_obj is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
masks_raw = _to_numpy_array(masks_data_obj)
|
||||||
|
masks_float = np.asarray(masks_raw, dtype=np.float32)
|
||||||
|
if masks_float.ndim == 2:
|
||||||
|
masks_float = masks_float[np.newaxis, ...]
|
||||||
|
if masks_float.ndim != 3:
|
||||||
|
return None
|
||||||
|
mask_count = int(cast(tuple[int, int, int], masks_float.shape)[0])
|
||||||
|
if mask_count <= 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
box_values: list[tuple[float, float, float, float]] | None = None
|
||||||
|
boxes_obj = _read_attr(result, "boxes")
|
||||||
|
if boxes_obj is not None:
|
||||||
|
xyxy_obj = _read_attr(boxes_obj, "xyxy")
|
||||||
|
if xyxy_obj is not None:
|
||||||
|
xyxy_raw = np.asarray(_to_numpy_array(xyxy_obj), dtype=np.float32)
|
||||||
|
if xyxy_raw.ndim == 1 and int(xyxy_raw.size) >= 4:
|
||||||
|
xyxy_2d = np.asarray(xyxy_raw[:4].reshape(1, 4), dtype=np.float64)
|
||||||
|
x1f = cast(np.float64, xyxy_2d[0, 0])
|
||||||
|
y1f = cast(np.float64, xyxy_2d[0, 1])
|
||||||
|
x2f = cast(np.float64, xyxy_2d[0, 2])
|
||||||
|
y2f = cast(np.float64, xyxy_2d[0, 3])
|
||||||
|
box_values = [
|
||||||
|
(
|
||||||
|
float(x1f),
|
||||||
|
float(y1f),
|
||||||
|
float(x2f),
|
||||||
|
float(y2f),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
elif xyxy_raw.ndim == 2:
|
||||||
|
shape_2d = cast(tuple[int, int], xyxy_raw.shape)
|
||||||
|
if int(shape_2d[1]) >= 4:
|
||||||
|
xyxy_2d = np.asarray(xyxy_raw[:, :4], dtype=np.float64)
|
||||||
|
box_values = []
|
||||||
|
for row_idx in range(int(cast(tuple[int, int], xyxy_2d.shape)[0])):
|
||||||
|
x1f = cast(np.float64, xyxy_2d[row_idx, 0])
|
||||||
|
y1f = cast(np.float64, xyxy_2d[row_idx, 1])
|
||||||
|
x2f = cast(np.float64, xyxy_2d[row_idx, 2])
|
||||||
|
y2f = cast(np.float64, xyxy_2d[row_idx, 3])
|
||||||
|
box_values.append(
|
||||||
|
(
|
||||||
|
float(x1f),
|
||||||
|
float(y1f),
|
||||||
|
float(x2f),
|
||||||
|
float(y2f),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
best_area = -1
|
||||||
|
best_mask: UInt8[ndarray, "h w"] | None = None
|
||||||
|
best_bbox: tuple[int, int, int, int] | None = None
|
||||||
|
|
||||||
|
for idx in range(mask_count):
|
||||||
|
mask_float = np.asarray(masks_float[idx], dtype=np.float32)
|
||||||
|
if mask_float.ndim != 2:
|
||||||
|
continue
|
||||||
|
mask_binary = np.where(mask_float > 0.5, np.uint8(255), np.uint8(0)).astype(
|
||||||
|
np.uint8
|
||||||
|
)
|
||||||
|
mask_u8 = cast(UInt8[ndarray, "h w"], mask_binary)
|
||||||
|
|
||||||
|
area = int(np.count_nonzero(mask_u8))
|
||||||
|
if area < min_area:
|
||||||
|
continue
|
||||||
|
|
||||||
|
bbox: tuple[int, int, int, int] | None = None
|
||||||
|
shape_2d = cast(tuple[int, int], mask_binary.shape)
|
||||||
|
h = int(shape_2d[0])
|
||||||
|
w = int(shape_2d[1])
|
||||||
|
if box_values is not None:
|
||||||
|
box_count = len(box_values)
|
||||||
|
if idx >= box_count:
|
||||||
|
continue
|
||||||
|
row0, row1, row2, row3 = box_values[idx]
|
||||||
|
bbox_candidate = (
|
||||||
|
int(math.floor(row0)),
|
||||||
|
int(math.floor(row1)),
|
||||||
|
int(math.ceil(row2)),
|
||||||
|
int(math.ceil(row3)),
|
||||||
|
)
|
||||||
|
bbox = _sanitize_bbox(bbox_candidate, h, w)
|
||||||
|
|
||||||
|
if bbox is None:
|
||||||
|
bbox = _bbox_from_mask(mask_u8)
|
||||||
|
|
||||||
|
if bbox is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if area > best_area:
|
||||||
|
best_area = area
|
||||||
|
best_mask = mask_u8
|
||||||
|
best_bbox = bbox
|
||||||
|
|
||||||
|
if best_mask is None or best_bbox is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return best_mask, best_bbox
|
||||||
|
|
||||||
|
|
||||||
|
@jaxtyped(typechecker=beartype)
|
||||||
|
def mask_to_silhouette(
|
||||||
|
mask: UInt8[ndarray, "h w"],
|
||||||
|
bbox: tuple[int, int, int, int],
|
||||||
|
) -> Float[ndarray, "64 44"] | None:
|
||||||
|
mask_u8 = np.where(mask > 0, np.uint8(255), np.uint8(0)).astype(np.uint8)
|
||||||
|
if int(np.count_nonzero(mask_u8)) < MIN_MASK_AREA:
|
||||||
|
return None
|
||||||
|
|
||||||
|
mask_shape = cast(tuple[int, int], mask_u8.shape)
|
||||||
|
h = int(mask_shape[0])
|
||||||
|
w = int(mask_shape[1])
|
||||||
|
bbox_sanitized = _sanitize_bbox(bbox, h, w)
|
||||||
|
if bbox_sanitized is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
x1, y1, x2, y2 = bbox_sanitized
|
||||||
|
cropped = mask_u8[y1:y2, x1:x2]
|
||||||
|
if cropped.size == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
cropped_u8 = np.asarray(cropped, dtype=np.uint8)
|
||||||
|
row_sums = np.sum(cropped_u8, axis=1, dtype=np.int64)
|
||||||
|
row_nonzero = np.nonzero(row_sums > 0)[0].astype(np.int64)
|
||||||
|
if int(row_nonzero.size) == 0:
|
||||||
|
return None
|
||||||
|
top = int(cast(np.int64, row_nonzero[0]))
|
||||||
|
bottom = int(cast(np.int64, row_nonzero[-1])) + 1
|
||||||
|
tightened = cropped[top:bottom, :]
|
||||||
|
if tightened.size == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
tight_shape = cast(tuple[int, int], tightened.shape)
|
||||||
|
tight_h = int(tight_shape[0])
|
||||||
|
tight_w = int(tight_shape[1])
|
||||||
|
if tight_h <= 0 or tight_w <= 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
resized_w = max(1, int(SIL_HEIGHT * (tight_w / tight_h)))
|
||||||
|
resized = np.asarray(
|
||||||
|
cv2.resize(tightened, (resized_w, SIL_HEIGHT), interpolation=cv2.INTER_CUBIC),
|
||||||
|
dtype=np.uint8,
|
||||||
|
)
|
||||||
|
|
||||||
|
if resized_w >= SIL_FULL_WIDTH:
|
||||||
|
start = (resized_w - SIL_FULL_WIDTH) // 2
|
||||||
|
normalized_64 = resized[:, start : start + SIL_FULL_WIDTH]
|
||||||
|
else:
|
||||||
|
pad_left = (SIL_FULL_WIDTH - resized_w) // 2
|
||||||
|
pad_right = SIL_FULL_WIDTH - resized_w - pad_left
|
||||||
|
normalized_64 = np.pad(
|
||||||
|
resized,
|
||||||
|
((0, 0), (pad_left, pad_right)),
|
||||||
|
mode="constant",
|
||||||
|
constant_values=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
silhouette = np.asarray(
|
||||||
|
normalized_64[:, SIDE_CUT : SIL_FULL_WIDTH - SIDE_CUT], dtype=np.float32
|
||||||
|
)
|
||||||
|
if silhouette.shape != (SIL_HEIGHT, SIL_WIDTH):
|
||||||
|
return None
|
||||||
|
|
||||||
|
silhouette_norm = np.clip(silhouette / np.float32(255.0), 0.0, 1.0).astype(
|
||||||
|
np.float32
|
||||||
|
)
|
||||||
|
return cast(Float[ndarray, "64 44"], silhouette_norm)
|
||||||
@@ -0,0 +1,317 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from pathlib import Path
|
||||||
|
import sys
|
||||||
|
from typing import ClassVar, Protocol, cast, override
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from beartype import beartype
|
||||||
|
from einops import rearrange
|
||||||
|
from jaxtyping import Float
|
||||||
|
import jaxtyping
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
_OPENGAIT_PACKAGE_ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
if str(_OPENGAIT_PACKAGE_ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(_OPENGAIT_PACKAGE_ROOT))
|
||||||
|
|
||||||
|
from opengait.modeling.backbones.resnet import ResNet9
|
||||||
|
from opengait.modeling.modules import (
|
||||||
|
HorizontalPoolingPyramid,
|
||||||
|
PackSequenceWrapper as TemporalPool,
|
||||||
|
SeparateBNNecks,
|
||||||
|
SeparateFCs,
|
||||||
|
)
|
||||||
|
from opengait.utils import common as common_utils
|
||||||
|
|
||||||
|
|
||||||
|
JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
|
||||||
|
JaxtypedFactory = Callable[..., JaxtypedDecorator]
|
||||||
|
jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
|
||||||
|
ConfigLoader = Callable[[str], dict[str, object]]
|
||||||
|
config_loader = cast(ConfigLoader, common_utils.config_loader)
|
||||||
|
|
||||||
|
|
||||||
|
class TemporalPoolLike(Protocol):
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
seqs: Tensor,
|
||||||
|
seqL: object,
|
||||||
|
dim: int = 2,
|
||||||
|
options: dict[str, int] | None = None,
|
||||||
|
) -> object: ...
|
||||||
|
|
||||||
|
|
||||||
|
class HppLike(Protocol):
|
||||||
|
def __call__(self, x: Tensor) -> Tensor: ...
|
||||||
|
|
||||||
|
|
||||||
|
class FCsLike(Protocol):
|
||||||
|
def __call__(self, x: Tensor) -> Tensor: ...
|
||||||
|
|
||||||
|
|
||||||
|
class BNNecksLike(Protocol):
|
||||||
|
def __call__(self, x: Tensor) -> tuple[Tensor, Tensor]: ...
|
||||||
|
|
||||||
|
|
||||||
|
class ScoNetDemo(nn.Module):
|
||||||
|
LABEL_MAP: ClassVar[dict[int, str]] = {0: "negative", 1: "neutral", 2: "positive"}
|
||||||
|
cfg_path: str
|
||||||
|
cfg: dict[str, object]
|
||||||
|
backbone: ResNet9
|
||||||
|
temporal_pool: TemporalPoolLike
|
||||||
|
hpp: HppLike
|
||||||
|
fcs: FCsLike
|
||||||
|
bn_necks: BNNecksLike
|
||||||
|
device: torch.device
|
||||||
|
|
||||||
|
@jaxtyped(typechecker=beartype)
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cfg_path: str | Path = "configs/sconet/sconet_scoliosis1k.yaml",
|
||||||
|
checkpoint_path: str | Path | None = None,
|
||||||
|
device: str | torch.device | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
resolved_cfg = self._resolve_path(cfg_path)
|
||||||
|
self.cfg_path = str(resolved_cfg)
|
||||||
|
|
||||||
|
self.cfg = config_loader(self.cfg_path)
|
||||||
|
|
||||||
|
model_cfg = self._extract_model_cfg(self.cfg)
|
||||||
|
backbone_cfg = self._extract_dict(model_cfg, "backbone_cfg")
|
||||||
|
|
||||||
|
if backbone_cfg.get("type") != "ResNet9":
|
||||||
|
raise ValueError(
|
||||||
|
"ScoNetDemo currently supports backbone type ResNet9 only."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.backbone = ResNet9(
|
||||||
|
block=self._extract_str(backbone_cfg, "block"),
|
||||||
|
channels=self._extract_int_list(backbone_cfg, "channels"),
|
||||||
|
in_channel=self._extract_int(backbone_cfg, "in_channel", default=1),
|
||||||
|
layers=self._extract_int_list(backbone_cfg, "layers"),
|
||||||
|
strides=self._extract_int_list(backbone_cfg, "strides"),
|
||||||
|
maxpool=self._extract_bool(backbone_cfg, "maxpool", default=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
fcs_cfg = self._extract_dict(model_cfg, "SeparateFCs")
|
||||||
|
bn_cfg = self._extract_dict(model_cfg, "SeparateBNNecks")
|
||||||
|
bin_num = self._extract_int_list(model_cfg, "bin_num")
|
||||||
|
|
||||||
|
self.temporal_pool = cast(TemporalPoolLike, TemporalPool(torch.max))
|
||||||
|
self.hpp = cast(HppLike, HorizontalPoolingPyramid(bin_num=bin_num))
|
||||||
|
self.fcs = cast(
|
||||||
|
FCsLike,
|
||||||
|
SeparateFCs(
|
||||||
|
parts_num=self._extract_int(fcs_cfg, "parts_num"),
|
||||||
|
in_channels=self._extract_int(fcs_cfg, "in_channels"),
|
||||||
|
out_channels=self._extract_int(fcs_cfg, "out_channels"),
|
||||||
|
norm=self._extract_bool(fcs_cfg, "norm", default=False),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.bn_necks = cast(
|
||||||
|
BNNecksLike,
|
||||||
|
SeparateBNNecks(
|
||||||
|
parts_num=self._extract_int(bn_cfg, "parts_num"),
|
||||||
|
in_channels=self._extract_int(bn_cfg, "in_channels"),
|
||||||
|
class_num=self._extract_int(bn_cfg, "class_num"),
|
||||||
|
norm=self._extract_bool(bn_cfg, "norm", default=True),
|
||||||
|
parallel_BN1d=self._extract_bool(bn_cfg, "parallel_BN1d", default=True),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.device = (
|
||||||
|
torch.device(device) if device is not None else torch.device("cpu")
|
||||||
|
)
|
||||||
|
_ = self.to(self.device)
|
||||||
|
|
||||||
|
if checkpoint_path is not None:
|
||||||
|
_ = self.load_checkpoint(checkpoint_path)
|
||||||
|
|
||||||
|
_ = self.eval()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_path(path: str | Path) -> Path:
|
||||||
|
candidate = Path(path)
|
||||||
|
if candidate.is_file():
|
||||||
|
return candidate
|
||||||
|
if candidate.is_absolute():
|
||||||
|
return candidate
|
||||||
|
repo_root = Path(__file__).resolve().parents[2]
|
||||||
|
return repo_root / candidate
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_model_cfg(cfg: dict[str, object]) -> dict[str, object]:
|
||||||
|
model_cfg_obj = cfg.get("model_cfg")
|
||||||
|
if not isinstance(model_cfg_obj, dict):
|
||||||
|
raise TypeError("model_cfg must be a dictionary.")
|
||||||
|
return cast(dict[str, object], model_cfg_obj)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_dict(container: dict[str, object], key: str) -> dict[str, object]:
|
||||||
|
value = container.get(key)
|
||||||
|
if not isinstance(value, dict):
|
||||||
|
raise TypeError(f"{key} must be a dictionary.")
|
||||||
|
return cast(dict[str, object], value)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_str(container: dict[str, object], key: str) -> str:
|
||||||
|
value = container.get(key)
|
||||||
|
if not isinstance(value, str):
|
||||||
|
raise TypeError(f"{key} must be a string.")
|
||||||
|
return value
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_int(
|
||||||
|
container: dict[str, object], key: str, default: int | None = None
|
||||||
|
) -> int:
|
||||||
|
value = container.get(key, default)
|
||||||
|
if not isinstance(value, int):
|
||||||
|
raise TypeError(f"{key} must be an int.")
|
||||||
|
return value
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_bool(
|
||||||
|
container: dict[str, object], key: str, default: bool | None = None
|
||||||
|
) -> bool:
|
||||||
|
value = container.get(key, default)
|
||||||
|
if not isinstance(value, bool):
|
||||||
|
raise TypeError(f"{key} must be a bool.")
|
||||||
|
return value
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_int_list(container: dict[str, object], key: str) -> list[int]:
|
||||||
|
value = container.get(key)
|
||||||
|
if not isinstance(value, list):
|
||||||
|
raise TypeError(f"{key} must be a list[int].")
|
||||||
|
values = cast(list[object], value)
|
||||||
|
if not all(isinstance(v, int) for v in values):
|
||||||
|
raise TypeError(f"{key} must be a list[int].")
|
||||||
|
return cast(list[int], values)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_state_dict(
|
||||||
|
state_dict_obj: dict[object, object],
|
||||||
|
) -> dict[str, Tensor]:
|
||||||
|
prefix_remap: tuple[tuple[str, str], ...] = (
|
||||||
|
("Backbone.forward_block.", "backbone."),
|
||||||
|
("FCs.", "fcs."),
|
||||||
|
("BNNecks.", "bn_necks."),
|
||||||
|
)
|
||||||
|
cleaned_state_dict: dict[str, Tensor] = {}
|
||||||
|
for key_obj, value_obj in state_dict_obj.items():
|
||||||
|
if not isinstance(key_obj, str):
|
||||||
|
raise TypeError("Checkpoint state_dict keys must be strings.")
|
||||||
|
if not isinstance(value_obj, Tensor):
|
||||||
|
raise TypeError("Checkpoint state_dict values must be torch.Tensor.")
|
||||||
|
key = key_obj[7:] if key_obj.startswith("module.") else key_obj
|
||||||
|
for source_prefix, target_prefix in prefix_remap:
|
||||||
|
if key.startswith(source_prefix):
|
||||||
|
key = f"{target_prefix}{key[len(source_prefix) :]}"
|
||||||
|
break
|
||||||
|
if key in cleaned_state_dict:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Checkpoint key normalization collision detected for key '{key}'."
|
||||||
|
)
|
||||||
|
cleaned_state_dict[key] = value_obj
|
||||||
|
return cleaned_state_dict
|
||||||
|
|
||||||
|
@jaxtyped(typechecker=beartype)
|
||||||
|
def load_checkpoint(
|
||||||
|
self,
|
||||||
|
checkpoint_path: str | Path,
|
||||||
|
map_location: str | torch.device | None = None,
|
||||||
|
strict: bool = True,
|
||||||
|
) -> None:
|
||||||
|
resolved_ckpt = self._resolve_path(checkpoint_path)
|
||||||
|
checkpoint_obj = cast(
|
||||||
|
object,
|
||||||
|
torch.load(
|
||||||
|
str(resolved_ckpt),
|
||||||
|
map_location=map_location if map_location is not None else self.device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
state_dict_obj: object = checkpoint_obj
|
||||||
|
if isinstance(checkpoint_obj, dict) and "model" in checkpoint_obj:
|
||||||
|
state_dict_obj = cast(dict[str, object], checkpoint_obj)["model"]
|
||||||
|
|
||||||
|
if not isinstance(state_dict_obj, dict):
|
||||||
|
raise TypeError("Unsupported checkpoint format.")
|
||||||
|
|
||||||
|
cleaned_state_dict = self._normalize_state_dict(
|
||||||
|
cast(dict[object, object], state_dict_obj)
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
_ = self.load_state_dict(cleaned_state_dict, strict=strict)
|
||||||
|
except RuntimeError as exc:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to load ScoNetDemo checkpoint after key normalization from '{resolved_ckpt}'."
|
||||||
|
) from exc
|
||||||
|
_ = self.eval()
|
||||||
|
|
||||||
|
def _prepare_sils(self, sils: Tensor) -> Tensor:
|
||||||
|
if sils.ndim == 4:
|
||||||
|
sils = sils.unsqueeze(1)
|
||||||
|
elif sils.ndim == 5 and sils.shape[1] != 1 and sils.shape[2] == 1:
|
||||||
|
sils = rearrange(sils, "b s c h w -> b c s h w")
|
||||||
|
|
||||||
|
if sils.ndim != 5 or sils.shape[1] != 1:
|
||||||
|
raise ValueError("Expected sils shape [B, 1, S, H, W] or [B, S, H, W].")
|
||||||
|
|
||||||
|
return sils.float().to(self.device)
|
||||||
|
|
||||||
|
def _forward_backbone(self, sils: Tensor) -> Tensor:
|
||||||
|
batch, channels, seq, height, width = sils.shape
|
||||||
|
framewise = sils.transpose(1, 2).reshape(batch * seq, channels, height, width)
|
||||||
|
frame_feats = cast(Tensor, self.backbone(framewise))
|
||||||
|
_, out_channels, out_h, out_w = frame_feats.shape
|
||||||
|
return (
|
||||||
|
frame_feats.reshape(batch, seq, out_channels, out_h, out_w)
|
||||||
|
.transpose(1, 2)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
|
@jaxtyped(typechecker=beartype)
|
||||||
|
def forward(self, sils: Float[Tensor, "batch 1 seq 64 44"]) -> dict[str, Tensor]:
|
||||||
|
with torch.inference_mode():
|
||||||
|
prepared_sils = self._prepare_sils(sils)
|
||||||
|
outs = self._forward_backbone(prepared_sils)
|
||||||
|
|
||||||
|
pooled_obj = self.temporal_pool(outs, None, options={"dim": 2})
|
||||||
|
if (
|
||||||
|
not isinstance(pooled_obj, tuple)
|
||||||
|
or not pooled_obj
|
||||||
|
or not isinstance(pooled_obj[0], Tensor)
|
||||||
|
):
|
||||||
|
raise TypeError("TemporalPool output is invalid.")
|
||||||
|
pooled = pooled_obj[0]
|
||||||
|
|
||||||
|
feat = self.hpp(pooled)
|
||||||
|
embed_1 = self.fcs(feat)
|
||||||
|
_, logits = self.bn_necks(embed_1)
|
||||||
|
|
||||||
|
mean_logits = logits.mean(dim=-1)
|
||||||
|
pred_ids = torch.argmax(mean_logits, dim=-1)
|
||||||
|
probs = torch.softmax(mean_logits, dim=-1)
|
||||||
|
confidence = torch.gather(
|
||||||
|
probs, dim=-1, index=pred_ids.unsqueeze(-1)
|
||||||
|
).squeeze(-1)
|
||||||
|
|
||||||
|
return {"logits": logits, "label": pred_ids, "confidence": confidence}
|
||||||
|
|
||||||
|
@jaxtyped(typechecker=beartype)
|
||||||
|
def predict(self, sils: Float[Tensor, "batch 1 seq 64 44"]) -> tuple[str, float]:
|
||||||
|
outputs = cast(dict[str, Tensor], self.forward(sils))
|
||||||
|
labels = outputs["label"]
|
||||||
|
confidence = outputs["confidence"]
|
||||||
|
|
||||||
|
if labels.numel() != 1:
|
||||||
|
raise ValueError("predict expects batch size 1.")
|
||||||
|
|
||||||
|
label_id = int(labels.item())
|
||||||
|
return self.LABEL_MAP[label_id], float(confidence.item())
|
||||||
@@ -0,0 +1,295 @@
|
|||||||
|
"""Sliding window / ring buffer manager for real-time gait analysis.
|
||||||
|
|
||||||
|
This module provides bounded buffer management for silhouette sequences
|
||||||
|
with track ID tracking and gap detection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections import deque
|
||||||
|
from typing import TYPE_CHECKING, Protocol, final
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from jaxtyping import Float
|
||||||
|
from numpy import ndarray
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
|
|
||||||
|
# Silhouette dimensions from preprocess.py
|
||||||
|
SIL_HEIGHT: int = 64
|
||||||
|
SIL_WIDTH: int = 44
|
||||||
|
|
||||||
|
|
||||||
|
class _Boxes(Protocol):
|
||||||
|
"""Protocol for boxes with xyxy and id attributes."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def xyxy(self) -> "NDArray[np.float32] | object": ...
|
||||||
|
@property
|
||||||
|
def id(self) -> "NDArray[np.int64] | object | None": ...
|
||||||
|
|
||||||
|
|
||||||
|
class _Masks(Protocol):
|
||||||
|
"""Protocol for masks with data attribute."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data(self) -> "NDArray[np.float32] | object": ...
|
||||||
|
|
||||||
|
|
||||||
|
class _DetectionResults(Protocol):
|
||||||
|
"""Protocol for detection results from Ultralytics-style objects."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def boxes(self) -> _Boxes: ...
|
||||||
|
@property
|
||||||
|
def masks(self) -> _Masks: ...
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
|
class SilhouetteWindow:
|
||||||
|
"""Bounded sliding window for silhouette sequences.
|
||||||
|
|
||||||
|
Manages a fixed-size buffer of silhouettes with track ID tracking
|
||||||
|
and automatic reset on track changes or frame gaps.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
window_size: Maximum number of frames in the buffer.
|
||||||
|
stride: Classification stride (frames between classifications).
|
||||||
|
gap_threshold: Maximum allowed frame gap before reset.
|
||||||
|
"""
|
||||||
|
|
||||||
|
window_size: int
|
||||||
|
stride: int
|
||||||
|
gap_threshold: int
|
||||||
|
_buffer: deque[Float[ndarray, "64 44"]]
|
||||||
|
_frame_indices: deque[int]
|
||||||
|
_track_id: int | None
|
||||||
|
_last_classified_frame: int
|
||||||
|
_frame_count: int
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
window_size: int = 30,
|
||||||
|
stride: int = 1,
|
||||||
|
gap_threshold: int = 15,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the silhouette window.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
window_size: Maximum buffer size (default 30).
|
||||||
|
stride: Frames between classifications (default 1).
|
||||||
|
gap_threshold: Max frame gap before reset (default 15).
|
||||||
|
"""
|
||||||
|
self.window_size = window_size
|
||||||
|
self.stride = stride
|
||||||
|
self.gap_threshold = gap_threshold
|
||||||
|
|
||||||
|
# Bounded storage via deque
|
||||||
|
self._buffer = deque(maxlen=window_size)
|
||||||
|
self._frame_indices = deque(maxlen=window_size)
|
||||||
|
self._track_id = None
|
||||||
|
self._last_classified_frame = -1
|
||||||
|
self._frame_count = 0
|
||||||
|
|
||||||
|
def push(self, sil: np.ndarray, frame_idx: int, track_id: int) -> None:
|
||||||
|
"""Push a new silhouette into the window.
|
||||||
|
|
||||||
|
Automatically resets buffer on track ID change or frame gap
|
||||||
|
exceeding gap_threshold.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sil: Silhouette array of shape (64, 44), float32.
|
||||||
|
frame_idx: Current frame index for gap detection.
|
||||||
|
track_id: Track ID for the person.
|
||||||
|
"""
|
||||||
|
# Check for track ID change
|
||||||
|
if self._track_id is not None and track_id != self._track_id:
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
# Check for frame gap
|
||||||
|
if self._frame_indices:
|
||||||
|
last_frame = self._frame_indices[-1]
|
||||||
|
gap = frame_idx - last_frame
|
||||||
|
if gap > self.gap_threshold:
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
# Update track ID
|
||||||
|
self._track_id = track_id
|
||||||
|
|
||||||
|
# Validate and append silhouette
|
||||||
|
sil_array = np.asarray(sil, dtype=np.float32)
|
||||||
|
if sil_array.shape != (SIL_HEIGHT, SIL_WIDTH):
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected silhouette shape ({SIL_HEIGHT}, {SIL_WIDTH}), got {sil_array.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._buffer.append(sil_array)
|
||||||
|
self._frame_indices.append(frame_idx)
|
||||||
|
self._frame_count += 1
|
||||||
|
|
||||||
|
def is_ready(self) -> bool:
|
||||||
|
"""Check if window has enough frames for classification.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if buffer is full (window_size frames).
|
||||||
|
"""
|
||||||
|
return len(self._buffer) >= self.window_size
|
||||||
|
|
||||||
|
def should_classify(self) -> bool:
|
||||||
|
"""Check if classification should run based on stride.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if enough frames have passed since last classification.
|
||||||
|
"""
|
||||||
|
if not self.is_ready():
|
||||||
|
return False
|
||||||
|
|
||||||
|
if self._last_classified_frame < 0:
|
||||||
|
return True
|
||||||
|
|
||||||
|
current_frame = self._frame_indices[-1]
|
||||||
|
frames_since = current_frame - self._last_classified_frame
|
||||||
|
return frames_since >= self.stride
|
||||||
|
|
||||||
|
def get_tensor(self, device: str = "cpu") -> torch.Tensor:
|
||||||
|
"""Get window contents as a tensor for model input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device: Target device for the tensor (default 'cpu').
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor of shape [1, 1, window_size, 64, 44] with dtype float32.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If buffer is not full.
|
||||||
|
"""
|
||||||
|
if not self.is_ready():
|
||||||
|
raise ValueError(
|
||||||
|
f"Window not ready: {len(self._buffer)}/{self.window_size} frames"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stack buffer into array [window_size, 64, 44]
|
||||||
|
stacked = np.stack(list(self._buffer), axis=0)
|
||||||
|
|
||||||
|
# Add batch and channel dims: [1, 1, window_size, 64, 44]
|
||||||
|
tensor = torch.from_numpy(stacked.astype(np.float32))
|
||||||
|
tensor = tensor.unsqueeze(0).unsqueeze(0)
|
||||||
|
|
||||||
|
return tensor.to(device)
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset the window, clearing all buffers and counters."""
|
||||||
|
self._buffer.clear()
|
||||||
|
self._frame_indices.clear()
|
||||||
|
self._track_id = None
|
||||||
|
self._last_classified_frame = -1
|
||||||
|
self._frame_count = 0
|
||||||
|
|
||||||
|
def mark_classified(self) -> None:
|
||||||
|
"""Mark current frame as classified, updating stride tracking."""
|
||||||
|
if self._frame_indices:
|
||||||
|
self._last_classified_frame = self._frame_indices[-1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_track_id(self) -> int | None:
|
||||||
|
"""Current track ID, or None if buffer is empty."""
|
||||||
|
return self._track_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def frame_count(self) -> int:
|
||||||
|
"""Total frames pushed since last reset."""
|
||||||
|
return self._frame_count
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fill_level(self) -> float:
|
||||||
|
"""Fill ratio of the buffer (0.0 to 1.0)."""
|
||||||
|
return len(self._buffer) / self.window_size
|
||||||
|
|
||||||
|
|
||||||
|
def select_person(
|
||||||
|
results: _DetectionResults,
|
||||||
|
) -> tuple[ndarray, tuple[int, int, int, int], int] | None:
|
||||||
|
"""Select the person with largest bounding box from detection results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: Detection results object with boxes and masks attributes.
|
||||||
|
Expected to have:
|
||||||
|
- boxes.xyxy: array of bounding boxes [N, 4]
|
||||||
|
- masks.data: array of masks [N, H, W]
|
||||||
|
- boxes.id: optional track IDs [N]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (mask, bbox, track_id) for the largest person,
|
||||||
|
or None if no valid detections or track IDs unavailable.
|
||||||
|
"""
|
||||||
|
# Check for track IDs
|
||||||
|
boxes_obj: _Boxes | object = getattr(results, "boxes", None)
|
||||||
|
if boxes_obj is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
track_ids_obj: ndarray | object | None = getattr(boxes_obj, "id", None)
|
||||||
|
if track_ids_obj is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
track_ids: ndarray = np.asarray(track_ids_obj)
|
||||||
|
if track_ids.size == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get bounding boxes
|
||||||
|
xyxy_obj: ndarray | object = getattr(boxes_obj, "xyxy", None)
|
||||||
|
if xyxy_obj is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
bboxes: ndarray = np.asarray(xyxy_obj)
|
||||||
|
if bboxes.ndim == 1:
|
||||||
|
bboxes = bboxes.reshape(1, -1)
|
||||||
|
|
||||||
|
if bboxes.shape[0] == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get masks
|
||||||
|
masks_obj: _Masks | object = getattr(results, "masks", None)
|
||||||
|
if masks_obj is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
masks_data: ndarray | object = getattr(masks_obj, "data", None)
|
||||||
|
if masks_data is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
masks: ndarray = np.asarray(masks_data)
|
||||||
|
if masks.ndim == 2:
|
||||||
|
masks = masks[np.newaxis, ...]
|
||||||
|
|
||||||
|
if masks.shape[0] != bboxes.shape[0]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Find largest bbox by area
|
||||||
|
best_idx: int = -1
|
||||||
|
best_area: float = -1.0
|
||||||
|
|
||||||
|
for i in range(int(bboxes.shape[0])):
|
||||||
|
row: "NDArray[np.float32]" = bboxes[i][:4]
|
||||||
|
x1f: float = float(row[0])
|
||||||
|
y1f: float = float(row[1])
|
||||||
|
x2f: float = float(row[2])
|
||||||
|
y2f: float = float(row[3])
|
||||||
|
area: float = (x2f - x1f) * (y2f - y1f)
|
||||||
|
if area > best_area:
|
||||||
|
best_area = area
|
||||||
|
best_idx = i
|
||||||
|
|
||||||
|
if best_idx < 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Extract mask and bbox
|
||||||
|
mask: "NDArray[np.float32]" = masks[best_idx]
|
||||||
|
bbox = (
|
||||||
|
int(float(bboxes[best_idx][0])),
|
||||||
|
int(float(bboxes[best_idx][1])),
|
||||||
|
int(float(bboxes[best_idx][2])),
|
||||||
|
int(float(bboxes[best_idx][3])),
|
||||||
|
)
|
||||||
|
track_id = int(track_ids[best_idx]) if best_idx < len(track_ids) else best_idx
|
||||||
|
|
||||||
|
return mask, bbox, track_id
|
||||||
Reference in New Issue
Block a user