"""Sliding window / ring buffer manager for real-time gait analysis. This module provides bounded buffer management for silhouette sequences with track ID tracking and gap detection. """ from collections import deque from typing import TYPE_CHECKING, Protocol, cast, final import numpy as np import torch from jaxtyping import Float from numpy import ndarray if TYPE_CHECKING: from numpy.typing import NDArray # Silhouette dimensions from preprocess.py SIL_HEIGHT: int = 64 SIL_WIDTH: int = 44 # Type alias for array-like inputs type _ArrayLike = torch.Tensor | ndarray class _Boxes(Protocol): """Protocol for boxes with xyxy and id attributes.""" @property def xyxy(self) -> "NDArray[np.float32] | object": ... @property def id(self) -> "NDArray[np.int64] | object | None": ... class _Masks(Protocol): """Protocol for masks with data attribute.""" @property def data(self) -> "NDArray[np.float32] | object": ... class _DetectionResults(Protocol): """Protocol for detection results from Ultralytics-style objects.""" @property def boxes(self) -> _Boxes: ... @property def masks(self) -> _Masks: ... @final class SilhouetteWindow: """Bounded sliding window for silhouette sequences. Manages a fixed-size buffer of silhouettes with track ID tracking and automatic reset on track changes or frame gaps. Attributes: window_size: Maximum number of frames in the buffer. stride: Classification stride (frames between classifications). gap_threshold: Maximum allowed frame gap before reset. """ window_size: int stride: int gap_threshold: int _buffer: deque[Float[ndarray, "64 44"]] _frame_indices: deque[int] _track_id: int | None _last_classified_frame: int _frame_count: int def __init__( self, window_size: int = 30, stride: int = 1, gap_threshold: int = 15, ) -> None: """Initialize the silhouette window. Args: window_size: Maximum buffer size (default 30). stride: Frames between classifications (default 1). gap_threshold: Max frame gap before reset (default 15). """ self.window_size = window_size self.stride = stride self.gap_threshold = gap_threshold # Bounded storage via deque self._buffer = deque(maxlen=window_size) self._frame_indices = deque(maxlen=window_size) self._track_id = None self._last_classified_frame = -1 self._frame_count = 0 def push(self, sil: np.ndarray, frame_idx: int, track_id: int) -> None: """Push a new silhouette into the window. Automatically resets buffer on track ID change or frame gap exceeding gap_threshold. Args: sil: Silhouette array of shape (64, 44), float32. frame_idx: Current frame index for gap detection. track_id: Track ID for the person. """ # Check for track ID change if self._track_id is not None and track_id != self._track_id: self.reset() # Check for frame gap if self._frame_indices: last_frame = self._frame_indices[-1] gap = frame_idx - last_frame if gap > self.gap_threshold: self.reset() # Update track ID self._track_id = track_id # Validate and append silhouette sil_array = np.asarray(sil, dtype=np.float32) if sil_array.shape != (SIL_HEIGHT, SIL_WIDTH): raise ValueError( f"Expected silhouette shape ({SIL_HEIGHT}, {SIL_WIDTH}), got {sil_array.shape}" ) self._buffer.append(sil_array) self._frame_indices.append(frame_idx) self._frame_count += 1 def is_ready(self) -> bool: """Check if window has enough frames for classification. Returns: True if buffer is full (window_size frames). """ return len(self._buffer) >= self.window_size def should_classify(self) -> bool: """Check if classification should run based on stride. Returns: True if enough frames have passed since last classification. """ if not self.is_ready(): return False if self._last_classified_frame < 0: return True current_frame = self._frame_indices[-1] frames_since = current_frame - self._last_classified_frame return frames_since >= self.stride def get_tensor(self, device: str = "cpu") -> torch.Tensor: """Get window contents as a tensor for model input. Args: device: Target device for the tensor (default 'cpu'). Returns: Tensor of shape [1, 1, window_size, 64, 44] with dtype float32. Raises: ValueError: If buffer is not full. """ if not self.is_ready(): raise ValueError( f"Window not ready: {len(self._buffer)}/{self.window_size} frames" ) # Stack buffer into array [window_size, 64, 44] stacked = np.stack(list(self._buffer), axis=0) # Add batch and channel dims: [1, 1, window_size, 64, 44] tensor = torch.from_numpy(stacked.astype(np.float32)) tensor = tensor.unsqueeze(0).unsqueeze(0) return tensor.to(device) def reset(self) -> None: """Reset the window, clearing all buffers and counters.""" self._buffer.clear() self._frame_indices.clear() self._track_id = None self._last_classified_frame = -1 self._frame_count = 0 def mark_classified(self) -> None: """Mark current frame as classified, updating stride tracking.""" if self._frame_indices: self._last_classified_frame = self._frame_indices[-1] @property def current_track_id(self) -> int | None: """Current track ID, or None if buffer is empty.""" return self._track_id @property def frame_count(self) -> int: """Total frames pushed since last reset.""" return self._frame_count @property def fill_level(self) -> float: """Fill ratio of the buffer (0.0 to 1.0).""" return len(self._buffer) / self.window_size def _to_numpy(obj: _ArrayLike) -> ndarray: """Safely convert array-like object to numpy array. Handles torch tensors (CPU or CUDA) by detaching and moving to CPU first. Falls back to np.asarray for other array-like objects. Args: obj: Array-like object (numpy array, torch tensor, or similar). Returns: Numpy array representation of the input. """ # Handle torch tensors (including CUDA tensors) detach_fn = getattr(obj, "detach", None) if detach_fn is not None and callable(detach_fn): # It's a torch tensor tensor = detach_fn() cpu_fn = getattr(tensor, "cpu", None) if cpu_fn is not None and callable(cpu_fn): tensor = cpu_fn() numpy_fn = getattr(tensor, "numpy", None) if numpy_fn is not None and callable(numpy_fn): return cast(ndarray, numpy_fn()) # Fall back to np.asarray for other array-like objects return cast(ndarray, np.asarray(obj)) def select_person( results: _DetectionResults, ) -> tuple[ndarray, tuple[int, int, int, int], int] | None: """Select the person with largest bounding box from detection results. Args: results: Detection results object with boxes and masks attributes. Expected to have: - boxes.xyxy: array of bounding boxes [N, 4] - masks.data: array of masks [N, H, W] - boxes.id: optional track IDs [N] Returns: Tuple of (mask, bbox, track_id) for the largest person, or None if no valid detections or track IDs unavailable. """ # Check for track IDs boxes_obj: _Boxes | object = getattr(results, "boxes", None) if boxes_obj is None: return None track_ids_obj: ndarray | object | None = getattr(boxes_obj, "id", None) if track_ids_obj is None: return None track_ids: ndarray = _to_numpy(cast(ndarray, track_ids_obj)) if track_ids.size == 0: return None # Get bounding boxes xyxy_obj: ndarray | object = getattr(boxes_obj, "xyxy", None) if xyxy_obj is None: return None bboxes: ndarray = _to_numpy(cast(ndarray, xyxy_obj)) if bboxes.ndim == 1: bboxes = bboxes.reshape(1, -1) if bboxes.shape[0] == 0: return None # Get masks masks_obj: _Masks | object = getattr(results, "masks", None) if masks_obj is None: return None masks_data: ndarray | object = getattr(masks_obj, "data", None) if masks_data is None: return None masks: ndarray = _to_numpy(cast(ndarray, masks_data)) if masks.ndim == 2: masks = masks[np.newaxis, ...] if masks.shape[0] != bboxes.shape[0]: return None # Find largest bbox by area best_idx: int = -1 best_area: float = -1.0 for i in range(int(bboxes.shape[0])): row: "NDArray[np.float32]" = bboxes[i][:4] x1f: float = float(row[0]) y1f: float = float(row[1]) x2f: float = float(row[2]) y2f: float = float(row[3]) area: float = (x2f - x1f) * (y2f - y1f) if area > best_area: best_area = area best_idx = i if best_idx < 0: return None # Extract mask and bbox mask: "NDArray[np.float32]" = masks[best_idx] mask_shape = mask.shape mask_h, mask_w = int(mask_shape[0]), int(mask_shape[1]) # Get original image dimensions from results (YOLO provides this) orig_shape = getattr(results, "orig_shape", None) # Validate orig_shape is a sequence of at least 2 numeric values if ( orig_shape is not None and isinstance(orig_shape, (tuple, list)) and len(orig_shape) >= 2 ): frame_h, frame_w = int(orig_shape[0]), int(orig_shape[1]) # Scale bbox from frame space to mask space scale_x = mask_w / frame_w if frame_w > 0 else 1.0 scale_y = mask_h / frame_h if frame_h > 0 else 1.0 bbox = ( int(float(bboxes[best_idx][0]) * scale_x), int(float(bboxes[best_idx][1]) * scale_y), int(float(bboxes[best_idx][2]) * scale_x), int(float(bboxes[best_idx][3]) * scale_y), ) else: # Fallback: use bbox as-is (assume same coordinate space) bbox = ( int(float(bboxes[best_idx][0])), int(float(bboxes[best_idx][1])), int(float(bboxes[best_idx][2])), int(float(bboxes[best_idx][3])), ) track_id = int(track_ids[best_idx]) if best_idx < len(track_ids) else best_idx return mask, bbox, track_id