feat: extract opengait_studio monorepo module
Move demo implementation into opengait_studio, retire Sports2D runtime integration, and align packaging with root-level monorepo dependency management.
This commit is contained in:
@@ -0,0 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .pipeline import main
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,219 @@
|
||||
"""
|
||||
Input adapters for OpenGait demo.
|
||||
|
||||
Provides generator-based interfaces for video sources:
|
||||
- OpenCV (video files, cameras)
|
||||
- cv-mmap (shared memory streams)
|
||||
"""
|
||||
|
||||
from collections.abc import AsyncIterator, Generator, Iterable
|
||||
from typing import Protocol, cast
|
||||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Type alias for frame stream: (frame_array, metadata_dict)
|
||||
FrameStream = Iterable[tuple[np.ndarray, dict[str, object]]]
|
||||
|
||||
|
||||
# Protocol for cv-mmap metadata (needed at runtime for nested function annotation)
|
||||
class _FrameMetadata(Protocol):
|
||||
frame_count: int
|
||||
timestamp_ns: int
|
||||
|
||||
|
||||
# Protocol for cv-mmap client (needed at runtime for cast)
|
||||
class _CvMmapClient(Protocol):
|
||||
def __aiter__(self) -> AsyncIterator[tuple[np.ndarray, _FrameMetadata]]: ...
|
||||
|
||||
|
||||
def opencv_source(
|
||||
path: str | int, max_frames: int | None = None
|
||||
) -> Generator[tuple[np.ndarray, dict[str, object]], None, None]:
|
||||
"""
|
||||
Generator that yields frames from an OpenCV video source.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : str | int
|
||||
Video file path or camera index (e.g., 0 for default camera)
|
||||
max_frames : int | None, optional
|
||||
Maximum number of frames to yield. None means unlimited.
|
||||
|
||||
Yields
|
||||
------
|
||||
tuple[np.ndarray, dict[str, object]]
|
||||
(frame_array, metadata_dict) where metadata includes:
|
||||
- frame_count: frame index (0-based)
|
||||
- timestamp_ns: monotonic timestamp in nanoseconds (if available)
|
||||
- source: the path/int provided
|
||||
"""
|
||||
import time
|
||||
|
||||
import cv2
|
||||
|
||||
cap = cv2.VideoCapture(path)
|
||||
if not cap.isOpened():
|
||||
raise RuntimeError(f"Failed to open video source: {path}")
|
||||
|
||||
is_file_source = isinstance(path, str)
|
||||
source_fps = float(cap.get(cv2.CAP_PROP_FPS)) if is_file_source else 0.0
|
||||
fps_valid = source_fps > 0.0 and np.isfinite(source_fps)
|
||||
fallback_fps = source_fps if fps_valid else 30.0
|
||||
fallback_interval_ns = int(1_000_000_000 / fallback_fps)
|
||||
start_ns = time.monotonic_ns()
|
||||
|
||||
frame_idx = 0
|
||||
try:
|
||||
while max_frames is None or frame_idx < max_frames:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
# End of stream
|
||||
break
|
||||
|
||||
if is_file_source:
|
||||
pos_msec = float(cap.get(cv2.CAP_PROP_POS_MSEC))
|
||||
if np.isfinite(pos_msec) and pos_msec > 0.0:
|
||||
timestamp_ns = start_ns + int(pos_msec * 1_000_000)
|
||||
else:
|
||||
timestamp_ns = start_ns + frame_idx * fallback_interval_ns
|
||||
else:
|
||||
timestamp_ns = time.monotonic_ns()
|
||||
|
||||
metadata: dict[str, object] = {
|
||||
"frame_count": frame_idx,
|
||||
"timestamp_ns": timestamp_ns,
|
||||
"source": path,
|
||||
}
|
||||
if fps_valid:
|
||||
metadata["source_fps"] = source_fps
|
||||
|
||||
yield frame, metadata
|
||||
frame_idx += 1
|
||||
|
||||
finally:
|
||||
cap.release()
|
||||
logger.debug(f"OpenCV source closed: {path}")
|
||||
|
||||
|
||||
def cvmmap_source(
|
||||
name: str, max_frames: int | None = None
|
||||
) -> Generator[tuple[np.ndarray, dict[str, object]], None, None]:
|
||||
"""
|
||||
Generator that yields frames from a cv-mmap shared memory stream.
|
||||
|
||||
Bridges async cv-mmap client to synchronous generator using asyncio.run().
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
Base name of the cv-mmap source (e.g., "default")
|
||||
max_frames : int | None, optional
|
||||
Maximum number of frames to yield. None means unlimited.
|
||||
|
||||
Yields
|
||||
------
|
||||
tuple[np.ndarray, dict[str, object]]
|
||||
(frame_array, metadata_dict) where metadata includes:
|
||||
- frame_count: frame index from cv-mmap
|
||||
- timestamp_ns: timestamp in nanoseconds from cv-mmap
|
||||
- source: the cv-mmap name
|
||||
|
||||
Raises
|
||||
------
|
||||
ImportError
|
||||
If cvmmap package is not available
|
||||
RuntimeError
|
||||
If cv-mmap stream disconnects or errors
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
# Import cvmmap only when function is called
|
||||
# Use try/except for runtime import check
|
||||
try:
|
||||
from cvmmap import CvMmapClient as _CvMmapClientReal # pyright: ignore[reportMissingTypeStubs]
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"cvmmap package is required for cv-mmap sources. "
|
||||
+ "Install from: https://github.com/crosstyan/cv-mmap"
|
||||
) from e
|
||||
|
||||
# Cast to protocol type for type checking
|
||||
client: _CvMmapClient = cast("_CvMmapClient", _CvMmapClientReal(name))
|
||||
frame_count = 0
|
||||
|
||||
async def _async_generator() -> AsyncIterator[tuple[np.ndarray, _FrameMetadata]]:
|
||||
"""Async generator wrapper."""
|
||||
async for frame, meta in client:
|
||||
yield frame, meta
|
||||
|
||||
# Bridge async to sync using asyncio.run()
|
||||
# We process frames one at a time to keep it simple and robust
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
agen = _async_generator().__aiter__()
|
||||
|
||||
while max_frames is None or frame_count < max_frames:
|
||||
try:
|
||||
frame, meta = loop.run_until_complete(agen.__anext__())
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
metadata: dict[str, object] = {
|
||||
"frame_count": meta.frame_count,
|
||||
"timestamp_ns": meta.timestamp_ns,
|
||||
"source": f"cvmmap://{name}",
|
||||
}
|
||||
|
||||
yield frame, metadata
|
||||
frame_count += 1
|
||||
|
||||
finally:
|
||||
loop.close()
|
||||
logger.debug(f"cv-mmap source closed: {name}")
|
||||
|
||||
|
||||
def create_source(source: str, max_frames: int | None = None) -> FrameStream:
|
||||
"""
|
||||
Factory function to create a frame source from a string specification.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source : str
|
||||
Source specification:
|
||||
- '0', '1', etc. -> Camera index (OpenCV)
|
||||
- 'cvmmap://name' -> cv-mmap shared memory stream
|
||||
- Any other string -> Video file path (OpenCV)
|
||||
max_frames : int | None, optional
|
||||
Maximum number of frames to yield. None means unlimited.
|
||||
|
||||
Returns
|
||||
-------
|
||||
FrameStream
|
||||
Generator yielding (frame, metadata) tuples
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> for frame, meta in create_source('0'): # Camera 0
|
||||
... process(frame)
|
||||
>>> for frame, meta in create_source('cvmmap://default'): # cv-mmap
|
||||
... process(frame)
|
||||
>>> for frame, meta in create_source('/path/to/video.mp4'):
|
||||
... process(frame)
|
||||
"""
|
||||
# Check for cv-mmap protocol
|
||||
if source.startswith("cvmmap://"):
|
||||
name = source[len("cvmmap://") :]
|
||||
return cvmmap_source(name, max_frames)
|
||||
|
||||
# Check for camera index (single digit string)
|
||||
if source.isdigit():
|
||||
return opencv_source(int(source), max_frames)
|
||||
|
||||
# Otherwise treat as file path
|
||||
return opencv_source(source, max_frames)
|
||||
@@ -0,0 +1,390 @@
|
||||
"""
|
||||
Output publishers for OpenGait demo results.
|
||||
|
||||
Provides pluggable result publishing:
|
||||
- ConsolePublisher: JSONL to stdout
|
||||
- NatsPublisher: NATS message broker integration
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import nats
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Protocol, TextIO, TypedDict, cast, runtime_checkable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DemoResult(TypedDict):
|
||||
"""Typed result dictionary for demo pipeline output.
|
||||
|
||||
Contains classification result with frame metadata.
|
||||
"""
|
||||
|
||||
frame: int
|
||||
track_id: int
|
||||
label: str
|
||||
confidence: float
|
||||
window: int
|
||||
timestamp_ns: int
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ResultPublisher(Protocol):
|
||||
"""Protocol for result publishers."""
|
||||
|
||||
def publish(self, result: DemoResult) -> None:
|
||||
"""
|
||||
Publish a result dictionary.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
result : DemoResult
|
||||
Result data with keys: frame, track_id, label, confidence, window, timestamp_ns
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class ConsolePublisher:
|
||||
"""Publisher that outputs JSON Lines to stdout."""
|
||||
|
||||
_output: TextIO
|
||||
|
||||
def __init__(self, output: TextIO = sys.stdout) -> None:
|
||||
"""
|
||||
Initialize console publisher.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : TextIO
|
||||
File-like object to write to (default: sys.stdout)
|
||||
"""
|
||||
self._output = output
|
||||
|
||||
def publish(self, result: DemoResult) -> None:
|
||||
"""
|
||||
Publish result as JSON line.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
result : DemoResult
|
||||
Result data with keys: frame, track_id, label, confidence, window, timestamp_ns
|
||||
"""
|
||||
try:
|
||||
json_line = json.dumps(result, ensure_ascii=False, default=str)
|
||||
_ = self._output.write(json_line + "\n")
|
||||
self._output.flush()
|
||||
except (OSError, ValueError, TypeError) as e:
|
||||
logger.warning(f"Failed to publish to console: {e}")
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the publisher (no-op for console)."""
|
||||
pass
|
||||
|
||||
def __enter__(self) -> ConsolePublisher:
|
||||
"""Context manager entry."""
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
"""Context manager exit."""
|
||||
self.close()
|
||||
|
||||
|
||||
class _NatsClient(Protocol):
|
||||
"""Protocol for connected NATS client."""
|
||||
|
||||
async def publish(self, subject: str, payload: bytes) -> object: ...
|
||||
|
||||
async def close(self) -> object: ...
|
||||
|
||||
async def flush(self) -> object: ...
|
||||
|
||||
|
||||
class NatsPublisher:
|
||||
"""
|
||||
Publisher that sends results to NATS message broker.
|
||||
|
||||
This is a sync-friendly wrapper around the async nats-py client.
|
||||
Uses a background thread with dedicated event loop to bridge sync
|
||||
publish calls to async NATS operations, making it safe to use in
|
||||
both sync and async contexts.
|
||||
"""
|
||||
|
||||
_nats_url: str
|
||||
_subject: str
|
||||
_nc: _NatsClient | None
|
||||
_connected: bool
|
||||
_loop: asyncio.AbstractEventLoop | None
|
||||
_thread: threading.Thread | None
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, nats_url: str, subject: str = "scoliosis.result") -> None:
|
||||
"""
|
||||
Initialize NATS publisher.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
nats_url : str
|
||||
NATS server URL (e.g., "nats://localhost:4222")
|
||||
subject : str
|
||||
NATS subject to publish to (default: "scoliosis.result")
|
||||
"""
|
||||
self._nats_url = nats_url
|
||||
self._subject = subject
|
||||
self._nc = None
|
||||
self._connected = False
|
||||
self._loop = None
|
||||
self._thread = None
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def _start_background_loop(self) -> bool:
|
||||
"""
|
||||
Start background thread with event loop for async operations.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if loop is running, False otherwise
|
||||
"""
|
||||
with self._lock:
|
||||
if self._loop is not None and self._loop.is_running():
|
||||
return True
|
||||
|
||||
try:
|
||||
loop = asyncio.new_event_loop()
|
||||
self._loop = loop
|
||||
|
||||
def run_loop() -> None:
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_forever()
|
||||
|
||||
self._thread = threading.Thread(target=run_loop, daemon=True)
|
||||
self._thread.start()
|
||||
return True
|
||||
except (RuntimeError, OSError) as e:
|
||||
logger.warning(f"Failed to start background event loop: {e}")
|
||||
return False
|
||||
|
||||
def _stop_background_loop(self) -> None:
|
||||
"""Stop the background event loop and thread."""
|
||||
with self._lock:
|
||||
if self._loop is not None and self._loop.is_running():
|
||||
_ = self._loop.call_soon_threadsafe(self._loop.stop)
|
||||
if self._thread is not None and self._thread.is_alive():
|
||||
self._thread.join(timeout=2.0)
|
||||
self._loop = None
|
||||
self._thread = None
|
||||
|
||||
def _ensure_connected(self) -> bool:
|
||||
"""
|
||||
Ensure connection to NATS server.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if connected, False otherwise
|
||||
"""
|
||||
with self._lock:
|
||||
if self._connected and self._nc is not None:
|
||||
return True
|
||||
|
||||
if not self._start_background_loop():
|
||||
return False
|
||||
|
||||
try:
|
||||
|
||||
async def _connect() -> _NatsClient:
|
||||
nc = await nats.connect(self._nats_url) # pyright: ignore[reportUnknownMemberType]
|
||||
return cast(_NatsClient, nc)
|
||||
|
||||
# Run connection in background loop
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
_connect(),
|
||||
self._loop, # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
self._nc = future.result(timeout=10.0)
|
||||
self._connected = True
|
||||
logger.info(f"Connected to NATS at {self._nats_url}")
|
||||
return True
|
||||
except (RuntimeError, OSError, TimeoutError) as e:
|
||||
logger.warning(f"Failed to connect to NATS at {self._nats_url}: {e}")
|
||||
return False
|
||||
|
||||
def publish(self, result: DemoResult) -> None:
|
||||
"""
|
||||
Publish result to NATS subject.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
result : DemoResult
|
||||
Result data with keys: frame, track_id, label, confidence, window, timestamp_ns
|
||||
"""
|
||||
if not self._ensure_connected():
|
||||
# Graceful degradation: log warning but don't crash
|
||||
logger.debug(
|
||||
f"NATS unavailable, dropping result: {result.get('track_id', 'unknown')}"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
|
||||
async def _publish() -> None:
|
||||
if self._nc is not None:
|
||||
payload = json.dumps(
|
||||
result, ensure_ascii=False, default=str
|
||||
).encode("utf-8")
|
||||
_ = await self._nc.publish(self._subject, payload)
|
||||
_ = await self._nc.flush()
|
||||
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
_publish(),
|
||||
self._loop, # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
|
||||
def _on_done(publish_future: object) -> None:
|
||||
fut = cast("asyncio.Future[None]", publish_future)
|
||||
try:
|
||||
exc = fut.exception()
|
||||
except (RuntimeError, OSError) as callback_error:
|
||||
logger.warning(f"NATS publish callback failed: {callback_error}")
|
||||
self._connected = False
|
||||
return
|
||||
if exc is not None:
|
||||
logger.warning(f"Failed to publish to NATS: {exc}")
|
||||
self._connected = False
|
||||
|
||||
future.add_done_callback(_on_done)
|
||||
except (RuntimeError, OSError, ValueError, TypeError) as e:
|
||||
logger.warning(f"Failed to schedule NATS publish: {e}")
|
||||
self._connected = False
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close NATS connection."""
|
||||
with self._lock:
|
||||
if self._nc is not None and self._connected and self._loop is not None:
|
||||
try:
|
||||
|
||||
async def _close() -> None:
|
||||
if self._nc is not None:
|
||||
_ = await self._nc.close()
|
||||
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
_close(),
|
||||
self._loop,
|
||||
)
|
||||
future.result(timeout=5.0)
|
||||
except (RuntimeError, OSError, TimeoutError) as e:
|
||||
logger.debug(f"Error closing NATS connection: {e}")
|
||||
finally:
|
||||
self._nc = None
|
||||
self._connected = False
|
||||
|
||||
self._stop_background_loop()
|
||||
|
||||
def __enter__(self) -> NatsPublisher:
|
||||
"""Context manager entry."""
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
"""Context manager exit."""
|
||||
self.close()
|
||||
|
||||
|
||||
def create_publisher(
|
||||
nats_url: str | None,
|
||||
subject: str = "scoliosis.result",
|
||||
) -> ResultPublisher:
|
||||
"""
|
||||
Factory function to create appropriate publisher.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
nats_url : str | None
|
||||
NATS server URL. If None or empty, returns ConsolePublisher.
|
||||
subject : str
|
||||
NATS subject to publish to (default: "scoliosis.result")
|
||||
|
||||
Returns
|
||||
-------
|
||||
ResultPublisher
|
||||
NatsPublisher if nats_url provided, otherwise ConsolePublisher
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> # Console output (default)
|
||||
>>> pub = create_publisher(None)
|
||||
>>> pub.publish({"frame": 1, "track_id": 42, "label": "normal", "confidence": 0.95, "window": 30, "timestamp_ns": 1234567890})
|
||||
>>>
|
||||
>>> # NATS output
|
||||
>>> pub = create_publisher("nats://localhost:4222")
|
||||
>>> pub.publish({"frame": 1, "track_id": 42, "label": "normal", "confidence": 0.95, "window": 30, "timestamp_ns": 1234567890})
|
||||
>>>
|
||||
>>> # Context manager usage
|
||||
>>> with create_publisher("nats://localhost:4222") as pub:
|
||||
... pub.publish(result)
|
||||
"""
|
||||
if nats_url:
|
||||
return NatsPublisher(nats_url, subject)
|
||||
return ConsolePublisher()
|
||||
|
||||
|
||||
def create_result(
|
||||
frame: int,
|
||||
track_id: int,
|
||||
label: str,
|
||||
confidence: float,
|
||||
window: int | tuple[int, int],
|
||||
timestamp_ns: int | None = None,
|
||||
) -> DemoResult:
|
||||
"""
|
||||
Create a standardized result dictionary.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
frame : int
|
||||
Frame number
|
||||
track_id : int
|
||||
Track/person identifier
|
||||
label : str
|
||||
Classification label (e.g., "normal", "scoliosis")
|
||||
confidence : float
|
||||
Confidence score (0.0 to 1.0)
|
||||
window : int | tuple[int, int]
|
||||
Frame window as int (end frame) or tuple [start, end] that produced this result
|
||||
Frame window [start, end] that produced this result
|
||||
timestamp_ns : int | None
|
||||
Timestamp in nanoseconds. If None, uses current time.
|
||||
|
||||
Returns
|
||||
-------
|
||||
DemoResult
|
||||
Standardized result dictionary
|
||||
"""
|
||||
return {
|
||||
"frame": frame,
|
||||
"track_id": track_id,
|
||||
"label": label,
|
||||
"confidence": confidence,
|
||||
"window": window if isinstance(window, int) else window[1],
|
||||
"timestamp_ns": timestamp_ns
|
||||
if timestamp_ns is not None
|
||||
else time.monotonic_ns(),
|
||||
}
|
||||
@@ -0,0 +1,947 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from contextlib import suppress
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, TypedDict, cast
|
||||
|
||||
from beartype import beartype
|
||||
import click
|
||||
import jaxtyping
|
||||
from jaxtyping import Float, UInt8
|
||||
import numpy as np
|
||||
from numpy import ndarray
|
||||
from numpy.typing import NDArray
|
||||
from ultralytics.models.yolo.model import YOLO
|
||||
|
||||
from .input import FrameStream, create_source
|
||||
from .output import DemoResult, ResultPublisher, create_publisher, create_result
|
||||
from .preprocess import BBoxXYXY, frame_to_person_mask, mask_to_silhouette
|
||||
from .sconet_demo import ScoNetDemo
|
||||
from .window import SilhouetteWindow, select_person
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .visualizer import OpenCVVisualizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
|
||||
JaxtypedFactory = Callable[..., JaxtypedDecorator]
|
||||
jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
|
||||
|
||||
WindowMode: TypeAlias = Literal["manual", "sliding", "chunked"]
|
||||
|
||||
|
||||
def resolve_stride(window: int, stride: int, window_mode: WindowMode) -> int:
|
||||
if window_mode == "manual":
|
||||
return stride
|
||||
if window_mode == "sliding":
|
||||
return 1
|
||||
return window
|
||||
|
||||
|
||||
class _BoxesLike(Protocol):
|
||||
@property
|
||||
def xyxy(self) -> NDArray[np.float32] | object: ...
|
||||
|
||||
@property
|
||||
def id(self) -> NDArray[np.int64] | object | None: ...
|
||||
|
||||
|
||||
class _MasksLike(Protocol):
|
||||
@property
|
||||
def data(self) -> NDArray[np.float32] | object: ...
|
||||
|
||||
|
||||
class _DetectionResultsLike(Protocol):
|
||||
@property
|
||||
def boxes(self) -> _BoxesLike: ...
|
||||
|
||||
@property
|
||||
def masks(self) -> _MasksLike: ...
|
||||
|
||||
|
||||
class _TrackCallable(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
source: object,
|
||||
*,
|
||||
persist: bool = True,
|
||||
verbose: bool = False,
|
||||
device: str | None = None,
|
||||
classes: list[int] | None = None,
|
||||
) -> object: ...
|
||||
|
||||
|
||||
class _SelectedSilhouette(TypedDict):
|
||||
"""Selected silhouette payload produced from detector outputs.
|
||||
|
||||
Fields:
|
||||
silhouette: Normalized silhouette tensor fed into ScoNet `(64, 44)`.
|
||||
mask_raw: Full-resolution binary person mask in mask/image space.
|
||||
bbox_frame: Person bbox in frame coordinates `(x1, y1, x2, y2)` for visualization.
|
||||
bbox_mask: Person bbox in mask coordinates `(x1, y1, x2, y2)` for cropping.
|
||||
track_id: Tracking ID from detector, or `0` for fallback path.
|
||||
"""
|
||||
|
||||
silhouette: Float[ndarray, "64 44"]
|
||||
mask_raw: UInt8[ndarray, "h w"]
|
||||
bbox_frame: BBoxXYXY
|
||||
bbox_mask: BBoxXYXY
|
||||
track_id: int
|
||||
|
||||
|
||||
class _VizPayload(TypedDict, total=False):
|
||||
result: DemoResult
|
||||
mask_raw: UInt8[ndarray, "h w"] | None
|
||||
bbox: BBoxXYXY | None
|
||||
bbox_mask: BBoxXYXY | None
|
||||
silhouette: Float[ndarray, "64 44"] | None
|
||||
segmentation_input: NDArray[np.float32] | None
|
||||
track_id: int
|
||||
label: str | None
|
||||
confidence: float | None
|
||||
pose: dict[str, object] | None
|
||||
|
||||
|
||||
class _FramePacer:
|
||||
_interval_ns: int
|
||||
_next_emit_ns: int | None
|
||||
|
||||
def __init__(self, target_fps: float) -> None:
|
||||
if target_fps <= 0:
|
||||
raise ValueError(f"target_fps must be positive, got {target_fps}")
|
||||
self._interval_ns = int(1_000_000_000 / target_fps)
|
||||
self._next_emit_ns = None
|
||||
|
||||
def should_emit(self, timestamp_ns: int) -> bool:
|
||||
if self._next_emit_ns is None:
|
||||
self._next_emit_ns = timestamp_ns + self._interval_ns
|
||||
return True
|
||||
if timestamp_ns >= self._next_emit_ns:
|
||||
while self._next_emit_ns <= timestamp_ns:
|
||||
self._next_emit_ns += self._interval_ns
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class ScoliosisPipeline:
|
||||
_detector: object
|
||||
_source: FrameStream
|
||||
_window: SilhouetteWindow
|
||||
_publisher: ResultPublisher
|
||||
_classifier: ScoNetDemo
|
||||
_device: str
|
||||
_closed: bool
|
||||
_preprocess_only: bool
|
||||
_silhouette_export_path: Path | None
|
||||
_silhouette_export_format: str
|
||||
_silhouette_buffer: list[dict[str, object]]
|
||||
_silhouette_visualize_dir: Path | None
|
||||
_result_export_path: Path | None
|
||||
_result_export_format: str
|
||||
_result_buffer: list[DemoResult]
|
||||
_visualizer: OpenCVVisualizer | None
|
||||
_last_viz_payload: _VizPayload | None
|
||||
_frame_pacer: _FramePacer | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
source: str,
|
||||
checkpoint: str,
|
||||
config: str,
|
||||
device: str,
|
||||
yolo_model: str,
|
||||
window: int,
|
||||
stride: int,
|
||||
nats_url: str | None,
|
||||
nats_subject: str,
|
||||
max_frames: int | None,
|
||||
preprocess_only: bool = False,
|
||||
silhouette_export_path: str | None = None,
|
||||
silhouette_export_format: str = "pickle",
|
||||
silhouette_visualize_dir: str | None = None,
|
||||
result_export_path: str | None = None,
|
||||
result_export_format: str = "json",
|
||||
visualize: bool = False,
|
||||
target_fps: float | None = 15.0,
|
||||
) -> None:
|
||||
self._detector = YOLO(yolo_model)
|
||||
self._source = create_source(source, max_frames=max_frames)
|
||||
self._window = SilhouetteWindow(window_size=window, stride=stride)
|
||||
self._publisher = create_publisher(nats_url=nats_url, subject=nats_subject)
|
||||
self._classifier = ScoNetDemo(
|
||||
cfg_path=config,
|
||||
checkpoint_path=checkpoint,
|
||||
device=device,
|
||||
)
|
||||
self._device = device
|
||||
self._closed = False
|
||||
self._preprocess_only = preprocess_only
|
||||
self._silhouette_export_path = (
|
||||
Path(silhouette_export_path) if silhouette_export_path else None
|
||||
)
|
||||
self._silhouette_export_format = silhouette_export_format
|
||||
# Normalize format alias: pkl -> pickle
|
||||
if self._silhouette_export_format == "pkl":
|
||||
self._silhouette_export_format = "pickle"
|
||||
self._silhouette_buffer = []
|
||||
self._silhouette_visualize_dir = (
|
||||
Path(silhouette_visualize_dir) if silhouette_visualize_dir else None
|
||||
)
|
||||
self._result_export_path = (
|
||||
Path(result_export_path) if result_export_path else None
|
||||
)
|
||||
self._result_export_format = result_export_format
|
||||
self._result_buffer = []
|
||||
if visualize:
|
||||
from .visualizer import OpenCVVisualizer
|
||||
|
||||
self._visualizer = OpenCVVisualizer()
|
||||
else:
|
||||
self._visualizer = None
|
||||
self._last_viz_payload = None
|
||||
self._frame_pacer = _FramePacer(target_fps) if target_fps is not None else None
|
||||
|
||||
@staticmethod
|
||||
def _extract_int(meta: dict[str, object], key: str, fallback: int) -> int:
|
||||
value = meta.get(key)
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
return fallback
|
||||
|
||||
@staticmethod
|
||||
def _extract_timestamp(meta: dict[str, object]) -> int:
|
||||
value = meta.get("timestamp_ns")
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
return time.monotonic_ns()
|
||||
|
||||
@staticmethod
|
||||
def _to_mask_u8(mask: NDArray[np.generic]) -> UInt8[ndarray, "h w"]:
|
||||
mask_arr: NDArray[np.floating] = np.asarray(mask, dtype=np.float32) # type: ignore[reportAssignmentType]
|
||||
binary = np.where(mask_arr > 0.5, np.uint8(255), np.uint8(0)).astype(np.uint8)
|
||||
return cast(UInt8[ndarray, "h w"], binary)
|
||||
|
||||
def _first_result(self, detections: object) -> _DetectionResultsLike | None:
|
||||
if isinstance(detections, list):
|
||||
return cast(_DetectionResultsLike, detections[0]) if detections else None
|
||||
if isinstance(detections, tuple):
|
||||
return cast(_DetectionResultsLike, detections[0]) if detections else None
|
||||
return cast(_DetectionResultsLike, detections)
|
||||
|
||||
def _select_silhouette(
|
||||
self,
|
||||
result: _DetectionResultsLike,
|
||||
) -> _SelectedSilhouette | None:
|
||||
selected = select_person(result)
|
||||
if selected is not None:
|
||||
mask_raw, bbox_mask, bbox_frame, track_id = selected
|
||||
silhouette = cast(
|
||||
Float[ndarray, "64 44"] | None,
|
||||
mask_to_silhouette(self._to_mask_u8(mask_raw), bbox_mask),
|
||||
)
|
||||
if silhouette is not None:
|
||||
return {
|
||||
"silhouette": silhouette,
|
||||
"mask_raw": mask_raw,
|
||||
"bbox_frame": bbox_frame,
|
||||
"bbox_mask": bbox_mask,
|
||||
"track_id": int(track_id),
|
||||
}
|
||||
|
||||
fallback = cast(
|
||||
tuple[UInt8[ndarray, "h w"], BBoxXYXY] | None,
|
||||
frame_to_person_mask(result),
|
||||
)
|
||||
if fallback is None:
|
||||
return None
|
||||
|
||||
mask_u8, bbox_mask = fallback
|
||||
silhouette = cast(
|
||||
Float[ndarray, "64 44"] | None,
|
||||
mask_to_silhouette(mask_u8, bbox_mask),
|
||||
)
|
||||
if silhouette is None:
|
||||
return None
|
||||
# Convert mask-space bbox to frame-space for visualization
|
||||
# Use result.orig_shape to get frame dimensions safely
|
||||
orig_shape = getattr(result, "orig_shape", None)
|
||||
if (
|
||||
orig_shape is not None
|
||||
and isinstance(orig_shape, (tuple, list))
|
||||
and len(orig_shape) >= 2
|
||||
):
|
||||
frame_h, frame_w = int(orig_shape[0]), int(orig_shape[1])
|
||||
mask_h, mask_w = mask_u8.shape[0], mask_u8.shape[1]
|
||||
if mask_w > 0 and mask_h > 0 and frame_w > 0 and frame_h > 0:
|
||||
scale_x = frame_w / mask_w
|
||||
scale_y = frame_h / mask_h
|
||||
bbox_frame = (
|
||||
int(bbox_mask[0] * scale_x),
|
||||
int(bbox_mask[1] * scale_y),
|
||||
int(bbox_mask[2] * scale_x),
|
||||
int(bbox_mask[3] * scale_y),
|
||||
)
|
||||
else:
|
||||
# Fallback: use mask-space bbox if dimensions invalid
|
||||
bbox_frame = bbox_mask
|
||||
else:
|
||||
# Fallback: use mask-space bbox if orig_shape unavailable
|
||||
bbox_frame = bbox_mask
|
||||
# For fallback case, mask_raw is the same as mask_u8
|
||||
return {
|
||||
"silhouette": silhouette,
|
||||
"mask_raw": mask_u8,
|
||||
"bbox_frame": bbox_frame,
|
||||
"bbox_mask": bbox_mask,
|
||||
"track_id": 0,
|
||||
}
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def process_frame(
|
||||
self,
|
||||
frame: UInt8[ndarray, "h w c"],
|
||||
metadata: dict[str, object],
|
||||
) -> _VizPayload | None:
|
||||
frame_idx = self._extract_int(metadata, "frame_count", fallback=0)
|
||||
timestamp_ns = self._extract_timestamp(metadata)
|
||||
|
||||
track_fn_obj = getattr(self._detector, "track", None)
|
||||
if not callable(track_fn_obj):
|
||||
raise RuntimeError("YOLO detector does not expose a callable track()")
|
||||
|
||||
track_fn = cast(_TrackCallable, track_fn_obj)
|
||||
detections = track_fn(
|
||||
frame,
|
||||
persist=True,
|
||||
verbose=False,
|
||||
device=self._device,
|
||||
classes=[0],
|
||||
)
|
||||
first = self._first_result(detections)
|
||||
if first is None:
|
||||
return None
|
||||
|
||||
selected = self._select_silhouette(first)
|
||||
if selected is None:
|
||||
return None
|
||||
|
||||
silhouette = selected["silhouette"]
|
||||
mask_raw = selected["mask_raw"]
|
||||
bbox = selected["bbox_frame"]
|
||||
bbox_mask = selected["bbox_mask"]
|
||||
track_id = selected["track_id"]
|
||||
pose_data = None
|
||||
|
||||
if self._silhouette_export_path is not None or self._preprocess_only:
|
||||
self._silhouette_buffer.append(
|
||||
{
|
||||
"frame": frame_idx,
|
||||
"track_id": track_id,
|
||||
"timestamp_ns": timestamp_ns,
|
||||
"silhouette": silhouette.copy(),
|
||||
}
|
||||
)
|
||||
|
||||
# Visualize silhouette if requested
|
||||
if self._silhouette_visualize_dir is not None:
|
||||
self._visualize_silhouette(silhouette, frame_idx, track_id)
|
||||
|
||||
if self._preprocess_only:
|
||||
# Return visualization payload for display even in preprocess-only mode
|
||||
return {
|
||||
"mask_raw": mask_raw,
|
||||
"bbox": bbox,
|
||||
"bbox_mask": bbox_mask,
|
||||
"silhouette": silhouette,
|
||||
"segmentation_input": None,
|
||||
"track_id": track_id,
|
||||
"label": None,
|
||||
"confidence": None,
|
||||
"pose": pose_data,
|
||||
}
|
||||
self._window.push(silhouette, frame_idx=frame_idx, track_id=track_id)
|
||||
|
||||
if self._frame_pacer is not None and not self._frame_pacer.should_emit(
|
||||
timestamp_ns
|
||||
):
|
||||
return {
|
||||
"mask_raw": mask_raw,
|
||||
"bbox": bbox,
|
||||
"bbox_mask": bbox_mask,
|
||||
"silhouette": silhouette,
|
||||
"segmentation_input": self._window.buffered_silhouettes,
|
||||
"track_id": track_id,
|
||||
"label": None,
|
||||
"confidence": None,
|
||||
"pose": pose_data,
|
||||
}
|
||||
segmentation_input = self._window.buffered_silhouettes
|
||||
|
||||
if not self._window.should_classify():
|
||||
# Return visualization payload even when not classifying yet
|
||||
return {
|
||||
"mask_raw": mask_raw,
|
||||
"bbox": bbox,
|
||||
"bbox_mask": bbox_mask,
|
||||
"silhouette": silhouette,
|
||||
"segmentation_input": segmentation_input,
|
||||
"track_id": track_id,
|
||||
"label": None,
|
||||
"confidence": None,
|
||||
"pose": pose_data,
|
||||
}
|
||||
window_tensor = self._window.get_tensor(device=self._device)
|
||||
label, confidence = cast(
|
||||
tuple[str, float],
|
||||
self._classifier.predict(window_tensor),
|
||||
)
|
||||
self._window.mark_classified()
|
||||
|
||||
window_start = self._window.window_start_frame
|
||||
result = create_result(
|
||||
frame=frame_idx,
|
||||
track_id=track_id,
|
||||
label=label,
|
||||
confidence=float(confidence),
|
||||
window=(max(0, window_start), frame_idx),
|
||||
timestamp_ns=timestamp_ns,
|
||||
)
|
||||
|
||||
# Store result for export if export path specified
|
||||
if self._result_export_path is not None:
|
||||
self._result_buffer.append(result)
|
||||
|
||||
self._publisher.publish(result)
|
||||
# Return result with visualization payload
|
||||
return {
|
||||
"result": result,
|
||||
"mask_raw": mask_raw,
|
||||
"bbox": bbox,
|
||||
"bbox_mask": bbox_mask,
|
||||
"silhouette": silhouette,
|
||||
"segmentation_input": segmentation_input,
|
||||
"track_id": track_id,
|
||||
"label": label,
|
||||
"confidence": confidence,
|
||||
"pose": pose_data,
|
||||
}
|
||||
|
||||
def run(self) -> int:
|
||||
frame_count = 0
|
||||
start_time = time.perf_counter()
|
||||
# EMA FPS state (alpha=0.1 for smoothing)
|
||||
ema_fps = 0.0
|
||||
alpha = 0.1
|
||||
prev_time = start_time
|
||||
try:
|
||||
for item in self._source:
|
||||
frame, metadata = item
|
||||
frame_u8 = np.asarray(frame, dtype=np.uint8)
|
||||
frame_idx = self._extract_int(metadata, "frame_count", fallback=0)
|
||||
frame_count += 1
|
||||
|
||||
# Compute per-frame EMA FPS
|
||||
curr_time = time.perf_counter()
|
||||
delta = curr_time - prev_time
|
||||
prev_time = curr_time
|
||||
if delta > 0:
|
||||
instant_fps = 1.0 / delta
|
||||
if ema_fps == 0.0:
|
||||
ema_fps = instant_fps
|
||||
else:
|
||||
ema_fps = alpha * instant_fps + (1 - alpha) * ema_fps
|
||||
|
||||
viz_payload = None
|
||||
try:
|
||||
viz_payload = self.process_frame(frame_u8, metadata)
|
||||
except (RuntimeError, ValueError, TypeError, OSError) as frame_error:
|
||||
logger.warning(
|
||||
"Skipping frame %d due to processing error: %s",
|
||||
frame_idx,
|
||||
frame_error,
|
||||
)
|
||||
|
||||
# Update visualizer if enabled
|
||||
if self._visualizer is not None:
|
||||
# Cache valid payload for no-detection frames
|
||||
if viz_payload is not None:
|
||||
# Cache a copy to prevent mutation of original data
|
||||
viz_payload_dict = cast(_VizPayload, viz_payload)
|
||||
cached: _VizPayload = {}
|
||||
for k, v in viz_payload_dict.items():
|
||||
copy_method = cast(
|
||||
Callable[[], object] | None, getattr(v, "copy", None)
|
||||
)
|
||||
if copy_method is not None and callable(copy_method):
|
||||
cached[k] = copy_method()
|
||||
else:
|
||||
cached[k] = v
|
||||
self._last_viz_payload = cached
|
||||
|
||||
if viz_payload is not None:
|
||||
viz_data = viz_payload
|
||||
elif self._last_viz_payload is not None:
|
||||
viz_data = dict(self._last_viz_payload)
|
||||
viz_data["bbox"] = None
|
||||
viz_data["bbox_mask"] = None
|
||||
viz_data["label"] = None
|
||||
viz_data["confidence"] = None
|
||||
viz_data["pose"] = None
|
||||
else:
|
||||
viz_data = None
|
||||
if viz_data is not None:
|
||||
# Cast viz_payload to dict for type checking
|
||||
viz_dict = cast(_VizPayload, viz_data)
|
||||
mask_raw_obj = viz_dict.get("mask_raw")
|
||||
bbox_obj = viz_dict.get("bbox")
|
||||
bbox_mask_obj = viz_dict.get("bbox_mask")
|
||||
silhouette_obj = viz_dict.get("silhouette")
|
||||
segmentation_input_obj = viz_dict.get("segmentation_input")
|
||||
track_id_val = viz_dict.get("track_id", 0)
|
||||
track_id = track_id_val if isinstance(track_id_val, int) else 0
|
||||
label_obj = viz_dict.get("label")
|
||||
confidence_obj = viz_dict.get("confidence")
|
||||
pose_obj = viz_dict.get("pose")
|
||||
|
||||
mask_raw = cast(NDArray[np.uint8] | None, mask_raw_obj)
|
||||
bbox = cast(BBoxXYXY | None, bbox_obj)
|
||||
bbox_mask = cast(BBoxXYXY | None, bbox_mask_obj)
|
||||
silhouette = cast(NDArray[np.float32] | None, silhouette_obj)
|
||||
segmentation_input = cast(
|
||||
NDArray[np.float32] | None,
|
||||
segmentation_input_obj,
|
||||
)
|
||||
label = cast(str | None, label_obj)
|
||||
confidence = cast(float | None, confidence_obj)
|
||||
pose_data = cast(dict[str, object] | None, pose_obj)
|
||||
else:
|
||||
# No detection and no cache - use default values
|
||||
mask_raw = None
|
||||
bbox = None
|
||||
bbox_mask = None
|
||||
track_id = 0
|
||||
silhouette = None
|
||||
segmentation_input = None
|
||||
label = None
|
||||
confidence = None
|
||||
pose_data = None
|
||||
|
||||
# Try keyword arg for pose_data (backward compatible with old signatures)
|
||||
try:
|
||||
keep_running = self._visualizer.update(
|
||||
frame_u8,
|
||||
bbox,
|
||||
bbox_mask,
|
||||
track_id,
|
||||
mask_raw,
|
||||
silhouette,
|
||||
segmentation_input,
|
||||
label,
|
||||
confidence,
|
||||
ema_fps,
|
||||
pose_data=pose_data,
|
||||
)
|
||||
except TypeError:
|
||||
# Fallback for legacy visualizers that don't accept pose_data
|
||||
keep_running = self._visualizer.update(
|
||||
frame_u8,
|
||||
bbox,
|
||||
bbox_mask,
|
||||
track_id,
|
||||
mask_raw,
|
||||
silhouette,
|
||||
segmentation_input,
|
||||
label,
|
||||
confidence,
|
||||
ema_fps,
|
||||
)
|
||||
if not keep_running:
|
||||
logger.info("Visualization closed by user.")
|
||||
break
|
||||
|
||||
if frame_count % 100 == 0:
|
||||
elapsed = time.perf_counter() - start_time
|
||||
fps = frame_count / elapsed if elapsed > 0 else 0.0
|
||||
logger.info("Processed %d frames (%.2f FPS)", frame_count, fps)
|
||||
return 0
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Interrupted by user, shutting down cleanly.")
|
||||
return 130
|
||||
finally:
|
||||
self.close()
|
||||
|
||||
def close(self) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
# Close visualizer if enabled
|
||||
if self._visualizer is not None:
|
||||
with suppress(Exception):
|
||||
self._visualizer.close()
|
||||
|
||||
# Export silhouettes if requested
|
||||
if self._silhouette_export_path is not None and self._silhouette_buffer:
|
||||
self._export_silhouettes()
|
||||
|
||||
# Export results if requested
|
||||
if self._result_export_path is not None and self._result_buffer:
|
||||
self._export_results()
|
||||
|
||||
close_fn = getattr(self._publisher, "close", None)
|
||||
if callable(close_fn):
|
||||
with suppress(Exception):
|
||||
_ = close_fn()
|
||||
self._closed = True
|
||||
|
||||
def _export_silhouettes(self) -> None:
|
||||
"""Export silhouettes to file in specified format."""
|
||||
if self._silhouette_export_path is None:
|
||||
return
|
||||
|
||||
self._silhouette_export_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if self._silhouette_export_format == "pickle":
|
||||
import pickle
|
||||
|
||||
with open(self._silhouette_export_path, "wb") as f:
|
||||
pickle.dump(self._silhouette_buffer, f)
|
||||
logger.info(
|
||||
"Exported %d silhouettes to %s",
|
||||
len(self._silhouette_buffer),
|
||||
self._silhouette_export_path,
|
||||
)
|
||||
elif self._silhouette_export_format == "parquet":
|
||||
self._export_parquet_silhouettes()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported silhouette export format: {self._silhouette_export_format}"
|
||||
)
|
||||
|
||||
def _visualize_silhouette(
|
||||
self,
|
||||
silhouette: Float[ndarray, "64 44"],
|
||||
frame_idx: int,
|
||||
track_id: int,
|
||||
) -> None:
|
||||
"""Save silhouette as PNG image."""
|
||||
if self._silhouette_visualize_dir is None:
|
||||
return
|
||||
|
||||
self._silhouette_visualize_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Convert float silhouette to uint8 (0-255)
|
||||
silhouette_u8 = (silhouette * 255).astype(np.uint8)
|
||||
|
||||
# Create deterministic filename
|
||||
filename = f"silhouette_frame{frame_idx:06d}_track{track_id:04d}.png"
|
||||
output_path = self._silhouette_visualize_dir / filename
|
||||
|
||||
# Save using PIL
|
||||
from PIL import Image
|
||||
|
||||
Image.fromarray(silhouette_u8).save(output_path)
|
||||
|
||||
def _export_parquet_silhouettes(self) -> None:
|
||||
"""Export silhouettes to parquet format."""
|
||||
import importlib
|
||||
|
||||
try:
|
||||
pa = importlib.import_module("pyarrow")
|
||||
pq = importlib.import_module("pyarrow.parquet")
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"Parquet export requires pyarrow. Install with: pip install pyarrow"
|
||||
) from e
|
||||
|
||||
# Convert silhouettes to columnar format
|
||||
frames = []
|
||||
track_ids = []
|
||||
timestamps = []
|
||||
silhouettes = []
|
||||
|
||||
for item in self._silhouette_buffer:
|
||||
frames.append(item["frame"])
|
||||
track_ids.append(item["track_id"])
|
||||
timestamps.append(item["timestamp_ns"])
|
||||
silhouette_array = cast(NDArray[np.float32], item["silhouette"])
|
||||
silhouettes.append(silhouette_array.flatten().tolist())
|
||||
|
||||
table = pa.table(
|
||||
{
|
||||
"frame": pa.array(frames, type=pa.int64()),
|
||||
"track_id": pa.array(track_ids, type=pa.int64()),
|
||||
"timestamp_ns": pa.array(timestamps, type=pa.int64()),
|
||||
"silhouette": pa.array(silhouettes, type=pa.list_(pa.float64())),
|
||||
}
|
||||
)
|
||||
|
||||
pq.write_table(table, self._silhouette_export_path)
|
||||
logger.info(
|
||||
"Exported %d silhouettes to parquet: %s",
|
||||
len(self._silhouette_buffer),
|
||||
self._silhouette_export_path,
|
||||
)
|
||||
|
||||
def _export_results(self) -> None:
|
||||
"""Export results to file in specified format."""
|
||||
if self._result_export_path is None:
|
||||
return
|
||||
|
||||
self._result_export_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if self._result_export_format == "json":
|
||||
import json
|
||||
|
||||
with open(self._result_export_path, "w", encoding="utf-8") as f:
|
||||
for result in self._result_buffer:
|
||||
f.write(json.dumps(result, ensure_ascii=False, default=str) + "\n")
|
||||
logger.info(
|
||||
"Exported %d results to JSON: %s",
|
||||
len(self._result_buffer),
|
||||
self._result_export_path,
|
||||
)
|
||||
elif self._result_export_format == "pickle":
|
||||
import pickle
|
||||
|
||||
with open(self._result_export_path, "wb") as f:
|
||||
pickle.dump(self._result_buffer, f)
|
||||
logger.info(
|
||||
"Exported %d results to pickle: %s",
|
||||
len(self._result_buffer),
|
||||
self._result_export_path,
|
||||
)
|
||||
elif self._result_export_format == "parquet":
|
||||
self._export_parquet_results()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported result export format: {self._result_export_format}"
|
||||
)
|
||||
|
||||
def _export_parquet_results(self) -> None:
|
||||
"""Export results to parquet format."""
|
||||
import importlib
|
||||
|
||||
try:
|
||||
pa = importlib.import_module("pyarrow")
|
||||
pq = importlib.import_module("pyarrow.parquet")
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"Parquet export requires pyarrow. Install with: pip install pyarrow"
|
||||
) from e
|
||||
|
||||
frames = []
|
||||
track_ids = []
|
||||
labels = []
|
||||
confidences = []
|
||||
windows = []
|
||||
timestamps = []
|
||||
|
||||
for result in self._result_buffer:
|
||||
frames.append(result["frame"])
|
||||
track_ids.append(result["track_id"])
|
||||
labels.append(result["label"])
|
||||
confidences.append(result["confidence"])
|
||||
windows.append(result["window"])
|
||||
timestamps.append(result["timestamp_ns"])
|
||||
|
||||
table = pa.table(
|
||||
{
|
||||
"frame": pa.array(frames, type=pa.int64()),
|
||||
"track_id": pa.array(track_ids, type=pa.int64()),
|
||||
"label": pa.array(labels, type=pa.string()),
|
||||
"confidence": pa.array(confidences, type=pa.float64()),
|
||||
"window": pa.array(windows, type=pa.int64()),
|
||||
"timestamp_ns": pa.array(timestamps, type=pa.int64()),
|
||||
}
|
||||
)
|
||||
|
||||
pq.write_table(table, self._result_export_path)
|
||||
logger.info(
|
||||
"Exported %d results to parquet: %s",
|
||||
len(self._result_buffer),
|
||||
self._result_export_path,
|
||||
)
|
||||
|
||||
|
||||
def validate_runtime_inputs(source: str, checkpoint: str, config: str) -> None:
|
||||
if source.startswith("cvmmap://") or source.isdigit():
|
||||
pass
|
||||
else:
|
||||
source_path = Path(source)
|
||||
if not source_path.is_file():
|
||||
raise ValueError(f"Video source not found: {source}")
|
||||
|
||||
checkpoint_path = Path(checkpoint)
|
||||
if not checkpoint_path.is_file():
|
||||
raise ValueError(f"Checkpoint not found: {checkpoint}")
|
||||
|
||||
config_path = Path(config)
|
||||
if not config_path.is_file():
|
||||
raise ValueError(f"Config not found: {config}")
|
||||
|
||||
|
||||
@click.command(context_settings={"help_option_names": ["-h", "--help"]})
|
||||
@click.option("--source", type=str, required=True)
|
||||
@click.option("--checkpoint", type=str, required=True)
|
||||
@click.option(
|
||||
"--config",
|
||||
type=str,
|
||||
default="configs/sconet/sconet_scoliosis1k.yaml",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option("--device", type=str, default="cuda:0", show_default=True)
|
||||
@click.option(
|
||||
"--yolo-model", type=str, default="ckpt/yolo11n-seg.pt", show_default=True
|
||||
)
|
||||
@click.option("--window", type=click.IntRange(min=1), default=30, show_default=True)
|
||||
@click.option("--stride", type=click.IntRange(min=1), default=30, show_default=True)
|
||||
@click.option(
|
||||
"--window-mode",
|
||||
type=click.Choice(["manual", "sliding", "chunked"], case_sensitive=False),
|
||||
default="manual",
|
||||
show_default=True,
|
||||
help=(
|
||||
"Window scheduling mode: manual uses --stride; "
|
||||
"sliding forces stride=1; chunked forces stride=window"
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--target-fps",
|
||||
type=click.FloatRange(min=0.1),
|
||||
default=15.0,
|
||||
show_default=True,
|
||||
)
|
||||
@click.option("--no-target-fps", is_flag=True, default=False)
|
||||
@click.option("--nats-url", type=str, default=None)
|
||||
@click.option(
|
||||
"--nats-subject",
|
||||
type=str,
|
||||
default="scoliosis.result",
|
||||
show_default=True,
|
||||
)
|
||||
@click.option("--max-frames", type=click.IntRange(min=1), default=None)
|
||||
@click.option(
|
||||
"--preprocess-only",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Only preprocess silhouettes, skip classification.",
|
||||
)
|
||||
@click.option(
|
||||
"--silhouette-export-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to export silhouettes (required for preprocess-only mode).",
|
||||
)
|
||||
@click.option(
|
||||
"--silhouette-export-format",
|
||||
type=click.Choice(["pickle", "parquet"]),
|
||||
default="pickle",
|
||||
show_default=True,
|
||||
help="Format for silhouette export.",
|
||||
)
|
||||
@click.option(
|
||||
"--result-export-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to export inference results.",
|
||||
)
|
||||
@click.option(
|
||||
"--result-export-format",
|
||||
type=click.Choice(["json", "pickle", "parquet"]),
|
||||
default="json",
|
||||
show_default=True,
|
||||
help="Format for result export.",
|
||||
)
|
||||
@click.option(
|
||||
"--silhouette-visualize-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory to save silhouette PNG visualizations.",
|
||||
)
|
||||
@click.option(
|
||||
"--visualize",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Enable real-time visualization.",
|
||||
)
|
||||
def main(
|
||||
source: str,
|
||||
checkpoint: str,
|
||||
config: str,
|
||||
device: str,
|
||||
yolo_model: str,
|
||||
window: int,
|
||||
stride: int,
|
||||
window_mode: str,
|
||||
target_fps: float,
|
||||
no_target_fps: bool,
|
||||
nats_url: str | None,
|
||||
nats_subject: str,
|
||||
max_frames: int | None,
|
||||
preprocess_only: bool,
|
||||
silhouette_export_path: str | None,
|
||||
silhouette_export_format: str,
|
||||
result_export_path: str | None,
|
||||
result_export_format: str,
|
||||
silhouette_visualize_dir: str | None,
|
||||
visualize: bool,
|
||||
) -> None:
|
||||
# Resolve effective target_fps: respect --no-target_fps to disable pacing
|
||||
effective_target_fps = None if no_target_fps else target_fps
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
)
|
||||
|
||||
# Validate preprocess-only mode requirements
|
||||
if preprocess_only and not silhouette_export_path:
|
||||
raise click.UsageError(
|
||||
"--silhouette-export-path is required when using --preprocess-only"
|
||||
)
|
||||
|
||||
try:
|
||||
validate_runtime_inputs(source=source, checkpoint=checkpoint, config=config)
|
||||
effective_stride = resolve_stride(
|
||||
window=window,
|
||||
stride=stride,
|
||||
window_mode=cast(WindowMode, window_mode.lower()),
|
||||
)
|
||||
if effective_stride != stride:
|
||||
logger.info(
|
||||
"window_mode=%s overrides stride=%d -> effective_stride=%d",
|
||||
window_mode,
|
||||
stride,
|
||||
effective_stride,
|
||||
)
|
||||
pipeline = ScoliosisPipeline(
|
||||
source=source,
|
||||
checkpoint=checkpoint,
|
||||
config=config,
|
||||
device=device,
|
||||
yolo_model=yolo_model,
|
||||
window=window,
|
||||
stride=effective_stride,
|
||||
nats_url=nats_url,
|
||||
nats_subject=nats_subject,
|
||||
max_frames=max_frames,
|
||||
preprocess_only=preprocess_only,
|
||||
silhouette_export_path=silhouette_export_path,
|
||||
silhouette_export_format=silhouette_export_format,
|
||||
silhouette_visualize_dir=silhouette_visualize_dir,
|
||||
result_export_path=result_export_path,
|
||||
result_export_format=result_export_format,
|
||||
visualize=visualize,
|
||||
target_fps=effective_target_fps,
|
||||
)
|
||||
raise SystemExit(pipeline.run())
|
||||
except ValueError as err:
|
||||
click.echo(f"Error: {err}", err=True)
|
||||
raise SystemExit(2) from err
|
||||
except RuntimeError as err:
|
||||
click.echo(f"Runtime error: {err}", err=True)
|
||||
raise SystemExit(1) from err
|
||||
@@ -0,0 +1,335 @@
|
||||
from collections.abc import Callable
|
||||
import math
|
||||
from typing import cast
|
||||
|
||||
import cv2
|
||||
from beartype import beartype
|
||||
import jaxtyping
|
||||
from jaxtyping import Float, UInt8
|
||||
import numpy as np
|
||||
from numpy import ndarray
|
||||
from numpy.typing import NDArray
|
||||
|
||||
SIL_HEIGHT = 64
|
||||
SIL_WIDTH = 44
|
||||
SIL_FULL_WIDTH = 64
|
||||
SIDE_CUT = 10
|
||||
MIN_MASK_AREA = 500
|
||||
|
||||
JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
|
||||
JaxtypedFactory = Callable[..., JaxtypedDecorator]
|
||||
jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
|
||||
|
||||
UInt8Array = NDArray[np.uint8]
|
||||
Float32Array = NDArray[np.float32]
|
||||
|
||||
BBoxXYXY = tuple[int, int, int, int]
|
||||
"""
|
||||
Bounding box in XYXY format: (x1, y1, x2, y2) where (x1,y1) is top-left and (x2,y2) is bottom-right.
|
||||
"""
|
||||
|
||||
|
||||
def _read_attr(container: object, key: str) -> object | None:
|
||||
if isinstance(container, dict):
|
||||
dict_obj = cast(dict[object, object], container)
|
||||
return dict_obj.get(key)
|
||||
try:
|
||||
return cast(object, object.__getattribute__(container, key))
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
|
||||
def _to_numpy_array(value: object) -> NDArray[np.generic]:
|
||||
current: object = value
|
||||
if isinstance(current, np.ndarray):
|
||||
return current
|
||||
|
||||
detach_obj = _read_attr(current, "detach")
|
||||
if callable(detach_obj):
|
||||
detach_fn = cast(Callable[[], object], detach_obj)
|
||||
current = detach_fn()
|
||||
|
||||
cpu_obj = _read_attr(current, "cpu")
|
||||
if callable(cpu_obj):
|
||||
cpu_fn = cast(Callable[[], object], cpu_obj)
|
||||
current = cpu_fn()
|
||||
|
||||
numpy_obj = _read_attr(current, "numpy")
|
||||
if callable(numpy_obj):
|
||||
numpy_fn = cast(Callable[[], object], numpy_obj)
|
||||
as_numpy = numpy_fn()
|
||||
if isinstance(as_numpy, np.ndarray):
|
||||
return as_numpy
|
||||
|
||||
return cast(NDArray[np.generic], np.asarray(current))
|
||||
|
||||
|
||||
def _fill_binary_holes(mask_u8: UInt8Array) -> UInt8Array:
|
||||
mask_bin = np.where(mask_u8 > 0, np.uint8(255), np.uint8(0)).astype(np.uint8)
|
||||
h, w = cast(tuple[int, int], mask_bin.shape)
|
||||
if h <= 2 or w <= 2:
|
||||
return mask_bin
|
||||
|
||||
seed_candidates = [(0, 0), (w - 1, 0), (0, h - 1), (w - 1, h - 1)]
|
||||
seed: tuple[int, int] | None = None
|
||||
for x, y in seed_candidates:
|
||||
if int(mask_bin[y, x]) == 0:
|
||||
seed = (x, y)
|
||||
break
|
||||
if seed is None:
|
||||
return mask_bin
|
||||
|
||||
flood = mask_bin.copy()
|
||||
flood_mask = np.zeros((h + 2, w + 2), dtype=np.uint8)
|
||||
_ = cv2.floodFill(flood, flood_mask, seed, 255)
|
||||
holes = cv2.bitwise_not(flood)
|
||||
filled = cv2.bitwise_or(mask_bin, holes)
|
||||
return cast(UInt8Array, filled)
|
||||
|
||||
|
||||
def _bbox_from_mask(mask: UInt8[ndarray, "h w"]) -> BBoxXYXY | None:
|
||||
"""Extract bounding box from binary mask in XYXY format.
|
||||
|
||||
Args:
|
||||
mask: Binary mask array of shape (H, W) with dtype uint8.
|
||||
|
||||
Returns:
|
||||
Bounding box as (x1, y1, x2, y2) in XYXY format, or None if mask is empty.
|
||||
"""
|
||||
mask_u8 = np.asarray(mask, dtype=np.uint8)
|
||||
coords = np.argwhere(mask_u8 > 0)
|
||||
if int(coords.size) == 0:
|
||||
return None
|
||||
|
||||
ys = coords[:, 0].astype(np.int64)
|
||||
xs = coords[:, 1].astype(np.int64)
|
||||
x1 = int(np.min(xs))
|
||||
x2 = int(np.max(xs)) + 1
|
||||
y1 = int(np.min(ys))
|
||||
y2 = int(np.max(ys)) + 1
|
||||
if x2 <= x1 or y2 <= y1:
|
||||
return None
|
||||
return (x1, y1, x2, y2)
|
||||
|
||||
|
||||
def _sanitize_bbox(bbox: BBoxXYXY, height: int, width: int) -> BBoxXYXY | None:
|
||||
"""Sanitize bounding box to ensure it's within image bounds.
|
||||
|
||||
Args:
|
||||
bbox: Bounding box in XYXY format (x1, y1, x2, y2).
|
||||
height: Image height.
|
||||
width: Image width.
|
||||
|
||||
Returns:
|
||||
Sanitized bounding box in XYXY format, or None if invalid.
|
||||
"""
|
||||
x1, y1, x2, y2 = bbox
|
||||
x1c = max(0, min(int(x1), width - 1))
|
||||
y1c = max(0, min(int(y1), height - 1))
|
||||
x2c = max(0, min(int(x2), width))
|
||||
y2c = max(0, min(int(y2), height))
|
||||
if x2c <= x1c or y2c <= y1c:
|
||||
return None
|
||||
return (x1c, y1c, x2c, y2c)
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def frame_to_person_mask(
|
||||
result: object, min_area: int = MIN_MASK_AREA
|
||||
) -> tuple[UInt8[ndarray, "h w"], BBoxXYXY] | None:
|
||||
"""Extract person mask and bounding box from detection result.
|
||||
|
||||
Args:
|
||||
result: Detection results object with boxes and masks attributes.
|
||||
min_area: Minimum mask area to consider valid.
|
||||
|
||||
Returns:
|
||||
Tuple of (mask, bbox) where bbox is in XYXY format (x1, y1, x2, y2),
|
||||
or None if no valid detections.
|
||||
"""
|
||||
masks_obj = _read_attr(result, "masks")
|
||||
if masks_obj is None:
|
||||
return None
|
||||
|
||||
masks_data_obj = _read_attr(masks_obj, "data")
|
||||
if masks_data_obj is None:
|
||||
return None
|
||||
|
||||
masks_raw = _to_numpy_array(masks_data_obj)
|
||||
masks_float = np.asarray(masks_raw, dtype=np.float32)
|
||||
if masks_float.ndim == 2:
|
||||
masks_float = masks_float[np.newaxis, ...]
|
||||
if masks_float.ndim != 3:
|
||||
return None
|
||||
mask_count = int(cast(tuple[int, int, int], masks_float.shape)[0])
|
||||
if mask_count <= 0:
|
||||
return None
|
||||
|
||||
box_values: list[tuple[float, float, float, float]] | None = None
|
||||
boxes_obj = _read_attr(result, "boxes")
|
||||
if boxes_obj is not None:
|
||||
xyxy_obj = _read_attr(boxes_obj, "xyxy")
|
||||
if xyxy_obj is not None:
|
||||
xyxy_raw = np.asarray(_to_numpy_array(xyxy_obj), dtype=np.float32)
|
||||
if xyxy_raw.ndim == 1 and int(xyxy_raw.size) >= 4:
|
||||
xyxy_2d = np.asarray(xyxy_raw[:4].reshape(1, 4), dtype=np.float64)
|
||||
x1f = cast(np.float64, xyxy_2d[0, 0])
|
||||
y1f = cast(np.float64, xyxy_2d[0, 1])
|
||||
x2f = cast(np.float64, xyxy_2d[0, 2])
|
||||
y2f = cast(np.float64, xyxy_2d[0, 3])
|
||||
box_values = [
|
||||
(
|
||||
float(x1f),
|
||||
float(y1f),
|
||||
float(x2f),
|
||||
float(y2f),
|
||||
)
|
||||
]
|
||||
elif xyxy_raw.ndim == 2:
|
||||
shape_2d = cast(tuple[int, int], xyxy_raw.shape)
|
||||
if int(shape_2d[1]) >= 4:
|
||||
xyxy_2d = np.asarray(xyxy_raw[:, :4], dtype=np.float64)
|
||||
box_values = []
|
||||
for row_idx in range(int(cast(tuple[int, int], xyxy_2d.shape)[0])):
|
||||
x1f = cast(np.float64, xyxy_2d[row_idx, 0])
|
||||
y1f = cast(np.float64, xyxy_2d[row_idx, 1])
|
||||
x2f = cast(np.float64, xyxy_2d[row_idx, 2])
|
||||
y2f = cast(np.float64, xyxy_2d[row_idx, 3])
|
||||
box_values.append(
|
||||
(
|
||||
float(x1f),
|
||||
float(y1f),
|
||||
float(x2f),
|
||||
float(y2f),
|
||||
)
|
||||
)
|
||||
|
||||
best_area = -1
|
||||
best_mask: UInt8[ndarray, "h w"] | None = None
|
||||
best_bbox: BBoxXYXY | None = None
|
||||
|
||||
for idx in range(mask_count):
|
||||
mask_float = np.asarray(masks_float[idx], dtype=np.float32)
|
||||
if mask_float.ndim != 2:
|
||||
continue
|
||||
mask_binary = np.where(mask_float > 0.5, np.uint8(255), np.uint8(0)).astype(
|
||||
np.uint8
|
||||
)
|
||||
mask_u8 = cast(UInt8[ndarray, "h w"], mask_binary)
|
||||
|
||||
area = int(np.count_nonzero(mask_u8))
|
||||
if area < min_area:
|
||||
continue
|
||||
|
||||
bbox: BBoxXYXY | None = None
|
||||
shape_2d = cast(tuple[int, int], mask_binary.shape)
|
||||
h = int(shape_2d[0])
|
||||
w = int(shape_2d[1])
|
||||
if box_values is not None:
|
||||
box_count = len(box_values)
|
||||
if idx >= box_count:
|
||||
continue
|
||||
row0, row1, row2, row3 = box_values[idx]
|
||||
bbox_candidate = (
|
||||
int(math.floor(row0)),
|
||||
int(math.floor(row1)),
|
||||
int(math.ceil(row2)),
|
||||
int(math.ceil(row3)),
|
||||
)
|
||||
bbox = _sanitize_bbox(bbox_candidate, h, w)
|
||||
|
||||
if bbox is None:
|
||||
bbox = _bbox_from_mask(mask_u8)
|
||||
|
||||
if bbox is None:
|
||||
continue
|
||||
|
||||
if area > best_area:
|
||||
best_area = area
|
||||
best_mask = mask_u8
|
||||
best_bbox = bbox
|
||||
|
||||
if best_mask is None or best_bbox is None:
|
||||
return None
|
||||
|
||||
return best_mask, best_bbox
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def mask_to_silhouette(
|
||||
mask: UInt8[ndarray, "h w"],
|
||||
bbox: BBoxXYXY,
|
||||
) -> Float[ndarray, "64 44"] | None:
|
||||
"""Convert mask to standardized silhouette using bounding box.
|
||||
|
||||
Args:
|
||||
mask: Binary mask array of shape (H, W) with dtype uint8.
|
||||
bbox: Bounding box in XYXY format (x1, y1, x2, y2).
|
||||
|
||||
Returns:
|
||||
Standardized silhouette array of shape (64, 44) with dtype float32,
|
||||
or None if conversion fails.
|
||||
"""
|
||||
mask_u8 = np.where(mask > 0, np.uint8(255), np.uint8(0)).astype(np.uint8)
|
||||
mask_u8 = _fill_binary_holes(mask_u8)
|
||||
if int(np.count_nonzero(mask_u8)) < MIN_MASK_AREA:
|
||||
return None
|
||||
|
||||
mask_shape = cast(tuple[int, int], mask_u8.shape)
|
||||
h = int(mask_shape[0])
|
||||
w = int(mask_shape[1])
|
||||
bbox_sanitized = _sanitize_bbox(bbox, h, w)
|
||||
if bbox_sanitized is None:
|
||||
return None
|
||||
|
||||
x1, y1, x2, y2 = bbox_sanitized
|
||||
cropped = mask_u8[y1:y2, x1:x2]
|
||||
if cropped.size == 0:
|
||||
return None
|
||||
|
||||
cropped_u8 = np.asarray(cropped, dtype=np.uint8)
|
||||
row_sums = np.sum(cropped_u8, axis=1, dtype=np.int64)
|
||||
row_nonzero = np.nonzero(row_sums > 0)[0].astype(np.int64)
|
||||
if int(row_nonzero.size) == 0:
|
||||
return None
|
||||
top = int(cast(np.int64, row_nonzero[0]))
|
||||
bottom = int(cast(np.int64, row_nonzero[-1])) + 1
|
||||
tightened = cropped[top:bottom, :]
|
||||
if tightened.size == 0:
|
||||
return None
|
||||
|
||||
tight_shape = cast(tuple[int, int], tightened.shape)
|
||||
tight_h = int(tight_shape[0])
|
||||
tight_w = int(tight_shape[1])
|
||||
if tight_h <= 0 or tight_w <= 0:
|
||||
return None
|
||||
|
||||
resized_w = max(1, int(SIL_HEIGHT * (tight_w / tight_h)))
|
||||
resized = np.asarray(
|
||||
cv2.resize(tightened, (resized_w, SIL_HEIGHT), interpolation=cv2.INTER_CUBIC),
|
||||
dtype=np.uint8,
|
||||
)
|
||||
|
||||
if resized_w >= SIL_FULL_WIDTH:
|
||||
start = (resized_w - SIL_FULL_WIDTH) // 2
|
||||
normalized_64 = resized[:, start : start + SIL_FULL_WIDTH]
|
||||
else:
|
||||
pad_left = (SIL_FULL_WIDTH - resized_w) // 2
|
||||
pad_right = SIL_FULL_WIDTH - resized_w - pad_left
|
||||
normalized_64 = np.pad(
|
||||
resized,
|
||||
((0, 0), (pad_left, pad_right)),
|
||||
mode="constant",
|
||||
constant_values=0,
|
||||
)
|
||||
|
||||
silhouette = np.asarray(
|
||||
normalized_64[:, SIDE_CUT : SIL_FULL_WIDTH - SIDE_CUT], dtype=np.float32
|
||||
)
|
||||
if silhouette.shape != (SIL_HEIGHT, SIL_WIDTH):
|
||||
return None
|
||||
|
||||
silhouette_norm = np.clip(silhouette / np.float32(255.0), 0.0, 1.0).astype(
|
||||
np.float32
|
||||
)
|
||||
return cast(Float[ndarray, "64 44"], silhouette_norm)
|
||||
@@ -0,0 +1,312 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Protocol, cast, override
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from beartype import beartype
|
||||
from einops import rearrange
|
||||
from jaxtyping import Float
|
||||
import jaxtyping
|
||||
from torch import Tensor
|
||||
|
||||
from opengait.modeling.backbones.resnet import ResNet9
|
||||
from opengait.modeling.modules import (
|
||||
HorizontalPoolingPyramid,
|
||||
PackSequenceWrapper as TemporalPool,
|
||||
SeparateBNNecks,
|
||||
SeparateFCs,
|
||||
)
|
||||
from opengait.utils import common as common_utils
|
||||
|
||||
|
||||
JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
|
||||
JaxtypedFactory = Callable[..., JaxtypedDecorator]
|
||||
jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
|
||||
ConfigLoader = Callable[[str], dict[str, object]]
|
||||
config_loader = cast(ConfigLoader, common_utils.config_loader)
|
||||
|
||||
|
||||
class TemporalPoolLike(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
seqs: Tensor,
|
||||
seqL: object,
|
||||
dim: int = 2,
|
||||
options: dict[str, int] | None = None,
|
||||
) -> object: ...
|
||||
|
||||
|
||||
class HppLike(Protocol):
|
||||
def __call__(self, x: Tensor) -> Tensor: ...
|
||||
|
||||
|
||||
class FCsLike(Protocol):
|
||||
def __call__(self, x: Tensor) -> Tensor: ...
|
||||
|
||||
|
||||
class BNNecksLike(Protocol):
|
||||
def __call__(self, x: Tensor) -> tuple[Tensor, Tensor]: ...
|
||||
|
||||
|
||||
class ScoNetDemo(nn.Module):
|
||||
LABEL_MAP: ClassVar[dict[int, str]] = {0: "negative", 1: "neutral", 2: "positive"}
|
||||
cfg_path: str
|
||||
cfg: dict[str, object]
|
||||
backbone: ResNet9
|
||||
temporal_pool: TemporalPoolLike
|
||||
hpp: HppLike
|
||||
fcs: FCsLike
|
||||
bn_necks: BNNecksLike
|
||||
device: torch.device
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def __init__(
|
||||
self,
|
||||
cfg_path: str | Path = "configs/sconet/sconet_scoliosis1k.yaml",
|
||||
checkpoint_path: str | Path | None = None,
|
||||
device: str | torch.device | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
resolved_cfg = self._resolve_path(cfg_path)
|
||||
self.cfg_path = str(resolved_cfg)
|
||||
|
||||
self.cfg = config_loader(self.cfg_path)
|
||||
|
||||
model_cfg = self._extract_model_cfg(self.cfg)
|
||||
backbone_cfg = self._extract_dict(model_cfg, "backbone_cfg")
|
||||
|
||||
if backbone_cfg.get("type") != "ResNet9":
|
||||
raise ValueError(
|
||||
"ScoNetDemo currently supports backbone type ResNet9 only."
|
||||
)
|
||||
|
||||
self.backbone = ResNet9(
|
||||
block=self._extract_str(backbone_cfg, "block"),
|
||||
channels=self._extract_int_list(backbone_cfg, "channels"),
|
||||
in_channel=self._extract_int(backbone_cfg, "in_channel", default=1),
|
||||
layers=self._extract_int_list(backbone_cfg, "layers"),
|
||||
strides=self._extract_int_list(backbone_cfg, "strides"),
|
||||
maxpool=self._extract_bool(backbone_cfg, "maxpool", default=True),
|
||||
)
|
||||
|
||||
fcs_cfg = self._extract_dict(model_cfg, "SeparateFCs")
|
||||
bn_cfg = self._extract_dict(model_cfg, "SeparateBNNecks")
|
||||
bin_num = self._extract_int_list(model_cfg, "bin_num")
|
||||
|
||||
self.temporal_pool = cast(TemporalPoolLike, TemporalPool(torch.max))
|
||||
self.hpp = cast(HppLike, HorizontalPoolingPyramid(bin_num=bin_num))
|
||||
self.fcs = cast(
|
||||
FCsLike,
|
||||
SeparateFCs(
|
||||
parts_num=self._extract_int(fcs_cfg, "parts_num"),
|
||||
in_channels=self._extract_int(fcs_cfg, "in_channels"),
|
||||
out_channels=self._extract_int(fcs_cfg, "out_channels"),
|
||||
norm=self._extract_bool(fcs_cfg, "norm", default=False),
|
||||
),
|
||||
)
|
||||
self.bn_necks = cast(
|
||||
BNNecksLike,
|
||||
SeparateBNNecks(
|
||||
parts_num=self._extract_int(bn_cfg, "parts_num"),
|
||||
in_channels=self._extract_int(bn_cfg, "in_channels"),
|
||||
class_num=self._extract_int(bn_cfg, "class_num"),
|
||||
norm=self._extract_bool(bn_cfg, "norm", default=True),
|
||||
parallel_BN1d=self._extract_bool(bn_cfg, "parallel_BN1d", default=True),
|
||||
),
|
||||
)
|
||||
|
||||
self.device = (
|
||||
torch.device(device) if device is not None else torch.device("cpu")
|
||||
)
|
||||
_ = self.to(self.device)
|
||||
|
||||
if checkpoint_path is not None:
|
||||
_ = self.load_checkpoint(checkpoint_path)
|
||||
|
||||
_ = self.eval()
|
||||
|
||||
@staticmethod
|
||||
def _resolve_path(path: str | Path) -> Path:
|
||||
candidate = Path(path)
|
||||
if candidate.is_file():
|
||||
return candidate
|
||||
if candidate.is_absolute():
|
||||
return candidate
|
||||
repo_root = Path(__file__).resolve().parents[2]
|
||||
return repo_root / candidate
|
||||
|
||||
@staticmethod
|
||||
def _extract_model_cfg(cfg: dict[str, object]) -> dict[str, object]:
|
||||
model_cfg_obj = cfg.get("model_cfg")
|
||||
if not isinstance(model_cfg_obj, dict):
|
||||
raise TypeError("model_cfg must be a dictionary.")
|
||||
return cast(dict[str, object], model_cfg_obj)
|
||||
|
||||
@staticmethod
|
||||
def _extract_dict(container: dict[str, object], key: str) -> dict[str, object]:
|
||||
value = container.get(key)
|
||||
if not isinstance(value, dict):
|
||||
raise TypeError(f"{key} must be a dictionary.")
|
||||
return cast(dict[str, object], value)
|
||||
|
||||
@staticmethod
|
||||
def _extract_str(container: dict[str, object], key: str) -> str:
|
||||
value = container.get(key)
|
||||
if not isinstance(value, str):
|
||||
raise TypeError(f"{key} must be a string.")
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _extract_int(
|
||||
container: dict[str, object], key: str, default: int | None = None
|
||||
) -> int:
|
||||
value = container.get(key, default)
|
||||
if not isinstance(value, int):
|
||||
raise TypeError(f"{key} must be an int.")
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _extract_bool(
|
||||
container: dict[str, object], key: str, default: bool | None = None
|
||||
) -> bool:
|
||||
value = container.get(key, default)
|
||||
if not isinstance(value, bool):
|
||||
raise TypeError(f"{key} must be a bool.")
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _extract_int_list(container: dict[str, object], key: str) -> list[int]:
|
||||
value = container.get(key)
|
||||
if not isinstance(value, list):
|
||||
raise TypeError(f"{key} must be a list[int].")
|
||||
values = cast(list[object], value)
|
||||
if not all(isinstance(v, int) for v in values):
|
||||
raise TypeError(f"{key} must be a list[int].")
|
||||
return cast(list[int], values)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_state_dict(
|
||||
state_dict_obj: dict[object, object],
|
||||
) -> dict[str, Tensor]:
|
||||
prefix_remap: tuple[tuple[str, str], ...] = (
|
||||
("Backbone.forward_block.", "backbone."),
|
||||
("FCs.", "fcs."),
|
||||
("BNNecks.", "bn_necks."),
|
||||
)
|
||||
cleaned_state_dict: dict[str, Tensor] = {}
|
||||
for key_obj, value_obj in state_dict_obj.items():
|
||||
if not isinstance(key_obj, str):
|
||||
raise TypeError("Checkpoint state_dict keys must be strings.")
|
||||
if not isinstance(value_obj, Tensor):
|
||||
raise TypeError("Checkpoint state_dict values must be torch.Tensor.")
|
||||
key = key_obj[7:] if key_obj.startswith("module.") else key_obj
|
||||
for source_prefix, target_prefix in prefix_remap:
|
||||
if key.startswith(source_prefix):
|
||||
key = f"{target_prefix}{key[len(source_prefix) :]}"
|
||||
break
|
||||
if key in cleaned_state_dict:
|
||||
raise RuntimeError(
|
||||
f"Checkpoint key normalization collision detected for key '{key}'."
|
||||
)
|
||||
cleaned_state_dict[key] = value_obj
|
||||
return cleaned_state_dict
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def load_checkpoint(
|
||||
self,
|
||||
checkpoint_path: str | Path,
|
||||
map_location: str | torch.device | None = None,
|
||||
strict: bool = True,
|
||||
) -> None:
|
||||
resolved_ckpt = self._resolve_path(checkpoint_path)
|
||||
checkpoint_obj = cast(
|
||||
object,
|
||||
torch.load(
|
||||
str(resolved_ckpt),
|
||||
map_location=map_location if map_location is not None else self.device,
|
||||
),
|
||||
)
|
||||
|
||||
state_dict_obj: object = checkpoint_obj
|
||||
if isinstance(checkpoint_obj, dict) and "model" in checkpoint_obj:
|
||||
state_dict_obj = cast(dict[str, object], checkpoint_obj)["model"]
|
||||
|
||||
if not isinstance(state_dict_obj, dict):
|
||||
raise TypeError("Unsupported checkpoint format.")
|
||||
|
||||
cleaned_state_dict = self._normalize_state_dict(
|
||||
cast(dict[object, object], state_dict_obj)
|
||||
)
|
||||
try:
|
||||
_ = self.load_state_dict(cleaned_state_dict, strict=strict)
|
||||
except RuntimeError as exc:
|
||||
raise RuntimeError(
|
||||
f"Failed to load ScoNetDemo checkpoint after key normalization from '{resolved_ckpt}'."
|
||||
) from exc
|
||||
_ = self.eval()
|
||||
|
||||
def _prepare_sils(self, sils: Tensor) -> Tensor:
|
||||
if sils.ndim == 4:
|
||||
sils = sils.unsqueeze(1)
|
||||
elif sils.ndim == 5 and sils.shape[1] != 1 and sils.shape[2] == 1:
|
||||
sils = rearrange(sils, "b s c h w -> b c s h w")
|
||||
|
||||
if sils.ndim != 5 or sils.shape[1] != 1:
|
||||
raise ValueError("Expected sils shape [B, 1, S, H, W] or [B, S, H, W].")
|
||||
|
||||
return sils.float().to(self.device)
|
||||
|
||||
def _forward_backbone(self, sils: Tensor) -> Tensor:
|
||||
batch, channels, seq, height, width = sils.shape
|
||||
framewise = sils.transpose(1, 2).reshape(batch * seq, channels, height, width)
|
||||
frame_feats = cast(Tensor, self.backbone(framewise))
|
||||
_, out_channels, out_h, out_w = frame_feats.shape
|
||||
return (
|
||||
frame_feats.reshape(batch, seq, out_channels, out_h, out_w)
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
@override
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def forward(self, sils: Float[Tensor, "batch 1 seq 64 44"]) -> dict[str, Tensor]:
|
||||
with torch.inference_mode():
|
||||
prepared_sils = self._prepare_sils(sils)
|
||||
outs = self._forward_backbone(prepared_sils)
|
||||
|
||||
pooled_obj = self.temporal_pool(outs, None, options={"dim": 2})
|
||||
if (
|
||||
not isinstance(pooled_obj, tuple)
|
||||
or not pooled_obj
|
||||
or not isinstance(pooled_obj[0], Tensor)
|
||||
):
|
||||
raise TypeError("TemporalPool output is invalid.")
|
||||
pooled = pooled_obj[0]
|
||||
|
||||
feat = self.hpp(pooled)
|
||||
embed_1 = self.fcs(feat)
|
||||
_, logits = self.bn_necks(embed_1)
|
||||
|
||||
mean_logits = logits.mean(dim=-1)
|
||||
pred_ids = torch.argmax(mean_logits, dim=-1)
|
||||
probs = torch.softmax(mean_logits, dim=-1)
|
||||
confidence = torch.gather(
|
||||
probs, dim=-1, index=pred_ids.unsqueeze(-1)
|
||||
).squeeze(-1)
|
||||
|
||||
return {"logits": logits, "label": pred_ids, "confidence": confidence}
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
def predict(self, sils: Float[Tensor, "batch 1 seq 64 44"]) -> tuple[str, float]:
|
||||
outputs = cast(dict[str, Tensor], self.forward(sils))
|
||||
labels = outputs["label"]
|
||||
confidence = outputs["confidence"]
|
||||
|
||||
if labels.numel() != 1:
|
||||
raise ValueError("predict expects batch size 1.")
|
||||
|
||||
label_id = int(labels.item())
|
||||
return self.LABEL_MAP[label_id], float(confidence.item())
|
||||
@@ -0,0 +1,767 @@
|
||||
"""OpenCV-based visualizer for demo pipeline.
|
||||
|
||||
Provides real-time visualization of detection, segmentation, and classification results
|
||||
with interactive mode switching for mask display.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from .preprocess import BBoxXYXY
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Window names
|
||||
MAIN_WINDOW = "Scoliosis Detection"
|
||||
SEG_WINDOW = "Normalized Silhouette"
|
||||
RAW_WINDOW = "Raw Mask"
|
||||
WINDOW_SEG_INPUT = "Segmentation Input"
|
||||
|
||||
# Silhouette dimensions (from preprocess.py)
|
||||
SIL_HEIGHT = 64
|
||||
SIL_WIDTH = 44
|
||||
|
||||
# Display dimensions for upscaled silhouette
|
||||
DISPLAY_HEIGHT = 256
|
||||
DISPLAY_WIDTH = 176
|
||||
RAW_STATS_PAD = 54
|
||||
MODE_LABEL_PAD = 26
|
||||
|
||||
# Colors (BGR)
|
||||
COLOR_GREEN = (0, 255, 0)
|
||||
COLOR_WHITE = (255, 255, 255)
|
||||
COLOR_BLACK = (0, 0, 0)
|
||||
COLOR_DARK_GRAY = (56, 56, 56)
|
||||
COLOR_RED = (0, 0, 255)
|
||||
COLOR_YELLOW = (0, 255, 255)
|
||||
# Type alias for image arrays (NDArray or cv2.Mat)
|
||||
COLOR_CYAN = (255, 255, 0)
|
||||
COLOR_ORANGE = (0, 165, 255)
|
||||
COLOR_MAGENTA = (255, 0, 255)
|
||||
ImageArray = NDArray[np.uint8]
|
||||
|
||||
# COCO-format skeleton connections (17 keypoints)
|
||||
# Connections are pairs of keypoint indices
|
||||
SKELETON_CONNECTIONS: list[tuple[int, int]] = [
|
||||
(0, 1), # nose -> left_eye
|
||||
(0, 2), # nose -> right_eye
|
||||
(1, 3), # left_eye -> left_ear
|
||||
(2, 4), # right_eye -> right_ear
|
||||
(5, 6), # left_shoulder -> right_shoulder
|
||||
(5, 7), # left_shoulder -> left_elbow
|
||||
(7, 9), # left_elbow -> left_wrist
|
||||
(6, 8), # right_shoulder -> right_elbow
|
||||
(8, 10), # right_elbow -> right_wrist
|
||||
(11, 12), # left_hip -> right_hip
|
||||
(5, 11), # left_shoulder -> left_hip
|
||||
(6, 12), # right_shoulder -> right_hip
|
||||
(11, 13), # left_hip -> left_knee
|
||||
(13, 15), # left_knee -> left_ankle
|
||||
(12, 14), # right_hip -> right_knee
|
||||
(14, 16), # right_knee -> right_ankle
|
||||
]
|
||||
|
||||
# Keypoint names for COCO format (17 keypoints)
|
||||
KEYPOINT_NAMES: list[str] = [
|
||||
"nose", "left_eye", "right_eye", "left_ear", "right_ear",
|
||||
"left_shoulder", "right_shoulder", "left_elbow", "right_elbow",
|
||||
"left_wrist", "right_wrist", "left_hip", "right_hip",
|
||||
"left_knee", "right_knee", "left_ankle", "right_ankle"
|
||||
]
|
||||
|
||||
# Joints where angles are typically calculated (for scoliosis/ gait analysis)
|
||||
ANGLE_JOINTS: list[tuple[int, int, int]] = [
|
||||
(5, 7, 9), # left_shoulder -> left_elbow -> left_wrist
|
||||
(6, 8, 10), # right_shoulder -> right_elbow -> right_wrist
|
||||
(7, 5, 11), # left_elbow -> left_shoulder -> left_hip
|
||||
(8, 6, 12), # right_elbow -> right_shoulder -> right_hip
|
||||
(5, 11, 13), # left_shoulder -> left_hip -> left_knee
|
||||
(6, 12, 14), # right_shoulder -> right_hip -> right_knee
|
||||
(11, 13, 15),# left_hip -> left_knee -> left_ankle
|
||||
(12, 14, 16),# right_hip -> right_knee -> right_ankle
|
||||
]
|
||||
|
||||
|
||||
|
||||
class OpenCVVisualizer:
|
||||
def __init__(self) -> None:
|
||||
self.show_raw_window: bool = False
|
||||
self.show_raw_debug: bool = False
|
||||
self._windows_created: bool = False
|
||||
self._raw_window_created: bool = False
|
||||
|
||||
def _ensure_windows(self) -> None:
|
||||
if not self._windows_created:
|
||||
cv2.namedWindow(MAIN_WINDOW, cv2.WINDOW_NORMAL)
|
||||
cv2.namedWindow(SEG_WINDOW, cv2.WINDOW_NORMAL)
|
||||
cv2.namedWindow(WINDOW_SEG_INPUT, cv2.WINDOW_NORMAL)
|
||||
self._windows_created = True
|
||||
|
||||
def _ensure_raw_window(self) -> None:
|
||||
if not self._raw_window_created:
|
||||
cv2.namedWindow(RAW_WINDOW, cv2.WINDOW_NORMAL)
|
||||
self._raw_window_created = True
|
||||
|
||||
def _hide_raw_window(self) -> None:
|
||||
if self._raw_window_created:
|
||||
cv2.destroyWindow(RAW_WINDOW)
|
||||
self._raw_window_created = False
|
||||
|
||||
def _draw_bbox(
|
||||
self,
|
||||
frame: ImageArray,
|
||||
bbox: BBoxXYXY | None,
|
||||
) -> None:
|
||||
"""Draw bounding box on frame if present.
|
||||
|
||||
Args:
|
||||
frame: Input frame (H, W, 3) uint8 - modified in place
|
||||
bbox: Bounding box in XYXY format as (x1, y1, x2, y2) or None
|
||||
"""
|
||||
if bbox is None:
|
||||
return
|
||||
|
||||
x1, y1, x2, y2 = bbox
|
||||
# Draw rectangle with green color, thickness 2
|
||||
_ = cv2.rectangle(frame, (x1, y1), (x2, y2), COLOR_GREEN, 2)
|
||||
|
||||
def _draw_text_overlay(
|
||||
self,
|
||||
frame: ImageArray,
|
||||
track_id: int,
|
||||
fps: float,
|
||||
label: str | None,
|
||||
confidence: float | None,
|
||||
) -> None:
|
||||
"""Draw text overlay with track info, FPS, label, and confidence.
|
||||
|
||||
Args:
|
||||
frame: Input frame (H, W, 3) uint8 - modified in place
|
||||
track_id: Tracking ID
|
||||
fps: Current FPS
|
||||
label: Classification label or None
|
||||
confidence: Classification confidence or None
|
||||
"""
|
||||
# Prepare text lines
|
||||
lines: list[str] = []
|
||||
lines.append(f"ID: {track_id}")
|
||||
lines.append(f"FPS: {fps:.1f}")
|
||||
|
||||
if label is not None:
|
||||
if confidence is not None:
|
||||
lines.append(f"{label}: {confidence:.2%}")
|
||||
else:
|
||||
lines.append(label)
|
||||
|
||||
# Draw text with background for readability
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
font_scale = 0.6
|
||||
thickness = 1
|
||||
line_height = 25
|
||||
margin = 10
|
||||
|
||||
for i, text in enumerate(lines):
|
||||
y_pos = margin + (i + 1) * line_height
|
||||
|
||||
# Draw background rectangle
|
||||
(text_width, text_height), _ = cv2.getTextSize(
|
||||
text, font, font_scale, thickness
|
||||
)
|
||||
_ = cv2.rectangle(
|
||||
frame,
|
||||
(margin, y_pos - text_height - 5),
|
||||
(margin + text_width + 10, y_pos + 5),
|
||||
COLOR_BLACK,
|
||||
-1,
|
||||
)
|
||||
|
||||
# Draw text
|
||||
_ = cv2.putText(
|
||||
frame,
|
||||
text,
|
||||
(margin + 5, y_pos),
|
||||
font,
|
||||
font_scale,
|
||||
COLOR_WHITE,
|
||||
thickness,
|
||||
)
|
||||
|
||||
def _draw_pose_skeleton(
|
||||
self,
|
||||
frame: ImageArray,
|
||||
pose_data: dict[str, object] | None,
|
||||
) -> None:
|
||||
"""Draw pose skeleton on frame.
|
||||
|
||||
Args:
|
||||
frame: Input frame (H, W, 3) uint8 - modified in place
|
||||
pose_data: Pose data dictionary from Sports2D or similar
|
||||
Expected format: {'keypoints': [[x1, y1], [x2, y2], ...],
|
||||
'confidence': [c1, c2, ...],
|
||||
'angles': {'joint_name': angle, ...}}
|
||||
"""
|
||||
if pose_data is None:
|
||||
return
|
||||
|
||||
keypoints_obj = pose_data.get('keypoints')
|
||||
if keypoints_obj is None:
|
||||
return
|
||||
|
||||
# Convert keypoints to numpy array
|
||||
keypoints = np.asarray(keypoints_obj, dtype=np.float32)
|
||||
if keypoints.size == 0:
|
||||
return
|
||||
|
||||
h, w = frame.shape[:2]
|
||||
|
||||
# Get confidence scores if available
|
||||
confidence_obj = pose_data.get('confidence')
|
||||
confidences = (
|
||||
np.asarray(confidence_obj, dtype=np.float32)
|
||||
if confidence_obj is not None
|
||||
else np.ones(len(keypoints), dtype=np.float32)
|
||||
)
|
||||
|
||||
# Draw skeleton connections
|
||||
for connection in SKELETON_CONNECTIONS:
|
||||
idx1, idx2 = connection
|
||||
if idx1 < len(keypoints) and idx2 < len(keypoints):
|
||||
# Check confidence threshold (0.3)
|
||||
if confidences[idx1] > 0.3 and confidences[idx2] > 0.3:
|
||||
pt1 = (int(keypoints[idx1][0]), int(keypoints[idx1][1]))
|
||||
pt2 = (int(keypoints[idx2][0]), int(keypoints[idx2][1]))
|
||||
# Clip to frame bounds
|
||||
pt1 = (max(0, min(w - 1, pt1[0])), max(0, min(h - 1, pt1[1])))
|
||||
pt2 = (max(0, min(w - 1, pt2[0])), max(0, min(h - 1, pt2[1])))
|
||||
_ = cv2.line(frame, pt1, pt2, COLOR_CYAN, 2)
|
||||
|
||||
# Draw keypoints
|
||||
for i, (kp, conf) in enumerate(zip(keypoints, confidences)):
|
||||
if conf > 0.3 and i < len(keypoints):
|
||||
x, y = int(kp[0]), int(kp[1])
|
||||
# Clip to frame bounds
|
||||
x = max(0, min(w - 1, x))
|
||||
y = max(0, min(h - 1, y))
|
||||
# Draw keypoint as circle
|
||||
_ = cv2.circle(frame, (x, y), 4, COLOR_MAGENTA, -1)
|
||||
_ = cv2.circle(frame, (x, y), 4, COLOR_WHITE, 1)
|
||||
|
||||
def _draw_pose_angles(
|
||||
self,
|
||||
frame: ImageArray,
|
||||
pose_data: dict[str, object] | None,
|
||||
) -> None:
|
||||
"""Draw pose angles as text overlay.
|
||||
|
||||
Args:
|
||||
frame: Input frame (H, W, 3) uint8 - modified in place
|
||||
pose_data: Pose data dictionary with 'angles' key
|
||||
"""
|
||||
if pose_data is None:
|
||||
return
|
||||
|
||||
angles_obj = pose_data.get('angles')
|
||||
if angles_obj is None:
|
||||
return
|
||||
|
||||
angles = cast(dict[str, float], angles_obj)
|
||||
if not angles:
|
||||
return
|
||||
|
||||
# Draw angles in top-right corner
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
font_scale = 0.45
|
||||
thickness = 1
|
||||
line_height = 20
|
||||
margin = 10
|
||||
h, w = frame.shape[:2]
|
||||
|
||||
# Filter and format angles
|
||||
angle_texts: list[tuple[str, float]] = []
|
||||
for name, angle in angles.items():
|
||||
# Only show angles that are reasonable (0-180 degrees)
|
||||
if 0 <= angle <= 180:
|
||||
angle_texts.append((str(name), float(angle)))
|
||||
|
||||
# Sort by name for consistent display
|
||||
angle_texts.sort(key=lambda x: x[0])
|
||||
|
||||
# Draw from top-right
|
||||
for i, (name, angle) in enumerate(angle_texts[:8]): # Limit to 8 angles
|
||||
text = f"{name}: {angle:.1f}"
|
||||
(text_width, text_height), _ = cv2.getTextSize(
|
||||
text, font, font_scale, thickness
|
||||
)
|
||||
x_pos = w - margin - text_width - 10
|
||||
y_pos = margin + (i + 1) * line_height
|
||||
|
||||
# Draw background rectangle
|
||||
_ = cv2.rectangle(
|
||||
frame,
|
||||
(x_pos - 4, y_pos - text_height - 4),
|
||||
(x_pos + text_width + 4, y_pos + 4),
|
||||
COLOR_BLACK,
|
||||
-1,
|
||||
)
|
||||
# Draw text in orange
|
||||
_ = cv2.putText(
|
||||
frame,
|
||||
text,
|
||||
(x_pos, y_pos),
|
||||
font,
|
||||
font_scale,
|
||||
COLOR_ORANGE,
|
||||
thickness,
|
||||
)
|
||||
|
||||
def _prepare_main_frame(
|
||||
self,
|
||||
frame: ImageArray,
|
||||
bbox: BBoxXYXY | None,
|
||||
track_id: int,
|
||||
fps: float,
|
||||
label: str | None,
|
||||
confidence: float | None,
|
||||
pose_data: dict[str, object] | None = None,
|
||||
) -> ImageArray:
|
||||
"""Prepare main display frame with bbox and text overlay.
|
||||
|
||||
Args:
|
||||
frame: Input frame (H, W, C) uint8
|
||||
bbox: Bounding box in XYXY format (x1, y1, x2, y2) or None
|
||||
track_id: Tracking ID
|
||||
fps: Current FPS
|
||||
label: Classification label or None
|
||||
confidence: Classification confidence or None
|
||||
pose_data: Pose data dictionary or None
|
||||
|
||||
Returns:
|
||||
Processed frame ready for display
|
||||
"""
|
||||
# Ensure BGR format (convert grayscale if needed)
|
||||
if len(frame.shape) == 2:
|
||||
display_frame = cast(ImageArray, cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR))
|
||||
elif frame.shape[2] == 1:
|
||||
display_frame = cast(ImageArray, cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR))
|
||||
elif frame.shape[2] == 3:
|
||||
display_frame = frame.copy()
|
||||
elif frame.shape[2] == 4:
|
||||
display_frame = cast(ImageArray, cv2.cvtColor(frame, cv2.COLOR_BGRA2BGR))
|
||||
else:
|
||||
display_frame = frame.copy()
|
||||
|
||||
# Draw bbox and text (modifies in place)
|
||||
self._draw_bbox(display_frame, bbox)
|
||||
self._draw_text_overlay(display_frame, track_id, fps, label, confidence)
|
||||
|
||||
# Draw pose skeleton and angles if available
|
||||
self._draw_pose_skeleton(display_frame, pose_data)
|
||||
self._draw_pose_angles(display_frame, pose_data)
|
||||
|
||||
return display_frame
|
||||
|
||||
def _upscale_silhouette(
|
||||
self,
|
||||
silhouette: NDArray[np.float32] | NDArray[np.uint8],
|
||||
) -> ImageArray:
|
||||
"""Upscale silhouette to display size.
|
||||
|
||||
Args:
|
||||
silhouette: Input silhouette (64, 44) float32 [0,1] or uint8 [0,255]
|
||||
|
||||
Returns:
|
||||
Upscaled silhouette (256, 176) uint8
|
||||
"""
|
||||
# Normalize to uint8 if needed
|
||||
if silhouette.dtype == np.float32 or silhouette.dtype == np.float64:
|
||||
sil_u8 = (silhouette * 255).astype(np.uint8)
|
||||
else:
|
||||
sil_u8 = silhouette.astype(np.uint8)
|
||||
|
||||
# Upscale using nearest neighbor to preserve pixelation
|
||||
upscaled = cast(
|
||||
ImageArray,
|
||||
cv2.resize(
|
||||
sil_u8,
|
||||
(DISPLAY_WIDTH, DISPLAY_HEIGHT),
|
||||
interpolation=cv2.INTER_NEAREST,
|
||||
),
|
||||
)
|
||||
|
||||
return upscaled
|
||||
|
||||
def _normalize_mask_for_display(self, mask: NDArray[np.generic]) -> ImageArray:
|
||||
mask_array = np.asarray(mask)
|
||||
if mask_array.dtype == np.bool_:
|
||||
bool_scaled = np.where(mask_array, np.uint8(255), np.uint8(0)).astype(
|
||||
np.uint8
|
||||
)
|
||||
return cast(ImageArray, bool_scaled)
|
||||
|
||||
if mask_array.dtype == np.uint8:
|
||||
mask_array = cast(ImageArray, mask_array)
|
||||
max_u8 = int(np.max(mask_array)) if mask_array.size > 0 else 0
|
||||
if max_u8 <= 1:
|
||||
scaled_u8 = np.where(mask_array > 0, np.uint8(255), np.uint8(0)).astype(
|
||||
np.uint8
|
||||
)
|
||||
return cast(ImageArray, scaled_u8)
|
||||
return cast(ImageArray, mask_array)
|
||||
|
||||
if np.issubdtype(mask_array.dtype, np.integer):
|
||||
max_int = float(np.max(mask_array)) if mask_array.size > 0 else 0.0
|
||||
if max_int <= 1.0:
|
||||
return cast(
|
||||
ImageArray, (mask_array.astype(np.float32) * 255.0).astype(np.uint8)
|
||||
)
|
||||
clipped = np.clip(mask_array, 0, 255).astype(np.uint8)
|
||||
return cast(ImageArray, clipped)
|
||||
|
||||
mask_float = np.asarray(mask_array, dtype=np.float32)
|
||||
max_val = float(np.max(mask_float)) if mask_float.size > 0 else 0.0
|
||||
if max_val <= 0.0:
|
||||
return np.zeros(mask_float.shape, dtype=np.uint8)
|
||||
|
||||
normalized = np.clip((mask_float / max_val) * 255.0, 0.0, 255.0).astype(
|
||||
np.uint8
|
||||
)
|
||||
return cast(ImageArray, normalized)
|
||||
|
||||
def _draw_raw_stats(self, image: ImageArray, mask_raw: ImageArray | None) -> None:
|
||||
if mask_raw is None:
|
||||
return
|
||||
|
||||
mask = np.asarray(mask_raw)
|
||||
if mask.size == 0:
|
||||
return
|
||||
|
||||
stats = [
|
||||
f"raw: {mask.dtype}",
|
||||
f"min/max: {float(mask.min()):.3f}/{float(mask.max()):.3f}",
|
||||
f"nnz: {int(np.count_nonzero(mask))}",
|
||||
]
|
||||
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
font_scale = 0.45
|
||||
thickness = 1
|
||||
line_h = 18
|
||||
x0 = 8
|
||||
y0 = 20
|
||||
|
||||
for i, txt in enumerate(stats):
|
||||
y = y0 + i * line_h
|
||||
(tw, th), _ = cv2.getTextSize(txt, font, font_scale, thickness)
|
||||
_ = cv2.rectangle(
|
||||
image, (x0 - 4, y - th - 4), (x0 + tw + 4, y + 4), COLOR_BLACK, -1
|
||||
)
|
||||
_ = cv2.putText(
|
||||
image, txt, (x0, y), font, font_scale, COLOR_YELLOW, thickness
|
||||
)
|
||||
|
||||
def _prepare_segmentation_view(
|
||||
self,
|
||||
mask_raw: ImageArray | None,
|
||||
silhouette: NDArray[np.float32] | None,
|
||||
bbox: BBoxXYXY | None,
|
||||
) -> ImageArray:
|
||||
_ = mask_raw
|
||||
_ = bbox
|
||||
return self._prepare_normalized_view(silhouette)
|
||||
|
||||
def _fit_gray_to_display(
|
||||
self,
|
||||
gray: ImageArray,
|
||||
out_h: int = DISPLAY_HEIGHT,
|
||||
out_w: int = DISPLAY_WIDTH,
|
||||
) -> ImageArray:
|
||||
src_h, src_w = gray.shape[:2]
|
||||
if src_h <= 0 or src_w <= 0:
|
||||
return np.zeros((out_h, out_w), dtype=np.uint8)
|
||||
|
||||
scale = min(out_w / src_w, out_h / src_h)
|
||||
new_w = max(1, int(round(src_w * scale)))
|
||||
new_h = max(1, int(round(src_h * scale)))
|
||||
|
||||
resized = cast(
|
||||
ImageArray,
|
||||
cv2.resize(gray, (new_w, new_h), interpolation=cv2.INTER_NEAREST),
|
||||
)
|
||||
canvas = np.zeros((out_h, out_w), dtype=np.uint8)
|
||||
x0 = (out_w - new_w) // 2
|
||||
y0 = (out_h - new_h) // 2
|
||||
canvas[y0 : y0 + new_h, x0 : x0 + new_w] = resized
|
||||
return cast(ImageArray, canvas)
|
||||
|
||||
def _crop_mask_to_bbox(
|
||||
self,
|
||||
mask_gray: ImageArray,
|
||||
bbox: BBoxXYXY | None,
|
||||
) -> ImageArray:
|
||||
if bbox is None:
|
||||
return mask_gray
|
||||
|
||||
h, w = mask_gray.shape[:2]
|
||||
x1, y1, x2, y2 = bbox
|
||||
x1c = max(0, min(w, int(x1)))
|
||||
x2c = max(0, min(w, int(x2)))
|
||||
y1c = max(0, min(h, int(y1)))
|
||||
y2c = max(0, min(h, int(y2)))
|
||||
|
||||
if x2c <= x1c or y2c <= y1c:
|
||||
return mask_gray
|
||||
|
||||
cropped = mask_gray[y1c:y2c, x1c:x2c]
|
||||
if cropped.size == 0:
|
||||
return mask_gray
|
||||
return cast(ImageArray, cropped)
|
||||
|
||||
def _prepare_segmentation_input_view(
|
||||
self,
|
||||
silhouettes: NDArray[np.float32] | None,
|
||||
) -> ImageArray:
|
||||
if silhouettes is None or silhouettes.size == 0:
|
||||
placeholder = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH, 3), dtype=np.uint8)
|
||||
self._draw_mode_indicator(placeholder, "Input Silhouettes (No Data)")
|
||||
return placeholder
|
||||
|
||||
n_frames = int(silhouettes.shape[0])
|
||||
tiles_per_row = int(np.ceil(np.sqrt(n_frames)))
|
||||
rows = int(np.ceil(n_frames / tiles_per_row))
|
||||
|
||||
tile_h = DISPLAY_HEIGHT
|
||||
tile_w = DISPLAY_WIDTH
|
||||
grid = np.zeros((rows * tile_h, tiles_per_row * tile_w), dtype=np.uint8)
|
||||
|
||||
for idx in range(n_frames):
|
||||
sil = silhouettes[idx]
|
||||
tile = self._upscale_silhouette(sil)
|
||||
r = idx // tiles_per_row
|
||||
c = idx % tiles_per_row
|
||||
y0, y1 = r * tile_h, (r + 1) * tile_h
|
||||
x0, x1 = c * tile_w, (c + 1) * tile_w
|
||||
grid[y0:y1, x0:x1] = tile
|
||||
|
||||
grid_bgr = cast(ImageArray, cv2.cvtColor(grid, cv2.COLOR_GRAY2BGR))
|
||||
|
||||
for idx in range(n_frames):
|
||||
r = idx // tiles_per_row
|
||||
c = idx % tiles_per_row
|
||||
y0 = r * tile_h
|
||||
x0 = c * tile_w
|
||||
cv2.putText(
|
||||
grid_bgr,
|
||||
str(idx),
|
||||
(x0 + 8, y0 + 22),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.6,
|
||||
(0, 255, 255),
|
||||
2,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
|
||||
return grid_bgr
|
||||
|
||||
def _prepare_raw_view(
|
||||
self,
|
||||
mask_raw: ImageArray | None,
|
||||
bbox: BBoxXYXY | None = None,
|
||||
) -> ImageArray:
|
||||
"""Prepare raw mask view.
|
||||
|
||||
Args:
|
||||
mask_raw: Raw binary mask or None
|
||||
|
||||
Returns:
|
||||
Displayable image with mode indicator
|
||||
"""
|
||||
if mask_raw is None:
|
||||
# Create placeholder
|
||||
placeholder = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH, 3), dtype=np.uint8)
|
||||
self._draw_mode_indicator(placeholder, "Raw Mask (No Data)")
|
||||
return placeholder
|
||||
|
||||
# Ensure single channel
|
||||
if len(mask_raw.shape) == 3:
|
||||
mask_gray = cast(ImageArray, cv2.cvtColor(mask_raw, cv2.COLOR_BGR2GRAY))
|
||||
else:
|
||||
mask_gray = cast(ImageArray, mask_raw)
|
||||
|
||||
mask_gray = self._normalize_mask_for_display(mask_gray)
|
||||
mask_gray = self._crop_mask_to_bbox(mask_gray, bbox)
|
||||
|
||||
debug_pad = RAW_STATS_PAD if self.show_raw_debug else 0
|
||||
content_h = max(1, DISPLAY_HEIGHT - debug_pad - MODE_LABEL_PAD)
|
||||
mask_resized = self._fit_gray_to_display(
|
||||
mask_gray, out_h=content_h, out_w=DISPLAY_WIDTH
|
||||
)
|
||||
full_mask = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH), dtype=np.uint8)
|
||||
full_mask[debug_pad : debug_pad + content_h, :] = mask_resized
|
||||
|
||||
# Convert to BGR for display
|
||||
mask_bgr = cast(ImageArray, cv2.cvtColor(full_mask, cv2.COLOR_GRAY2BGR))
|
||||
if self.show_raw_debug:
|
||||
self._draw_raw_stats(mask_bgr, mask_raw)
|
||||
self._draw_mode_indicator(mask_bgr, "Raw Mask")
|
||||
|
||||
return mask_bgr
|
||||
|
||||
def _prepare_normalized_view(
|
||||
self,
|
||||
silhouette: NDArray[np.float32] | None,
|
||||
) -> ImageArray:
|
||||
"""Prepare normalized silhouette view.
|
||||
|
||||
Args:
|
||||
silhouette: Normalized silhouette (64, 44) or None
|
||||
|
||||
Returns:
|
||||
Displayable image with mode indicator
|
||||
"""
|
||||
if silhouette is None:
|
||||
# Create placeholder
|
||||
placeholder = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH, 3), dtype=np.uint8)
|
||||
self._draw_mode_indicator(placeholder, "Normalized (No Data)")
|
||||
return placeholder
|
||||
|
||||
# Upscale and convert
|
||||
upscaled = self._upscale_silhouette(silhouette)
|
||||
content_h = max(1, DISPLAY_HEIGHT - MODE_LABEL_PAD)
|
||||
sil_compact = self._fit_gray_to_display(
|
||||
upscaled, out_h=content_h, out_w=DISPLAY_WIDTH
|
||||
)
|
||||
sil_canvas = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH), dtype=np.uint8)
|
||||
sil_canvas[:content_h, :] = sil_compact
|
||||
sil_bgr = cast(ImageArray, cv2.cvtColor(sil_canvas, cv2.COLOR_GRAY2BGR))
|
||||
self._draw_mode_indicator(sil_bgr, "Normalized")
|
||||
|
||||
return sil_bgr
|
||||
|
||||
def _draw_mode_indicator(self, image: ImageArray, label: str) -> None:
|
||||
h, w = image.shape[:2]
|
||||
|
||||
mode_text = label
|
||||
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
font_scale = 0.5
|
||||
thickness = 1
|
||||
|
||||
# Get text size for background
|
||||
(text_width, text_height), _ = cv2.getTextSize(
|
||||
mode_text, font, font_scale, thickness
|
||||
)
|
||||
|
||||
x_pos = 14
|
||||
y_pos = h - 8
|
||||
y_top = max(0, h - MODE_LABEL_PAD)
|
||||
|
||||
_ = cv2.rectangle(
|
||||
image,
|
||||
(0, y_top),
|
||||
(w, h),
|
||||
COLOR_DARK_GRAY,
|
||||
-1,
|
||||
)
|
||||
_ = cv2.rectangle(
|
||||
image,
|
||||
(x_pos - 6, y_pos - text_height - 6),
|
||||
(x_pos + text_width + 8, y_pos + 6),
|
||||
COLOR_DARK_GRAY,
|
||||
-1,
|
||||
)
|
||||
|
||||
# Draw text
|
||||
_ = cv2.putText(
|
||||
image,
|
||||
mode_text,
|
||||
(x_pos, y_pos),
|
||||
font,
|
||||
font_scale,
|
||||
COLOR_YELLOW,
|
||||
thickness,
|
||||
)
|
||||
|
||||
def update(
|
||||
self,
|
||||
frame: ImageArray,
|
||||
bbox: BBoxXYXY | None,
|
||||
bbox_mask: BBoxXYXY | None,
|
||||
track_id: int,
|
||||
mask_raw: ImageArray | None,
|
||||
silhouette: NDArray[np.float32] | None,
|
||||
segmentation_input: NDArray[np.float32] | None,
|
||||
label: str | None,
|
||||
confidence: float | None,
|
||||
fps: float,
|
||||
pose_data: dict[str, object] | None = None,
|
||||
) -> bool:
|
||||
"""Update visualization with new frame data.
|
||||
|
||||
Args:
|
||||
frame: Input frame (H, W, C) uint8
|
||||
bbox: Bounding box in XYXY format (x1, y1, x2, y2) or None
|
||||
track_id: Tracking ID
|
||||
mask_raw: Raw binary mask (H, W) uint8 or None
|
||||
silhouette: Normalized silhouette (64, 44) float32 [0,1] or None
|
||||
label: Classification label or None
|
||||
confidence: Classification confidence [0,1] or None
|
||||
fps: Current FPS
|
||||
pose_data: Pose data dictionary or None
|
||||
|
||||
Returns:
|
||||
False if user requested quit (pressed 'q'), True otherwise
|
||||
"""
|
||||
self._ensure_windows()
|
||||
|
||||
# Prepare and show main window
|
||||
main_display = self._prepare_main_frame(
|
||||
frame, bbox, track_id, fps, label, confidence, pose_data
|
||||
)
|
||||
cv2.imshow(MAIN_WINDOW, main_display)
|
||||
|
||||
# Prepare and show segmentation window
|
||||
seg_display = self._prepare_segmentation_view(mask_raw, silhouette, bbox)
|
||||
cv2.imshow(SEG_WINDOW, seg_display)
|
||||
|
||||
if self.show_raw_window:
|
||||
self._ensure_raw_window()
|
||||
raw_display = self._prepare_raw_view(mask_raw, bbox_mask)
|
||||
cv2.imshow(RAW_WINDOW, raw_display)
|
||||
|
||||
seg_input_display = self._prepare_segmentation_input_view(segmentation_input)
|
||||
cv2.imshow(WINDOW_SEG_INPUT, seg_input_display)
|
||||
|
||||
# Handle keyboard input
|
||||
key = cv2.waitKey(1) & 0xFF
|
||||
|
||||
if key == ord("q"):
|
||||
return False
|
||||
elif key == ord("r"):
|
||||
self.show_raw_window = not self.show_raw_window
|
||||
if self.show_raw_window:
|
||||
self._ensure_raw_window()
|
||||
logger.debug("Raw mask window enabled")
|
||||
else:
|
||||
self._hide_raw_window()
|
||||
logger.debug("Raw mask window disabled")
|
||||
elif key == ord("d"):
|
||||
self.show_raw_debug = not self.show_raw_debug
|
||||
logger.debug(
|
||||
"Raw mask debug overlay %s",
|
||||
"enabled" if self.show_raw_debug else "disabled",
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def close(self) -> None:
|
||||
if self._windows_created:
|
||||
self._hide_raw_window()
|
||||
cv2.destroyAllWindows()
|
||||
self._windows_created = False
|
||||
self._raw_window_created = False
|
||||
@@ -0,0 +1,375 @@
|
||||
"""Sliding window / ring buffer manager for real-time gait analysis.
|
||||
|
||||
This module provides bounded buffer management for silhouette sequences
|
||||
with track ID tracking and gap detection.
|
||||
"""
|
||||
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING, Protocol, cast, final
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from jaxtyping import Float
|
||||
from numpy import ndarray
|
||||
|
||||
from .preprocess import BBoxXYXY
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from numpy.typing import NDArray
|
||||
|
||||
# Silhouette dimensions from preprocess.py
|
||||
SIL_HEIGHT: int = 64
|
||||
SIL_WIDTH: int = 44
|
||||
|
||||
# Type alias for array-like inputs
|
||||
type _ArrayLike = torch.Tensor | ndarray
|
||||
|
||||
|
||||
class _Boxes(Protocol):
|
||||
"""Protocol for boxes with xyxy and id attributes."""
|
||||
|
||||
@property
|
||||
def xyxy(self) -> "NDArray[np.float32] | object": ...
|
||||
@property
|
||||
def id(self) -> "NDArray[np.int64] | object | None": ...
|
||||
|
||||
|
||||
class _Masks(Protocol):
|
||||
"""Protocol for masks with data attribute."""
|
||||
|
||||
@property
|
||||
def data(self) -> "NDArray[np.float32] | object": ...
|
||||
|
||||
|
||||
class _DetectionResults(Protocol):
|
||||
"""Protocol for detection results from Ultralytics-style objects."""
|
||||
|
||||
@property
|
||||
def boxes(self) -> _Boxes: ...
|
||||
@property
|
||||
def masks(self) -> _Masks: ...
|
||||
|
||||
|
||||
@final
|
||||
class SilhouetteWindow:
|
||||
"""Bounded sliding window for silhouette sequences.
|
||||
|
||||
Manages a fixed-size buffer of silhouettes with track ID tracking
|
||||
and automatic reset on track changes or frame gaps.
|
||||
|
||||
Attributes:
|
||||
window_size: Maximum number of frames in the buffer.
|
||||
stride: Classification stride (frames between classifications).
|
||||
gap_threshold: Maximum allowed frame gap before reset.
|
||||
"""
|
||||
|
||||
window_size: int
|
||||
stride: int
|
||||
gap_threshold: int
|
||||
_buffer: deque[Float[ndarray, "64 44"]]
|
||||
_frame_indices: deque[int]
|
||||
_track_id: int | None
|
||||
_last_classified_frame: int
|
||||
_frame_count: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
window_size: int = 30,
|
||||
stride: int = 1,
|
||||
gap_threshold: int = 15,
|
||||
) -> None:
|
||||
"""Initialize the silhouette window.
|
||||
|
||||
Args:
|
||||
window_size: Maximum buffer size (default 30).
|
||||
stride: Frames between classifications (default 1).
|
||||
gap_threshold: Max frame gap before reset (default 15).
|
||||
"""
|
||||
self.window_size = window_size
|
||||
self.stride = stride
|
||||
self.gap_threshold = gap_threshold
|
||||
|
||||
# Bounded storage via deque
|
||||
self._buffer = deque(maxlen=window_size)
|
||||
self._frame_indices = deque(maxlen=window_size)
|
||||
self._track_id = None
|
||||
self._last_classified_frame = -1
|
||||
self._frame_count = 0
|
||||
|
||||
def push(self, sil: np.ndarray, frame_idx: int, track_id: int) -> None:
|
||||
"""Push a new silhouette into the window.
|
||||
|
||||
Automatically resets buffer on track ID change or frame gap
|
||||
exceeding gap_threshold.
|
||||
|
||||
Args:
|
||||
sil: Silhouette array of shape (64, 44), float32.
|
||||
frame_idx: Current frame index for gap detection.
|
||||
track_id: Track ID for the person.
|
||||
"""
|
||||
# Check for track ID change
|
||||
if self._track_id is not None and track_id != self._track_id:
|
||||
self.reset()
|
||||
|
||||
# Check for frame gap
|
||||
if self._frame_indices:
|
||||
last_frame = self._frame_indices[-1]
|
||||
gap = frame_idx - last_frame
|
||||
if gap > self.gap_threshold:
|
||||
self.reset()
|
||||
|
||||
# Update track ID
|
||||
self._track_id = track_id
|
||||
|
||||
# Validate and append silhouette
|
||||
sil_array = np.asarray(sil, dtype=np.float32)
|
||||
if sil_array.shape != (SIL_HEIGHT, SIL_WIDTH):
|
||||
raise ValueError(
|
||||
f"Expected silhouette shape ({SIL_HEIGHT}, {SIL_WIDTH}), got {sil_array.shape}"
|
||||
)
|
||||
|
||||
self._buffer.append(sil_array)
|
||||
self._frame_indices.append(frame_idx)
|
||||
self._frame_count += 1
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
"""Check if window has enough frames for classification.
|
||||
|
||||
Returns:
|
||||
True if buffer is full (window_size frames).
|
||||
"""
|
||||
return len(self._buffer) >= self.window_size
|
||||
|
||||
def should_classify(self) -> bool:
|
||||
"""Check if classification should run based on stride.
|
||||
|
||||
Returns:
|
||||
True if enough frames have passed since last classification.
|
||||
"""
|
||||
if not self.is_ready():
|
||||
return False
|
||||
|
||||
if self._last_classified_frame < 0:
|
||||
return True
|
||||
|
||||
current_frame = self._frame_indices[-1]
|
||||
frames_since = current_frame - self._last_classified_frame
|
||||
return frames_since >= self.stride
|
||||
|
||||
def get_tensor(self, device: str = "cpu") -> torch.Tensor:
|
||||
"""Get window contents as a tensor for model input.
|
||||
|
||||
Args:
|
||||
device: Target device for the tensor (default 'cpu').
|
||||
|
||||
Returns:
|
||||
Tensor of shape [1, 1, window_size, 64, 44] with dtype float32.
|
||||
|
||||
Raises:
|
||||
ValueError: If buffer is not full.
|
||||
"""
|
||||
if not self.is_ready():
|
||||
raise ValueError(
|
||||
f"Window not ready: {len(self._buffer)}/{self.window_size} frames"
|
||||
)
|
||||
|
||||
# Stack buffer into array [window_size, 64, 44]
|
||||
stacked = np.stack(list(self._buffer), axis=0)
|
||||
|
||||
# Add batch and channel dims: [1, 1, window_size, 64, 44]
|
||||
tensor = torch.from_numpy(stacked.astype(np.float32))
|
||||
tensor = tensor.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
return tensor.to(device)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the window, clearing all buffers and counters."""
|
||||
self._buffer.clear()
|
||||
self._frame_indices.clear()
|
||||
self._track_id = None
|
||||
self._last_classified_frame = -1
|
||||
self._frame_count = 0
|
||||
|
||||
def mark_classified(self) -> None:
|
||||
"""Mark current frame as classified, updating stride tracking."""
|
||||
if self._frame_indices:
|
||||
self._last_classified_frame = self._frame_indices[-1]
|
||||
|
||||
@property
|
||||
def current_track_id(self) -> int | None:
|
||||
"""Current track ID, or None if buffer is empty."""
|
||||
return self._track_id
|
||||
|
||||
@property
|
||||
def frame_count(self) -> int:
|
||||
"""Total frames pushed since last reset."""
|
||||
return self._frame_count
|
||||
|
||||
@property
|
||||
def fill_level(self) -> float:
|
||||
"""Fill ratio of the buffer (0.0 to 1.0)."""
|
||||
return len(self._buffer) / self.window_size
|
||||
|
||||
@property
|
||||
def window_start_frame(self) -> int:
|
||||
if not self._frame_indices:
|
||||
raise ValueError("Window is empty")
|
||||
return int(self._frame_indices[0])
|
||||
|
||||
@property
|
||||
def buffered_silhouettes(self) -> Float[ndarray, "n 64 44"]:
|
||||
if not self._buffer:
|
||||
return np.empty((0, SIL_HEIGHT, SIL_WIDTH), dtype=np.float32)
|
||||
return cast(
|
||||
Float[ndarray, "n 64 44"],
|
||||
np.stack(list(self._buffer), axis=0).astype(np.float32, copy=True),
|
||||
)
|
||||
|
||||
|
||||
def _to_numpy(obj: _ArrayLike) -> ndarray:
|
||||
"""Safely convert array-like object to numpy array.
|
||||
|
||||
Handles torch tensors (CPU or CUDA) by detaching and moving to CPU first.
|
||||
Falls back to np.asarray for other array-like objects.
|
||||
|
||||
Args:
|
||||
obj: Array-like object (numpy array, torch tensor, or similar).
|
||||
|
||||
Returns:
|
||||
Numpy array representation of the input.
|
||||
"""
|
||||
# Handle torch tensors (including CUDA tensors)
|
||||
detach_fn = getattr(obj, "detach", None)
|
||||
if detach_fn is not None and callable(detach_fn):
|
||||
# It's a torch tensor
|
||||
tensor = detach_fn()
|
||||
cpu_fn = getattr(tensor, "cpu", None)
|
||||
if cpu_fn is not None and callable(cpu_fn):
|
||||
tensor = cpu_fn()
|
||||
numpy_fn = getattr(tensor, "numpy", None)
|
||||
if numpy_fn is not None and callable(numpy_fn):
|
||||
return cast(ndarray, numpy_fn())
|
||||
# Fall back to np.asarray for other array-like objects
|
||||
return cast(ndarray, np.asarray(obj))
|
||||
|
||||
|
||||
def select_person(
|
||||
results: _DetectionResults,
|
||||
) -> tuple[ndarray, BBoxXYXY, BBoxXYXY, int] | None:
|
||||
"""Select the person with largest bounding box from detection results.
|
||||
|
||||
Args:
|
||||
results: Detection results object with boxes and masks attributes.
|
||||
Expected to have:
|
||||
- boxes.xyxy: array of bounding boxes [N, 4] in frame coordinates (XYXY format)
|
||||
- masks.data: array of masks [N, H, W] in mask coordinates
|
||||
- boxes.id: optional track IDs [N]
|
||||
|
||||
Returns:
|
||||
Tuple of (mask, bbox_mask, bbox_frame, track_id) for the largest person,
|
||||
or None if no valid detections or track IDs unavailable.
|
||||
- mask: the person's segmentation mask
|
||||
- bbox_mask: bounding box in mask coordinate space (XYXY format: x1, y1, x2, y2)
|
||||
- bbox_frame: bounding box in frame coordinate space (XYXY format: x1, y1, x2, y2)
|
||||
- track_id: the person's track ID
|
||||
"""
|
||||
# Check for track IDs
|
||||
boxes_obj: _Boxes | object = getattr(results, "boxes", None)
|
||||
if boxes_obj is None:
|
||||
return None
|
||||
|
||||
track_ids_obj: ndarray | object | None = getattr(boxes_obj, "id", None)
|
||||
if track_ids_obj is None:
|
||||
return None
|
||||
|
||||
track_ids: ndarray = _to_numpy(cast(ndarray, track_ids_obj))
|
||||
if track_ids.size == 0:
|
||||
return None
|
||||
|
||||
# Get bounding boxes
|
||||
xyxy_obj: ndarray | object = getattr(boxes_obj, "xyxy", None)
|
||||
if xyxy_obj is None:
|
||||
return None
|
||||
|
||||
bboxes: ndarray = _to_numpy(cast(ndarray, xyxy_obj))
|
||||
if bboxes.ndim == 1:
|
||||
bboxes = bboxes.reshape(1, -1)
|
||||
|
||||
if bboxes.shape[0] == 0:
|
||||
return None
|
||||
|
||||
# Get masks
|
||||
masks_obj: _Masks | object = getattr(results, "masks", None)
|
||||
if masks_obj is None:
|
||||
return None
|
||||
|
||||
masks_data: ndarray | object = getattr(masks_obj, "data", None)
|
||||
if masks_data is None:
|
||||
return None
|
||||
|
||||
masks: ndarray = _to_numpy(cast(ndarray, masks_data))
|
||||
if masks.ndim == 2:
|
||||
masks = masks[np.newaxis, ...]
|
||||
|
||||
if masks.shape[0] != bboxes.shape[0]:
|
||||
return None
|
||||
|
||||
# Find largest bbox by area
|
||||
best_idx: int = -1
|
||||
best_area: float = -1.0
|
||||
|
||||
for i in range(int(bboxes.shape[0])):
|
||||
row: "NDArray[np.float32]" = bboxes[i][:4]
|
||||
x1f: float = float(row[0])
|
||||
y1f: float = float(row[1])
|
||||
x2f: float = float(row[2])
|
||||
y2f: float = float(row[3])
|
||||
area: float = (x2f - x1f) * (y2f - y1f)
|
||||
if area > best_area:
|
||||
best_area = area
|
||||
best_idx = i
|
||||
|
||||
if best_idx < 0:
|
||||
return None
|
||||
|
||||
# Extract mask and bbox
|
||||
mask: "NDArray[np.float32]" = masks[best_idx]
|
||||
mask_shape = mask.shape
|
||||
mask_h, mask_w = int(mask_shape[0]), int(mask_shape[1])
|
||||
|
||||
# Get original image dimensions from results (YOLO provides this)
|
||||
orig_shape = getattr(results, "orig_shape", None)
|
||||
# Validate orig_shape is a sequence of at least 2 numeric values
|
||||
if (
|
||||
orig_shape is not None
|
||||
and isinstance(orig_shape, (tuple, list))
|
||||
and len(orig_shape) >= 2
|
||||
):
|
||||
frame_h, frame_w = int(orig_shape[0]), int(orig_shape[1])
|
||||
# Scale bbox from frame space to mask space
|
||||
scale_x = mask_w / frame_w if frame_w > 0 else 1.0
|
||||
scale_y = mask_h / frame_h if frame_h > 0 else 1.0
|
||||
bbox_mask = (
|
||||
int(float(bboxes[best_idx][0]) * scale_x),
|
||||
int(float(bboxes[best_idx][1]) * scale_y),
|
||||
int(float(bboxes[best_idx][2]) * scale_x),
|
||||
int(float(bboxes[best_idx][3]) * scale_y),
|
||||
)
|
||||
bbox_frame = (
|
||||
int(float(bboxes[best_idx][0])),
|
||||
int(float(bboxes[best_idx][1])),
|
||||
int(float(bboxes[best_idx][2])),
|
||||
int(float(bboxes[best_idx][3])),
|
||||
)
|
||||
else:
|
||||
# Fallback: use bbox as-is for both (assume same coordinate space)
|
||||
bbox_mask = (
|
||||
int(float(bboxes[best_idx][0])),
|
||||
int(float(bboxes[best_idx][1])),
|
||||
int(float(bboxes[best_idx][2])),
|
||||
int(float(bboxes[best_idx][3])),
|
||||
)
|
||||
bbox_frame = bbox_mask
|
||||
track_id = int(track_ids[best_idx]) if best_idx < len(track_ids) else best_idx
|
||||
|
||||
return mask, bbox_mask, bbox_frame, track_id
|
||||
Reference in New Issue
Block a user