feat(demo): implement ScoNet real-time pipeline runtime

Add the full demo runtime stack for single-person scoliosis inference, including input adapters, silhouette preprocessing, temporal windowing, ScoNet wrapper, result publishing, and click-based CLI orchestration. This commit captures the executable pipeline behavior independently from tests and planning artifacts for clearer review and rollback.
This commit is contained in:
2026-02-27 09:59:04 +08:00
parent cd754ffcfb
commit b24644f16e
8 changed files with 1785 additions and 0 deletions
View File
+7
View File
@@ -0,0 +1,7 @@
from __future__ import annotations
from .pipeline import main
if __name__ == "__main__":
main()
+203
View File
@@ -0,0 +1,203 @@
"""
Input adapters for OpenGait demo.
Provides generator-based interfaces for video sources:
- OpenCV (video files, cameras)
- cv-mmap (shared memory streams)
"""
from collections.abc import AsyncIterator, Generator, Iterable
from typing import TYPE_CHECKING, Protocol, cast
import logging
import numpy as np
logger = logging.getLogger(__name__)
# Type alias for frame stream: (frame_array, metadata_dict)
FrameStream = Iterable[tuple[np.ndarray, dict[str, object]]]
if TYPE_CHECKING:
# Protocol for cv-mmap metadata to avoid direct import
class _FrameMetadata(Protocol):
frame_count: int
timestamp_ns: int
# Protocol for cv-mmap client
class _CvMmapClient(Protocol):
def __aiter__(self) -> AsyncIterator[tuple[np.ndarray, _FrameMetadata]]: ...
def opencv_source(
path: str | int, max_frames: int | None = None
) -> Generator[tuple[np.ndarray, dict[str, object]], None, None]:
"""
Generator that yields frames from an OpenCV video source.
Parameters
----------
path : str | int
Video file path or camera index (e.g., 0 for default camera)
max_frames : int | None, optional
Maximum number of frames to yield. None means unlimited.
Yields
------
tuple[np.ndarray, dict[str, object]]
(frame_array, metadata_dict) where metadata includes:
- frame_count: frame index (0-based)
- timestamp_ns: monotonic timestamp in nanoseconds (if available)
- source: the path/int provided
"""
import time
import cv2
cap = cv2.VideoCapture(path)
if not cap.isOpened():
raise RuntimeError(f"Failed to open video source: {path}")
frame_idx = 0
try:
while max_frames is None or frame_idx < max_frames:
ret, frame = cap.read()
if not ret:
# End of stream
break
# Get timestamp if available (some backends support this)
timestamp_ns = time.monotonic_ns()
metadata: dict[str, object] = {
"frame_count": frame_idx,
"timestamp_ns": timestamp_ns,
"source": path,
}
yield frame, metadata
frame_idx += 1
finally:
cap.release()
logger.debug(f"OpenCV source closed: {path}")
def cvmmap_source(
name: str, max_frames: int | None = None
) -> Generator[tuple[np.ndarray, dict[str, object]], None, None]:
"""
Generator that yields frames from a cv-mmap shared memory stream.
Bridges async cv-mmap client to synchronous generator using asyncio.run().
Parameters
----------
name : str
Base name of the cv-mmap source (e.g., "default")
max_frames : int | None, optional
Maximum number of frames to yield. None means unlimited.
Yields
------
tuple[np.ndarray, dict[str, object]]
(frame_array, metadata_dict) where metadata includes:
- frame_count: frame index from cv-mmap
- timestamp_ns: timestamp in nanoseconds from cv-mmap
- source: the cv-mmap name
Raises
------
ImportError
If cvmmap package is not available
RuntimeError
If cv-mmap stream disconnects or errors
"""
import asyncio
# Import cvmmap only when function is called
# Use try/except for runtime import check
try:
from cvmmap import CvMmapClient as _CvMmapClientReal # pyright: ignore[reportMissingTypeStubs]
except ImportError as e:
raise ImportError(
"cvmmap package is required for cv-mmap sources. "
+ "Install from: https://github.com/crosstyan/cv-mmap"
) from e
# Cast to protocol type for type checking
client: _CvMmapClient = cast("_CvMmapClient", _CvMmapClientReal(name))
frame_count = 0
async def _async_generator() -> AsyncIterator[tuple[np.ndarray, _FrameMetadata]]:
"""Async generator wrapper."""
async for frame, meta in client:
yield frame, meta
# Bridge async to sync using asyncio.run()
# We process frames one at a time to keep it simple and robust
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
agen = _async_generator().__aiter__()
while max_frames is None or frame_count < max_frames:
try:
frame, meta = loop.run_until_complete(agen.__anext__())
except StopAsyncIteration:
break
metadata: dict[str, object] = {
"frame_count": meta.frame_count,
"timestamp_ns": meta.timestamp_ns,
"source": f"cvmmap://{name}",
}
yield frame, metadata
frame_count += 1
finally:
loop.close()
logger.debug(f"cv-mmap source closed: {name}")
def create_source(source: str, max_frames: int | None = None) -> FrameStream:
"""
Factory function to create a frame source from a string specification.
Parameters
----------
source : str
Source specification:
- '0', '1', etc. -> Camera index (OpenCV)
- 'cvmmap://name' -> cv-mmap shared memory stream
- Any other string -> Video file path (OpenCV)
max_frames : int | None, optional
Maximum number of frames to yield. None means unlimited.
Returns
-------
FrameStream
Generator yielding (frame, metadata) tuples
Examples
--------
>>> for frame, meta in create_source('0'): # Camera 0
... process(frame)
>>> for frame, meta in create_source('cvmmap://default'): # cv-mmap
... process(frame)
>>> for frame, meta in create_source('/path/to/video.mp4'):
... process(frame)
"""
# Check for cv-mmap protocol
if source.startswith("cvmmap://"):
name = source[len("cvmmap://") :]
return cvmmap_source(name, max_frames)
# Check for camera index (single digit string)
if source.isdigit():
return opencv_source(int(source), max_frames)
# Otherwise treat as file path
return opencv_source(source, max_frames)
+368
View File
@@ -0,0 +1,368 @@
"""
Output publishers for OpenGait demo results.
Provides pluggable result publishing:
- ConsolePublisher: JSONL to stdout
- NatsPublisher: NATS message broker integration
"""
from __future__ import annotations
import asyncio
import json
import logging
import sys
import threading
import time
from typing import TYPE_CHECKING, Protocol, TextIO, cast, runtime_checkable
if TYPE_CHECKING:
from types import TracebackType
logger = logging.getLogger(__name__)
@runtime_checkable
class ResultPublisher(Protocol):
"""Protocol for result publishers."""
def publish(self, result: dict[str, object]) -> None:
"""
Publish a result dictionary.
Parameters
----------
result : dict[str, object]
Result data with keys: frame, track_id, label, confidence, window, timestamp_ns
"""
...
class ConsolePublisher:
"""Publisher that outputs JSON Lines to stdout."""
_output: TextIO
def __init__(self, output: TextIO = sys.stdout) -> None:
"""
Initialize console publisher.
Parameters
----------
output : TextIO
File-like object to write to (default: sys.stdout)
"""
self._output = output
def publish(self, result: dict[str, object]) -> None:
"""
Publish result as JSON line.
Parameters
----------
result : dict[str, object]
Result data with keys: frame, track_id, label, confidence, window, timestamp_ns
"""
try:
json_line = json.dumps(result, ensure_ascii=False, default=str)
_ = self._output.write(json_line + "\n")
self._output.flush()
except Exception as e:
logger.warning(f"Failed to publish to console: {e}")
def close(self) -> None:
"""Close the publisher (no-op for console)."""
pass
def __enter__(self) -> ConsolePublisher:
"""Context manager entry."""
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""Context manager exit."""
self.close()
class _NatsClient(Protocol):
"""Protocol for connected NATS client."""
async def publish(self, subject: str, payload: bytes) -> object: ...
async def close(self) -> object: ...
async def flush(self) -> object: ...
class NatsPublisher:
"""
Publisher that sends results to NATS message broker.
This is a sync-friendly wrapper around the async nats-py client.
Uses a background thread with dedicated event loop to bridge sync
publish calls to async NATS operations, making it safe to use in
both sync and async contexts.
"""
_nats_url: str
_subject: str
_nc: _NatsClient | None
_connected: bool
_loop: asyncio.AbstractEventLoop | None
_thread: threading.Thread | None
_lock: threading.Lock
def __init__(self, nats_url: str, subject: str = "scoliosis.result") -> None:
"""
Initialize NATS publisher.
Parameters
----------
nats_url : str
NATS server URL (e.g., "nats://localhost:4222")
subject : str
NATS subject to publish to (default: "scoliosis.result")
"""
self._nats_url = nats_url
self._subject = subject
self._nc = None
self._connected = False
self._loop = None
self._thread = None
self._lock = threading.Lock()
def _start_background_loop(self) -> bool:
"""
Start background thread with event loop for async operations.
Returns
-------
bool
True if loop is running, False otherwise
"""
with self._lock:
if self._loop is not None and self._loop.is_running():
return True
try:
loop = asyncio.new_event_loop()
self._loop = loop
def run_loop() -> None:
asyncio.set_event_loop(loop)
loop.run_forever()
self._thread = threading.Thread(target=run_loop, daemon=True)
self._thread.start()
return True
except Exception as e:
logger.warning(f"Failed to start background event loop: {e}")
return False
def _stop_background_loop(self) -> None:
"""Stop the background event loop and thread."""
with self._lock:
if self._loop is not None and self._loop.is_running():
_ = self._loop.call_soon_threadsafe(self._loop.stop)
if self._thread is not None and self._thread.is_alive():
self._thread.join(timeout=2.0)
self._loop = None
self._thread = None
def _ensure_connected(self) -> bool:
"""
Ensure connection to NATS server.
Returns
-------
bool
True if connected, False otherwise
"""
with self._lock:
if self._connected and self._nc is not None:
return True
if not self._start_background_loop():
return False
try:
import nats
async def _connect() -> _NatsClient:
nc = await nats.connect(self._nats_url) # pyright: ignore[reportUnknownMemberType]
return cast(_NatsClient, nc)
# Run connection in background loop
future = asyncio.run_coroutine_threadsafe(
_connect(),
self._loop, # pyright: ignore[reportArgumentType]
)
self._nc = future.result(timeout=10.0)
self._connected = True
logger.info(f"Connected to NATS at {self._nats_url}")
return True
except ImportError:
logger.warning(
"nats-py package not installed. Install with: pip install nats-py"
)
return False
except Exception as e:
logger.warning(f"Failed to connect to NATS at {self._nats_url}: {e}")
return False
def publish(self, result: dict[str, object]) -> None:
"""
Publish result to NATS subject.
Parameters
----------
result : dict[str, object]
Result data with keys: frame, track_id, label, confidence, window, timestamp_ns
"""
if not self._ensure_connected():
# Graceful degradation: log warning but don't crash
logger.debug(
f"NATS unavailable, dropping result: {result.get('track_id', 'unknown')}"
)
return
try:
async def _publish() -> None:
if self._nc is not None:
payload = json.dumps(
result, ensure_ascii=False, default=str
).encode("utf-8")
_ = await self._nc.publish(self._subject, payload)
_ = await self._nc.flush()
# Run publish in background loop
future = asyncio.run_coroutine_threadsafe(
_publish(),
self._loop, # pyright: ignore[reportArgumentType]
)
future.result(timeout=5.0) # Wait for publish to complete
except Exception as e:
logger.warning(f"Failed to publish to NATS: {e}")
self._connected = False # Mark for reconnection on next publish
def close(self) -> None:
"""Close NATS connection."""
with self._lock:
if self._nc is not None and self._connected and self._loop is not None:
try:
async def _close() -> None:
if self._nc is not None:
_ = await self._nc.close()
future = asyncio.run_coroutine_threadsafe(
_close(),
self._loop,
)
future.result(timeout=5.0)
except Exception as e:
logger.debug(f"Error closing NATS connection: {e}")
finally:
self._nc = None
self._connected = False
self._stop_background_loop()
def __enter__(self) -> NatsPublisher:
"""Context manager entry."""
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""Context manager exit."""
self.close()
def create_publisher(
nats_url: str | None,
subject: str = "scoliosis.result",
) -> ResultPublisher:
"""
Factory function to create appropriate publisher.
Parameters
----------
nats_url : str | None
NATS server URL. If None or empty, returns ConsolePublisher.
subject : str
NATS subject to publish to (default: "scoliosis.result")
Returns
-------
ResultPublisher
NatsPublisher if nats_url provided, otherwise ConsolePublisher
Examples
--------
>>> # Console output (default)
>>> pub = create_publisher(None)
>>> pub.publish({"frame": 1, "track_id": 42, "label": "normal", "confidence": 0.95, "window": 30, "timestamp_ns": 1234567890})
>>>
>>> # NATS output
>>> pub = create_publisher("nats://localhost:4222")
>>> pub.publish({"frame": 1, "track_id": 42, "label": "normal", "confidence": 0.95, "window": 30, "timestamp_ns": 1234567890})
>>>
>>> # Context manager usage
>>> with create_publisher("nats://localhost:4222") as pub:
... pub.publish(result)
"""
if nats_url:
return NatsPublisher(nats_url, subject)
return ConsolePublisher()
def create_result(
frame: int,
track_id: int,
label: str,
confidence: float,
window: int | tuple[int, int],
timestamp_ns: int | None = None,
) -> dict[str, object]:
"""
Create a standardized result dictionary.
Parameters
----------
frame : int
Frame number
track_id : int
Track/person identifier
label : str
Classification label (e.g., "normal", "scoliosis")
confidence : float
Confidence score (0.0 to 1.0)
window : int | tuple[int, int]
Frame window as int (end frame) or tuple [start, end] that produced this result
Frame window [start, end] that produced this result
timestamp_ns : int | None
Timestamp in nanoseconds. If None, uses current time.
Returns
-------
dict[str, object]
Standardized result dictionary
"""
return {
"frame": frame,
"track_id": track_id,
"label": label,
"confidence": confidence,
"window": window if isinstance(window, int) else window[1],
"timestamp_ns": timestamp_ns
if timestamp_ns is not None
else time.monotonic_ns(),
}
+325
View File
@@ -0,0 +1,325 @@
from __future__ import annotations
from collections.abc import Callable
from contextlib import suppress
import logging
from pathlib import Path
import time
from typing import Protocol, cast
from beartype import beartype
import click
import jaxtyping
from jaxtyping import Float, UInt8
import numpy as np
from numpy import ndarray
from numpy.typing import NDArray
from ultralytics.models.yolo.model import YOLO
from .input import FrameStream, create_source
from .output import ResultPublisher, create_publisher, create_result
from .preprocess import frame_to_person_mask, mask_to_silhouette
from .sconet_demo import ScoNetDemo
from .window import SilhouetteWindow, select_person
logger = logging.getLogger(__name__)
JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
JaxtypedFactory = Callable[..., JaxtypedDecorator]
jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
class _BoxesLike(Protocol):
@property
def xyxy(self) -> NDArray[np.float32] | object: ...
@property
def id(self) -> NDArray[np.int64] | object | None: ...
class _MasksLike(Protocol):
@property
def data(self) -> NDArray[np.float32] | object: ...
class _DetectionResultsLike(Protocol):
@property
def boxes(self) -> _BoxesLike: ...
@property
def masks(self) -> _MasksLike: ...
class _TrackCallable(Protocol):
def __call__(
self,
source: object,
*,
persist: bool = True,
verbose: bool = False,
device: str | None = None,
classes: list[int] | None = None,
) -> object: ...
class ScoliosisPipeline:
_detector: object
_source: FrameStream
_window: SilhouetteWindow
_publisher: ResultPublisher
_classifier: ScoNetDemo
_device: str
_closed: bool
def __init__(
self,
*,
source: str,
checkpoint: str,
config: str,
device: str,
yolo_model: str,
window: int,
stride: int,
nats_url: str | None,
nats_subject: str,
max_frames: int | None,
) -> None:
self._detector = YOLO(yolo_model)
self._source = create_source(source, max_frames=max_frames)
self._window = SilhouetteWindow(window_size=window, stride=stride)
self._publisher = create_publisher(nats_url=nats_url, subject=nats_subject)
self._classifier = ScoNetDemo(
cfg_path=config,
checkpoint_path=checkpoint,
device=device,
)
self._device = device
self._closed = False
@staticmethod
def _extract_int(meta: dict[str, object], key: str, fallback: int) -> int:
value = meta.get(key)
if isinstance(value, int):
return value
return fallback
@staticmethod
def _extract_timestamp(meta: dict[str, object]) -> int:
value = meta.get("timestamp_ns")
if isinstance(value, int):
return value
return time.monotonic_ns()
@staticmethod
def _to_mask_u8(mask: ndarray) -> UInt8[ndarray, "h w"]:
binary = np.where(np.asarray(mask) > 0.5, np.uint8(255), np.uint8(0)).astype(
np.uint8
)
return cast(UInt8[ndarray, "h w"], binary)
def _first_result(self, detections: object) -> _DetectionResultsLike | None:
if isinstance(detections, list):
return cast(_DetectionResultsLike, detections[0]) if detections else None
if isinstance(detections, tuple):
return cast(_DetectionResultsLike, detections[0]) if detections else None
return cast(_DetectionResultsLike, detections)
def _select_silhouette(
self,
result: _DetectionResultsLike,
) -> tuple[Float[ndarray, "64 44"], int] | None:
selected = select_person(result)
if selected is not None:
mask_raw, bbox, track_id = selected
silhouette = cast(
Float[ndarray, "64 44"] | None,
mask_to_silhouette(self._to_mask_u8(mask_raw), bbox),
)
if silhouette is not None:
return silhouette, int(track_id)
fallback = cast(
tuple[UInt8[ndarray, "h w"], tuple[int, int, int, int]] | None,
frame_to_person_mask(result),
)
if fallback is None:
return None
mask_u8, bbox = fallback
silhouette = cast(
Float[ndarray, "64 44"] | None,
mask_to_silhouette(mask_u8, bbox),
)
if silhouette is None:
return None
return silhouette, 0
@jaxtyped(typechecker=beartype)
def process_frame(
self,
frame: UInt8[ndarray, "h w c"],
metadata: dict[str, object],
) -> dict[str, object] | None:
frame_idx = self._extract_int(metadata, "frame_count", fallback=0)
timestamp_ns = self._extract_timestamp(metadata)
track_fn_obj = getattr(self._detector, "track", None)
if not callable(track_fn_obj):
raise RuntimeError("YOLO detector does not expose a callable track()")
track_fn = cast(_TrackCallable, track_fn_obj)
detections = track_fn(
frame,
persist=True,
verbose=False,
device=self._device,
classes=[0],
)
first = self._first_result(detections)
if first is None:
return None
selected = self._select_silhouette(first)
if selected is None:
return None
silhouette, track_id = selected
self._window.push(silhouette, frame_idx=frame_idx, track_id=track_id)
if not self._window.should_classify():
return None
window_tensor = self._window.get_tensor(device=self._device)
label, confidence = cast(
tuple[str, float],
self._classifier.predict(window_tensor),
)
self._window.mark_classified()
window_start = frame_idx - self._window.window_size + 1
result = create_result(
frame=frame_idx,
track_id=track_id,
label=label,
confidence=float(confidence),
window=(max(0, window_start), frame_idx),
timestamp_ns=timestamp_ns,
)
self._publisher.publish(result)
return result
def run(self) -> int:
frame_count = 0
start_time = time.perf_counter()
try:
for item in self._source:
frame, metadata = item
frame_u8 = np.asarray(frame, dtype=np.uint8)
frame_idx = self._extract_int(metadata, "frame_count", fallback=0)
frame_count += 1
try:
_ = self.process_frame(frame_u8, metadata)
except Exception as frame_error:
logger.warning(
"Skipping frame %d due to processing error: %s",
frame_idx,
frame_error,
)
if frame_count % 100 == 0:
elapsed = time.perf_counter() - start_time
fps = frame_count / elapsed if elapsed > 0 else 0.0
logger.info("Processed %d frames (%.2f FPS)", frame_count, fps)
return 0
except KeyboardInterrupt:
logger.info("Interrupted by user, shutting down cleanly.")
return 130
finally:
self.close()
def close(self) -> None:
if self._closed:
return
close_fn = getattr(self._publisher, "close", None)
if callable(close_fn):
with suppress(Exception):
_ = close_fn()
self._closed = True
def validate_runtime_inputs(source: str, checkpoint: str, config: str) -> None:
if source.startswith("cvmmap://") or source.isdigit():
pass
else:
source_path = Path(source)
if not source_path.is_file():
raise ValueError(f"Video source not found: {source}")
checkpoint_path = Path(checkpoint)
if not checkpoint_path.is_file():
raise ValueError(f"Checkpoint not found: {checkpoint}")
config_path = Path(config)
if not config_path.is_file():
raise ValueError(f"Config not found: {config}")
@click.command(context_settings={"help_option_names": ["-h", "--help"]})
@click.option("--source", type=str, required=True)
@click.option("--checkpoint", type=str, required=True)
@click.option(
"--config",
type=str,
default="configs/sconet/sconet_scoliosis1k.yaml",
show_default=True,
)
@click.option("--device", type=str, default="cuda:0", show_default=True)
@click.option("--yolo-model", type=str, default="yolo11n-seg.pt", show_default=True)
@click.option("--window", type=click.IntRange(min=1), default=30, show_default=True)
@click.option("--stride", type=click.IntRange(min=1), default=30, show_default=True)
@click.option("--nats-url", type=str, default=None)
@click.option(
"--nats-subject",
type=str,
default="scoliosis.result",
show_default=True,
)
@click.option("--max-frames", type=click.IntRange(min=1), default=None)
def main(
source: str,
checkpoint: str,
config: str,
device: str,
yolo_model: str,
window: int,
stride: int,
nats_url: str | None,
nats_subject: str,
max_frames: int | None,
) -> None:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
try:
validate_runtime_inputs(source=source, checkpoint=checkpoint, config=config)
pipeline = ScoliosisPipeline(
source=source,
checkpoint=checkpoint,
config=config,
device=device,
yolo_model=yolo_model,
window=window,
stride=stride,
nats_url=nats_url,
nats_subject=nats_subject,
max_frames=max_frames,
)
raise SystemExit(pipeline.run())
except ValueError as err:
click.echo(f"Error: {err}", err=True)
raise SystemExit(2) from err
except RuntimeError as err:
click.echo(f"Runtime error: {err}", err=True)
raise SystemExit(1) from err
+270
View File
@@ -0,0 +1,270 @@
from collections.abc import Callable
import math
from typing import cast
import cv2
from beartype import beartype
import jaxtyping
from jaxtyping import Float, UInt8
import numpy as np
from numpy import ndarray
from numpy.typing import NDArray
SIL_HEIGHT = 64
SIL_WIDTH = 44
SIL_FULL_WIDTH = 64
SIDE_CUT = 10
MIN_MASK_AREA = 500
JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
JaxtypedFactory = Callable[..., JaxtypedDecorator]
jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
UInt8Array = NDArray[np.uint8]
Float32Array = NDArray[np.float32]
def _read_attr(container: object, key: str) -> object | None:
if isinstance(container, dict):
dict_obj = cast(dict[object, object], container)
return dict_obj.get(key)
try:
return cast(object, object.__getattribute__(container, key))
except AttributeError:
return None
def _to_numpy_array(value: object) -> NDArray[np.generic]:
current: object = value
if isinstance(current, np.ndarray):
return current
detach_obj = _read_attr(current, "detach")
if callable(detach_obj):
detach_fn = cast(Callable[[], object], detach_obj)
current = detach_fn()
cpu_obj = _read_attr(current, "cpu")
if callable(cpu_obj):
cpu_fn = cast(Callable[[], object], cpu_obj)
current = cpu_fn()
numpy_obj = _read_attr(current, "numpy")
if callable(numpy_obj):
numpy_fn = cast(Callable[[], object], numpy_obj)
as_numpy = numpy_fn()
if isinstance(as_numpy, np.ndarray):
return as_numpy
return cast(NDArray[np.generic], np.asarray(current))
def _bbox_from_mask(mask: UInt8[ndarray, "h w"]) -> tuple[int, int, int, int] | None:
mask_u8 = np.asarray(mask, dtype=np.uint8)
coords = np.argwhere(mask_u8 > 0)
if int(coords.size) == 0:
return None
ys = coords[:, 0].astype(np.int64)
xs = coords[:, 1].astype(np.int64)
x1 = int(np.min(xs))
x2 = int(np.max(xs)) + 1
y1 = int(np.min(ys))
y2 = int(np.max(ys)) + 1
if x2 <= x1 or y2 <= y1:
return None
return (x1, y1, x2, y2)
def _sanitize_bbox(
bbox: tuple[int, int, int, int], height: int, width: int
) -> tuple[int, int, int, int] | None:
x1, y1, x2, y2 = bbox
x1c = max(0, min(int(x1), width - 1))
y1c = max(0, min(int(y1), height - 1))
x2c = max(0, min(int(x2), width))
y2c = max(0, min(int(y2), height))
if x2c <= x1c or y2c <= y1c:
return None
return (x1c, y1c, x2c, y2c)
@jaxtyped(typechecker=beartype)
def frame_to_person_mask(
result: object, min_area: int = MIN_MASK_AREA
) -> tuple[UInt8[ndarray, "h w"], tuple[int, int, int, int]] | None:
masks_obj = _read_attr(result, "masks")
if masks_obj is None:
return None
masks_data_obj = _read_attr(masks_obj, "data")
if masks_data_obj is None:
return None
masks_raw = _to_numpy_array(masks_data_obj)
masks_float = np.asarray(masks_raw, dtype=np.float32)
if masks_float.ndim == 2:
masks_float = masks_float[np.newaxis, ...]
if masks_float.ndim != 3:
return None
mask_count = int(cast(tuple[int, int, int], masks_float.shape)[0])
if mask_count <= 0:
return None
box_values: list[tuple[float, float, float, float]] | None = None
boxes_obj = _read_attr(result, "boxes")
if boxes_obj is not None:
xyxy_obj = _read_attr(boxes_obj, "xyxy")
if xyxy_obj is not None:
xyxy_raw = np.asarray(_to_numpy_array(xyxy_obj), dtype=np.float32)
if xyxy_raw.ndim == 1 and int(xyxy_raw.size) >= 4:
xyxy_2d = np.asarray(xyxy_raw[:4].reshape(1, 4), dtype=np.float64)
x1f = cast(np.float64, xyxy_2d[0, 0])
y1f = cast(np.float64, xyxy_2d[0, 1])
x2f = cast(np.float64, xyxy_2d[0, 2])
y2f = cast(np.float64, xyxy_2d[0, 3])
box_values = [
(
float(x1f),
float(y1f),
float(x2f),
float(y2f),
)
]
elif xyxy_raw.ndim == 2:
shape_2d = cast(tuple[int, int], xyxy_raw.shape)
if int(shape_2d[1]) >= 4:
xyxy_2d = np.asarray(xyxy_raw[:, :4], dtype=np.float64)
box_values = []
for row_idx in range(int(cast(tuple[int, int], xyxy_2d.shape)[0])):
x1f = cast(np.float64, xyxy_2d[row_idx, 0])
y1f = cast(np.float64, xyxy_2d[row_idx, 1])
x2f = cast(np.float64, xyxy_2d[row_idx, 2])
y2f = cast(np.float64, xyxy_2d[row_idx, 3])
box_values.append(
(
float(x1f),
float(y1f),
float(x2f),
float(y2f),
)
)
best_area = -1
best_mask: UInt8[ndarray, "h w"] | None = None
best_bbox: tuple[int, int, int, int] | None = None
for idx in range(mask_count):
mask_float = np.asarray(masks_float[idx], dtype=np.float32)
if mask_float.ndim != 2:
continue
mask_binary = np.where(mask_float > 0.5, np.uint8(255), np.uint8(0)).astype(
np.uint8
)
mask_u8 = cast(UInt8[ndarray, "h w"], mask_binary)
area = int(np.count_nonzero(mask_u8))
if area < min_area:
continue
bbox: tuple[int, int, int, int] | None = None
shape_2d = cast(tuple[int, int], mask_binary.shape)
h = int(shape_2d[0])
w = int(shape_2d[1])
if box_values is not None:
box_count = len(box_values)
if idx >= box_count:
continue
row0, row1, row2, row3 = box_values[idx]
bbox_candidate = (
int(math.floor(row0)),
int(math.floor(row1)),
int(math.ceil(row2)),
int(math.ceil(row3)),
)
bbox = _sanitize_bbox(bbox_candidate, h, w)
if bbox is None:
bbox = _bbox_from_mask(mask_u8)
if bbox is None:
continue
if area > best_area:
best_area = area
best_mask = mask_u8
best_bbox = bbox
if best_mask is None or best_bbox is None:
return None
return best_mask, best_bbox
@jaxtyped(typechecker=beartype)
def mask_to_silhouette(
mask: UInt8[ndarray, "h w"],
bbox: tuple[int, int, int, int],
) -> Float[ndarray, "64 44"] | None:
mask_u8 = np.where(mask > 0, np.uint8(255), np.uint8(0)).astype(np.uint8)
if int(np.count_nonzero(mask_u8)) < MIN_MASK_AREA:
return None
mask_shape = cast(tuple[int, int], mask_u8.shape)
h = int(mask_shape[0])
w = int(mask_shape[1])
bbox_sanitized = _sanitize_bbox(bbox, h, w)
if bbox_sanitized is None:
return None
x1, y1, x2, y2 = bbox_sanitized
cropped = mask_u8[y1:y2, x1:x2]
if cropped.size == 0:
return None
cropped_u8 = np.asarray(cropped, dtype=np.uint8)
row_sums = np.sum(cropped_u8, axis=1, dtype=np.int64)
row_nonzero = np.nonzero(row_sums > 0)[0].astype(np.int64)
if int(row_nonzero.size) == 0:
return None
top = int(cast(np.int64, row_nonzero[0]))
bottom = int(cast(np.int64, row_nonzero[-1])) + 1
tightened = cropped[top:bottom, :]
if tightened.size == 0:
return None
tight_shape = cast(tuple[int, int], tightened.shape)
tight_h = int(tight_shape[0])
tight_w = int(tight_shape[1])
if tight_h <= 0 or tight_w <= 0:
return None
resized_w = max(1, int(SIL_HEIGHT * (tight_w / tight_h)))
resized = np.asarray(
cv2.resize(tightened, (resized_w, SIL_HEIGHT), interpolation=cv2.INTER_CUBIC),
dtype=np.uint8,
)
if resized_w >= SIL_FULL_WIDTH:
start = (resized_w - SIL_FULL_WIDTH) // 2
normalized_64 = resized[:, start : start + SIL_FULL_WIDTH]
else:
pad_left = (SIL_FULL_WIDTH - resized_w) // 2
pad_right = SIL_FULL_WIDTH - resized_w - pad_left
normalized_64 = np.pad(
resized,
((0, 0), (pad_left, pad_right)),
mode="constant",
constant_values=0,
)
silhouette = np.asarray(
normalized_64[:, SIDE_CUT : SIL_FULL_WIDTH - SIDE_CUT], dtype=np.float32
)
if silhouette.shape != (SIL_HEIGHT, SIL_WIDTH):
return None
silhouette_norm = np.clip(silhouette / np.float32(255.0), 0.0, 1.0).astype(
np.float32
)
return cast(Float[ndarray, "64 44"], silhouette_norm)
+317
View File
@@ -0,0 +1,317 @@
from __future__ import annotations
from collections.abc import Callable
from pathlib import Path
import sys
from typing import ClassVar, Protocol, cast, override
import torch
import torch.nn as nn
from beartype import beartype
from einops import rearrange
from jaxtyping import Float
import jaxtyping
from torch import Tensor
_OPENGAIT_PACKAGE_ROOT = Path(__file__).resolve().parents[1]
if str(_OPENGAIT_PACKAGE_ROOT) not in sys.path:
sys.path.insert(0, str(_OPENGAIT_PACKAGE_ROOT))
from opengait.modeling.backbones.resnet import ResNet9
from opengait.modeling.modules import (
HorizontalPoolingPyramid,
PackSequenceWrapper as TemporalPool,
SeparateBNNecks,
SeparateFCs,
)
from opengait.utils import common as common_utils
JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
JaxtypedFactory = Callable[..., JaxtypedDecorator]
jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
ConfigLoader = Callable[[str], dict[str, object]]
config_loader = cast(ConfigLoader, common_utils.config_loader)
class TemporalPoolLike(Protocol):
def __call__(
self,
seqs: Tensor,
seqL: object,
dim: int = 2,
options: dict[str, int] | None = None,
) -> object: ...
class HppLike(Protocol):
def __call__(self, x: Tensor) -> Tensor: ...
class FCsLike(Protocol):
def __call__(self, x: Tensor) -> Tensor: ...
class BNNecksLike(Protocol):
def __call__(self, x: Tensor) -> tuple[Tensor, Tensor]: ...
class ScoNetDemo(nn.Module):
LABEL_MAP: ClassVar[dict[int, str]] = {0: "negative", 1: "neutral", 2: "positive"}
cfg_path: str
cfg: dict[str, object]
backbone: ResNet9
temporal_pool: TemporalPoolLike
hpp: HppLike
fcs: FCsLike
bn_necks: BNNecksLike
device: torch.device
@jaxtyped(typechecker=beartype)
def __init__(
self,
cfg_path: str | Path = "configs/sconet/sconet_scoliosis1k.yaml",
checkpoint_path: str | Path | None = None,
device: str | torch.device | None = None,
) -> None:
super().__init__()
resolved_cfg = self._resolve_path(cfg_path)
self.cfg_path = str(resolved_cfg)
self.cfg = config_loader(self.cfg_path)
model_cfg = self._extract_model_cfg(self.cfg)
backbone_cfg = self._extract_dict(model_cfg, "backbone_cfg")
if backbone_cfg.get("type") != "ResNet9":
raise ValueError(
"ScoNetDemo currently supports backbone type ResNet9 only."
)
self.backbone = ResNet9(
block=self._extract_str(backbone_cfg, "block"),
channels=self._extract_int_list(backbone_cfg, "channels"),
in_channel=self._extract_int(backbone_cfg, "in_channel", default=1),
layers=self._extract_int_list(backbone_cfg, "layers"),
strides=self._extract_int_list(backbone_cfg, "strides"),
maxpool=self._extract_bool(backbone_cfg, "maxpool", default=True),
)
fcs_cfg = self._extract_dict(model_cfg, "SeparateFCs")
bn_cfg = self._extract_dict(model_cfg, "SeparateBNNecks")
bin_num = self._extract_int_list(model_cfg, "bin_num")
self.temporal_pool = cast(TemporalPoolLike, TemporalPool(torch.max))
self.hpp = cast(HppLike, HorizontalPoolingPyramid(bin_num=bin_num))
self.fcs = cast(
FCsLike,
SeparateFCs(
parts_num=self._extract_int(fcs_cfg, "parts_num"),
in_channels=self._extract_int(fcs_cfg, "in_channels"),
out_channels=self._extract_int(fcs_cfg, "out_channels"),
norm=self._extract_bool(fcs_cfg, "norm", default=False),
),
)
self.bn_necks = cast(
BNNecksLike,
SeparateBNNecks(
parts_num=self._extract_int(bn_cfg, "parts_num"),
in_channels=self._extract_int(bn_cfg, "in_channels"),
class_num=self._extract_int(bn_cfg, "class_num"),
norm=self._extract_bool(bn_cfg, "norm", default=True),
parallel_BN1d=self._extract_bool(bn_cfg, "parallel_BN1d", default=True),
),
)
self.device = (
torch.device(device) if device is not None else torch.device("cpu")
)
_ = self.to(self.device)
if checkpoint_path is not None:
_ = self.load_checkpoint(checkpoint_path)
_ = self.eval()
@staticmethod
def _resolve_path(path: str | Path) -> Path:
candidate = Path(path)
if candidate.is_file():
return candidate
if candidate.is_absolute():
return candidate
repo_root = Path(__file__).resolve().parents[2]
return repo_root / candidate
@staticmethod
def _extract_model_cfg(cfg: dict[str, object]) -> dict[str, object]:
model_cfg_obj = cfg.get("model_cfg")
if not isinstance(model_cfg_obj, dict):
raise TypeError("model_cfg must be a dictionary.")
return cast(dict[str, object], model_cfg_obj)
@staticmethod
def _extract_dict(container: dict[str, object], key: str) -> dict[str, object]:
value = container.get(key)
if not isinstance(value, dict):
raise TypeError(f"{key} must be a dictionary.")
return cast(dict[str, object], value)
@staticmethod
def _extract_str(container: dict[str, object], key: str) -> str:
value = container.get(key)
if not isinstance(value, str):
raise TypeError(f"{key} must be a string.")
return value
@staticmethod
def _extract_int(
container: dict[str, object], key: str, default: int | None = None
) -> int:
value = container.get(key, default)
if not isinstance(value, int):
raise TypeError(f"{key} must be an int.")
return value
@staticmethod
def _extract_bool(
container: dict[str, object], key: str, default: bool | None = None
) -> bool:
value = container.get(key, default)
if not isinstance(value, bool):
raise TypeError(f"{key} must be a bool.")
return value
@staticmethod
def _extract_int_list(container: dict[str, object], key: str) -> list[int]:
value = container.get(key)
if not isinstance(value, list):
raise TypeError(f"{key} must be a list[int].")
values = cast(list[object], value)
if not all(isinstance(v, int) for v in values):
raise TypeError(f"{key} must be a list[int].")
return cast(list[int], values)
@staticmethod
def _normalize_state_dict(
state_dict_obj: dict[object, object],
) -> dict[str, Tensor]:
prefix_remap: tuple[tuple[str, str], ...] = (
("Backbone.forward_block.", "backbone."),
("FCs.", "fcs."),
("BNNecks.", "bn_necks."),
)
cleaned_state_dict: dict[str, Tensor] = {}
for key_obj, value_obj in state_dict_obj.items():
if not isinstance(key_obj, str):
raise TypeError("Checkpoint state_dict keys must be strings.")
if not isinstance(value_obj, Tensor):
raise TypeError("Checkpoint state_dict values must be torch.Tensor.")
key = key_obj[7:] if key_obj.startswith("module.") else key_obj
for source_prefix, target_prefix in prefix_remap:
if key.startswith(source_prefix):
key = f"{target_prefix}{key[len(source_prefix) :]}"
break
if key in cleaned_state_dict:
raise RuntimeError(
f"Checkpoint key normalization collision detected for key '{key}'."
)
cleaned_state_dict[key] = value_obj
return cleaned_state_dict
@jaxtyped(typechecker=beartype)
def load_checkpoint(
self,
checkpoint_path: str | Path,
map_location: str | torch.device | None = None,
strict: bool = True,
) -> None:
resolved_ckpt = self._resolve_path(checkpoint_path)
checkpoint_obj = cast(
object,
torch.load(
str(resolved_ckpt),
map_location=map_location if map_location is not None else self.device,
),
)
state_dict_obj: object = checkpoint_obj
if isinstance(checkpoint_obj, dict) and "model" in checkpoint_obj:
state_dict_obj = cast(dict[str, object], checkpoint_obj)["model"]
if not isinstance(state_dict_obj, dict):
raise TypeError("Unsupported checkpoint format.")
cleaned_state_dict = self._normalize_state_dict(
cast(dict[object, object], state_dict_obj)
)
try:
_ = self.load_state_dict(cleaned_state_dict, strict=strict)
except RuntimeError as exc:
raise RuntimeError(
f"Failed to load ScoNetDemo checkpoint after key normalization from '{resolved_ckpt}'."
) from exc
_ = self.eval()
def _prepare_sils(self, sils: Tensor) -> Tensor:
if sils.ndim == 4:
sils = sils.unsqueeze(1)
elif sils.ndim == 5 and sils.shape[1] != 1 and sils.shape[2] == 1:
sils = rearrange(sils, "b s c h w -> b c s h w")
if sils.ndim != 5 or sils.shape[1] != 1:
raise ValueError("Expected sils shape [B, 1, S, H, W] or [B, S, H, W].")
return sils.float().to(self.device)
def _forward_backbone(self, sils: Tensor) -> Tensor:
batch, channels, seq, height, width = sils.shape
framewise = sils.transpose(1, 2).reshape(batch * seq, channels, height, width)
frame_feats = cast(Tensor, self.backbone(framewise))
_, out_channels, out_h, out_w = frame_feats.shape
return (
frame_feats.reshape(batch, seq, out_channels, out_h, out_w)
.transpose(1, 2)
.contiguous()
)
@override
@jaxtyped(typechecker=beartype)
def forward(self, sils: Float[Tensor, "batch 1 seq 64 44"]) -> dict[str, Tensor]:
with torch.inference_mode():
prepared_sils = self._prepare_sils(sils)
outs = self._forward_backbone(prepared_sils)
pooled_obj = self.temporal_pool(outs, None, options={"dim": 2})
if (
not isinstance(pooled_obj, tuple)
or not pooled_obj
or not isinstance(pooled_obj[0], Tensor)
):
raise TypeError("TemporalPool output is invalid.")
pooled = pooled_obj[0]
feat = self.hpp(pooled)
embed_1 = self.fcs(feat)
_, logits = self.bn_necks(embed_1)
mean_logits = logits.mean(dim=-1)
pred_ids = torch.argmax(mean_logits, dim=-1)
probs = torch.softmax(mean_logits, dim=-1)
confidence = torch.gather(
probs, dim=-1, index=pred_ids.unsqueeze(-1)
).squeeze(-1)
return {"logits": logits, "label": pred_ids, "confidence": confidence}
@jaxtyped(typechecker=beartype)
def predict(self, sils: Float[Tensor, "batch 1 seq 64 44"]) -> tuple[str, float]:
outputs = cast(dict[str, Tensor], self.forward(sils))
labels = outputs["label"]
confidence = outputs["confidence"]
if labels.numel() != 1:
raise ValueError("predict expects batch size 1.")
label_id = int(labels.item())
return self.LABEL_MAP[label_id], float(confidence.item())
+295
View File
@@ -0,0 +1,295 @@
"""Sliding window / ring buffer manager for real-time gait analysis.
This module provides bounded buffer management for silhouette sequences
with track ID tracking and gap detection.
"""
from collections import deque
from typing import TYPE_CHECKING, Protocol, final
import numpy as np
import torch
from jaxtyping import Float
from numpy import ndarray
if TYPE_CHECKING:
from numpy.typing import NDArray
# Silhouette dimensions from preprocess.py
SIL_HEIGHT: int = 64
SIL_WIDTH: int = 44
class _Boxes(Protocol):
"""Protocol for boxes with xyxy and id attributes."""
@property
def xyxy(self) -> "NDArray[np.float32] | object": ...
@property
def id(self) -> "NDArray[np.int64] | object | None": ...
class _Masks(Protocol):
"""Protocol for masks with data attribute."""
@property
def data(self) -> "NDArray[np.float32] | object": ...
class _DetectionResults(Protocol):
"""Protocol for detection results from Ultralytics-style objects."""
@property
def boxes(self) -> _Boxes: ...
@property
def masks(self) -> _Masks: ...
@final
class SilhouetteWindow:
"""Bounded sliding window for silhouette sequences.
Manages a fixed-size buffer of silhouettes with track ID tracking
and automatic reset on track changes or frame gaps.
Attributes:
window_size: Maximum number of frames in the buffer.
stride: Classification stride (frames between classifications).
gap_threshold: Maximum allowed frame gap before reset.
"""
window_size: int
stride: int
gap_threshold: int
_buffer: deque[Float[ndarray, "64 44"]]
_frame_indices: deque[int]
_track_id: int | None
_last_classified_frame: int
_frame_count: int
def __init__(
self,
window_size: int = 30,
stride: int = 1,
gap_threshold: int = 15,
) -> None:
"""Initialize the silhouette window.
Args:
window_size: Maximum buffer size (default 30).
stride: Frames between classifications (default 1).
gap_threshold: Max frame gap before reset (default 15).
"""
self.window_size = window_size
self.stride = stride
self.gap_threshold = gap_threshold
# Bounded storage via deque
self._buffer = deque(maxlen=window_size)
self._frame_indices = deque(maxlen=window_size)
self._track_id = None
self._last_classified_frame = -1
self._frame_count = 0
def push(self, sil: np.ndarray, frame_idx: int, track_id: int) -> None:
"""Push a new silhouette into the window.
Automatically resets buffer on track ID change or frame gap
exceeding gap_threshold.
Args:
sil: Silhouette array of shape (64, 44), float32.
frame_idx: Current frame index for gap detection.
track_id: Track ID for the person.
"""
# Check for track ID change
if self._track_id is not None and track_id != self._track_id:
self.reset()
# Check for frame gap
if self._frame_indices:
last_frame = self._frame_indices[-1]
gap = frame_idx - last_frame
if gap > self.gap_threshold:
self.reset()
# Update track ID
self._track_id = track_id
# Validate and append silhouette
sil_array = np.asarray(sil, dtype=np.float32)
if sil_array.shape != (SIL_HEIGHT, SIL_WIDTH):
raise ValueError(
f"Expected silhouette shape ({SIL_HEIGHT}, {SIL_WIDTH}), got {sil_array.shape}"
)
self._buffer.append(sil_array)
self._frame_indices.append(frame_idx)
self._frame_count += 1
def is_ready(self) -> bool:
"""Check if window has enough frames for classification.
Returns:
True if buffer is full (window_size frames).
"""
return len(self._buffer) >= self.window_size
def should_classify(self) -> bool:
"""Check if classification should run based on stride.
Returns:
True if enough frames have passed since last classification.
"""
if not self.is_ready():
return False
if self._last_classified_frame < 0:
return True
current_frame = self._frame_indices[-1]
frames_since = current_frame - self._last_classified_frame
return frames_since >= self.stride
def get_tensor(self, device: str = "cpu") -> torch.Tensor:
"""Get window contents as a tensor for model input.
Args:
device: Target device for the tensor (default 'cpu').
Returns:
Tensor of shape [1, 1, window_size, 64, 44] with dtype float32.
Raises:
ValueError: If buffer is not full.
"""
if not self.is_ready():
raise ValueError(
f"Window not ready: {len(self._buffer)}/{self.window_size} frames"
)
# Stack buffer into array [window_size, 64, 44]
stacked = np.stack(list(self._buffer), axis=0)
# Add batch and channel dims: [1, 1, window_size, 64, 44]
tensor = torch.from_numpy(stacked.astype(np.float32))
tensor = tensor.unsqueeze(0).unsqueeze(0)
return tensor.to(device)
def reset(self) -> None:
"""Reset the window, clearing all buffers and counters."""
self._buffer.clear()
self._frame_indices.clear()
self._track_id = None
self._last_classified_frame = -1
self._frame_count = 0
def mark_classified(self) -> None:
"""Mark current frame as classified, updating stride tracking."""
if self._frame_indices:
self._last_classified_frame = self._frame_indices[-1]
@property
def current_track_id(self) -> int | None:
"""Current track ID, or None if buffer is empty."""
return self._track_id
@property
def frame_count(self) -> int:
"""Total frames pushed since last reset."""
return self._frame_count
@property
def fill_level(self) -> float:
"""Fill ratio of the buffer (0.0 to 1.0)."""
return len(self._buffer) / self.window_size
def select_person(
results: _DetectionResults,
) -> tuple[ndarray, tuple[int, int, int, int], int] | None:
"""Select the person with largest bounding box from detection results.
Args:
results: Detection results object with boxes and masks attributes.
Expected to have:
- boxes.xyxy: array of bounding boxes [N, 4]
- masks.data: array of masks [N, H, W]
- boxes.id: optional track IDs [N]
Returns:
Tuple of (mask, bbox, track_id) for the largest person,
or None if no valid detections or track IDs unavailable.
"""
# Check for track IDs
boxes_obj: _Boxes | object = getattr(results, "boxes", None)
if boxes_obj is None:
return None
track_ids_obj: ndarray | object | None = getattr(boxes_obj, "id", None)
if track_ids_obj is None:
return None
track_ids: ndarray = np.asarray(track_ids_obj)
if track_ids.size == 0:
return None
# Get bounding boxes
xyxy_obj: ndarray | object = getattr(boxes_obj, "xyxy", None)
if xyxy_obj is None:
return None
bboxes: ndarray = np.asarray(xyxy_obj)
if bboxes.ndim == 1:
bboxes = bboxes.reshape(1, -1)
if bboxes.shape[0] == 0:
return None
# Get masks
masks_obj: _Masks | object = getattr(results, "masks", None)
if masks_obj is None:
return None
masks_data: ndarray | object = getattr(masks_obj, "data", None)
if masks_data is None:
return None
masks: ndarray = np.asarray(masks_data)
if masks.ndim == 2:
masks = masks[np.newaxis, ...]
if masks.shape[0] != bboxes.shape[0]:
return None
# Find largest bbox by area
best_idx: int = -1
best_area: float = -1.0
for i in range(int(bboxes.shape[0])):
row: "NDArray[np.float32]" = bboxes[i][:4]
x1f: float = float(row[0])
y1f: float = float(row[1])
x2f: float = float(row[2])
y2f: float = float(row[3])
area: float = (x2f - x1f) * (y2f - y1f)
if area > best_area:
best_area = area
best_idx = i
if best_idx < 0:
return None
# Extract mask and bbox
mask: "NDArray[np.float32]" = masks[best_idx]
bbox = (
int(float(bboxes[best_idx][0])),
int(float(bboxes[best_idx][1])),
int(float(bboxes[best_idx][2])),
int(float(bboxes[best_idx][3])),
)
track_id = int(track_ids[best_idx]) if best_idx < len(track_ids) else best_idx
return mask, bbox, track_id