Files
OpenGait/opengait-studio/opengait_studio/window.py
T
crosstyan d4e2a59ad2 fix(demo): pace gait windows before buffering
Make the OpenGait-studio demo drop unpaced frames before they grow the silhouette window. Separate source-frame gap tracking from paced-frame stride tracking so runtime scheduling matches the documented demo-window-and-stride behavior.

Add regressions for paced window growth and schedule-frame stride semantics.
2026-03-14 11:31:44 +08:00

389 lines
12 KiB
Python

"""Sliding window / ring buffer manager for real-time gait analysis.
This module provides bounded buffer management for silhouette sequences
with track ID tracking and gap detection.
"""
from collections import deque
from typing import TYPE_CHECKING, Protocol, cast, final
import numpy as np
import torch
from jaxtyping import Float
from numpy import ndarray
from .preprocess import BBoxXYXY
if TYPE_CHECKING:
from numpy.typing import NDArray
# Silhouette dimensions from preprocess.py
SIL_HEIGHT: int = 64
SIL_WIDTH: int = 44
# Type alias for array-like inputs
type _ArrayLike = torch.Tensor | ndarray
class _Boxes(Protocol):
"""Protocol for boxes with xyxy and id attributes."""
@property
def xyxy(self) -> "NDArray[np.float32] | object": ...
@property
def id(self) -> "NDArray[np.int64] | object | None": ...
class _Masks(Protocol):
"""Protocol for masks with data attribute."""
@property
def data(self) -> "NDArray[np.float32] | object": ...
class _DetectionResults(Protocol):
"""Protocol for detection results from Ultralytics-style objects."""
@property
def boxes(self) -> _Boxes: ...
@property
def masks(self) -> _Masks: ...
@final
class SilhouetteWindow:
"""Bounded sliding window for silhouette sequences.
Manages a fixed-size buffer of silhouettes with track ID tracking
and automatic reset on track changes or frame gaps.
Attributes:
window_size: Maximum number of frames in the buffer.
stride: Classification stride (frames between classifications).
gap_threshold: Maximum allowed frame gap before reset.
"""
window_size: int
stride: int
gap_threshold: int
_buffer: deque[Float[ndarray, "64 44"]]
_source_frame_indices: deque[int]
_schedule_frame_indices: deque[int]
_track_id: int | None
_last_classified_frame: int
_frame_count: int
def __init__(
self,
window_size: int = 30,
stride: int = 1,
gap_threshold: int = 15,
) -> None:
"""Initialize the silhouette window.
Args:
window_size: Maximum buffer size (default 30).
stride: Frames between classifications (default 1).
gap_threshold: Max frame gap before reset (default 15).
"""
self.window_size = window_size
self.stride = stride
self.gap_threshold = gap_threshold
# Bounded storage via deque
self._buffer = deque(maxlen=window_size)
self._source_frame_indices = deque(maxlen=window_size)
self._schedule_frame_indices = deque(maxlen=window_size)
self._track_id = None
self._last_classified_frame = -1
self._frame_count = 0
def push(
self,
sil: np.ndarray,
frame_idx: int,
track_id: int,
*,
schedule_frame_idx: int | None = None,
) -> None:
"""Push a new silhouette into the window.
Automatically resets buffer on track ID change or frame gap
exceeding gap_threshold.
Args:
sil: Silhouette array of shape (64, 44), float32.
frame_idx: Current frame index for gap detection.
track_id: Track ID for the person.
"""
# Check for track ID change
if self._track_id is not None and track_id != self._track_id:
self.reset()
# Check for frame gap
if self._source_frame_indices:
last_frame = self._source_frame_indices[-1]
gap = frame_idx - last_frame
if gap > self.gap_threshold:
self.reset()
# Update track ID
self._track_id = track_id
# Validate and append silhouette
sil_array = np.asarray(sil, dtype=np.float32)
if sil_array.shape != (SIL_HEIGHT, SIL_WIDTH):
raise ValueError(
f"Expected silhouette shape ({SIL_HEIGHT}, {SIL_WIDTH}), got {sil_array.shape}"
)
self._buffer.append(sil_array)
self._source_frame_indices.append(frame_idx)
self._schedule_frame_indices.append(
frame_idx if schedule_frame_idx is None else schedule_frame_idx
)
self._frame_count += 1
def is_ready(self) -> bool:
"""Check if window has enough frames for classification.
Returns:
True if buffer is full (window_size frames).
"""
return len(self._buffer) >= self.window_size
def should_classify(self) -> bool:
"""Check if classification should run based on stride.
Returns:
True if enough frames have passed since last classification.
"""
if not self.is_ready():
return False
if self._last_classified_frame < 0:
return True
current_frame = self._schedule_frame_indices[-1]
frames_since = current_frame - self._last_classified_frame
return frames_since >= self.stride
def get_tensor(self, device: str = "cpu") -> torch.Tensor:
"""Get window contents as a tensor for model input.
Args:
device: Target device for the tensor (default 'cpu').
Returns:
Tensor of shape [1, 1, window_size, 64, 44] with dtype float32.
Raises:
ValueError: If buffer is not full.
"""
if not self.is_ready():
raise ValueError(
f"Window not ready: {len(self._buffer)}/{self.window_size} frames"
)
# Stack buffer into array [window_size, 64, 44]
stacked = np.stack(list(self._buffer), axis=0)
# Add batch and channel dims: [1, 1, window_size, 64, 44]
tensor = torch.from_numpy(stacked.astype(np.float32))
tensor = tensor.unsqueeze(0).unsqueeze(0)
return tensor.to(device)
def reset(self) -> None:
"""Reset the window, clearing all buffers and counters."""
self._buffer.clear()
self._source_frame_indices.clear()
self._schedule_frame_indices.clear()
self._track_id = None
self._last_classified_frame = -1
self._frame_count = 0
def mark_classified(self) -> None:
"""Mark current frame as classified, updating stride tracking."""
if self._schedule_frame_indices:
self._last_classified_frame = self._schedule_frame_indices[-1]
@property
def current_track_id(self) -> int | None:
"""Current track ID, or None if buffer is empty."""
return self._track_id
@property
def frame_count(self) -> int:
"""Total frames pushed since last reset."""
return self._frame_count
@property
def fill_level(self) -> float:
"""Fill ratio of the buffer (0.0 to 1.0)."""
return len(self._buffer) / self.window_size
@property
def window_start_frame(self) -> int:
if not self._source_frame_indices:
raise ValueError("Window is empty")
return int(self._source_frame_indices[0])
@property
def buffered_silhouettes(self) -> Float[ndarray, "n 64 44"]:
if not self._buffer:
return np.empty((0, SIL_HEIGHT, SIL_WIDTH), dtype=np.float32)
return cast(
Float[ndarray, "n 64 44"],
np.stack(list(self._buffer), axis=0).astype(np.float32, copy=True),
)
def _to_numpy(obj: _ArrayLike) -> ndarray:
"""Safely convert array-like object to numpy array.
Handles torch tensors (CPU or CUDA) by detaching and moving to CPU first.
Falls back to np.asarray for other array-like objects.
Args:
obj: Array-like object (numpy array, torch tensor, or similar).
Returns:
Numpy array representation of the input.
"""
# Handle torch tensors (including CUDA tensors)
detach_fn = getattr(obj, "detach", None)
if detach_fn is not None and callable(detach_fn):
# It's a torch tensor
tensor = detach_fn()
cpu_fn = getattr(tensor, "cpu", None)
if cpu_fn is not None and callable(cpu_fn):
tensor = cpu_fn()
numpy_fn = getattr(tensor, "numpy", None)
if numpy_fn is not None and callable(numpy_fn):
return cast(ndarray, numpy_fn())
# Fall back to np.asarray for other array-like objects
return cast(ndarray, np.asarray(obj))
def select_person(
results: _DetectionResults,
) -> tuple[ndarray, BBoxXYXY, BBoxXYXY, int] | None:
"""Select the person with largest bounding box from detection results.
Args:
results: Detection results object with boxes and masks attributes.
Expected to have:
- boxes.xyxy: array of bounding boxes [N, 4] in frame coordinates (XYXY format)
- masks.data: array of masks [N, H, W] in mask coordinates
- boxes.id: optional track IDs [N]
Returns:
Tuple of (mask, bbox_mask, bbox_frame, track_id) for the largest person,
or None if no valid detections or track IDs unavailable.
- mask: the person's segmentation mask
- bbox_mask: bounding box in mask coordinate space (XYXY format: x1, y1, x2, y2)
- bbox_frame: bounding box in frame coordinate space (XYXY format: x1, y1, x2, y2)
- track_id: the person's track ID
"""
# Check for track IDs
boxes_obj: _Boxes | object = getattr(results, "boxes", None)
if boxes_obj is None:
return None
track_ids_obj: ndarray | object | None = getattr(boxes_obj, "id", None)
if track_ids_obj is None:
return None
track_ids: ndarray = _to_numpy(cast(ndarray, track_ids_obj))
if track_ids.size == 0:
return None
# Get bounding boxes
xyxy_obj: ndarray | object = getattr(boxes_obj, "xyxy", None)
if xyxy_obj is None:
return None
bboxes: ndarray = _to_numpy(cast(ndarray, xyxy_obj))
if bboxes.ndim == 1:
bboxes = bboxes.reshape(1, -1)
if bboxes.shape[0] == 0:
return None
# Get masks
masks_obj: _Masks | object = getattr(results, "masks", None)
if masks_obj is None:
return None
masks_data: ndarray | object = getattr(masks_obj, "data", None)
if masks_data is None:
return None
masks: ndarray = _to_numpy(cast(ndarray, masks_data))
if masks.ndim == 2:
masks = masks[np.newaxis, ...]
if masks.shape[0] != bboxes.shape[0]:
return None
# Find largest bbox by area
best_idx: int = -1
best_area: float = -1.0
for i in range(int(bboxes.shape[0])):
row: "NDArray[np.float32]" = bboxes[i][:4]
x1f: float = float(row[0])
y1f: float = float(row[1])
x2f: float = float(row[2])
y2f: float = float(row[3])
area: float = (x2f - x1f) * (y2f - y1f)
if area > best_area:
best_area = area
best_idx = i
if best_idx < 0:
return None
# Extract mask and bbox
mask: "NDArray[np.float32]" = masks[best_idx]
mask_shape = mask.shape
mask_h, mask_w = int(mask_shape[0]), int(mask_shape[1])
# Get original image dimensions from results (YOLO provides this)
orig_shape = getattr(results, "orig_shape", None)
# Validate orig_shape is a sequence of at least 2 numeric values
if (
orig_shape is not None
and isinstance(orig_shape, (tuple, list))
and len(orig_shape) >= 2
):
frame_h, frame_w = int(orig_shape[0]), int(orig_shape[1])
# Scale bbox from frame space to mask space
scale_x = mask_w / frame_w if frame_w > 0 else 1.0
scale_y = mask_h / frame_h if frame_h > 0 else 1.0
bbox_mask = (
int(float(bboxes[best_idx][0]) * scale_x),
int(float(bboxes[best_idx][1]) * scale_y),
int(float(bboxes[best_idx][2]) * scale_x),
int(float(bboxes[best_idx][3]) * scale_y),
)
bbox_frame = (
int(float(bboxes[best_idx][0])),
int(float(bboxes[best_idx][1])),
int(float(bboxes[best_idx][2])),
int(float(bboxes[best_idx][3])),
)
else:
# Fallback: use bbox as-is for both (assume same coordinate space)
bbox_mask = (
int(float(bboxes[best_idx][0])),
int(float(bboxes[best_idx][1])),
int(float(bboxes[best_idx][2])),
int(float(bboxes[best_idx][3])),
)
bbox_frame = bbox_mask
track_id = int(track_ids[best_idx]) if best_idx < len(track_ids) else best_idx
return mask, bbox_mask, bbox_frame, track_id