Files
OpenGait/opengait/demo/window.py
T
crosstyan b24644f16e feat(demo): implement ScoNet real-time pipeline runtime
Add the full demo runtime stack for single-person scoliosis inference, including input adapters, silhouette preprocessing, temporal windowing, ScoNet wrapper, result publishing, and click-based CLI orchestration. This commit captures the executable pipeline behavior independently from tests and planning artifacts for clearer review and rollback.
2026-02-27 09:59:04 +08:00

296 lines
8.6 KiB
Python

"""Sliding window / ring buffer manager for real-time gait analysis.
This module provides bounded buffer management for silhouette sequences
with track ID tracking and gap detection.
"""
from collections import deque
from typing import TYPE_CHECKING, Protocol, final
import numpy as np
import torch
from jaxtyping import Float
from numpy import ndarray
if TYPE_CHECKING:
from numpy.typing import NDArray
# Silhouette dimensions from preprocess.py
SIL_HEIGHT: int = 64
SIL_WIDTH: int = 44
class _Boxes(Protocol):
"""Protocol for boxes with xyxy and id attributes."""
@property
def xyxy(self) -> "NDArray[np.float32] | object": ...
@property
def id(self) -> "NDArray[np.int64] | object | None": ...
class _Masks(Protocol):
"""Protocol for masks with data attribute."""
@property
def data(self) -> "NDArray[np.float32] | object": ...
class _DetectionResults(Protocol):
"""Protocol for detection results from Ultralytics-style objects."""
@property
def boxes(self) -> _Boxes: ...
@property
def masks(self) -> _Masks: ...
@final
class SilhouetteWindow:
"""Bounded sliding window for silhouette sequences.
Manages a fixed-size buffer of silhouettes with track ID tracking
and automatic reset on track changes or frame gaps.
Attributes:
window_size: Maximum number of frames in the buffer.
stride: Classification stride (frames between classifications).
gap_threshold: Maximum allowed frame gap before reset.
"""
window_size: int
stride: int
gap_threshold: int
_buffer: deque[Float[ndarray, "64 44"]]
_frame_indices: deque[int]
_track_id: int | None
_last_classified_frame: int
_frame_count: int
def __init__(
self,
window_size: int = 30,
stride: int = 1,
gap_threshold: int = 15,
) -> None:
"""Initialize the silhouette window.
Args:
window_size: Maximum buffer size (default 30).
stride: Frames between classifications (default 1).
gap_threshold: Max frame gap before reset (default 15).
"""
self.window_size = window_size
self.stride = stride
self.gap_threshold = gap_threshold
# Bounded storage via deque
self._buffer = deque(maxlen=window_size)
self._frame_indices = deque(maxlen=window_size)
self._track_id = None
self._last_classified_frame = -1
self._frame_count = 0
def push(self, sil: np.ndarray, frame_idx: int, track_id: int) -> None:
"""Push a new silhouette into the window.
Automatically resets buffer on track ID change or frame gap
exceeding gap_threshold.
Args:
sil: Silhouette array of shape (64, 44), float32.
frame_idx: Current frame index for gap detection.
track_id: Track ID for the person.
"""
# Check for track ID change
if self._track_id is not None and track_id != self._track_id:
self.reset()
# Check for frame gap
if self._frame_indices:
last_frame = self._frame_indices[-1]
gap = frame_idx - last_frame
if gap > self.gap_threshold:
self.reset()
# Update track ID
self._track_id = track_id
# Validate and append silhouette
sil_array = np.asarray(sil, dtype=np.float32)
if sil_array.shape != (SIL_HEIGHT, SIL_WIDTH):
raise ValueError(
f"Expected silhouette shape ({SIL_HEIGHT}, {SIL_WIDTH}), got {sil_array.shape}"
)
self._buffer.append(sil_array)
self._frame_indices.append(frame_idx)
self._frame_count += 1
def is_ready(self) -> bool:
"""Check if window has enough frames for classification.
Returns:
True if buffer is full (window_size frames).
"""
return len(self._buffer) >= self.window_size
def should_classify(self) -> bool:
"""Check if classification should run based on stride.
Returns:
True if enough frames have passed since last classification.
"""
if not self.is_ready():
return False
if self._last_classified_frame < 0:
return True
current_frame = self._frame_indices[-1]
frames_since = current_frame - self._last_classified_frame
return frames_since >= self.stride
def get_tensor(self, device: str = "cpu") -> torch.Tensor:
"""Get window contents as a tensor for model input.
Args:
device: Target device for the tensor (default 'cpu').
Returns:
Tensor of shape [1, 1, window_size, 64, 44] with dtype float32.
Raises:
ValueError: If buffer is not full.
"""
if not self.is_ready():
raise ValueError(
f"Window not ready: {len(self._buffer)}/{self.window_size} frames"
)
# Stack buffer into array [window_size, 64, 44]
stacked = np.stack(list(self._buffer), axis=0)
# Add batch and channel dims: [1, 1, window_size, 64, 44]
tensor = torch.from_numpy(stacked.astype(np.float32))
tensor = tensor.unsqueeze(0).unsqueeze(0)
return tensor.to(device)
def reset(self) -> None:
"""Reset the window, clearing all buffers and counters."""
self._buffer.clear()
self._frame_indices.clear()
self._track_id = None
self._last_classified_frame = -1
self._frame_count = 0
def mark_classified(self) -> None:
"""Mark current frame as classified, updating stride tracking."""
if self._frame_indices:
self._last_classified_frame = self._frame_indices[-1]
@property
def current_track_id(self) -> int | None:
"""Current track ID, or None if buffer is empty."""
return self._track_id
@property
def frame_count(self) -> int:
"""Total frames pushed since last reset."""
return self._frame_count
@property
def fill_level(self) -> float:
"""Fill ratio of the buffer (0.0 to 1.0)."""
return len(self._buffer) / self.window_size
def select_person(
results: _DetectionResults,
) -> tuple[ndarray, tuple[int, int, int, int], int] | None:
"""Select the person with largest bounding box from detection results.
Args:
results: Detection results object with boxes and masks attributes.
Expected to have:
- boxes.xyxy: array of bounding boxes [N, 4]
- masks.data: array of masks [N, H, W]
- boxes.id: optional track IDs [N]
Returns:
Tuple of (mask, bbox, track_id) for the largest person,
or None if no valid detections or track IDs unavailable.
"""
# Check for track IDs
boxes_obj: _Boxes | object = getattr(results, "boxes", None)
if boxes_obj is None:
return None
track_ids_obj: ndarray | object | None = getattr(boxes_obj, "id", None)
if track_ids_obj is None:
return None
track_ids: ndarray = np.asarray(track_ids_obj)
if track_ids.size == 0:
return None
# Get bounding boxes
xyxy_obj: ndarray | object = getattr(boxes_obj, "xyxy", None)
if xyxy_obj is None:
return None
bboxes: ndarray = np.asarray(xyxy_obj)
if bboxes.ndim == 1:
bboxes = bboxes.reshape(1, -1)
if bboxes.shape[0] == 0:
return None
# Get masks
masks_obj: _Masks | object = getattr(results, "masks", None)
if masks_obj is None:
return None
masks_data: ndarray | object = getattr(masks_obj, "data", None)
if masks_data is None:
return None
masks: ndarray = np.asarray(masks_data)
if masks.ndim == 2:
masks = masks[np.newaxis, ...]
if masks.shape[0] != bboxes.shape[0]:
return None
# Find largest bbox by area
best_idx: int = -1
best_area: float = -1.0
for i in range(int(bboxes.shape[0])):
row: "NDArray[np.float32]" = bboxes[i][:4]
x1f: float = float(row[0])
y1f: float = float(row[1])
x2f: float = float(row[2])
y2f: float = float(row[3])
area: float = (x2f - x1f) * (y2f - y1f)
if area > best_area:
best_area = area
best_idx = i
if best_idx < 0:
return None
# Extract mask and bbox
mask: "NDArray[np.float32]" = masks[best_idx]
bbox = (
int(float(bboxes[best_idx][0])),
int(float(bboxes[best_idx][1])),
int(float(bboxes[best_idx][2])),
int(float(bboxes[best_idx][3])),
)
track_id = int(track_ids[best_idx]) if best_idx < len(track_ids) else best_idx
return mask, bbox, track_id