b24644f16e
Add the full demo runtime stack for single-person scoliosis inference, including input adapters, silhouette preprocessing, temporal windowing, ScoNet wrapper, result publishing, and click-based CLI orchestration. This commit captures the executable pipeline behavior independently from tests and planning artifacts for clearer review and rollback.
296 lines
8.6 KiB
Python
296 lines
8.6 KiB
Python
"""Sliding window / ring buffer manager for real-time gait analysis.
|
|
|
|
This module provides bounded buffer management for silhouette sequences
|
|
with track ID tracking and gap detection.
|
|
"""
|
|
|
|
from collections import deque
|
|
from typing import TYPE_CHECKING, Protocol, final
|
|
|
|
import numpy as np
|
|
import torch
|
|
from jaxtyping import Float
|
|
from numpy import ndarray
|
|
|
|
if TYPE_CHECKING:
|
|
from numpy.typing import NDArray
|
|
|
|
|
|
# Silhouette dimensions from preprocess.py
|
|
SIL_HEIGHT: int = 64
|
|
SIL_WIDTH: int = 44
|
|
|
|
|
|
class _Boxes(Protocol):
|
|
"""Protocol for boxes with xyxy and id attributes."""
|
|
|
|
@property
|
|
def xyxy(self) -> "NDArray[np.float32] | object": ...
|
|
@property
|
|
def id(self) -> "NDArray[np.int64] | object | None": ...
|
|
|
|
|
|
class _Masks(Protocol):
|
|
"""Protocol for masks with data attribute."""
|
|
|
|
@property
|
|
def data(self) -> "NDArray[np.float32] | object": ...
|
|
|
|
|
|
class _DetectionResults(Protocol):
|
|
"""Protocol for detection results from Ultralytics-style objects."""
|
|
|
|
@property
|
|
def boxes(self) -> _Boxes: ...
|
|
@property
|
|
def masks(self) -> _Masks: ...
|
|
|
|
|
|
@final
|
|
class SilhouetteWindow:
|
|
"""Bounded sliding window for silhouette sequences.
|
|
|
|
Manages a fixed-size buffer of silhouettes with track ID tracking
|
|
and automatic reset on track changes or frame gaps.
|
|
|
|
Attributes:
|
|
window_size: Maximum number of frames in the buffer.
|
|
stride: Classification stride (frames between classifications).
|
|
gap_threshold: Maximum allowed frame gap before reset.
|
|
"""
|
|
|
|
window_size: int
|
|
stride: int
|
|
gap_threshold: int
|
|
_buffer: deque[Float[ndarray, "64 44"]]
|
|
_frame_indices: deque[int]
|
|
_track_id: int | None
|
|
_last_classified_frame: int
|
|
_frame_count: int
|
|
|
|
def __init__(
|
|
self,
|
|
window_size: int = 30,
|
|
stride: int = 1,
|
|
gap_threshold: int = 15,
|
|
) -> None:
|
|
"""Initialize the silhouette window.
|
|
|
|
Args:
|
|
window_size: Maximum buffer size (default 30).
|
|
stride: Frames between classifications (default 1).
|
|
gap_threshold: Max frame gap before reset (default 15).
|
|
"""
|
|
self.window_size = window_size
|
|
self.stride = stride
|
|
self.gap_threshold = gap_threshold
|
|
|
|
# Bounded storage via deque
|
|
self._buffer = deque(maxlen=window_size)
|
|
self._frame_indices = deque(maxlen=window_size)
|
|
self._track_id = None
|
|
self._last_classified_frame = -1
|
|
self._frame_count = 0
|
|
|
|
def push(self, sil: np.ndarray, frame_idx: int, track_id: int) -> None:
|
|
"""Push a new silhouette into the window.
|
|
|
|
Automatically resets buffer on track ID change or frame gap
|
|
exceeding gap_threshold.
|
|
|
|
Args:
|
|
sil: Silhouette array of shape (64, 44), float32.
|
|
frame_idx: Current frame index for gap detection.
|
|
track_id: Track ID for the person.
|
|
"""
|
|
# Check for track ID change
|
|
if self._track_id is not None and track_id != self._track_id:
|
|
self.reset()
|
|
|
|
# Check for frame gap
|
|
if self._frame_indices:
|
|
last_frame = self._frame_indices[-1]
|
|
gap = frame_idx - last_frame
|
|
if gap > self.gap_threshold:
|
|
self.reset()
|
|
|
|
# Update track ID
|
|
self._track_id = track_id
|
|
|
|
# Validate and append silhouette
|
|
sil_array = np.asarray(sil, dtype=np.float32)
|
|
if sil_array.shape != (SIL_HEIGHT, SIL_WIDTH):
|
|
raise ValueError(
|
|
f"Expected silhouette shape ({SIL_HEIGHT}, {SIL_WIDTH}), got {sil_array.shape}"
|
|
)
|
|
|
|
self._buffer.append(sil_array)
|
|
self._frame_indices.append(frame_idx)
|
|
self._frame_count += 1
|
|
|
|
def is_ready(self) -> bool:
|
|
"""Check if window has enough frames for classification.
|
|
|
|
Returns:
|
|
True if buffer is full (window_size frames).
|
|
"""
|
|
return len(self._buffer) >= self.window_size
|
|
|
|
def should_classify(self) -> bool:
|
|
"""Check if classification should run based on stride.
|
|
|
|
Returns:
|
|
True if enough frames have passed since last classification.
|
|
"""
|
|
if not self.is_ready():
|
|
return False
|
|
|
|
if self._last_classified_frame < 0:
|
|
return True
|
|
|
|
current_frame = self._frame_indices[-1]
|
|
frames_since = current_frame - self._last_classified_frame
|
|
return frames_since >= self.stride
|
|
|
|
def get_tensor(self, device: str = "cpu") -> torch.Tensor:
|
|
"""Get window contents as a tensor for model input.
|
|
|
|
Args:
|
|
device: Target device for the tensor (default 'cpu').
|
|
|
|
Returns:
|
|
Tensor of shape [1, 1, window_size, 64, 44] with dtype float32.
|
|
|
|
Raises:
|
|
ValueError: If buffer is not full.
|
|
"""
|
|
if not self.is_ready():
|
|
raise ValueError(
|
|
f"Window not ready: {len(self._buffer)}/{self.window_size} frames"
|
|
)
|
|
|
|
# Stack buffer into array [window_size, 64, 44]
|
|
stacked = np.stack(list(self._buffer), axis=0)
|
|
|
|
# Add batch and channel dims: [1, 1, window_size, 64, 44]
|
|
tensor = torch.from_numpy(stacked.astype(np.float32))
|
|
tensor = tensor.unsqueeze(0).unsqueeze(0)
|
|
|
|
return tensor.to(device)
|
|
|
|
def reset(self) -> None:
|
|
"""Reset the window, clearing all buffers and counters."""
|
|
self._buffer.clear()
|
|
self._frame_indices.clear()
|
|
self._track_id = None
|
|
self._last_classified_frame = -1
|
|
self._frame_count = 0
|
|
|
|
def mark_classified(self) -> None:
|
|
"""Mark current frame as classified, updating stride tracking."""
|
|
if self._frame_indices:
|
|
self._last_classified_frame = self._frame_indices[-1]
|
|
|
|
@property
|
|
def current_track_id(self) -> int | None:
|
|
"""Current track ID, or None if buffer is empty."""
|
|
return self._track_id
|
|
|
|
@property
|
|
def frame_count(self) -> int:
|
|
"""Total frames pushed since last reset."""
|
|
return self._frame_count
|
|
|
|
@property
|
|
def fill_level(self) -> float:
|
|
"""Fill ratio of the buffer (0.0 to 1.0)."""
|
|
return len(self._buffer) / self.window_size
|
|
|
|
|
|
def select_person(
|
|
results: _DetectionResults,
|
|
) -> tuple[ndarray, tuple[int, int, int, int], int] | None:
|
|
"""Select the person with largest bounding box from detection results.
|
|
|
|
Args:
|
|
results: Detection results object with boxes and masks attributes.
|
|
Expected to have:
|
|
- boxes.xyxy: array of bounding boxes [N, 4]
|
|
- masks.data: array of masks [N, H, W]
|
|
- boxes.id: optional track IDs [N]
|
|
|
|
Returns:
|
|
Tuple of (mask, bbox, track_id) for the largest person,
|
|
or None if no valid detections or track IDs unavailable.
|
|
"""
|
|
# Check for track IDs
|
|
boxes_obj: _Boxes | object = getattr(results, "boxes", None)
|
|
if boxes_obj is None:
|
|
return None
|
|
|
|
track_ids_obj: ndarray | object | None = getattr(boxes_obj, "id", None)
|
|
if track_ids_obj is None:
|
|
return None
|
|
|
|
track_ids: ndarray = np.asarray(track_ids_obj)
|
|
if track_ids.size == 0:
|
|
return None
|
|
|
|
# Get bounding boxes
|
|
xyxy_obj: ndarray | object = getattr(boxes_obj, "xyxy", None)
|
|
if xyxy_obj is None:
|
|
return None
|
|
|
|
bboxes: ndarray = np.asarray(xyxy_obj)
|
|
if bboxes.ndim == 1:
|
|
bboxes = bboxes.reshape(1, -1)
|
|
|
|
if bboxes.shape[0] == 0:
|
|
return None
|
|
|
|
# Get masks
|
|
masks_obj: _Masks | object = getattr(results, "masks", None)
|
|
if masks_obj is None:
|
|
return None
|
|
|
|
masks_data: ndarray | object = getattr(masks_obj, "data", None)
|
|
if masks_data is None:
|
|
return None
|
|
|
|
masks: ndarray = np.asarray(masks_data)
|
|
if masks.ndim == 2:
|
|
masks = masks[np.newaxis, ...]
|
|
|
|
if masks.shape[0] != bboxes.shape[0]:
|
|
return None
|
|
|
|
# Find largest bbox by area
|
|
best_idx: int = -1
|
|
best_area: float = -1.0
|
|
|
|
for i in range(int(bboxes.shape[0])):
|
|
row: "NDArray[np.float32]" = bboxes[i][:4]
|
|
x1f: float = float(row[0])
|
|
y1f: float = float(row[1])
|
|
x2f: float = float(row[2])
|
|
y2f: float = float(row[3])
|
|
area: float = (x2f - x1f) * (y2f - y1f)
|
|
if area > best_area:
|
|
best_area = area
|
|
best_idx = i
|
|
|
|
if best_idx < 0:
|
|
return None
|
|
|
|
# Extract mask and bbox
|
|
mask: "NDArray[np.float32]" = masks[best_idx]
|
|
bbox = (
|
|
int(float(bboxes[best_idx][0])),
|
|
int(float(bboxes[best_idx][1])),
|
|
int(float(bboxes[best_idx][2])),
|
|
int(float(bboxes[best_idx][3])),
|
|
)
|
|
track_id = int(track_ids[best_idx]) if best_idx < len(track_ids) else best_idx
|
|
|
|
return mask, bbox, track_id
|