376 lines
12 KiB
Python
376 lines
12 KiB
Python
"""Sliding window / ring buffer manager for real-time gait analysis.
|
|
|
|
This module provides bounded buffer management for silhouette sequences
|
|
with track ID tracking and gap detection.
|
|
"""
|
|
|
|
from collections import deque
|
|
from typing import TYPE_CHECKING, Protocol, cast, final
|
|
|
|
import numpy as np
|
|
import torch
|
|
from jaxtyping import Float
|
|
from numpy import ndarray
|
|
|
|
from .preprocess import BBoxXYXY
|
|
|
|
if TYPE_CHECKING:
|
|
from numpy.typing import NDArray
|
|
|
|
# Silhouette dimensions from preprocess.py
|
|
SIL_HEIGHT: int = 64
|
|
SIL_WIDTH: int = 44
|
|
|
|
# Type alias for array-like inputs
|
|
type _ArrayLike = torch.Tensor | ndarray
|
|
|
|
|
|
class _Boxes(Protocol):
|
|
"""Protocol for boxes with xyxy and id attributes."""
|
|
|
|
@property
|
|
def xyxy(self) -> "NDArray[np.float32] | object": ...
|
|
@property
|
|
def id(self) -> "NDArray[np.int64] | object | None": ...
|
|
|
|
|
|
class _Masks(Protocol):
|
|
"""Protocol for masks with data attribute."""
|
|
|
|
@property
|
|
def data(self) -> "NDArray[np.float32] | object": ...
|
|
|
|
|
|
class _DetectionResults(Protocol):
|
|
"""Protocol for detection results from Ultralytics-style objects."""
|
|
|
|
@property
|
|
def boxes(self) -> _Boxes: ...
|
|
@property
|
|
def masks(self) -> _Masks: ...
|
|
|
|
|
|
@final
|
|
class SilhouetteWindow:
|
|
"""Bounded sliding window for silhouette sequences.
|
|
|
|
Manages a fixed-size buffer of silhouettes with track ID tracking
|
|
and automatic reset on track changes or frame gaps.
|
|
|
|
Attributes:
|
|
window_size: Maximum number of frames in the buffer.
|
|
stride: Classification stride (frames between classifications).
|
|
gap_threshold: Maximum allowed frame gap before reset.
|
|
"""
|
|
|
|
window_size: int
|
|
stride: int
|
|
gap_threshold: int
|
|
_buffer: deque[Float[ndarray, "64 44"]]
|
|
_frame_indices: deque[int]
|
|
_track_id: int | None
|
|
_last_classified_frame: int
|
|
_frame_count: int
|
|
|
|
def __init__(
|
|
self,
|
|
window_size: int = 30,
|
|
stride: int = 1,
|
|
gap_threshold: int = 15,
|
|
) -> None:
|
|
"""Initialize the silhouette window.
|
|
|
|
Args:
|
|
window_size: Maximum buffer size (default 30).
|
|
stride: Frames between classifications (default 1).
|
|
gap_threshold: Max frame gap before reset (default 15).
|
|
"""
|
|
self.window_size = window_size
|
|
self.stride = stride
|
|
self.gap_threshold = gap_threshold
|
|
|
|
# Bounded storage via deque
|
|
self._buffer = deque(maxlen=window_size)
|
|
self._frame_indices = deque(maxlen=window_size)
|
|
self._track_id = None
|
|
self._last_classified_frame = -1
|
|
self._frame_count = 0
|
|
|
|
def push(self, sil: np.ndarray, frame_idx: int, track_id: int) -> None:
|
|
"""Push a new silhouette into the window.
|
|
|
|
Automatically resets buffer on track ID change or frame gap
|
|
exceeding gap_threshold.
|
|
|
|
Args:
|
|
sil: Silhouette array of shape (64, 44), float32.
|
|
frame_idx: Current frame index for gap detection.
|
|
track_id: Track ID for the person.
|
|
"""
|
|
# Check for track ID change
|
|
if self._track_id is not None and track_id != self._track_id:
|
|
self.reset()
|
|
|
|
# Check for frame gap
|
|
if self._frame_indices:
|
|
last_frame = self._frame_indices[-1]
|
|
gap = frame_idx - last_frame
|
|
if gap > self.gap_threshold:
|
|
self.reset()
|
|
|
|
# Update track ID
|
|
self._track_id = track_id
|
|
|
|
# Validate and append silhouette
|
|
sil_array = np.asarray(sil, dtype=np.float32)
|
|
if sil_array.shape != (SIL_HEIGHT, SIL_WIDTH):
|
|
raise ValueError(
|
|
f"Expected silhouette shape ({SIL_HEIGHT}, {SIL_WIDTH}), got {sil_array.shape}"
|
|
)
|
|
|
|
self._buffer.append(sil_array)
|
|
self._frame_indices.append(frame_idx)
|
|
self._frame_count += 1
|
|
|
|
def is_ready(self) -> bool:
|
|
"""Check if window has enough frames for classification.
|
|
|
|
Returns:
|
|
True if buffer is full (window_size frames).
|
|
"""
|
|
return len(self._buffer) >= self.window_size
|
|
|
|
def should_classify(self) -> bool:
|
|
"""Check if classification should run based on stride.
|
|
|
|
Returns:
|
|
True if enough frames have passed since last classification.
|
|
"""
|
|
if not self.is_ready():
|
|
return False
|
|
|
|
if self._last_classified_frame < 0:
|
|
return True
|
|
|
|
current_frame = self._frame_indices[-1]
|
|
frames_since = current_frame - self._last_classified_frame
|
|
return frames_since >= self.stride
|
|
|
|
def get_tensor(self, device: str = "cpu") -> torch.Tensor:
|
|
"""Get window contents as a tensor for model input.
|
|
|
|
Args:
|
|
device: Target device for the tensor (default 'cpu').
|
|
|
|
Returns:
|
|
Tensor of shape [1, 1, window_size, 64, 44] with dtype float32.
|
|
|
|
Raises:
|
|
ValueError: If buffer is not full.
|
|
"""
|
|
if not self.is_ready():
|
|
raise ValueError(
|
|
f"Window not ready: {len(self._buffer)}/{self.window_size} frames"
|
|
)
|
|
|
|
# Stack buffer into array [window_size, 64, 44]
|
|
stacked = np.stack(list(self._buffer), axis=0)
|
|
|
|
# Add batch and channel dims: [1, 1, window_size, 64, 44]
|
|
tensor = torch.from_numpy(stacked.astype(np.float32))
|
|
tensor = tensor.unsqueeze(0).unsqueeze(0)
|
|
|
|
return tensor.to(device)
|
|
|
|
def reset(self) -> None:
|
|
"""Reset the window, clearing all buffers and counters."""
|
|
self._buffer.clear()
|
|
self._frame_indices.clear()
|
|
self._track_id = None
|
|
self._last_classified_frame = -1
|
|
self._frame_count = 0
|
|
|
|
def mark_classified(self) -> None:
|
|
"""Mark current frame as classified, updating stride tracking."""
|
|
if self._frame_indices:
|
|
self._last_classified_frame = self._frame_indices[-1]
|
|
|
|
@property
|
|
def current_track_id(self) -> int | None:
|
|
"""Current track ID, or None if buffer is empty."""
|
|
return self._track_id
|
|
|
|
@property
|
|
def frame_count(self) -> int:
|
|
"""Total frames pushed since last reset."""
|
|
return self._frame_count
|
|
|
|
@property
|
|
def fill_level(self) -> float:
|
|
"""Fill ratio of the buffer (0.0 to 1.0)."""
|
|
return len(self._buffer) / self.window_size
|
|
|
|
@property
|
|
def window_start_frame(self) -> int:
|
|
if not self._frame_indices:
|
|
raise ValueError("Window is empty")
|
|
return int(self._frame_indices[0])
|
|
|
|
@property
|
|
def buffered_silhouettes(self) -> Float[ndarray, "n 64 44"]:
|
|
if not self._buffer:
|
|
return np.empty((0, SIL_HEIGHT, SIL_WIDTH), dtype=np.float32)
|
|
return cast(
|
|
Float[ndarray, "n 64 44"],
|
|
np.stack(list(self._buffer), axis=0).astype(np.float32, copy=True),
|
|
)
|
|
|
|
|
|
def _to_numpy(obj: _ArrayLike) -> ndarray:
|
|
"""Safely convert array-like object to numpy array.
|
|
|
|
Handles torch tensors (CPU or CUDA) by detaching and moving to CPU first.
|
|
Falls back to np.asarray for other array-like objects.
|
|
|
|
Args:
|
|
obj: Array-like object (numpy array, torch tensor, or similar).
|
|
|
|
Returns:
|
|
Numpy array representation of the input.
|
|
"""
|
|
# Handle torch tensors (including CUDA tensors)
|
|
detach_fn = getattr(obj, "detach", None)
|
|
if detach_fn is not None and callable(detach_fn):
|
|
# It's a torch tensor
|
|
tensor = detach_fn()
|
|
cpu_fn = getattr(tensor, "cpu", None)
|
|
if cpu_fn is not None and callable(cpu_fn):
|
|
tensor = cpu_fn()
|
|
numpy_fn = getattr(tensor, "numpy", None)
|
|
if numpy_fn is not None and callable(numpy_fn):
|
|
return cast(ndarray, numpy_fn())
|
|
# Fall back to np.asarray for other array-like objects
|
|
return cast(ndarray, np.asarray(obj))
|
|
|
|
|
|
def select_person(
|
|
results: _DetectionResults,
|
|
) -> tuple[ndarray, BBoxXYXY, BBoxXYXY, int] | None:
|
|
"""Select the person with largest bounding box from detection results.
|
|
|
|
Args:
|
|
results: Detection results object with boxes and masks attributes.
|
|
Expected to have:
|
|
- boxes.xyxy: array of bounding boxes [N, 4] in frame coordinates (XYXY format)
|
|
- masks.data: array of masks [N, H, W] in mask coordinates
|
|
- boxes.id: optional track IDs [N]
|
|
|
|
Returns:
|
|
Tuple of (mask, bbox_mask, bbox_frame, track_id) for the largest person,
|
|
or None if no valid detections or track IDs unavailable.
|
|
- mask: the person's segmentation mask
|
|
- bbox_mask: bounding box in mask coordinate space (XYXY format: x1, y1, x2, y2)
|
|
- bbox_frame: bounding box in frame coordinate space (XYXY format: x1, y1, x2, y2)
|
|
- track_id: the person's track ID
|
|
"""
|
|
# Check for track IDs
|
|
boxes_obj: _Boxes | object = getattr(results, "boxes", None)
|
|
if boxes_obj is None:
|
|
return None
|
|
|
|
track_ids_obj: ndarray | object | None = getattr(boxes_obj, "id", None)
|
|
if track_ids_obj is None:
|
|
return None
|
|
|
|
track_ids: ndarray = _to_numpy(cast(ndarray, track_ids_obj))
|
|
if track_ids.size == 0:
|
|
return None
|
|
|
|
# Get bounding boxes
|
|
xyxy_obj: ndarray | object = getattr(boxes_obj, "xyxy", None)
|
|
if xyxy_obj is None:
|
|
return None
|
|
|
|
bboxes: ndarray = _to_numpy(cast(ndarray, xyxy_obj))
|
|
if bboxes.ndim == 1:
|
|
bboxes = bboxes.reshape(1, -1)
|
|
|
|
if bboxes.shape[0] == 0:
|
|
return None
|
|
|
|
# Get masks
|
|
masks_obj: _Masks | object = getattr(results, "masks", None)
|
|
if masks_obj is None:
|
|
return None
|
|
|
|
masks_data: ndarray | object = getattr(masks_obj, "data", None)
|
|
if masks_data is None:
|
|
return None
|
|
|
|
masks: ndarray = _to_numpy(cast(ndarray, masks_data))
|
|
if masks.ndim == 2:
|
|
masks = masks[np.newaxis, ...]
|
|
|
|
if masks.shape[0] != bboxes.shape[0]:
|
|
return None
|
|
|
|
# Find largest bbox by area
|
|
best_idx: int = -1
|
|
best_area: float = -1.0
|
|
|
|
for i in range(int(bboxes.shape[0])):
|
|
row: "NDArray[np.float32]" = bboxes[i][:4]
|
|
x1f: float = float(row[0])
|
|
y1f: float = float(row[1])
|
|
x2f: float = float(row[2])
|
|
y2f: float = float(row[3])
|
|
area: float = (x2f - x1f) * (y2f - y1f)
|
|
if area > best_area:
|
|
best_area = area
|
|
best_idx = i
|
|
|
|
if best_idx < 0:
|
|
return None
|
|
|
|
# Extract mask and bbox
|
|
mask: "NDArray[np.float32]" = masks[best_idx]
|
|
mask_shape = mask.shape
|
|
mask_h, mask_w = int(mask_shape[0]), int(mask_shape[1])
|
|
|
|
# Get original image dimensions from results (YOLO provides this)
|
|
orig_shape = getattr(results, "orig_shape", None)
|
|
# Validate orig_shape is a sequence of at least 2 numeric values
|
|
if (
|
|
orig_shape is not None
|
|
and isinstance(orig_shape, (tuple, list))
|
|
and len(orig_shape) >= 2
|
|
):
|
|
frame_h, frame_w = int(orig_shape[0]), int(orig_shape[1])
|
|
# Scale bbox from frame space to mask space
|
|
scale_x = mask_w / frame_w if frame_w > 0 else 1.0
|
|
scale_y = mask_h / frame_h if frame_h > 0 else 1.0
|
|
bbox_mask = (
|
|
int(float(bboxes[best_idx][0]) * scale_x),
|
|
int(float(bboxes[best_idx][1]) * scale_y),
|
|
int(float(bboxes[best_idx][2]) * scale_x),
|
|
int(float(bboxes[best_idx][3]) * scale_y),
|
|
)
|
|
bbox_frame = (
|
|
int(float(bboxes[best_idx][0])),
|
|
int(float(bboxes[best_idx][1])),
|
|
int(float(bboxes[best_idx][2])),
|
|
int(float(bboxes[best_idx][3])),
|
|
)
|
|
else:
|
|
# Fallback: use bbox as-is for both (assume same coordinate space)
|
|
bbox_mask = (
|
|
int(float(bboxes[best_idx][0])),
|
|
int(float(bboxes[best_idx][1])),
|
|
int(float(bboxes[best_idx][2])),
|
|
int(float(bboxes[best_idx][3])),
|
|
)
|
|
bbox_frame = bbox_mask
|
|
track_id = int(track_ids[best_idx]) if best_idx < len(track_ids) else best_idx
|
|
|
|
return mask, bbox_mask, bbox_frame, track_id
|