feat(demo): implement ScoNet real-time pipeline runtime
Add the full demo runtime stack for single-person scoliosis inference, including input adapters, silhouette preprocessing, temporal windowing, ScoNet wrapper, result publishing, and click-based CLI orchestration. This commit captures the executable pipeline behavior independently from tests and planning artifacts for clearer review and rollback.
This commit is contained in:
@@ -0,0 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .pipeline import main
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,203 @@
|
||||
"""
|
||||
Input adapters for OpenGait demo.
|
||||
|
||||
Provides generator-based interfaces for video sources:
|
||||
- OpenCV (video files, cameras)
|
||||
- cv-mmap (shared memory streams)
|
||||
"""
|
||||
|
||||
from collections.abc import AsyncIterator, Generator, Iterable
|
||||
from typing import TYPE_CHECKING, Protocol, cast
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Type alias for frame stream: (frame_array, metadata_dict)
|
||||
FrameStream = Iterable[tuple[np.ndarray, dict[str, object]]]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Protocol for cv-mmap metadata to avoid direct import
|
||||
class _FrameMetadata(Protocol):
|
||||
frame_count: int
|
||||
timestamp_ns: int
|
||||
|
||||
# Protocol for cv-mmap client
|
||||
class _CvMmapClient(Protocol):
|
||||
def __aiter__(self) -> AsyncIterator[tuple[np.ndarray, _FrameMetadata]]: ...
|
||||
|
||||
|
||||
def opencv_source(
|
||||
path: str | int, max_frames: int | None = None
|
||||
) -> Generator[tuple[np.ndarray, dict[str, object]], None, None]:
|
||||
"""
|
||||
Generator that yields frames from an OpenCV video source.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : str | int
|
||||
Video file path or camera index (e.g., 0 for default camera)
|
||||
max_frames : int | None, optional
|
||||
Maximum number of frames to yield. None means unlimited.
|
||||
|
||||
Yields
|
||||
------
|
||||
tuple[np.ndarray, dict[str, object]]
|
||||
(frame_array, metadata_dict) where metadata includes:
|
||||
- frame_count: frame index (0-based)
|
||||
- timestamp_ns: monotonic timestamp in nanoseconds (if available)
|
||||
- source: the path/int provided
|
||||
"""
|
||||
import time
|
||||
|
||||
import cv2
|
||||
|
||||
cap = cv2.VideoCapture(path)
|
||||
if not cap.isOpened():
|
||||
raise RuntimeError(f"Failed to open video source: {path}")
|
||||
|
||||
frame_idx = 0
|
||||
try:
|
||||
while max_frames is None or frame_idx < max_frames:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
# End of stream
|
||||
break
|
||||
|
||||
# Get timestamp if available (some backends support this)
|
||||
timestamp_ns = time.monotonic_ns()
|
||||
|
||||
metadata: dict[str, object] = {
|
||||
"frame_count": frame_idx,
|
||||
"timestamp_ns": timestamp_ns,
|
||||
"source": path,
|
||||
}
|
||||
|
||||
yield frame, metadata
|
||||
frame_idx += 1
|
||||
|
||||
finally:
|
||||
cap.release()
|
||||
logger.debug(f"OpenCV source closed: {path}")
|
||||
|
||||
|
||||
def cvmmap_source(
|
||||
name: str, max_frames: int | None = None
|
||||
) -> Generator[tuple[np.ndarray, dict[str, object]], None, None]:
|
||||
"""
|
||||
Generator that yields frames from a cv-mmap shared memory stream.
|
||||
|
||||
Bridges async cv-mmap client to synchronous generator using asyncio.run().
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
Base name of the cv-mmap source (e.g., "default")
|
||||
max_frames : int | None, optional
|
||||
Maximum number of frames to yield. None means unlimited.
|
||||
|
||||
Yields
|
||||
------
|
||||
tuple[np.ndarray, dict[str, object]]
|
||||
(frame_array, metadata_dict) where metadata includes:
|
||||
- frame_count: frame index from cv-mmap
|
||||
- timestamp_ns: timestamp in nanoseconds from cv-mmap
|
||||
- source: the cv-mmap name
|
||||
|
||||
Raises
|
||||
------
|
||||
ImportError
|
||||
If cvmmap package is not available
|
||||
RuntimeError
|
||||
If cv-mmap stream disconnects or errors
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
# Import cvmmap only when function is called
|
||||
# Use try/except for runtime import check
|
||||
try:
|
||||
|
||||
from cvmmap import CvMmapClient as _CvMmapClientReal # pyright: ignore[reportMissingTypeStubs]
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"cvmmap package is required for cv-mmap sources. "
|
||||
+ "Install from: https://github.com/crosstyan/cv-mmap"
|
||||
) from e
|
||||
|
||||
# Cast to protocol type for type checking
|
||||
client: _CvMmapClient = cast("_CvMmapClient", _CvMmapClientReal(name))
|
||||
frame_count = 0
|
||||
|
||||
async def _async_generator() -> AsyncIterator[tuple[np.ndarray, _FrameMetadata]]:
|
||||
"""Async generator wrapper."""
|
||||
async for frame, meta in client:
|
||||
yield frame, meta
|
||||
|
||||
# Bridge async to sync using asyncio.run()
|
||||
# We process frames one at a time to keep it simple and robust
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
agen = _async_generator().__aiter__()
|
||||
|
||||
while max_frames is None or frame_count < max_frames:
|
||||
try:
|
||||
frame, meta = loop.run_until_complete(agen.__anext__())
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
metadata: dict[str, object] = {
|
||||
"frame_count": meta.frame_count,
|
||||
"timestamp_ns": meta.timestamp_ns,
|
||||
"source": f"cvmmap://{name}",
|
||||
}
|
||||
|
||||
yield frame, metadata
|
||||
frame_count += 1
|
||||
|
||||
finally:
|
||||
loop.close()
|
||||
logger.debug(f"cv-mmap source closed: {name}")
|
||||
|
||||
|
||||
def create_source(source: str, max_frames: int | None = None) -> FrameStream:
|
||||
"""
|
||||
Factory function to create a frame source from a string specification.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source : str
|
||||
Source specification:
|
||||
- '0', '1', etc. -> Camera index (OpenCV)
|
||||
- 'cvmmap://name' -> cv-mmap shared memory stream
|
||||
- Any other string -> Video file path (OpenCV)
|
||||
max_frames : int | None, optional
|
||||
Maximum number of frames to yield. None means unlimited.
|
||||
|
||||
Returns
|
||||
-------
|
||||
FrameStream
|
||||
Generator yielding (frame, metadata) tuples
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> for frame, meta in create_source('0'): # Camera 0
|
||||
... process(frame)
|
||||
>>> for frame, meta in create_source('cvmmap://default'): # cv-mmap
|
||||
... process(frame)
|
||||
>>> for frame, meta in create_source('/path/to/video.mp4'):
|
||||
... process(frame)
|
||||
"""
|
||||
# Check for cv-mmap protocol
|
||||
if source.startswith("cvmmap://"):
|
||||
name = source[len("cvmmap://") :]
|
||||
return cvmmap_source(name, max_frames)
|
||||
|
||||
# Check for camera index (single digit string)
|
||||
if source.isdigit():
|
||||
return opencv_source(int(source), max_frames)
|
||||
|
||||
# Otherwise treat as file path
|
||||
return opencv_source(source, max_frames)
|
||||
@@ -0,0 +1,368 @@
|
||||
"""
|
||||
Output publishers for OpenGait demo results.
|
||||
|
||||
Provides pluggable result publishing:
|
||||
- ConsolePublisher: JSONL to stdout
|
||||
- NatsPublisher: NATS message broker integration
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Protocol, TextIO, cast, runtime_checkable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ResultPublisher(Protocol):
|
||||
"""Protocol for result publishers."""
|
||||
|
||||
def publish(self, result: dict[str, object]) -> None:
|
||||
"""
|
||||
Publish a result dictionary.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
result : dict[str, object]
|
||||
Result data with keys: frame, track_id, label, confidence, window, timestamp_ns
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class ConsolePublisher:
|
||||
"""Publisher that outputs JSON Lines to stdout."""
|
||||
|
||||
_output: TextIO
|
||||
|
||||
def __init__(self, output: TextIO = sys.stdout) -> None:
|
||||
"""
|
||||
Initialize console publisher.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : TextIO
|
||||
File-like object to write to (default: sys.stdout)
|
||||
"""
|
||||
self._output = output
|
||||
|
||||
def publish(self, result: dict[str, object]) -> None:
|
||||
"""
|
||||
Publish result as JSON line.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
result : dict[str, object]
|
||||
Result data with keys: frame, track_id, label, confidence, window, timestamp_ns
|
||||
"""
|
||||
try:
|
||||
json_line = json.dumps(result, ensure_ascii=False, default=str)
|
||||
_ = self._output.write(json_line + "\n")
|
||||
self._output.flush()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to publish to console: {e}")
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the publisher (no-op for console)."""
|
||||
pass
|
||||
|
||||
def __enter__(self) -> ConsolePublisher:
|
||||
"""Context manager entry."""
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
"""Context manager exit."""
|
||||
self.close()
|
||||
|
||||
|
||||
class _NatsClient(Protocol):
|
||||
"""Protocol for connected NATS client."""
|
||||
|
||||
async def publish(self, subject: str, payload: bytes) -> object: ...
|
||||
|
||||
async def close(self) -> object: ...
|
||||
|
||||
async def flush(self) -> object: ...
|
||||
|
||||
|
||||
class NatsPublisher:
|
||||
"""
|
||||
Publisher that sends results to NATS message broker.
|
||||
|
||||
This is a sync-friendly wrapper around the async nats-py client.
|
||||
Uses a background thread with dedicated event loop to bridge sync
|
||||
publish calls to async NATS operations, making it safe to use in
|
||||
both sync and async contexts.
|
||||
"""
|
||||
|
||||
_nats_url: str
|
||||
_subject: str
|
||||
_nc: _NatsClient | None
|
||||
_connected: bool
|
||||
_loop: asyncio.AbstractEventLoop | None
|
||||
_thread: threading.Thread | None
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, nats_url: str, subject: str = "scoliosis.result") -> None:
|
||||
"""
|
||||
Initialize NATS publisher.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
nats_url : str
|
||||
NATS server URL (e.g., "nats://localhost:4222")
|
||||
subject : str
|
||||
NATS subject to publish to (default: "scoliosis.result")
|
||||
"""
|
||||
self._nats_url = nats_url
|
||||
self._subject = subject
|
||||
self._nc = None
|
||||
self._connected = False
|
||||
self._loop = None
|
||||
self._thread = None
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def _start_background_loop(self) -> bool:
|
||||
"""
|
||||
Start background thread with event loop for async operations.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if loop is running, False otherwise
|
||||
"""
|
||||
with self._lock:
|
||||
if self._loop is not None and self._loop.is_running():
|
||||
return True
|
||||
|
||||
try:
|
||||
loop = asyncio.new_event_loop()
|
||||
self._loop = loop
|
||||
|
||||
def run_loop() -> None:
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_forever()
|
||||
|
||||
self._thread = threading.Thread(target=run_loop, daemon=True)
|
||||
self._thread.start()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to start background event loop: {e}")
|
||||
return False
|
||||
|
||||
def _stop_background_loop(self) -> None:
|
||||
"""Stop the background event loop and thread."""
|
||||
with self._lock:
|
||||
if self._loop is not None and self._loop.is_running():
|
||||
_ = self._loop.call_soon_threadsafe(self._loop.stop)
|
||||
if self._thread is not None and self._thread.is_alive():
|
||||
self._thread.join(timeout=2.0)
|
||||
self._loop = None
|
||||
self._thread = None
|
||||
|
||||
def _ensure_connected(self) -> bool:
|
||||
"""
|
||||
Ensure connection to NATS server.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if connected, False otherwise
|
||||
"""
|
||||
with self._lock:
|
||||
if self._connected and self._nc is not None:
|
||||
return True
|
||||
|
||||
if not self._start_background_loop():
|
||||
return False
|
||||
|
||||
try:
|
||||
import nats
|
||||
|
||||
async def _connect() -> _NatsClient:
|
||||
nc = await nats.connect(self._nats_url) # pyright: ignore[reportUnknownMemberType]
|
||||
return cast(_NatsClient, nc)
|
||||
|
||||
# Run connection in background loop
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
_connect(),
|
||||
self._loop, # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
self._nc = future.result(timeout=10.0)
|
||||
self._connected = True
|
||||
logger.info(f"Connected to NATS at {self._nats_url}")
|
||||
return True
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"nats-py package not installed. Install with: pip install nats-py"
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to connect to NATS at {self._nats_url}: {e}")
|
||||
return False
|
||||
|
||||
def publish(self, result: dict[str, object]) -> None:
|
||||
"""
|
||||
Publish result to NATS subject.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
result : dict[str, object]
|
||||
Result data with keys: frame, track_id, label, confidence, window, timestamp_ns
|
||||
"""
|
||||
if not self._ensure_connected():
|
||||
# Graceful degradation: log warning but don't crash
|
||||
logger.debug(
|
||||
f"NATS unavailable, dropping result: {result.get('track_id', 'unknown')}"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
|
||||
async def _publish() -> None:
|
||||
if self._nc is not None:
|
||||
payload = json.dumps(
|
||||
result, ensure_ascii=False, default=str
|
||||
).encode("utf-8")
|
||||
_ = await self._nc.publish(self._subject, payload)
|
||||
_ = await self._nc.flush()
|
||||
# Run publish in background loop
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
_publish(),
|
||||
self._loop, # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
future.result(timeout=5.0) # Wait for publish to complete
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to publish to NATS: {e}")
|
||||
self._connected = False # Mark for reconnection on next publish
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close NATS connection."""
|
||||
with self._lock:
|
||||
if self._nc is not None and self._connected and self._loop is not None:
|
||||
try:
|
||||
|
||||
async def _close() -> None:
|
||||
if self._nc is not None:
|
||||
_ = await self._nc.close()
|
||||
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
_close(),
|
||||
self._loop,
|
||||
)
|
||||
future.result(timeout=5.0)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error closing NATS connection: {e}")
|
||||
finally:
|
||||
self._nc = None
|
||||
self._connected = False
|
||||
|
||||
self._stop_background_loop()
|
||||
|
||||
def __enter__(self) -> NatsPublisher:
|
||||
"""Context manager entry."""
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
"""Context manager exit."""
|
||||
self.close()
|
||||
|
||||
|
||||
def create_publisher(
|
||||
nats_url: str | None,
|
||||
subject: str = "scoliosis.result",
|
||||
) -> ResultPublisher:
|
||||
"""
|
||||
Factory function to create appropriate publisher.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
nats_url : str | None
|
||||
NATS server URL. If None or empty, returns ConsolePublisher.
|
||||
subject : str
|
||||
NATS subject to publish to (default: "scoliosis.result")
|
||||
|
||||
Returns
|
||||
-------
|
||||
ResultPublisher
|
||||
NatsPublisher if nats_url provided, otherwise ConsolePublisher
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> # Console output (default)
|
||||
>>> pub = create_publisher(None)
|
||||
>>> pub.publish({"frame": 1, "track_id": 42, "label": "normal", "confidence": 0.95, "window": 30, "timestamp_ns": 1234567890})
|
||||
>>>
|
||||
>>> # NATS output
|
||||
>>> pub = create_publisher("nats://localhost:4222")
|
||||
>>> pub.publish({"frame": 1, "track_id": 42, "label": "normal", "confidence": 0.95, "window": 30, "timestamp_ns": 1234567890})
|
||||
>>>
|
||||
>>> # Context manager usage
|
||||
>>> with create_publisher("nats://localhost:4222") as pub:
|
||||
... pub.publish(result)
|
||||
"""
|
||||
if nats_url:
|
||||
return NatsPublisher(nats_url, subject)
|
||||
return ConsolePublisher()
|
||||
|
||||
|
||||
def create_result(
|
||||
frame: int,
|
||||
track_id: int,
|
||||
label: str,
|
||||
confidence: float,
|
||||
window: int | tuple[int, int],
|
||||
timestamp_ns: int | None = None,
|
||||
) -> dict[str, object]:
|
||||
"""
|
||||
Create a standardized result dictionary.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
frame : int
|
||||
Frame number
|
||||
track_id : int
|
||||
Track/person identifier
|
||||
label : str
|
||||
Classification label (e.g., "normal", "scoliosis")
|
||||
confidence : float
|
||||
Confidence score (0.0 to 1.0)
|
||||
window : int | tuple[int, int]
|
||||
Frame window as int (end frame) or tuple [start, end] that produced this result
|
||||
Frame window [start, end] that produced this result
|
||||
timestamp_ns : int | None
|
||||
Timestamp in nanoseconds. If None, uses current time.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict[str, object]
|
||||
Standardized result dictionary
|
||||
"""
|
||||
return {
|
||||
"frame": frame,
|
||||
"track_id": track_id,
|
||||
"label": label,
|
||||
"confidence": confidence,
|
||||
"window": window if isinstance(window, int) else window[1],
|
||||
"timestamp_ns": timestamp_ns
|
||||
if timestamp_ns is not None
|
||||
else time.monotonic_ns(),
|
||||
}
|
||||
@@ -0,0 +1,325 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from contextlib import suppress
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import time
|
||||
from typing import Protocol, cast
|
||||
|
||||
from beartype import beartype
|
||||
import click
|
||||
import jaxtyping
|
||||
from jaxtyping import Float, UInt8
|
||||
import numpy as np
|
||||
from numpy import ndarray
|
||||
from numpy.typing import NDArray
|
||||
from ultralytics.models.yolo.model import YOLO
|
||||
|
||||
from .input import FrameStream, create_source
|
||||
from .output import ResultPublisher, create_publisher, create_result
|
||||
from .preprocess import frame_to_person_mask, mask_to_silhouette
|
||||
from .sconet_demo import ScoNetDemo
|
||||
from .window import SilhouetteWindow, select_person
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
|
||||
JaxtypedFactory = Callable[..., JaxtypedDecorator]
|
||||
jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
|
||||
|
||||
|
||||
class _BoxesLike(Protocol):
|
||||
@property
|
||||
def xyxy(self) -> NDArray[np.float32] | object: ...
|
||||
|
||||
@property
|
||||
def id(self) -> NDArray[np.int64] | object | None: ...
|
||||
|
||||
|
||||
class _MasksLike(Protocol):
|
||||
@property
|
||||
def data(self) -> NDArray[np.float32] | object: ...
|
||||
|
||||
|
||||
class _DetectionResultsLike(Protocol):
|
||||
@property
|
||||
def boxes(self) -> _BoxesLike: ...
|
||||
|
||||
@property
|
||||
def masks(self) -> _MasksLike: ...
|
||||
|
||||
|
||||
class _TrackCallable(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
source: object,
|
||||
*,
|
||||
persist: bool = True,
|
||||
verbose: bool = False,
|
||||
device: str | None = None,
|
||||
classes: list[int] | None = None,
|
||||
) -> object: ...
|
||||
|
||||
|
||||
class ScoliosisPipeline:
|
||||
_detector: object
|
||||
_source: FrameStream
|
||||
_window: SilhouetteWindow
|
||||
_publisher: ResultPublisher
|
||||
_classifier: ScoNetDemo
|
||||
_device: str
|
||||
_closed: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
source: str,
|
||||
checkpoint: str,
|
||||
config: str,
|
||||
device: str,
|
||||
yolo_model: str,
|
||||
window: int,
|
||||
stride: int,
|
||||
nats_url: str | None,
|
||||
nats_subject: str,
|
||||
max_frames: int | None,
|
||||
) -> None:
|
||||
self._detector = YOLO(yolo_model)
|
||||
self._source = create_source(source, max_frames=max_frames)
|
||||
self._window = SilhouetteWindow(window_size=window, stride=stride)
|
||||
self._publisher = create_publisher(nats_url=nats_url, subject=nats_subject)
|
||||
self._classifier = ScoNetDemo(
|
||||
cfg_path=config,
|
||||
checkpoint_path=checkpoint,
|
||||
device=device,
|
||||
)
|
||||
self._device = device
|
||||
self._closed = False
|
||||
|
||||
@staticmethod
|
||||
def _extract_int(meta: dict[str, object], key: str, fallback: int) -> int:
|
||||
value = meta.get(key)
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
return fallback
|
||||
|
||||
@staticmethod
|
||||
def _extract_timestamp(meta: dict[str, object]) -> int:
|
||||
value = meta.get("timestamp_ns")
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
return time.monotonic_ns()
|
||||
|
||||
@staticmethod
|
||||
def _to_mask_u8(mask: ndarray) -> UInt8[ndarray, "h w"]:
|
||||
binary = np.where(np.asarray(mask) > 0.5, np.uint8(255), np.uint8(0)).astype(
|
||||
np.uint8
|
||||
)
|
||||
return cast(UInt8[ndarray, "h w"], binary)
|
||||
|
||||
def _first_result(self, detections: object) -> _DetectionResultsLike | None:
|
||||
if isinstance(detections, list):
|
||||
return cast(_DetectionResultsLike, detections[0]) if detections else None
|
||||
if isinstance(detections, tuple):
|
||||
return cast(_DetectionResultsLike, detections[0]) if detections else None
|
||||
return cast(_DetectionResultsLike, detections)
|
||||
|
||||
def _select_silhouette(
|
||||
self,
|
||||
result: _DetectionResultsLike,
|
||||
) -> tuple[Float[ndarray, "64 44"], int] | None:
|
||||
selected = select_person(result)
|
||||
if selected is not None:
|
||||
mask_raw, bbox, track_id = selected
|
||||
silhouette = cast(
|
||||
Float[ndarray, "64 44"] | None,
|
||||
mask_to_silhouette(self._to_mask_u8(mask_raw), bbox),
|
||||
)
|
||||
if silhouette is not None:
|
||||
return silhouette, int(track_id)
|
||||
|
||||
fallback = cast(
|
||||
tuple[UInt8[ndarray, "h w"], tuple[int, int, int, int]] | None,
|
||||
frame_to_person_mask(result),
|
||||
)
|
||||
if fallback is None:
|
||||
return None
|
||||
|
||||
mask_u8, bbox = fallback
|
||||
silhouette = cast(
|
||||
Float[ndarray, "64 44"] | None,
|
||||
mask_to_silhouette(mask_u8, bbox),
|
||||
)
|
||||
if silhouette is None:
|
||||
return None
|
||||
return silhouette, 0
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def process_frame(
|
||||
self,
|
||||
frame: UInt8[ndarray, "h w c"],
|
||||
metadata: dict[str, object],
|
||||
) -> dict[str, object] | None:
|
||||
frame_idx = self._extract_int(metadata, "frame_count", fallback=0)
|
||||
timestamp_ns = self._extract_timestamp(metadata)
|
||||
|
||||
track_fn_obj = getattr(self._detector, "track", None)
|
||||
if not callable(track_fn_obj):
|
||||
raise RuntimeError("YOLO detector does not expose a callable track()")
|
||||
|
||||
track_fn = cast(_TrackCallable, track_fn_obj)
|
||||
detections = track_fn(
|
||||
frame,
|
||||
persist=True,
|
||||
verbose=False,
|
||||
device=self._device,
|
||||
classes=[0],
|
||||
)
|
||||
first = self._first_result(detections)
|
||||
if first is None:
|
||||
return None
|
||||
|
||||
selected = self._select_silhouette(first)
|
||||
if selected is None:
|
||||
return None
|
||||
|
||||
silhouette, track_id = selected
|
||||
self._window.push(silhouette, frame_idx=frame_idx, track_id=track_id)
|
||||
|
||||
if not self._window.should_classify():
|
||||
return None
|
||||
|
||||
window_tensor = self._window.get_tensor(device=self._device)
|
||||
label, confidence = cast(
|
||||
tuple[str, float],
|
||||
self._classifier.predict(window_tensor),
|
||||
)
|
||||
self._window.mark_classified()
|
||||
|
||||
window_start = frame_idx - self._window.window_size + 1
|
||||
result = create_result(
|
||||
frame=frame_idx,
|
||||
track_id=track_id,
|
||||
label=label,
|
||||
confidence=float(confidence),
|
||||
window=(max(0, window_start), frame_idx),
|
||||
timestamp_ns=timestamp_ns,
|
||||
)
|
||||
self._publisher.publish(result)
|
||||
return result
|
||||
|
||||
def run(self) -> int:
|
||||
frame_count = 0
|
||||
start_time = time.perf_counter()
|
||||
try:
|
||||
for item in self._source:
|
||||
frame, metadata = item
|
||||
frame_u8 = np.asarray(frame, dtype=np.uint8)
|
||||
frame_idx = self._extract_int(metadata, "frame_count", fallback=0)
|
||||
frame_count += 1
|
||||
try:
|
||||
_ = self.process_frame(frame_u8, metadata)
|
||||
except Exception as frame_error:
|
||||
logger.warning(
|
||||
"Skipping frame %d due to processing error: %s",
|
||||
frame_idx,
|
||||
frame_error,
|
||||
)
|
||||
if frame_count % 100 == 0:
|
||||
elapsed = time.perf_counter() - start_time
|
||||
fps = frame_count / elapsed if elapsed > 0 else 0.0
|
||||
logger.info("Processed %d frames (%.2f FPS)", frame_count, fps)
|
||||
return 0
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Interrupted by user, shutting down cleanly.")
|
||||
return 130
|
||||
finally:
|
||||
self.close()
|
||||
|
||||
def close(self) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
close_fn = getattr(self._publisher, "close", None)
|
||||
if callable(close_fn):
|
||||
with suppress(Exception):
|
||||
_ = close_fn()
|
||||
self._closed = True
|
||||
|
||||
|
||||
def validate_runtime_inputs(source: str, checkpoint: str, config: str) -> None:
|
||||
if source.startswith("cvmmap://") or source.isdigit():
|
||||
pass
|
||||
else:
|
||||
source_path = Path(source)
|
||||
if not source_path.is_file():
|
||||
raise ValueError(f"Video source not found: {source}")
|
||||
|
||||
checkpoint_path = Path(checkpoint)
|
||||
if not checkpoint_path.is_file():
|
||||
raise ValueError(f"Checkpoint not found: {checkpoint}")
|
||||
|
||||
config_path = Path(config)
|
||||
if not config_path.is_file():
|
||||
raise ValueError(f"Config not found: {config}")
|
||||
|
||||
|
||||
@click.command(context_settings={"help_option_names": ["-h", "--help"]})
|
||||
@click.option("--source", type=str, required=True)
|
||||
@click.option("--checkpoint", type=str, required=True)
|
||||
@click.option(
|
||||
"--config",
|
||||
type=str,
|
||||
default="configs/sconet/sconet_scoliosis1k.yaml",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option("--device", type=str, default="cuda:0", show_default=True)
|
||||
@click.option("--yolo-model", type=str, default="yolo11n-seg.pt", show_default=True)
|
||||
@click.option("--window", type=click.IntRange(min=1), default=30, show_default=True)
|
||||
@click.option("--stride", type=click.IntRange(min=1), default=30, show_default=True)
|
||||
@click.option("--nats-url", type=str, default=None)
|
||||
@click.option(
|
||||
"--nats-subject",
|
||||
type=str,
|
||||
default="scoliosis.result",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option("--max-frames", type=click.IntRange(min=1), default=None)
|
||||
def main(
|
||||
source: str,
|
||||
checkpoint: str,
|
||||
config: str,
|
||||
device: str,
|
||||
yolo_model: str,
|
||||
window: int,
|
||||
stride: int,
|
||||
nats_url: str | None,
|
||||
nats_subject: str,
|
||||
max_frames: int | None,
|
||||
) -> None:
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
)
|
||||
|
||||
try:
|
||||
validate_runtime_inputs(source=source, checkpoint=checkpoint, config=config)
|
||||
pipeline = ScoliosisPipeline(
|
||||
source=source,
|
||||
checkpoint=checkpoint,
|
||||
config=config,
|
||||
device=device,
|
||||
yolo_model=yolo_model,
|
||||
window=window,
|
||||
stride=stride,
|
||||
nats_url=nats_url,
|
||||
nats_subject=nats_subject,
|
||||
max_frames=max_frames,
|
||||
)
|
||||
raise SystemExit(pipeline.run())
|
||||
except ValueError as err:
|
||||
click.echo(f"Error: {err}", err=True)
|
||||
raise SystemExit(2) from err
|
||||
except RuntimeError as err:
|
||||
click.echo(f"Runtime error: {err}", err=True)
|
||||
raise SystemExit(1) from err
|
||||
@@ -0,0 +1,270 @@
|
||||
from collections.abc import Callable
|
||||
import math
|
||||
from typing import cast
|
||||
|
||||
import cv2
|
||||
from beartype import beartype
|
||||
import jaxtyping
|
||||
from jaxtyping import Float, UInt8
|
||||
import numpy as np
|
||||
from numpy import ndarray
|
||||
from numpy.typing import NDArray
|
||||
|
||||
SIL_HEIGHT = 64
|
||||
SIL_WIDTH = 44
|
||||
SIL_FULL_WIDTH = 64
|
||||
SIDE_CUT = 10
|
||||
MIN_MASK_AREA = 500
|
||||
|
||||
JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
|
||||
JaxtypedFactory = Callable[..., JaxtypedDecorator]
|
||||
jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
|
||||
|
||||
UInt8Array = NDArray[np.uint8]
|
||||
Float32Array = NDArray[np.float32]
|
||||
|
||||
|
||||
def _read_attr(container: object, key: str) -> object | None:
|
||||
if isinstance(container, dict):
|
||||
dict_obj = cast(dict[object, object], container)
|
||||
return dict_obj.get(key)
|
||||
try:
|
||||
return cast(object, object.__getattribute__(container, key))
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
|
||||
def _to_numpy_array(value: object) -> NDArray[np.generic]:
|
||||
current: object = value
|
||||
if isinstance(current, np.ndarray):
|
||||
return current
|
||||
|
||||
detach_obj = _read_attr(current, "detach")
|
||||
if callable(detach_obj):
|
||||
detach_fn = cast(Callable[[], object], detach_obj)
|
||||
current = detach_fn()
|
||||
|
||||
cpu_obj = _read_attr(current, "cpu")
|
||||
if callable(cpu_obj):
|
||||
cpu_fn = cast(Callable[[], object], cpu_obj)
|
||||
current = cpu_fn()
|
||||
|
||||
numpy_obj = _read_attr(current, "numpy")
|
||||
if callable(numpy_obj):
|
||||
numpy_fn = cast(Callable[[], object], numpy_obj)
|
||||
as_numpy = numpy_fn()
|
||||
if isinstance(as_numpy, np.ndarray):
|
||||
return as_numpy
|
||||
|
||||
return cast(NDArray[np.generic], np.asarray(current))
|
||||
|
||||
|
||||
def _bbox_from_mask(mask: UInt8[ndarray, "h w"]) -> tuple[int, int, int, int] | None:
|
||||
mask_u8 = np.asarray(mask, dtype=np.uint8)
|
||||
coords = np.argwhere(mask_u8 > 0)
|
||||
if int(coords.size) == 0:
|
||||
return None
|
||||
|
||||
ys = coords[:, 0].astype(np.int64)
|
||||
xs = coords[:, 1].astype(np.int64)
|
||||
x1 = int(np.min(xs))
|
||||
x2 = int(np.max(xs)) + 1
|
||||
y1 = int(np.min(ys))
|
||||
y2 = int(np.max(ys)) + 1
|
||||
if x2 <= x1 or y2 <= y1:
|
||||
return None
|
||||
return (x1, y1, x2, y2)
|
||||
|
||||
|
||||
def _sanitize_bbox(
|
||||
bbox: tuple[int, int, int, int], height: int, width: int
|
||||
) -> tuple[int, int, int, int] | None:
|
||||
x1, y1, x2, y2 = bbox
|
||||
x1c = max(0, min(int(x1), width - 1))
|
||||
y1c = max(0, min(int(y1), height - 1))
|
||||
x2c = max(0, min(int(x2), width))
|
||||
y2c = max(0, min(int(y2), height))
|
||||
if x2c <= x1c or y2c <= y1c:
|
||||
return None
|
||||
return (x1c, y1c, x2c, y2c)
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def frame_to_person_mask(
|
||||
result: object, min_area: int = MIN_MASK_AREA
|
||||
) -> tuple[UInt8[ndarray, "h w"], tuple[int, int, int, int]] | None:
|
||||
masks_obj = _read_attr(result, "masks")
|
||||
if masks_obj is None:
|
||||
return None
|
||||
|
||||
masks_data_obj = _read_attr(masks_obj, "data")
|
||||
if masks_data_obj is None:
|
||||
return None
|
||||
|
||||
masks_raw = _to_numpy_array(masks_data_obj)
|
||||
masks_float = np.asarray(masks_raw, dtype=np.float32)
|
||||
if masks_float.ndim == 2:
|
||||
masks_float = masks_float[np.newaxis, ...]
|
||||
if masks_float.ndim != 3:
|
||||
return None
|
||||
mask_count = int(cast(tuple[int, int, int], masks_float.shape)[0])
|
||||
if mask_count <= 0:
|
||||
return None
|
||||
|
||||
box_values: list[tuple[float, float, float, float]] | None = None
|
||||
boxes_obj = _read_attr(result, "boxes")
|
||||
if boxes_obj is not None:
|
||||
xyxy_obj = _read_attr(boxes_obj, "xyxy")
|
||||
if xyxy_obj is not None:
|
||||
xyxy_raw = np.asarray(_to_numpy_array(xyxy_obj), dtype=np.float32)
|
||||
if xyxy_raw.ndim == 1 and int(xyxy_raw.size) >= 4:
|
||||
xyxy_2d = np.asarray(xyxy_raw[:4].reshape(1, 4), dtype=np.float64)
|
||||
x1f = cast(np.float64, xyxy_2d[0, 0])
|
||||
y1f = cast(np.float64, xyxy_2d[0, 1])
|
||||
x2f = cast(np.float64, xyxy_2d[0, 2])
|
||||
y2f = cast(np.float64, xyxy_2d[0, 3])
|
||||
box_values = [
|
||||
(
|
||||
float(x1f),
|
||||
float(y1f),
|
||||
float(x2f),
|
||||
float(y2f),
|
||||
)
|
||||
]
|
||||
elif xyxy_raw.ndim == 2:
|
||||
shape_2d = cast(tuple[int, int], xyxy_raw.shape)
|
||||
if int(shape_2d[1]) >= 4:
|
||||
xyxy_2d = np.asarray(xyxy_raw[:, :4], dtype=np.float64)
|
||||
box_values = []
|
||||
for row_idx in range(int(cast(tuple[int, int], xyxy_2d.shape)[0])):
|
||||
x1f = cast(np.float64, xyxy_2d[row_idx, 0])
|
||||
y1f = cast(np.float64, xyxy_2d[row_idx, 1])
|
||||
x2f = cast(np.float64, xyxy_2d[row_idx, 2])
|
||||
y2f = cast(np.float64, xyxy_2d[row_idx, 3])
|
||||
box_values.append(
|
||||
(
|
||||
float(x1f),
|
||||
float(y1f),
|
||||
float(x2f),
|
||||
float(y2f),
|
||||
)
|
||||
)
|
||||
|
||||
best_area = -1
|
||||
best_mask: UInt8[ndarray, "h w"] | None = None
|
||||
best_bbox: tuple[int, int, int, int] | None = None
|
||||
|
||||
for idx in range(mask_count):
|
||||
mask_float = np.asarray(masks_float[idx], dtype=np.float32)
|
||||
if mask_float.ndim != 2:
|
||||
continue
|
||||
mask_binary = np.where(mask_float > 0.5, np.uint8(255), np.uint8(0)).astype(
|
||||
np.uint8
|
||||
)
|
||||
mask_u8 = cast(UInt8[ndarray, "h w"], mask_binary)
|
||||
|
||||
area = int(np.count_nonzero(mask_u8))
|
||||
if area < min_area:
|
||||
continue
|
||||
|
||||
bbox: tuple[int, int, int, int] | None = None
|
||||
shape_2d = cast(tuple[int, int], mask_binary.shape)
|
||||
h = int(shape_2d[0])
|
||||
w = int(shape_2d[1])
|
||||
if box_values is not None:
|
||||
box_count = len(box_values)
|
||||
if idx >= box_count:
|
||||
continue
|
||||
row0, row1, row2, row3 = box_values[idx]
|
||||
bbox_candidate = (
|
||||
int(math.floor(row0)),
|
||||
int(math.floor(row1)),
|
||||
int(math.ceil(row2)),
|
||||
int(math.ceil(row3)),
|
||||
)
|
||||
bbox = _sanitize_bbox(bbox_candidate, h, w)
|
||||
|
||||
if bbox is None:
|
||||
bbox = _bbox_from_mask(mask_u8)
|
||||
|
||||
if bbox is None:
|
||||
continue
|
||||
|
||||
if area > best_area:
|
||||
best_area = area
|
||||
best_mask = mask_u8
|
||||
best_bbox = bbox
|
||||
|
||||
if best_mask is None or best_bbox is None:
|
||||
return None
|
||||
|
||||
return best_mask, best_bbox
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def mask_to_silhouette(
|
||||
mask: UInt8[ndarray, "h w"],
|
||||
bbox: tuple[int, int, int, int],
|
||||
) -> Float[ndarray, "64 44"] | None:
|
||||
mask_u8 = np.where(mask > 0, np.uint8(255), np.uint8(0)).astype(np.uint8)
|
||||
if int(np.count_nonzero(mask_u8)) < MIN_MASK_AREA:
|
||||
return None
|
||||
|
||||
mask_shape = cast(tuple[int, int], mask_u8.shape)
|
||||
h = int(mask_shape[0])
|
||||
w = int(mask_shape[1])
|
||||
bbox_sanitized = _sanitize_bbox(bbox, h, w)
|
||||
if bbox_sanitized is None:
|
||||
return None
|
||||
|
||||
x1, y1, x2, y2 = bbox_sanitized
|
||||
cropped = mask_u8[y1:y2, x1:x2]
|
||||
if cropped.size == 0:
|
||||
return None
|
||||
|
||||
cropped_u8 = np.asarray(cropped, dtype=np.uint8)
|
||||
row_sums = np.sum(cropped_u8, axis=1, dtype=np.int64)
|
||||
row_nonzero = np.nonzero(row_sums > 0)[0].astype(np.int64)
|
||||
if int(row_nonzero.size) == 0:
|
||||
return None
|
||||
top = int(cast(np.int64, row_nonzero[0]))
|
||||
bottom = int(cast(np.int64, row_nonzero[-1])) + 1
|
||||
tightened = cropped[top:bottom, :]
|
||||
if tightened.size == 0:
|
||||
return None
|
||||
|
||||
tight_shape = cast(tuple[int, int], tightened.shape)
|
||||
tight_h = int(tight_shape[0])
|
||||
tight_w = int(tight_shape[1])
|
||||
if tight_h <= 0 or tight_w <= 0:
|
||||
return None
|
||||
|
||||
resized_w = max(1, int(SIL_HEIGHT * (tight_w / tight_h)))
|
||||
resized = np.asarray(
|
||||
cv2.resize(tightened, (resized_w, SIL_HEIGHT), interpolation=cv2.INTER_CUBIC),
|
||||
dtype=np.uint8,
|
||||
)
|
||||
|
||||
if resized_w >= SIL_FULL_WIDTH:
|
||||
start = (resized_w - SIL_FULL_WIDTH) // 2
|
||||
normalized_64 = resized[:, start : start + SIL_FULL_WIDTH]
|
||||
else:
|
||||
pad_left = (SIL_FULL_WIDTH - resized_w) // 2
|
||||
pad_right = SIL_FULL_WIDTH - resized_w - pad_left
|
||||
normalized_64 = np.pad(
|
||||
resized,
|
||||
((0, 0), (pad_left, pad_right)),
|
||||
mode="constant",
|
||||
constant_values=0,
|
||||
)
|
||||
|
||||
silhouette = np.asarray(
|
||||
normalized_64[:, SIDE_CUT : SIL_FULL_WIDTH - SIDE_CUT], dtype=np.float32
|
||||
)
|
||||
if silhouette.shape != (SIL_HEIGHT, SIL_WIDTH):
|
||||
return None
|
||||
|
||||
silhouette_norm = np.clip(silhouette / np.float32(255.0), 0.0, 1.0).astype(
|
||||
np.float32
|
||||
)
|
||||
return cast(Float[ndarray, "64 44"], silhouette_norm)
|
||||
@@ -0,0 +1,317 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from typing import ClassVar, Protocol, cast, override
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from beartype import beartype
|
||||
from einops import rearrange
|
||||
from jaxtyping import Float
|
||||
import jaxtyping
|
||||
from torch import Tensor
|
||||
|
||||
_OPENGAIT_PACKAGE_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(_OPENGAIT_PACKAGE_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(_OPENGAIT_PACKAGE_ROOT))
|
||||
|
||||
from opengait.modeling.backbones.resnet import ResNet9
|
||||
from opengait.modeling.modules import (
|
||||
HorizontalPoolingPyramid,
|
||||
PackSequenceWrapper as TemporalPool,
|
||||
SeparateBNNecks,
|
||||
SeparateFCs,
|
||||
)
|
||||
from opengait.utils import common as common_utils
|
||||
|
||||
|
||||
JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
|
||||
JaxtypedFactory = Callable[..., JaxtypedDecorator]
|
||||
jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
|
||||
ConfigLoader = Callable[[str], dict[str, object]]
|
||||
config_loader = cast(ConfigLoader, common_utils.config_loader)
|
||||
|
||||
|
||||
class TemporalPoolLike(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
seqs: Tensor,
|
||||
seqL: object,
|
||||
dim: int = 2,
|
||||
options: dict[str, int] | None = None,
|
||||
) -> object: ...
|
||||
|
||||
|
||||
class HppLike(Protocol):
|
||||
def __call__(self, x: Tensor) -> Tensor: ...
|
||||
|
||||
|
||||
class FCsLike(Protocol):
|
||||
def __call__(self, x: Tensor) -> Tensor: ...
|
||||
|
||||
|
||||
class BNNecksLike(Protocol):
|
||||
def __call__(self, x: Tensor) -> tuple[Tensor, Tensor]: ...
|
||||
|
||||
|
||||
class ScoNetDemo(nn.Module):
|
||||
LABEL_MAP: ClassVar[dict[int, str]] = {0: "negative", 1: "neutral", 2: "positive"}
|
||||
cfg_path: str
|
||||
cfg: dict[str, object]
|
||||
backbone: ResNet9
|
||||
temporal_pool: TemporalPoolLike
|
||||
hpp: HppLike
|
||||
fcs: FCsLike
|
||||
bn_necks: BNNecksLike
|
||||
device: torch.device
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def __init__(
|
||||
self,
|
||||
cfg_path: str | Path = "configs/sconet/sconet_scoliosis1k.yaml",
|
||||
checkpoint_path: str | Path | None = None,
|
||||
device: str | torch.device | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
resolved_cfg = self._resolve_path(cfg_path)
|
||||
self.cfg_path = str(resolved_cfg)
|
||||
|
||||
self.cfg = config_loader(self.cfg_path)
|
||||
|
||||
model_cfg = self._extract_model_cfg(self.cfg)
|
||||
backbone_cfg = self._extract_dict(model_cfg, "backbone_cfg")
|
||||
|
||||
if backbone_cfg.get("type") != "ResNet9":
|
||||
raise ValueError(
|
||||
"ScoNetDemo currently supports backbone type ResNet9 only."
|
||||
)
|
||||
|
||||
self.backbone = ResNet9(
|
||||
block=self._extract_str(backbone_cfg, "block"),
|
||||
channels=self._extract_int_list(backbone_cfg, "channels"),
|
||||
in_channel=self._extract_int(backbone_cfg, "in_channel", default=1),
|
||||
layers=self._extract_int_list(backbone_cfg, "layers"),
|
||||
strides=self._extract_int_list(backbone_cfg, "strides"),
|
||||
maxpool=self._extract_bool(backbone_cfg, "maxpool", default=True),
|
||||
)
|
||||
|
||||
fcs_cfg = self._extract_dict(model_cfg, "SeparateFCs")
|
||||
bn_cfg = self._extract_dict(model_cfg, "SeparateBNNecks")
|
||||
bin_num = self._extract_int_list(model_cfg, "bin_num")
|
||||
|
||||
self.temporal_pool = cast(TemporalPoolLike, TemporalPool(torch.max))
|
||||
self.hpp = cast(HppLike, HorizontalPoolingPyramid(bin_num=bin_num))
|
||||
self.fcs = cast(
|
||||
FCsLike,
|
||||
SeparateFCs(
|
||||
parts_num=self._extract_int(fcs_cfg, "parts_num"),
|
||||
in_channels=self._extract_int(fcs_cfg, "in_channels"),
|
||||
out_channels=self._extract_int(fcs_cfg, "out_channels"),
|
||||
norm=self._extract_bool(fcs_cfg, "norm", default=False),
|
||||
),
|
||||
)
|
||||
self.bn_necks = cast(
|
||||
BNNecksLike,
|
||||
SeparateBNNecks(
|
||||
parts_num=self._extract_int(bn_cfg, "parts_num"),
|
||||
in_channels=self._extract_int(bn_cfg, "in_channels"),
|
||||
class_num=self._extract_int(bn_cfg, "class_num"),
|
||||
norm=self._extract_bool(bn_cfg, "norm", default=True),
|
||||
parallel_BN1d=self._extract_bool(bn_cfg, "parallel_BN1d", default=True),
|
||||
),
|
||||
)
|
||||
|
||||
self.device = (
|
||||
torch.device(device) if device is not None else torch.device("cpu")
|
||||
)
|
||||
_ = self.to(self.device)
|
||||
|
||||
if checkpoint_path is not None:
|
||||
_ = self.load_checkpoint(checkpoint_path)
|
||||
|
||||
_ = self.eval()
|
||||
|
||||
@staticmethod
|
||||
def _resolve_path(path: str | Path) -> Path:
|
||||
candidate = Path(path)
|
||||
if candidate.is_file():
|
||||
return candidate
|
||||
if candidate.is_absolute():
|
||||
return candidate
|
||||
repo_root = Path(__file__).resolve().parents[2]
|
||||
return repo_root / candidate
|
||||
|
||||
@staticmethod
|
||||
def _extract_model_cfg(cfg: dict[str, object]) -> dict[str, object]:
|
||||
model_cfg_obj = cfg.get("model_cfg")
|
||||
if not isinstance(model_cfg_obj, dict):
|
||||
raise TypeError("model_cfg must be a dictionary.")
|
||||
return cast(dict[str, object], model_cfg_obj)
|
||||
|
||||
@staticmethod
|
||||
def _extract_dict(container: dict[str, object], key: str) -> dict[str, object]:
|
||||
value = container.get(key)
|
||||
if not isinstance(value, dict):
|
||||
raise TypeError(f"{key} must be a dictionary.")
|
||||
return cast(dict[str, object], value)
|
||||
|
||||
@staticmethod
|
||||
def _extract_str(container: dict[str, object], key: str) -> str:
|
||||
value = container.get(key)
|
||||
if not isinstance(value, str):
|
||||
raise TypeError(f"{key} must be a string.")
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _extract_int(
|
||||
container: dict[str, object], key: str, default: int | None = None
|
||||
) -> int:
|
||||
value = container.get(key, default)
|
||||
if not isinstance(value, int):
|
||||
raise TypeError(f"{key} must be an int.")
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _extract_bool(
|
||||
container: dict[str, object], key: str, default: bool | None = None
|
||||
) -> bool:
|
||||
value = container.get(key, default)
|
||||
if not isinstance(value, bool):
|
||||
raise TypeError(f"{key} must be a bool.")
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _extract_int_list(container: dict[str, object], key: str) -> list[int]:
|
||||
value = container.get(key)
|
||||
if not isinstance(value, list):
|
||||
raise TypeError(f"{key} must be a list[int].")
|
||||
values = cast(list[object], value)
|
||||
if not all(isinstance(v, int) for v in values):
|
||||
raise TypeError(f"{key} must be a list[int].")
|
||||
return cast(list[int], values)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_state_dict(
|
||||
state_dict_obj: dict[object, object],
|
||||
) -> dict[str, Tensor]:
|
||||
prefix_remap: tuple[tuple[str, str], ...] = (
|
||||
("Backbone.forward_block.", "backbone."),
|
||||
("FCs.", "fcs."),
|
||||
("BNNecks.", "bn_necks."),
|
||||
)
|
||||
cleaned_state_dict: dict[str, Tensor] = {}
|
||||
for key_obj, value_obj in state_dict_obj.items():
|
||||
if not isinstance(key_obj, str):
|
||||
raise TypeError("Checkpoint state_dict keys must be strings.")
|
||||
if not isinstance(value_obj, Tensor):
|
||||
raise TypeError("Checkpoint state_dict values must be torch.Tensor.")
|
||||
key = key_obj[7:] if key_obj.startswith("module.") else key_obj
|
||||
for source_prefix, target_prefix in prefix_remap:
|
||||
if key.startswith(source_prefix):
|
||||
key = f"{target_prefix}{key[len(source_prefix) :]}"
|
||||
break
|
||||
if key in cleaned_state_dict:
|
||||
raise RuntimeError(
|
||||
f"Checkpoint key normalization collision detected for key '{key}'."
|
||||
)
|
||||
cleaned_state_dict[key] = value_obj
|
||||
return cleaned_state_dict
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def load_checkpoint(
|
||||
self,
|
||||
checkpoint_path: str | Path,
|
||||
map_location: str | torch.device | None = None,
|
||||
strict: bool = True,
|
||||
) -> None:
|
||||
resolved_ckpt = self._resolve_path(checkpoint_path)
|
||||
checkpoint_obj = cast(
|
||||
object,
|
||||
torch.load(
|
||||
str(resolved_ckpt),
|
||||
map_location=map_location if map_location is not None else self.device,
|
||||
),
|
||||
)
|
||||
|
||||
state_dict_obj: object = checkpoint_obj
|
||||
if isinstance(checkpoint_obj, dict) and "model" in checkpoint_obj:
|
||||
state_dict_obj = cast(dict[str, object], checkpoint_obj)["model"]
|
||||
|
||||
if not isinstance(state_dict_obj, dict):
|
||||
raise TypeError("Unsupported checkpoint format.")
|
||||
|
||||
cleaned_state_dict = self._normalize_state_dict(
|
||||
cast(dict[object, object], state_dict_obj)
|
||||
)
|
||||
try:
|
||||
_ = self.load_state_dict(cleaned_state_dict, strict=strict)
|
||||
except RuntimeError as exc:
|
||||
raise RuntimeError(
|
||||
f"Failed to load ScoNetDemo checkpoint after key normalization from '{resolved_ckpt}'."
|
||||
) from exc
|
||||
_ = self.eval()
|
||||
|
||||
def _prepare_sils(self, sils: Tensor) -> Tensor:
|
||||
if sils.ndim == 4:
|
||||
sils = sils.unsqueeze(1)
|
||||
elif sils.ndim == 5 and sils.shape[1] != 1 and sils.shape[2] == 1:
|
||||
sils = rearrange(sils, "b s c h w -> b c s h w")
|
||||
|
||||
if sils.ndim != 5 or sils.shape[1] != 1:
|
||||
raise ValueError("Expected sils shape [B, 1, S, H, W] or [B, S, H, W].")
|
||||
|
||||
return sils.float().to(self.device)
|
||||
|
||||
def _forward_backbone(self, sils: Tensor) -> Tensor:
|
||||
batch, channels, seq, height, width = sils.shape
|
||||
framewise = sils.transpose(1, 2).reshape(batch * seq, channels, height, width)
|
||||
frame_feats = cast(Tensor, self.backbone(framewise))
|
||||
_, out_channels, out_h, out_w = frame_feats.shape
|
||||
return (
|
||||
frame_feats.reshape(batch, seq, out_channels, out_h, out_w)
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
@override
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def forward(self, sils: Float[Tensor, "batch 1 seq 64 44"]) -> dict[str, Tensor]:
|
||||
with torch.inference_mode():
|
||||
prepared_sils = self._prepare_sils(sils)
|
||||
outs = self._forward_backbone(prepared_sils)
|
||||
|
||||
pooled_obj = self.temporal_pool(outs, None, options={"dim": 2})
|
||||
if (
|
||||
not isinstance(pooled_obj, tuple)
|
||||
or not pooled_obj
|
||||
or not isinstance(pooled_obj[0], Tensor)
|
||||
):
|
||||
raise TypeError("TemporalPool output is invalid.")
|
||||
pooled = pooled_obj[0]
|
||||
|
||||
feat = self.hpp(pooled)
|
||||
embed_1 = self.fcs(feat)
|
||||
_, logits = self.bn_necks(embed_1)
|
||||
|
||||
mean_logits = logits.mean(dim=-1)
|
||||
pred_ids = torch.argmax(mean_logits, dim=-1)
|
||||
probs = torch.softmax(mean_logits, dim=-1)
|
||||
confidence = torch.gather(
|
||||
probs, dim=-1, index=pred_ids.unsqueeze(-1)
|
||||
).squeeze(-1)
|
||||
|
||||
return {"logits": logits, "label": pred_ids, "confidence": confidence}
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def predict(self, sils: Float[Tensor, "batch 1 seq 64 44"]) -> tuple[str, float]:
|
||||
outputs = cast(dict[str, Tensor], self.forward(sils))
|
||||
labels = outputs["label"]
|
||||
confidence = outputs["confidence"]
|
||||
|
||||
if labels.numel() != 1:
|
||||
raise ValueError("predict expects batch size 1.")
|
||||
|
||||
label_id = int(labels.item())
|
||||
return self.LABEL_MAP[label_id], float(confidence.item())
|
||||
@@ -0,0 +1,295 @@
|
||||
"""Sliding window / ring buffer manager for real-time gait analysis.
|
||||
|
||||
This module provides bounded buffer management for silhouette sequences
|
||||
with track ID tracking and gap detection.
|
||||
"""
|
||||
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING, Protocol, final
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from jaxtyping import Float
|
||||
from numpy import ndarray
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from numpy.typing import NDArray
|
||||
|
||||
|
||||
# Silhouette dimensions from preprocess.py
|
||||
SIL_HEIGHT: int = 64
|
||||
SIL_WIDTH: int = 44
|
||||
|
||||
|
||||
class _Boxes(Protocol):
|
||||
"""Protocol for boxes with xyxy and id attributes."""
|
||||
|
||||
@property
|
||||
def xyxy(self) -> "NDArray[np.float32] | object": ...
|
||||
@property
|
||||
def id(self) -> "NDArray[np.int64] | object | None": ...
|
||||
|
||||
|
||||
class _Masks(Protocol):
|
||||
"""Protocol for masks with data attribute."""
|
||||
|
||||
@property
|
||||
def data(self) -> "NDArray[np.float32] | object": ...
|
||||
|
||||
|
||||
class _DetectionResults(Protocol):
|
||||
"""Protocol for detection results from Ultralytics-style objects."""
|
||||
|
||||
@property
|
||||
def boxes(self) -> _Boxes: ...
|
||||
@property
|
||||
def masks(self) -> _Masks: ...
|
||||
|
||||
|
||||
@final
|
||||
class SilhouetteWindow:
|
||||
"""Bounded sliding window for silhouette sequences.
|
||||
|
||||
Manages a fixed-size buffer of silhouettes with track ID tracking
|
||||
and automatic reset on track changes or frame gaps.
|
||||
|
||||
Attributes:
|
||||
window_size: Maximum number of frames in the buffer.
|
||||
stride: Classification stride (frames between classifications).
|
||||
gap_threshold: Maximum allowed frame gap before reset.
|
||||
"""
|
||||
|
||||
window_size: int
|
||||
stride: int
|
||||
gap_threshold: int
|
||||
_buffer: deque[Float[ndarray, "64 44"]]
|
||||
_frame_indices: deque[int]
|
||||
_track_id: int | None
|
||||
_last_classified_frame: int
|
||||
_frame_count: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
window_size: int = 30,
|
||||
stride: int = 1,
|
||||
gap_threshold: int = 15,
|
||||
) -> None:
|
||||
"""Initialize the silhouette window.
|
||||
|
||||
Args:
|
||||
window_size: Maximum buffer size (default 30).
|
||||
stride: Frames between classifications (default 1).
|
||||
gap_threshold: Max frame gap before reset (default 15).
|
||||
"""
|
||||
self.window_size = window_size
|
||||
self.stride = stride
|
||||
self.gap_threshold = gap_threshold
|
||||
|
||||
# Bounded storage via deque
|
||||
self._buffer = deque(maxlen=window_size)
|
||||
self._frame_indices = deque(maxlen=window_size)
|
||||
self._track_id = None
|
||||
self._last_classified_frame = -1
|
||||
self._frame_count = 0
|
||||
|
||||
def push(self, sil: np.ndarray, frame_idx: int, track_id: int) -> None:
|
||||
"""Push a new silhouette into the window.
|
||||
|
||||
Automatically resets buffer on track ID change or frame gap
|
||||
exceeding gap_threshold.
|
||||
|
||||
Args:
|
||||
sil: Silhouette array of shape (64, 44), float32.
|
||||
frame_idx: Current frame index for gap detection.
|
||||
track_id: Track ID for the person.
|
||||
"""
|
||||
# Check for track ID change
|
||||
if self._track_id is not None and track_id != self._track_id:
|
||||
self.reset()
|
||||
|
||||
# Check for frame gap
|
||||
if self._frame_indices:
|
||||
last_frame = self._frame_indices[-1]
|
||||
gap = frame_idx - last_frame
|
||||
if gap > self.gap_threshold:
|
||||
self.reset()
|
||||
|
||||
# Update track ID
|
||||
self._track_id = track_id
|
||||
|
||||
# Validate and append silhouette
|
||||
sil_array = np.asarray(sil, dtype=np.float32)
|
||||
if sil_array.shape != (SIL_HEIGHT, SIL_WIDTH):
|
||||
raise ValueError(
|
||||
f"Expected silhouette shape ({SIL_HEIGHT}, {SIL_WIDTH}), got {sil_array.shape}"
|
||||
)
|
||||
|
||||
self._buffer.append(sil_array)
|
||||
self._frame_indices.append(frame_idx)
|
||||
self._frame_count += 1
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
"""Check if window has enough frames for classification.
|
||||
|
||||
Returns:
|
||||
True if buffer is full (window_size frames).
|
||||
"""
|
||||
return len(self._buffer) >= self.window_size
|
||||
|
||||
def should_classify(self) -> bool:
|
||||
"""Check if classification should run based on stride.
|
||||
|
||||
Returns:
|
||||
True if enough frames have passed since last classification.
|
||||
"""
|
||||
if not self.is_ready():
|
||||
return False
|
||||
|
||||
if self._last_classified_frame < 0:
|
||||
return True
|
||||
|
||||
current_frame = self._frame_indices[-1]
|
||||
frames_since = current_frame - self._last_classified_frame
|
||||
return frames_since >= self.stride
|
||||
|
||||
def get_tensor(self, device: str = "cpu") -> torch.Tensor:
|
||||
"""Get window contents as a tensor for model input.
|
||||
|
||||
Args:
|
||||
device: Target device for the tensor (default 'cpu').
|
||||
|
||||
Returns:
|
||||
Tensor of shape [1, 1, window_size, 64, 44] with dtype float32.
|
||||
|
||||
Raises:
|
||||
ValueError: If buffer is not full.
|
||||
"""
|
||||
if not self.is_ready():
|
||||
raise ValueError(
|
||||
f"Window not ready: {len(self._buffer)}/{self.window_size} frames"
|
||||
)
|
||||
|
||||
# Stack buffer into array [window_size, 64, 44]
|
||||
stacked = np.stack(list(self._buffer), axis=0)
|
||||
|
||||
# Add batch and channel dims: [1, 1, window_size, 64, 44]
|
||||
tensor = torch.from_numpy(stacked.astype(np.float32))
|
||||
tensor = tensor.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
return tensor.to(device)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the window, clearing all buffers and counters."""
|
||||
self._buffer.clear()
|
||||
self._frame_indices.clear()
|
||||
self._track_id = None
|
||||
self._last_classified_frame = -1
|
||||
self._frame_count = 0
|
||||
|
||||
def mark_classified(self) -> None:
|
||||
"""Mark current frame as classified, updating stride tracking."""
|
||||
if self._frame_indices:
|
||||
self._last_classified_frame = self._frame_indices[-1]
|
||||
|
||||
@property
|
||||
def current_track_id(self) -> int | None:
|
||||
"""Current track ID, or None if buffer is empty."""
|
||||
return self._track_id
|
||||
|
||||
@property
|
||||
def frame_count(self) -> int:
|
||||
"""Total frames pushed since last reset."""
|
||||
return self._frame_count
|
||||
|
||||
@property
|
||||
def fill_level(self) -> float:
|
||||
"""Fill ratio of the buffer (0.0 to 1.0)."""
|
||||
return len(self._buffer) / self.window_size
|
||||
|
||||
|
||||
def select_person(
|
||||
results: _DetectionResults,
|
||||
) -> tuple[ndarray, tuple[int, int, int, int], int] | None:
|
||||
"""Select the person with largest bounding box from detection results.
|
||||
|
||||
Args:
|
||||
results: Detection results object with boxes and masks attributes.
|
||||
Expected to have:
|
||||
- boxes.xyxy: array of bounding boxes [N, 4]
|
||||
- masks.data: array of masks [N, H, W]
|
||||
- boxes.id: optional track IDs [N]
|
||||
|
||||
Returns:
|
||||
Tuple of (mask, bbox, track_id) for the largest person,
|
||||
or None if no valid detections or track IDs unavailable.
|
||||
"""
|
||||
# Check for track IDs
|
||||
boxes_obj: _Boxes | object = getattr(results, "boxes", None)
|
||||
if boxes_obj is None:
|
||||
return None
|
||||
|
||||
track_ids_obj: ndarray | object | None = getattr(boxes_obj, "id", None)
|
||||
if track_ids_obj is None:
|
||||
return None
|
||||
|
||||
track_ids: ndarray = np.asarray(track_ids_obj)
|
||||
if track_ids.size == 0:
|
||||
return None
|
||||
|
||||
# Get bounding boxes
|
||||
xyxy_obj: ndarray | object = getattr(boxes_obj, "xyxy", None)
|
||||
if xyxy_obj is None:
|
||||
return None
|
||||
|
||||
bboxes: ndarray = np.asarray(xyxy_obj)
|
||||
if bboxes.ndim == 1:
|
||||
bboxes = bboxes.reshape(1, -1)
|
||||
|
||||
if bboxes.shape[0] == 0:
|
||||
return None
|
||||
|
||||
# Get masks
|
||||
masks_obj: _Masks | object = getattr(results, "masks", None)
|
||||
if masks_obj is None:
|
||||
return None
|
||||
|
||||
masks_data: ndarray | object = getattr(masks_obj, "data", None)
|
||||
if masks_data is None:
|
||||
return None
|
||||
|
||||
masks: ndarray = np.asarray(masks_data)
|
||||
if masks.ndim == 2:
|
||||
masks = masks[np.newaxis, ...]
|
||||
|
||||
if masks.shape[0] != bboxes.shape[0]:
|
||||
return None
|
||||
|
||||
# Find largest bbox by area
|
||||
best_idx: int = -1
|
||||
best_area: float = -1.0
|
||||
|
||||
for i in range(int(bboxes.shape[0])):
|
||||
row: "NDArray[np.float32]" = bboxes[i][:4]
|
||||
x1f: float = float(row[0])
|
||||
y1f: float = float(row[1])
|
||||
x2f: float = float(row[2])
|
||||
y2f: float = float(row[3])
|
||||
area: float = (x2f - x1f) * (y2f - y1f)
|
||||
if area > best_area:
|
||||
best_area = area
|
||||
best_idx = i
|
||||
|
||||
if best_idx < 0:
|
||||
return None
|
||||
|
||||
# Extract mask and bbox
|
||||
mask: "NDArray[np.float32]" = masks[best_idx]
|
||||
bbox = (
|
||||
int(float(bboxes[best_idx][0])),
|
||||
int(float(bboxes[best_idx][1])),
|
||||
int(float(bboxes[best_idx][2])),
|
||||
int(float(bboxes[best_idx][3])),
|
||||
)
|
||||
track_id = int(track_ids[best_idx]) if best_idx < len(track_ids) else best_idx
|
||||
|
||||
return mask, bbox, track_id
|
||||
Reference in New Issue
Block a user