diff --git a/opengait/demo/__init__.py b/opengait/demo/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/opengait/demo/__main__.py b/opengait/demo/__main__.py new file mode 100644 index 0000000..2590b2a --- /dev/null +++ b/opengait/demo/__main__.py @@ -0,0 +1,7 @@ +from __future__ import annotations + +from .pipeline import main + + +if __name__ == "__main__": + main() diff --git a/opengait/demo/input.py b/opengait/demo/input.py new file mode 100644 index 0000000..a96f540 --- /dev/null +++ b/opengait/demo/input.py @@ -0,0 +1,203 @@ +""" +Input adapters for OpenGait demo. + +Provides generator-based interfaces for video sources: +- OpenCV (video files, cameras) +- cv-mmap (shared memory streams) +""" + +from collections.abc import AsyncIterator, Generator, Iterable +from typing import TYPE_CHECKING, Protocol, cast +import logging + +import numpy as np + +logger = logging.getLogger(__name__) + +# Type alias for frame stream: (frame_array, metadata_dict) +FrameStream = Iterable[tuple[np.ndarray, dict[str, object]]] + +if TYPE_CHECKING: + # Protocol for cv-mmap metadata to avoid direct import + class _FrameMetadata(Protocol): + frame_count: int + timestamp_ns: int + + # Protocol for cv-mmap client + class _CvMmapClient(Protocol): + def __aiter__(self) -> AsyncIterator[tuple[np.ndarray, _FrameMetadata]]: ... + + +def opencv_source( + path: str | int, max_frames: int | None = None +) -> Generator[tuple[np.ndarray, dict[str, object]], None, None]: + """ + Generator that yields frames from an OpenCV video source. + + Parameters + ---------- + path : str | int + Video file path or camera index (e.g., 0 for default camera) + max_frames : int | None, optional + Maximum number of frames to yield. None means unlimited. + + Yields + ------ + tuple[np.ndarray, dict[str, object]] + (frame_array, metadata_dict) where metadata includes: + - frame_count: frame index (0-based) + - timestamp_ns: monotonic timestamp in nanoseconds (if available) + - source: the path/int provided + """ + import time + + import cv2 + + cap = cv2.VideoCapture(path) + if not cap.isOpened(): + raise RuntimeError(f"Failed to open video source: {path}") + + frame_idx = 0 + try: + while max_frames is None or frame_idx < max_frames: + ret, frame = cap.read() + if not ret: + # End of stream + break + + # Get timestamp if available (some backends support this) + timestamp_ns = time.monotonic_ns() + + metadata: dict[str, object] = { + "frame_count": frame_idx, + "timestamp_ns": timestamp_ns, + "source": path, + } + + yield frame, metadata + frame_idx += 1 + + finally: + cap.release() + logger.debug(f"OpenCV source closed: {path}") + + +def cvmmap_source( + name: str, max_frames: int | None = None +) -> Generator[tuple[np.ndarray, dict[str, object]], None, None]: + """ + Generator that yields frames from a cv-mmap shared memory stream. + + Bridges async cv-mmap client to synchronous generator using asyncio.run(). + + Parameters + ---------- + name : str + Base name of the cv-mmap source (e.g., "default") + max_frames : int | None, optional + Maximum number of frames to yield. None means unlimited. + + Yields + ------ + tuple[np.ndarray, dict[str, object]] + (frame_array, metadata_dict) where metadata includes: + - frame_count: frame index from cv-mmap + - timestamp_ns: timestamp in nanoseconds from cv-mmap + - source: the cv-mmap name + + Raises + ------ + ImportError + If cvmmap package is not available + RuntimeError + If cv-mmap stream disconnects or errors + """ + import asyncio + + # Import cvmmap only when function is called + # Use try/except for runtime import check + try: + + from cvmmap import CvMmapClient as _CvMmapClientReal # pyright: ignore[reportMissingTypeStubs] + except ImportError as e: + raise ImportError( + "cvmmap package is required for cv-mmap sources. " + + "Install from: https://github.com/crosstyan/cv-mmap" + ) from e + + # Cast to protocol type for type checking + client: _CvMmapClient = cast("_CvMmapClient", _CvMmapClientReal(name)) + frame_count = 0 + + async def _async_generator() -> AsyncIterator[tuple[np.ndarray, _FrameMetadata]]: + """Async generator wrapper.""" + async for frame, meta in client: + yield frame, meta + + # Bridge async to sync using asyncio.run() + # We process frames one at a time to keep it simple and robust + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + agen = _async_generator().__aiter__() + + while max_frames is None or frame_count < max_frames: + try: + frame, meta = loop.run_until_complete(agen.__anext__()) + except StopAsyncIteration: + break + + metadata: dict[str, object] = { + "frame_count": meta.frame_count, + "timestamp_ns": meta.timestamp_ns, + "source": f"cvmmap://{name}", + } + + yield frame, metadata + frame_count += 1 + + finally: + loop.close() + logger.debug(f"cv-mmap source closed: {name}") + + +def create_source(source: str, max_frames: int | None = None) -> FrameStream: + """ + Factory function to create a frame source from a string specification. + + Parameters + ---------- + source : str + Source specification: + - '0', '1', etc. -> Camera index (OpenCV) + - 'cvmmap://name' -> cv-mmap shared memory stream + - Any other string -> Video file path (OpenCV) + max_frames : int | None, optional + Maximum number of frames to yield. None means unlimited. + + Returns + ------- + FrameStream + Generator yielding (frame, metadata) tuples + + Examples + -------- + >>> for frame, meta in create_source('0'): # Camera 0 + ... process(frame) + >>> for frame, meta in create_source('cvmmap://default'): # cv-mmap + ... process(frame) + >>> for frame, meta in create_source('/path/to/video.mp4'): + ... process(frame) + """ + # Check for cv-mmap protocol + if source.startswith("cvmmap://"): + name = source[len("cvmmap://") :] + return cvmmap_source(name, max_frames) + + # Check for camera index (single digit string) + if source.isdigit(): + return opencv_source(int(source), max_frames) + + # Otherwise treat as file path + return opencv_source(source, max_frames) diff --git a/opengait/demo/output.py b/opengait/demo/output.py new file mode 100644 index 0000000..5c9180e --- /dev/null +++ b/opengait/demo/output.py @@ -0,0 +1,368 @@ +""" +Output publishers for OpenGait demo results. + +Provides pluggable result publishing: +- ConsolePublisher: JSONL to stdout +- NatsPublisher: NATS message broker integration +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import sys +import threading +import time +from typing import TYPE_CHECKING, Protocol, TextIO, cast, runtime_checkable + +if TYPE_CHECKING: + from types import TracebackType + +logger = logging.getLogger(__name__) + + +@runtime_checkable +class ResultPublisher(Protocol): + """Protocol for result publishers.""" + + def publish(self, result: dict[str, object]) -> None: + """ + Publish a result dictionary. + + Parameters + ---------- + result : dict[str, object] + Result data with keys: frame, track_id, label, confidence, window, timestamp_ns + """ + ... + + +class ConsolePublisher: + """Publisher that outputs JSON Lines to stdout.""" + + _output: TextIO + + def __init__(self, output: TextIO = sys.stdout) -> None: + """ + Initialize console publisher. + + Parameters + ---------- + output : TextIO + File-like object to write to (default: sys.stdout) + """ + self._output = output + + def publish(self, result: dict[str, object]) -> None: + """ + Publish result as JSON line. + + Parameters + ---------- + result : dict[str, object] + Result data with keys: frame, track_id, label, confidence, window, timestamp_ns + """ + try: + json_line = json.dumps(result, ensure_ascii=False, default=str) + _ = self._output.write(json_line + "\n") + self._output.flush() + except Exception as e: + logger.warning(f"Failed to publish to console: {e}") + + def close(self) -> None: + """Close the publisher (no-op for console).""" + pass + + def __enter__(self) -> ConsolePublisher: + """Context manager entry.""" + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Context manager exit.""" + self.close() + + +class _NatsClient(Protocol): + """Protocol for connected NATS client.""" + + async def publish(self, subject: str, payload: bytes) -> object: ... + + async def close(self) -> object: ... + + async def flush(self) -> object: ... + + +class NatsPublisher: + """ + Publisher that sends results to NATS message broker. + + This is a sync-friendly wrapper around the async nats-py client. + Uses a background thread with dedicated event loop to bridge sync + publish calls to async NATS operations, making it safe to use in + both sync and async contexts. + """ + + _nats_url: str + _subject: str + _nc: _NatsClient | None + _connected: bool + _loop: asyncio.AbstractEventLoop | None + _thread: threading.Thread | None + _lock: threading.Lock + + def __init__(self, nats_url: str, subject: str = "scoliosis.result") -> None: + """ + Initialize NATS publisher. + + Parameters + ---------- + nats_url : str + NATS server URL (e.g., "nats://localhost:4222") + subject : str + NATS subject to publish to (default: "scoliosis.result") + """ + self._nats_url = nats_url + self._subject = subject + self._nc = None + self._connected = False + self._loop = None + self._thread = None + self._lock = threading.Lock() + + def _start_background_loop(self) -> bool: + """ + Start background thread with event loop for async operations. + + Returns + ------- + bool + True if loop is running, False otherwise + """ + with self._lock: + if self._loop is not None and self._loop.is_running(): + return True + + try: + loop = asyncio.new_event_loop() + self._loop = loop + + def run_loop() -> None: + asyncio.set_event_loop(loop) + loop.run_forever() + + self._thread = threading.Thread(target=run_loop, daemon=True) + self._thread.start() + return True + except Exception as e: + logger.warning(f"Failed to start background event loop: {e}") + return False + + def _stop_background_loop(self) -> None: + """Stop the background event loop and thread.""" + with self._lock: + if self._loop is not None and self._loop.is_running(): + _ = self._loop.call_soon_threadsafe(self._loop.stop) + if self._thread is not None and self._thread.is_alive(): + self._thread.join(timeout=2.0) + self._loop = None + self._thread = None + + def _ensure_connected(self) -> bool: + """ + Ensure connection to NATS server. + + Returns + ------- + bool + True if connected, False otherwise + """ + with self._lock: + if self._connected and self._nc is not None: + return True + + if not self._start_background_loop(): + return False + + try: + import nats + + async def _connect() -> _NatsClient: + nc = await nats.connect(self._nats_url) # pyright: ignore[reportUnknownMemberType] + return cast(_NatsClient, nc) + + # Run connection in background loop + future = asyncio.run_coroutine_threadsafe( + _connect(), + self._loop, # pyright: ignore[reportArgumentType] + ) + self._nc = future.result(timeout=10.0) + self._connected = True + logger.info(f"Connected to NATS at {self._nats_url}") + return True + except ImportError: + logger.warning( + "nats-py package not installed. Install with: pip install nats-py" + ) + return False + except Exception as e: + logger.warning(f"Failed to connect to NATS at {self._nats_url}: {e}") + return False + + def publish(self, result: dict[str, object]) -> None: + """ + Publish result to NATS subject. + + Parameters + ---------- + result : dict[str, object] + Result data with keys: frame, track_id, label, confidence, window, timestamp_ns + """ + if not self._ensure_connected(): + # Graceful degradation: log warning but don't crash + logger.debug( + f"NATS unavailable, dropping result: {result.get('track_id', 'unknown')}" + ) + return + + try: + + async def _publish() -> None: + if self._nc is not None: + payload = json.dumps( + result, ensure_ascii=False, default=str + ).encode("utf-8") + _ = await self._nc.publish(self._subject, payload) + _ = await self._nc.flush() + # Run publish in background loop + future = asyncio.run_coroutine_threadsafe( + _publish(), + self._loop, # pyright: ignore[reportArgumentType] + ) + future.result(timeout=5.0) # Wait for publish to complete + except Exception as e: + logger.warning(f"Failed to publish to NATS: {e}") + self._connected = False # Mark for reconnection on next publish + + def close(self) -> None: + """Close NATS connection.""" + with self._lock: + if self._nc is not None and self._connected and self._loop is not None: + try: + + async def _close() -> None: + if self._nc is not None: + _ = await self._nc.close() + + future = asyncio.run_coroutine_threadsafe( + _close(), + self._loop, + ) + future.result(timeout=5.0) + except Exception as e: + logger.debug(f"Error closing NATS connection: {e}") + finally: + self._nc = None + self._connected = False + + self._stop_background_loop() + + def __enter__(self) -> NatsPublisher: + """Context manager entry.""" + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Context manager exit.""" + self.close() + + +def create_publisher( + nats_url: str | None, + subject: str = "scoliosis.result", +) -> ResultPublisher: + """ + Factory function to create appropriate publisher. + + Parameters + ---------- + nats_url : str | None + NATS server URL. If None or empty, returns ConsolePublisher. + subject : str + NATS subject to publish to (default: "scoliosis.result") + + Returns + ------- + ResultPublisher + NatsPublisher if nats_url provided, otherwise ConsolePublisher + + Examples + -------- + >>> # Console output (default) + >>> pub = create_publisher(None) + >>> pub.publish({"frame": 1, "track_id": 42, "label": "normal", "confidence": 0.95, "window": 30, "timestamp_ns": 1234567890}) + >>> + >>> # NATS output + >>> pub = create_publisher("nats://localhost:4222") + >>> pub.publish({"frame": 1, "track_id": 42, "label": "normal", "confidence": 0.95, "window": 30, "timestamp_ns": 1234567890}) + >>> + >>> # Context manager usage + >>> with create_publisher("nats://localhost:4222") as pub: + ... pub.publish(result) + """ + if nats_url: + return NatsPublisher(nats_url, subject) + return ConsolePublisher() + + +def create_result( + frame: int, + track_id: int, + label: str, + confidence: float, + window: int | tuple[int, int], + timestamp_ns: int | None = None, +) -> dict[str, object]: + """ + Create a standardized result dictionary. + + Parameters + ---------- + frame : int + Frame number + track_id : int + Track/person identifier + label : str + Classification label (e.g., "normal", "scoliosis") + confidence : float + Confidence score (0.0 to 1.0) + window : int | tuple[int, int] + Frame window as int (end frame) or tuple [start, end] that produced this result + Frame window [start, end] that produced this result + timestamp_ns : int | None + Timestamp in nanoseconds. If None, uses current time. + + Returns + ------- + dict[str, object] + Standardized result dictionary + """ + return { + "frame": frame, + "track_id": track_id, + "label": label, + "confidence": confidence, + "window": window if isinstance(window, int) else window[1], + "timestamp_ns": timestamp_ns + if timestamp_ns is not None + else time.monotonic_ns(), + } diff --git a/opengait/demo/pipeline.py b/opengait/demo/pipeline.py new file mode 100644 index 0000000..28d5a61 --- /dev/null +++ b/opengait/demo/pipeline.py @@ -0,0 +1,325 @@ +from __future__ import annotations + +from collections.abc import Callable +from contextlib import suppress +import logging +from pathlib import Path +import time +from typing import Protocol, cast + +from beartype import beartype +import click +import jaxtyping +from jaxtyping import Float, UInt8 +import numpy as np +from numpy import ndarray +from numpy.typing import NDArray +from ultralytics.models.yolo.model import YOLO + +from .input import FrameStream, create_source +from .output import ResultPublisher, create_publisher, create_result +from .preprocess import frame_to_person_mask, mask_to_silhouette +from .sconet_demo import ScoNetDemo +from .window import SilhouetteWindow, select_person + +logger = logging.getLogger(__name__) + +JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]] +JaxtypedFactory = Callable[..., JaxtypedDecorator] +jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped) + + +class _BoxesLike(Protocol): + @property + def xyxy(self) -> NDArray[np.float32] | object: ... + + @property + def id(self) -> NDArray[np.int64] | object | None: ... + + +class _MasksLike(Protocol): + @property + def data(self) -> NDArray[np.float32] | object: ... + + +class _DetectionResultsLike(Protocol): + @property + def boxes(self) -> _BoxesLike: ... + + @property + def masks(self) -> _MasksLike: ... + + +class _TrackCallable(Protocol): + def __call__( + self, + source: object, + *, + persist: bool = True, + verbose: bool = False, + device: str | None = None, + classes: list[int] | None = None, + ) -> object: ... + + +class ScoliosisPipeline: + _detector: object + _source: FrameStream + _window: SilhouetteWindow + _publisher: ResultPublisher + _classifier: ScoNetDemo + _device: str + _closed: bool + + def __init__( + self, + *, + source: str, + checkpoint: str, + config: str, + device: str, + yolo_model: str, + window: int, + stride: int, + nats_url: str | None, + nats_subject: str, + max_frames: int | None, + ) -> None: + self._detector = YOLO(yolo_model) + self._source = create_source(source, max_frames=max_frames) + self._window = SilhouetteWindow(window_size=window, stride=stride) + self._publisher = create_publisher(nats_url=nats_url, subject=nats_subject) + self._classifier = ScoNetDemo( + cfg_path=config, + checkpoint_path=checkpoint, + device=device, + ) + self._device = device + self._closed = False + + @staticmethod + def _extract_int(meta: dict[str, object], key: str, fallback: int) -> int: + value = meta.get(key) + if isinstance(value, int): + return value + return fallback + + @staticmethod + def _extract_timestamp(meta: dict[str, object]) -> int: + value = meta.get("timestamp_ns") + if isinstance(value, int): + return value + return time.monotonic_ns() + + @staticmethod + def _to_mask_u8(mask: ndarray) -> UInt8[ndarray, "h w"]: + binary = np.where(np.asarray(mask) > 0.5, np.uint8(255), np.uint8(0)).astype( + np.uint8 + ) + return cast(UInt8[ndarray, "h w"], binary) + + def _first_result(self, detections: object) -> _DetectionResultsLike | None: + if isinstance(detections, list): + return cast(_DetectionResultsLike, detections[0]) if detections else None + if isinstance(detections, tuple): + return cast(_DetectionResultsLike, detections[0]) if detections else None + return cast(_DetectionResultsLike, detections) + + def _select_silhouette( + self, + result: _DetectionResultsLike, + ) -> tuple[Float[ndarray, "64 44"], int] | None: + selected = select_person(result) + if selected is not None: + mask_raw, bbox, track_id = selected + silhouette = cast( + Float[ndarray, "64 44"] | None, + mask_to_silhouette(self._to_mask_u8(mask_raw), bbox), + ) + if silhouette is not None: + return silhouette, int(track_id) + + fallback = cast( + tuple[UInt8[ndarray, "h w"], tuple[int, int, int, int]] | None, + frame_to_person_mask(result), + ) + if fallback is None: + return None + + mask_u8, bbox = fallback + silhouette = cast( + Float[ndarray, "64 44"] | None, + mask_to_silhouette(mask_u8, bbox), + ) + if silhouette is None: + return None + return silhouette, 0 + + @jaxtyped(typechecker=beartype) + def process_frame( + self, + frame: UInt8[ndarray, "h w c"], + metadata: dict[str, object], + ) -> dict[str, object] | None: + frame_idx = self._extract_int(metadata, "frame_count", fallback=0) + timestamp_ns = self._extract_timestamp(metadata) + + track_fn_obj = getattr(self._detector, "track", None) + if not callable(track_fn_obj): + raise RuntimeError("YOLO detector does not expose a callable track()") + + track_fn = cast(_TrackCallable, track_fn_obj) + detections = track_fn( + frame, + persist=True, + verbose=False, + device=self._device, + classes=[0], + ) + first = self._first_result(detections) + if first is None: + return None + + selected = self._select_silhouette(first) + if selected is None: + return None + + silhouette, track_id = selected + self._window.push(silhouette, frame_idx=frame_idx, track_id=track_id) + + if not self._window.should_classify(): + return None + + window_tensor = self._window.get_tensor(device=self._device) + label, confidence = cast( + tuple[str, float], + self._classifier.predict(window_tensor), + ) + self._window.mark_classified() + + window_start = frame_idx - self._window.window_size + 1 + result = create_result( + frame=frame_idx, + track_id=track_id, + label=label, + confidence=float(confidence), + window=(max(0, window_start), frame_idx), + timestamp_ns=timestamp_ns, + ) + self._publisher.publish(result) + return result + + def run(self) -> int: + frame_count = 0 + start_time = time.perf_counter() + try: + for item in self._source: + frame, metadata = item + frame_u8 = np.asarray(frame, dtype=np.uint8) + frame_idx = self._extract_int(metadata, "frame_count", fallback=0) + frame_count += 1 + try: + _ = self.process_frame(frame_u8, metadata) + except Exception as frame_error: + logger.warning( + "Skipping frame %d due to processing error: %s", + frame_idx, + frame_error, + ) + if frame_count % 100 == 0: + elapsed = time.perf_counter() - start_time + fps = frame_count / elapsed if elapsed > 0 else 0.0 + logger.info("Processed %d frames (%.2f FPS)", frame_count, fps) + return 0 + except KeyboardInterrupt: + logger.info("Interrupted by user, shutting down cleanly.") + return 130 + finally: + self.close() + + def close(self) -> None: + if self._closed: + return + close_fn = getattr(self._publisher, "close", None) + if callable(close_fn): + with suppress(Exception): + _ = close_fn() + self._closed = True + + +def validate_runtime_inputs(source: str, checkpoint: str, config: str) -> None: + if source.startswith("cvmmap://") or source.isdigit(): + pass + else: + source_path = Path(source) + if not source_path.is_file(): + raise ValueError(f"Video source not found: {source}") + + checkpoint_path = Path(checkpoint) + if not checkpoint_path.is_file(): + raise ValueError(f"Checkpoint not found: {checkpoint}") + + config_path = Path(config) + if not config_path.is_file(): + raise ValueError(f"Config not found: {config}") + + +@click.command(context_settings={"help_option_names": ["-h", "--help"]}) +@click.option("--source", type=str, required=True) +@click.option("--checkpoint", type=str, required=True) +@click.option( + "--config", + type=str, + default="configs/sconet/sconet_scoliosis1k.yaml", + show_default=True, +) +@click.option("--device", type=str, default="cuda:0", show_default=True) +@click.option("--yolo-model", type=str, default="yolo11n-seg.pt", show_default=True) +@click.option("--window", type=click.IntRange(min=1), default=30, show_default=True) +@click.option("--stride", type=click.IntRange(min=1), default=30, show_default=True) +@click.option("--nats-url", type=str, default=None) +@click.option( + "--nats-subject", + type=str, + default="scoliosis.result", + show_default=True, +) +@click.option("--max-frames", type=click.IntRange(min=1), default=None) +def main( + source: str, + checkpoint: str, + config: str, + device: str, + yolo_model: str, + window: int, + stride: int, + nats_url: str | None, + nats_subject: str, + max_frames: int | None, +) -> None: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + ) + + try: + validate_runtime_inputs(source=source, checkpoint=checkpoint, config=config) + pipeline = ScoliosisPipeline( + source=source, + checkpoint=checkpoint, + config=config, + device=device, + yolo_model=yolo_model, + window=window, + stride=stride, + nats_url=nats_url, + nats_subject=nats_subject, + max_frames=max_frames, + ) + raise SystemExit(pipeline.run()) + except ValueError as err: + click.echo(f"Error: {err}", err=True) + raise SystemExit(2) from err + except RuntimeError as err: + click.echo(f"Runtime error: {err}", err=True) + raise SystemExit(1) from err diff --git a/opengait/demo/preprocess.py b/opengait/demo/preprocess.py new file mode 100644 index 0000000..74b9df9 --- /dev/null +++ b/opengait/demo/preprocess.py @@ -0,0 +1,270 @@ +from collections.abc import Callable +import math +from typing import cast + +import cv2 +from beartype import beartype +import jaxtyping +from jaxtyping import Float, UInt8 +import numpy as np +from numpy import ndarray +from numpy.typing import NDArray + +SIL_HEIGHT = 64 +SIL_WIDTH = 44 +SIL_FULL_WIDTH = 64 +SIDE_CUT = 10 +MIN_MASK_AREA = 500 + +JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]] +JaxtypedFactory = Callable[..., JaxtypedDecorator] +jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped) + +UInt8Array = NDArray[np.uint8] +Float32Array = NDArray[np.float32] + + +def _read_attr(container: object, key: str) -> object | None: + if isinstance(container, dict): + dict_obj = cast(dict[object, object], container) + return dict_obj.get(key) + try: + return cast(object, object.__getattribute__(container, key)) + except AttributeError: + return None + + +def _to_numpy_array(value: object) -> NDArray[np.generic]: + current: object = value + if isinstance(current, np.ndarray): + return current + + detach_obj = _read_attr(current, "detach") + if callable(detach_obj): + detach_fn = cast(Callable[[], object], detach_obj) + current = detach_fn() + + cpu_obj = _read_attr(current, "cpu") + if callable(cpu_obj): + cpu_fn = cast(Callable[[], object], cpu_obj) + current = cpu_fn() + + numpy_obj = _read_attr(current, "numpy") + if callable(numpy_obj): + numpy_fn = cast(Callable[[], object], numpy_obj) + as_numpy = numpy_fn() + if isinstance(as_numpy, np.ndarray): + return as_numpy + + return cast(NDArray[np.generic], np.asarray(current)) + + +def _bbox_from_mask(mask: UInt8[ndarray, "h w"]) -> tuple[int, int, int, int] | None: + mask_u8 = np.asarray(mask, dtype=np.uint8) + coords = np.argwhere(mask_u8 > 0) + if int(coords.size) == 0: + return None + + ys = coords[:, 0].astype(np.int64) + xs = coords[:, 1].astype(np.int64) + x1 = int(np.min(xs)) + x2 = int(np.max(xs)) + 1 + y1 = int(np.min(ys)) + y2 = int(np.max(ys)) + 1 + if x2 <= x1 or y2 <= y1: + return None + return (x1, y1, x2, y2) + + +def _sanitize_bbox( + bbox: tuple[int, int, int, int], height: int, width: int +) -> tuple[int, int, int, int] | None: + x1, y1, x2, y2 = bbox + x1c = max(0, min(int(x1), width - 1)) + y1c = max(0, min(int(y1), height - 1)) + x2c = max(0, min(int(x2), width)) + y2c = max(0, min(int(y2), height)) + if x2c <= x1c or y2c <= y1c: + return None + return (x1c, y1c, x2c, y2c) + + +@jaxtyped(typechecker=beartype) +def frame_to_person_mask( + result: object, min_area: int = MIN_MASK_AREA +) -> tuple[UInt8[ndarray, "h w"], tuple[int, int, int, int]] | None: + masks_obj = _read_attr(result, "masks") + if masks_obj is None: + return None + + masks_data_obj = _read_attr(masks_obj, "data") + if masks_data_obj is None: + return None + + masks_raw = _to_numpy_array(masks_data_obj) + masks_float = np.asarray(masks_raw, dtype=np.float32) + if masks_float.ndim == 2: + masks_float = masks_float[np.newaxis, ...] + if masks_float.ndim != 3: + return None + mask_count = int(cast(tuple[int, int, int], masks_float.shape)[0]) + if mask_count <= 0: + return None + + box_values: list[tuple[float, float, float, float]] | None = None + boxes_obj = _read_attr(result, "boxes") + if boxes_obj is not None: + xyxy_obj = _read_attr(boxes_obj, "xyxy") + if xyxy_obj is not None: + xyxy_raw = np.asarray(_to_numpy_array(xyxy_obj), dtype=np.float32) + if xyxy_raw.ndim == 1 and int(xyxy_raw.size) >= 4: + xyxy_2d = np.asarray(xyxy_raw[:4].reshape(1, 4), dtype=np.float64) + x1f = cast(np.float64, xyxy_2d[0, 0]) + y1f = cast(np.float64, xyxy_2d[0, 1]) + x2f = cast(np.float64, xyxy_2d[0, 2]) + y2f = cast(np.float64, xyxy_2d[0, 3]) + box_values = [ + ( + float(x1f), + float(y1f), + float(x2f), + float(y2f), + ) + ] + elif xyxy_raw.ndim == 2: + shape_2d = cast(tuple[int, int], xyxy_raw.shape) + if int(shape_2d[1]) >= 4: + xyxy_2d = np.asarray(xyxy_raw[:, :4], dtype=np.float64) + box_values = [] + for row_idx in range(int(cast(tuple[int, int], xyxy_2d.shape)[0])): + x1f = cast(np.float64, xyxy_2d[row_idx, 0]) + y1f = cast(np.float64, xyxy_2d[row_idx, 1]) + x2f = cast(np.float64, xyxy_2d[row_idx, 2]) + y2f = cast(np.float64, xyxy_2d[row_idx, 3]) + box_values.append( + ( + float(x1f), + float(y1f), + float(x2f), + float(y2f), + ) + ) + + best_area = -1 + best_mask: UInt8[ndarray, "h w"] | None = None + best_bbox: tuple[int, int, int, int] | None = None + + for idx in range(mask_count): + mask_float = np.asarray(masks_float[idx], dtype=np.float32) + if mask_float.ndim != 2: + continue + mask_binary = np.where(mask_float > 0.5, np.uint8(255), np.uint8(0)).astype( + np.uint8 + ) + mask_u8 = cast(UInt8[ndarray, "h w"], mask_binary) + + area = int(np.count_nonzero(mask_u8)) + if area < min_area: + continue + + bbox: tuple[int, int, int, int] | None = None + shape_2d = cast(tuple[int, int], mask_binary.shape) + h = int(shape_2d[0]) + w = int(shape_2d[1]) + if box_values is not None: + box_count = len(box_values) + if idx >= box_count: + continue + row0, row1, row2, row3 = box_values[idx] + bbox_candidate = ( + int(math.floor(row0)), + int(math.floor(row1)), + int(math.ceil(row2)), + int(math.ceil(row3)), + ) + bbox = _sanitize_bbox(bbox_candidate, h, w) + + if bbox is None: + bbox = _bbox_from_mask(mask_u8) + + if bbox is None: + continue + + if area > best_area: + best_area = area + best_mask = mask_u8 + best_bbox = bbox + + if best_mask is None or best_bbox is None: + return None + + return best_mask, best_bbox + + +@jaxtyped(typechecker=beartype) +def mask_to_silhouette( + mask: UInt8[ndarray, "h w"], + bbox: tuple[int, int, int, int], +) -> Float[ndarray, "64 44"] | None: + mask_u8 = np.where(mask > 0, np.uint8(255), np.uint8(0)).astype(np.uint8) + if int(np.count_nonzero(mask_u8)) < MIN_MASK_AREA: + return None + + mask_shape = cast(tuple[int, int], mask_u8.shape) + h = int(mask_shape[0]) + w = int(mask_shape[1]) + bbox_sanitized = _sanitize_bbox(bbox, h, w) + if bbox_sanitized is None: + return None + + x1, y1, x2, y2 = bbox_sanitized + cropped = mask_u8[y1:y2, x1:x2] + if cropped.size == 0: + return None + + cropped_u8 = np.asarray(cropped, dtype=np.uint8) + row_sums = np.sum(cropped_u8, axis=1, dtype=np.int64) + row_nonzero = np.nonzero(row_sums > 0)[0].astype(np.int64) + if int(row_nonzero.size) == 0: + return None + top = int(cast(np.int64, row_nonzero[0])) + bottom = int(cast(np.int64, row_nonzero[-1])) + 1 + tightened = cropped[top:bottom, :] + if tightened.size == 0: + return None + + tight_shape = cast(tuple[int, int], tightened.shape) + tight_h = int(tight_shape[0]) + tight_w = int(tight_shape[1]) + if tight_h <= 0 or tight_w <= 0: + return None + + resized_w = max(1, int(SIL_HEIGHT * (tight_w / tight_h))) + resized = np.asarray( + cv2.resize(tightened, (resized_w, SIL_HEIGHT), interpolation=cv2.INTER_CUBIC), + dtype=np.uint8, + ) + + if resized_w >= SIL_FULL_WIDTH: + start = (resized_w - SIL_FULL_WIDTH) // 2 + normalized_64 = resized[:, start : start + SIL_FULL_WIDTH] + else: + pad_left = (SIL_FULL_WIDTH - resized_w) // 2 + pad_right = SIL_FULL_WIDTH - resized_w - pad_left + normalized_64 = np.pad( + resized, + ((0, 0), (pad_left, pad_right)), + mode="constant", + constant_values=0, + ) + + silhouette = np.asarray( + normalized_64[:, SIDE_CUT : SIL_FULL_WIDTH - SIDE_CUT], dtype=np.float32 + ) + if silhouette.shape != (SIL_HEIGHT, SIL_WIDTH): + return None + + silhouette_norm = np.clip(silhouette / np.float32(255.0), 0.0, 1.0).astype( + np.float32 + ) + return cast(Float[ndarray, "64 44"], silhouette_norm) diff --git a/opengait/demo/sconet_demo.py b/opengait/demo/sconet_demo.py new file mode 100644 index 0000000..377649c --- /dev/null +++ b/opengait/demo/sconet_demo.py @@ -0,0 +1,317 @@ +from __future__ import annotations + +from collections.abc import Callable +from pathlib import Path +import sys +from typing import ClassVar, Protocol, cast, override + +import torch +import torch.nn as nn +from beartype import beartype +from einops import rearrange +from jaxtyping import Float +import jaxtyping +from torch import Tensor + +_OPENGAIT_PACKAGE_ROOT = Path(__file__).resolve().parents[1] +if str(_OPENGAIT_PACKAGE_ROOT) not in sys.path: + sys.path.insert(0, str(_OPENGAIT_PACKAGE_ROOT)) + +from opengait.modeling.backbones.resnet import ResNet9 +from opengait.modeling.modules import ( + HorizontalPoolingPyramid, + PackSequenceWrapper as TemporalPool, + SeparateBNNecks, + SeparateFCs, +) +from opengait.utils import common as common_utils + + +JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]] +JaxtypedFactory = Callable[..., JaxtypedDecorator] +jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped) +ConfigLoader = Callable[[str], dict[str, object]] +config_loader = cast(ConfigLoader, common_utils.config_loader) + + +class TemporalPoolLike(Protocol): + def __call__( + self, + seqs: Tensor, + seqL: object, + dim: int = 2, + options: dict[str, int] | None = None, + ) -> object: ... + + +class HppLike(Protocol): + def __call__(self, x: Tensor) -> Tensor: ... + + +class FCsLike(Protocol): + def __call__(self, x: Tensor) -> Tensor: ... + + +class BNNecksLike(Protocol): + def __call__(self, x: Tensor) -> tuple[Tensor, Tensor]: ... + + +class ScoNetDemo(nn.Module): + LABEL_MAP: ClassVar[dict[int, str]] = {0: "negative", 1: "neutral", 2: "positive"} + cfg_path: str + cfg: dict[str, object] + backbone: ResNet9 + temporal_pool: TemporalPoolLike + hpp: HppLike + fcs: FCsLike + bn_necks: BNNecksLike + device: torch.device + + @jaxtyped(typechecker=beartype) + def __init__( + self, + cfg_path: str | Path = "configs/sconet/sconet_scoliosis1k.yaml", + checkpoint_path: str | Path | None = None, + device: str | torch.device | None = None, + ) -> None: + super().__init__() + resolved_cfg = self._resolve_path(cfg_path) + self.cfg_path = str(resolved_cfg) + + self.cfg = config_loader(self.cfg_path) + + model_cfg = self._extract_model_cfg(self.cfg) + backbone_cfg = self._extract_dict(model_cfg, "backbone_cfg") + + if backbone_cfg.get("type") != "ResNet9": + raise ValueError( + "ScoNetDemo currently supports backbone type ResNet9 only." + ) + + self.backbone = ResNet9( + block=self._extract_str(backbone_cfg, "block"), + channels=self._extract_int_list(backbone_cfg, "channels"), + in_channel=self._extract_int(backbone_cfg, "in_channel", default=1), + layers=self._extract_int_list(backbone_cfg, "layers"), + strides=self._extract_int_list(backbone_cfg, "strides"), + maxpool=self._extract_bool(backbone_cfg, "maxpool", default=True), + ) + + fcs_cfg = self._extract_dict(model_cfg, "SeparateFCs") + bn_cfg = self._extract_dict(model_cfg, "SeparateBNNecks") + bin_num = self._extract_int_list(model_cfg, "bin_num") + + self.temporal_pool = cast(TemporalPoolLike, TemporalPool(torch.max)) + self.hpp = cast(HppLike, HorizontalPoolingPyramid(bin_num=bin_num)) + self.fcs = cast( + FCsLike, + SeparateFCs( + parts_num=self._extract_int(fcs_cfg, "parts_num"), + in_channels=self._extract_int(fcs_cfg, "in_channels"), + out_channels=self._extract_int(fcs_cfg, "out_channels"), + norm=self._extract_bool(fcs_cfg, "norm", default=False), + ), + ) + self.bn_necks = cast( + BNNecksLike, + SeparateBNNecks( + parts_num=self._extract_int(bn_cfg, "parts_num"), + in_channels=self._extract_int(bn_cfg, "in_channels"), + class_num=self._extract_int(bn_cfg, "class_num"), + norm=self._extract_bool(bn_cfg, "norm", default=True), + parallel_BN1d=self._extract_bool(bn_cfg, "parallel_BN1d", default=True), + ), + ) + + self.device = ( + torch.device(device) if device is not None else torch.device("cpu") + ) + _ = self.to(self.device) + + if checkpoint_path is not None: + _ = self.load_checkpoint(checkpoint_path) + + _ = self.eval() + + @staticmethod + def _resolve_path(path: str | Path) -> Path: + candidate = Path(path) + if candidate.is_file(): + return candidate + if candidate.is_absolute(): + return candidate + repo_root = Path(__file__).resolve().parents[2] + return repo_root / candidate + + @staticmethod + def _extract_model_cfg(cfg: dict[str, object]) -> dict[str, object]: + model_cfg_obj = cfg.get("model_cfg") + if not isinstance(model_cfg_obj, dict): + raise TypeError("model_cfg must be a dictionary.") + return cast(dict[str, object], model_cfg_obj) + + @staticmethod + def _extract_dict(container: dict[str, object], key: str) -> dict[str, object]: + value = container.get(key) + if not isinstance(value, dict): + raise TypeError(f"{key} must be a dictionary.") + return cast(dict[str, object], value) + + @staticmethod + def _extract_str(container: dict[str, object], key: str) -> str: + value = container.get(key) + if not isinstance(value, str): + raise TypeError(f"{key} must be a string.") + return value + + @staticmethod + def _extract_int( + container: dict[str, object], key: str, default: int | None = None + ) -> int: + value = container.get(key, default) + if not isinstance(value, int): + raise TypeError(f"{key} must be an int.") + return value + + @staticmethod + def _extract_bool( + container: dict[str, object], key: str, default: bool | None = None + ) -> bool: + value = container.get(key, default) + if not isinstance(value, bool): + raise TypeError(f"{key} must be a bool.") + return value + + @staticmethod + def _extract_int_list(container: dict[str, object], key: str) -> list[int]: + value = container.get(key) + if not isinstance(value, list): + raise TypeError(f"{key} must be a list[int].") + values = cast(list[object], value) + if not all(isinstance(v, int) for v in values): + raise TypeError(f"{key} must be a list[int].") + return cast(list[int], values) + + @staticmethod + def _normalize_state_dict( + state_dict_obj: dict[object, object], + ) -> dict[str, Tensor]: + prefix_remap: tuple[tuple[str, str], ...] = ( + ("Backbone.forward_block.", "backbone."), + ("FCs.", "fcs."), + ("BNNecks.", "bn_necks."), + ) + cleaned_state_dict: dict[str, Tensor] = {} + for key_obj, value_obj in state_dict_obj.items(): + if not isinstance(key_obj, str): + raise TypeError("Checkpoint state_dict keys must be strings.") + if not isinstance(value_obj, Tensor): + raise TypeError("Checkpoint state_dict values must be torch.Tensor.") + key = key_obj[7:] if key_obj.startswith("module.") else key_obj + for source_prefix, target_prefix in prefix_remap: + if key.startswith(source_prefix): + key = f"{target_prefix}{key[len(source_prefix) :]}" + break + if key in cleaned_state_dict: + raise RuntimeError( + f"Checkpoint key normalization collision detected for key '{key}'." + ) + cleaned_state_dict[key] = value_obj + return cleaned_state_dict + + @jaxtyped(typechecker=beartype) + def load_checkpoint( + self, + checkpoint_path: str | Path, + map_location: str | torch.device | None = None, + strict: bool = True, + ) -> None: + resolved_ckpt = self._resolve_path(checkpoint_path) + checkpoint_obj = cast( + object, + torch.load( + str(resolved_ckpt), + map_location=map_location if map_location is not None else self.device, + ), + ) + + state_dict_obj: object = checkpoint_obj + if isinstance(checkpoint_obj, dict) and "model" in checkpoint_obj: + state_dict_obj = cast(dict[str, object], checkpoint_obj)["model"] + + if not isinstance(state_dict_obj, dict): + raise TypeError("Unsupported checkpoint format.") + + cleaned_state_dict = self._normalize_state_dict( + cast(dict[object, object], state_dict_obj) + ) + try: + _ = self.load_state_dict(cleaned_state_dict, strict=strict) + except RuntimeError as exc: + raise RuntimeError( + f"Failed to load ScoNetDemo checkpoint after key normalization from '{resolved_ckpt}'." + ) from exc + _ = self.eval() + + def _prepare_sils(self, sils: Tensor) -> Tensor: + if sils.ndim == 4: + sils = sils.unsqueeze(1) + elif sils.ndim == 5 and sils.shape[1] != 1 and sils.shape[2] == 1: + sils = rearrange(sils, "b s c h w -> b c s h w") + + if sils.ndim != 5 or sils.shape[1] != 1: + raise ValueError("Expected sils shape [B, 1, S, H, W] or [B, S, H, W].") + + return sils.float().to(self.device) + + def _forward_backbone(self, sils: Tensor) -> Tensor: + batch, channels, seq, height, width = sils.shape + framewise = sils.transpose(1, 2).reshape(batch * seq, channels, height, width) + frame_feats = cast(Tensor, self.backbone(framewise)) + _, out_channels, out_h, out_w = frame_feats.shape + return ( + frame_feats.reshape(batch, seq, out_channels, out_h, out_w) + .transpose(1, 2) + .contiguous() + ) + + @override + @jaxtyped(typechecker=beartype) + def forward(self, sils: Float[Tensor, "batch 1 seq 64 44"]) -> dict[str, Tensor]: + with torch.inference_mode(): + prepared_sils = self._prepare_sils(sils) + outs = self._forward_backbone(prepared_sils) + + pooled_obj = self.temporal_pool(outs, None, options={"dim": 2}) + if ( + not isinstance(pooled_obj, tuple) + or not pooled_obj + or not isinstance(pooled_obj[0], Tensor) + ): + raise TypeError("TemporalPool output is invalid.") + pooled = pooled_obj[0] + + feat = self.hpp(pooled) + embed_1 = self.fcs(feat) + _, logits = self.bn_necks(embed_1) + + mean_logits = logits.mean(dim=-1) + pred_ids = torch.argmax(mean_logits, dim=-1) + probs = torch.softmax(mean_logits, dim=-1) + confidence = torch.gather( + probs, dim=-1, index=pred_ids.unsqueeze(-1) + ).squeeze(-1) + + return {"logits": logits, "label": pred_ids, "confidence": confidence} + + @jaxtyped(typechecker=beartype) + def predict(self, sils: Float[Tensor, "batch 1 seq 64 44"]) -> tuple[str, float]: + outputs = cast(dict[str, Tensor], self.forward(sils)) + labels = outputs["label"] + confidence = outputs["confidence"] + + if labels.numel() != 1: + raise ValueError("predict expects batch size 1.") + + label_id = int(labels.item()) + return self.LABEL_MAP[label_id], float(confidence.item()) diff --git a/opengait/demo/window.py b/opengait/demo/window.py new file mode 100644 index 0000000..c93eda3 --- /dev/null +++ b/opengait/demo/window.py @@ -0,0 +1,295 @@ +"""Sliding window / ring buffer manager for real-time gait analysis. + +This module provides bounded buffer management for silhouette sequences +with track ID tracking and gap detection. +""" + +from collections import deque +from typing import TYPE_CHECKING, Protocol, final + +import numpy as np +import torch +from jaxtyping import Float +from numpy import ndarray + +if TYPE_CHECKING: + from numpy.typing import NDArray + + +# Silhouette dimensions from preprocess.py +SIL_HEIGHT: int = 64 +SIL_WIDTH: int = 44 + + +class _Boxes(Protocol): + """Protocol for boxes with xyxy and id attributes.""" + + @property + def xyxy(self) -> "NDArray[np.float32] | object": ... + @property + def id(self) -> "NDArray[np.int64] | object | None": ... + + +class _Masks(Protocol): + """Protocol for masks with data attribute.""" + + @property + def data(self) -> "NDArray[np.float32] | object": ... + + +class _DetectionResults(Protocol): + """Protocol for detection results from Ultralytics-style objects.""" + + @property + def boxes(self) -> _Boxes: ... + @property + def masks(self) -> _Masks: ... + + +@final +class SilhouetteWindow: + """Bounded sliding window for silhouette sequences. + + Manages a fixed-size buffer of silhouettes with track ID tracking + and automatic reset on track changes or frame gaps. + + Attributes: + window_size: Maximum number of frames in the buffer. + stride: Classification stride (frames between classifications). + gap_threshold: Maximum allowed frame gap before reset. + """ + + window_size: int + stride: int + gap_threshold: int + _buffer: deque[Float[ndarray, "64 44"]] + _frame_indices: deque[int] + _track_id: int | None + _last_classified_frame: int + _frame_count: int + + def __init__( + self, + window_size: int = 30, + stride: int = 1, + gap_threshold: int = 15, + ) -> None: + """Initialize the silhouette window. + + Args: + window_size: Maximum buffer size (default 30). + stride: Frames between classifications (default 1). + gap_threshold: Max frame gap before reset (default 15). + """ + self.window_size = window_size + self.stride = stride + self.gap_threshold = gap_threshold + + # Bounded storage via deque + self._buffer = deque(maxlen=window_size) + self._frame_indices = deque(maxlen=window_size) + self._track_id = None + self._last_classified_frame = -1 + self._frame_count = 0 + + def push(self, sil: np.ndarray, frame_idx: int, track_id: int) -> None: + """Push a new silhouette into the window. + + Automatically resets buffer on track ID change or frame gap + exceeding gap_threshold. + + Args: + sil: Silhouette array of shape (64, 44), float32. + frame_idx: Current frame index for gap detection. + track_id: Track ID for the person. + """ + # Check for track ID change + if self._track_id is not None and track_id != self._track_id: + self.reset() + + # Check for frame gap + if self._frame_indices: + last_frame = self._frame_indices[-1] + gap = frame_idx - last_frame + if gap > self.gap_threshold: + self.reset() + + # Update track ID + self._track_id = track_id + + # Validate and append silhouette + sil_array = np.asarray(sil, dtype=np.float32) + if sil_array.shape != (SIL_HEIGHT, SIL_WIDTH): + raise ValueError( + f"Expected silhouette shape ({SIL_HEIGHT}, {SIL_WIDTH}), got {sil_array.shape}" + ) + + self._buffer.append(sil_array) + self._frame_indices.append(frame_idx) + self._frame_count += 1 + + def is_ready(self) -> bool: + """Check if window has enough frames for classification. + + Returns: + True if buffer is full (window_size frames). + """ + return len(self._buffer) >= self.window_size + + def should_classify(self) -> bool: + """Check if classification should run based on stride. + + Returns: + True if enough frames have passed since last classification. + """ + if not self.is_ready(): + return False + + if self._last_classified_frame < 0: + return True + + current_frame = self._frame_indices[-1] + frames_since = current_frame - self._last_classified_frame + return frames_since >= self.stride + + def get_tensor(self, device: str = "cpu") -> torch.Tensor: + """Get window contents as a tensor for model input. + + Args: + device: Target device for the tensor (default 'cpu'). + + Returns: + Tensor of shape [1, 1, window_size, 64, 44] with dtype float32. + + Raises: + ValueError: If buffer is not full. + """ + if not self.is_ready(): + raise ValueError( + f"Window not ready: {len(self._buffer)}/{self.window_size} frames" + ) + + # Stack buffer into array [window_size, 64, 44] + stacked = np.stack(list(self._buffer), axis=0) + + # Add batch and channel dims: [1, 1, window_size, 64, 44] + tensor = torch.from_numpy(stacked.astype(np.float32)) + tensor = tensor.unsqueeze(0).unsqueeze(0) + + return tensor.to(device) + + def reset(self) -> None: + """Reset the window, clearing all buffers and counters.""" + self._buffer.clear() + self._frame_indices.clear() + self._track_id = None + self._last_classified_frame = -1 + self._frame_count = 0 + + def mark_classified(self) -> None: + """Mark current frame as classified, updating stride tracking.""" + if self._frame_indices: + self._last_classified_frame = self._frame_indices[-1] + + @property + def current_track_id(self) -> int | None: + """Current track ID, or None if buffer is empty.""" + return self._track_id + + @property + def frame_count(self) -> int: + """Total frames pushed since last reset.""" + return self._frame_count + + @property + def fill_level(self) -> float: + """Fill ratio of the buffer (0.0 to 1.0).""" + return len(self._buffer) / self.window_size + + +def select_person( + results: _DetectionResults, +) -> tuple[ndarray, tuple[int, int, int, int], int] | None: + """Select the person with largest bounding box from detection results. + + Args: + results: Detection results object with boxes and masks attributes. + Expected to have: + - boxes.xyxy: array of bounding boxes [N, 4] + - masks.data: array of masks [N, H, W] + - boxes.id: optional track IDs [N] + + Returns: + Tuple of (mask, bbox, track_id) for the largest person, + or None if no valid detections or track IDs unavailable. + """ + # Check for track IDs + boxes_obj: _Boxes | object = getattr(results, "boxes", None) + if boxes_obj is None: + return None + + track_ids_obj: ndarray | object | None = getattr(boxes_obj, "id", None) + if track_ids_obj is None: + return None + + track_ids: ndarray = np.asarray(track_ids_obj) + if track_ids.size == 0: + return None + + # Get bounding boxes + xyxy_obj: ndarray | object = getattr(boxes_obj, "xyxy", None) + if xyxy_obj is None: + return None + + bboxes: ndarray = np.asarray(xyxy_obj) + if bboxes.ndim == 1: + bboxes = bboxes.reshape(1, -1) + + if bboxes.shape[0] == 0: + return None + + # Get masks + masks_obj: _Masks | object = getattr(results, "masks", None) + if masks_obj is None: + return None + + masks_data: ndarray | object = getattr(masks_obj, "data", None) + if masks_data is None: + return None + + masks: ndarray = np.asarray(masks_data) + if masks.ndim == 2: + masks = masks[np.newaxis, ...] + + if masks.shape[0] != bboxes.shape[0]: + return None + + # Find largest bbox by area + best_idx: int = -1 + best_area: float = -1.0 + + for i in range(int(bboxes.shape[0])): + row: "NDArray[np.float32]" = bboxes[i][:4] + x1f: float = float(row[0]) + y1f: float = float(row[1]) + x2f: float = float(row[2]) + y2f: float = float(row[3]) + area: float = (x2f - x1f) * (y2f - y1f) + if area > best_area: + best_area = area + best_idx = i + + if best_idx < 0: + return None + + # Extract mask and bbox + mask: "NDArray[np.float32]" = masks[best_idx] + bbox = ( + int(float(bboxes[best_idx][0])), + int(float(bboxes[best_idx][1])), + int(float(bboxes[best_idx][2])), + int(float(bboxes[best_idx][3])), + ) + track_id = int(track_ids[best_idx]) if best_idx < len(track_ids) else best_idx + + return mask, bbox, track_id