Files
OpenGait/opengait/demo/window.py
T
crosstyan f501119d43 feat(demo): add export and silhouette visualization outputs
Add preprocess-only silhouette export and configurable result exporters so demo runs can be persisted for offline analysis and reproducible evaluation. Include optional parquet support and CLI visualization dumps while updating tests and tracking notes for the verified pipeline/debug workflow.
2026-02-27 17:16:20 +08:00

349 lines
11 KiB
Python

"""Sliding window / ring buffer manager for real-time gait analysis.
This module provides bounded buffer management for silhouette sequences
with track ID tracking and gap detection.
"""
from collections import deque
from typing import TYPE_CHECKING, Protocol, cast, final
import numpy as np
import torch
from jaxtyping import Float
from numpy import ndarray
if TYPE_CHECKING:
from numpy.typing import NDArray
# Silhouette dimensions from preprocess.py
SIL_HEIGHT: int = 64
SIL_WIDTH: int = 44
# Type alias for array-like inputs
type _ArrayLike = torch.Tensor | ndarray
class _Boxes(Protocol):
"""Protocol for boxes with xyxy and id attributes."""
@property
def xyxy(self) -> "NDArray[np.float32] | object": ...
@property
def id(self) -> "NDArray[np.int64] | object | None": ...
class _Masks(Protocol):
"""Protocol for masks with data attribute."""
@property
def data(self) -> "NDArray[np.float32] | object": ...
class _DetectionResults(Protocol):
"""Protocol for detection results from Ultralytics-style objects."""
@property
def boxes(self) -> _Boxes: ...
@property
def masks(self) -> _Masks: ...
@final
class SilhouetteWindow:
"""Bounded sliding window for silhouette sequences.
Manages a fixed-size buffer of silhouettes with track ID tracking
and automatic reset on track changes or frame gaps.
Attributes:
window_size: Maximum number of frames in the buffer.
stride: Classification stride (frames between classifications).
gap_threshold: Maximum allowed frame gap before reset.
"""
window_size: int
stride: int
gap_threshold: int
_buffer: deque[Float[ndarray, "64 44"]]
_frame_indices: deque[int]
_track_id: int | None
_last_classified_frame: int
_frame_count: int
def __init__(
self,
window_size: int = 30,
stride: int = 1,
gap_threshold: int = 15,
) -> None:
"""Initialize the silhouette window.
Args:
window_size: Maximum buffer size (default 30).
stride: Frames between classifications (default 1).
gap_threshold: Max frame gap before reset (default 15).
"""
self.window_size = window_size
self.stride = stride
self.gap_threshold = gap_threshold
# Bounded storage via deque
self._buffer = deque(maxlen=window_size)
self._frame_indices = deque(maxlen=window_size)
self._track_id = None
self._last_classified_frame = -1
self._frame_count = 0
def push(self, sil: np.ndarray, frame_idx: int, track_id: int) -> None:
"""Push a new silhouette into the window.
Automatically resets buffer on track ID change or frame gap
exceeding gap_threshold.
Args:
sil: Silhouette array of shape (64, 44), float32.
frame_idx: Current frame index for gap detection.
track_id: Track ID for the person.
"""
# Check for track ID change
if self._track_id is not None and track_id != self._track_id:
self.reset()
# Check for frame gap
if self._frame_indices:
last_frame = self._frame_indices[-1]
gap = frame_idx - last_frame
if gap > self.gap_threshold:
self.reset()
# Update track ID
self._track_id = track_id
# Validate and append silhouette
sil_array = np.asarray(sil, dtype=np.float32)
if sil_array.shape != (SIL_HEIGHT, SIL_WIDTH):
raise ValueError(
f"Expected silhouette shape ({SIL_HEIGHT}, {SIL_WIDTH}), got {sil_array.shape}"
)
self._buffer.append(sil_array)
self._frame_indices.append(frame_idx)
self._frame_count += 1
def is_ready(self) -> bool:
"""Check if window has enough frames for classification.
Returns:
True if buffer is full (window_size frames).
"""
return len(self._buffer) >= self.window_size
def should_classify(self) -> bool:
"""Check if classification should run based on stride.
Returns:
True if enough frames have passed since last classification.
"""
if not self.is_ready():
return False
if self._last_classified_frame < 0:
return True
current_frame = self._frame_indices[-1]
frames_since = current_frame - self._last_classified_frame
return frames_since >= self.stride
def get_tensor(self, device: str = "cpu") -> torch.Tensor:
"""Get window contents as a tensor for model input.
Args:
device: Target device for the tensor (default 'cpu').
Returns:
Tensor of shape [1, 1, window_size, 64, 44] with dtype float32.
Raises:
ValueError: If buffer is not full.
"""
if not self.is_ready():
raise ValueError(
f"Window not ready: {len(self._buffer)}/{self.window_size} frames"
)
# Stack buffer into array [window_size, 64, 44]
stacked = np.stack(list(self._buffer), axis=0)
# Add batch and channel dims: [1, 1, window_size, 64, 44]
tensor = torch.from_numpy(stacked.astype(np.float32))
tensor = tensor.unsqueeze(0).unsqueeze(0)
return tensor.to(device)
def reset(self) -> None:
"""Reset the window, clearing all buffers and counters."""
self._buffer.clear()
self._frame_indices.clear()
self._track_id = None
self._last_classified_frame = -1
self._frame_count = 0
def mark_classified(self) -> None:
"""Mark current frame as classified, updating stride tracking."""
if self._frame_indices:
self._last_classified_frame = self._frame_indices[-1]
@property
def current_track_id(self) -> int | None:
"""Current track ID, or None if buffer is empty."""
return self._track_id
@property
def frame_count(self) -> int:
"""Total frames pushed since last reset."""
return self._frame_count
@property
def fill_level(self) -> float:
"""Fill ratio of the buffer (0.0 to 1.0)."""
return len(self._buffer) / self.window_size
def _to_numpy(obj: _ArrayLike) -> ndarray:
"""Safely convert array-like object to numpy array.
Handles torch tensors (CPU or CUDA) by detaching and moving to CPU first.
Falls back to np.asarray for other array-like objects.
Args:
obj: Array-like object (numpy array, torch tensor, or similar).
Returns:
Numpy array representation of the input.
"""
# Handle torch tensors (including CUDA tensors)
detach_fn = getattr(obj, "detach", None)
if detach_fn is not None and callable(detach_fn):
# It's a torch tensor
tensor = detach_fn()
cpu_fn = getattr(tensor, "cpu", None)
if cpu_fn is not None and callable(cpu_fn):
tensor = cpu_fn()
numpy_fn = getattr(tensor, "numpy", None)
if numpy_fn is not None and callable(numpy_fn):
return cast(ndarray, numpy_fn())
# Fall back to np.asarray for other array-like objects
return cast(ndarray, np.asarray(obj))
def select_person(
results: _DetectionResults,
) -> tuple[ndarray, tuple[int, int, int, int], int] | None:
"""Select the person with largest bounding box from detection results.
Args:
results: Detection results object with boxes and masks attributes.
Expected to have:
- boxes.xyxy: array of bounding boxes [N, 4]
- masks.data: array of masks [N, H, W]
- boxes.id: optional track IDs [N]
Returns:
Tuple of (mask, bbox, track_id) for the largest person,
or None if no valid detections or track IDs unavailable.
"""
# Check for track IDs
boxes_obj: _Boxes | object = getattr(results, "boxes", None)
if boxes_obj is None:
return None
track_ids_obj: ndarray | object | None = getattr(boxes_obj, "id", None)
if track_ids_obj is None:
return None
track_ids: ndarray = _to_numpy(cast(ndarray, track_ids_obj))
if track_ids.size == 0:
return None
# Get bounding boxes
xyxy_obj: ndarray | object = getattr(boxes_obj, "xyxy", None)
if xyxy_obj is None:
return None
bboxes: ndarray = _to_numpy(cast(ndarray, xyxy_obj))
if bboxes.ndim == 1:
bboxes = bboxes.reshape(1, -1)
if bboxes.shape[0] == 0:
return None
# Get masks
masks_obj: _Masks | object = getattr(results, "masks", None)
if masks_obj is None:
return None
masks_data: ndarray | object = getattr(masks_obj, "data", None)
if masks_data is None:
return None
masks: ndarray = _to_numpy(cast(ndarray, masks_data))
if masks.ndim == 2:
masks = masks[np.newaxis, ...]
if masks.shape[0] != bboxes.shape[0]:
return None
# Find largest bbox by area
best_idx: int = -1
best_area: float = -1.0
for i in range(int(bboxes.shape[0])):
row: "NDArray[np.float32]" = bboxes[i][:4]
x1f: float = float(row[0])
y1f: float = float(row[1])
x2f: float = float(row[2])
y2f: float = float(row[3])
area: float = (x2f - x1f) * (y2f - y1f)
if area > best_area:
best_area = area
best_idx = i
if best_idx < 0:
return None
# Extract mask and bbox
mask: "NDArray[np.float32]" = masks[best_idx]
mask_shape = mask.shape
mask_h, mask_w = int(mask_shape[0]), int(mask_shape[1])
# Get original image dimensions from results (YOLO provides this)
orig_shape = getattr(results, "orig_shape", None)
# Validate orig_shape is a sequence of at least 2 numeric values
if (
orig_shape is not None
and isinstance(orig_shape, (tuple, list))
and len(orig_shape) >= 2
):
frame_h, frame_w = int(orig_shape[0]), int(orig_shape[1])
# Scale bbox from frame space to mask space
scale_x = mask_w / frame_w if frame_w > 0 else 1.0
scale_y = mask_h / frame_h if frame_h > 0 else 1.0
bbox = (
int(float(bboxes[best_idx][0]) * scale_x),
int(float(bboxes[best_idx][1]) * scale_y),
int(float(bboxes[best_idx][2]) * scale_x),
int(float(bboxes[best_idx][3]) * scale_y),
)
else:
# Fallback: use bbox as-is (assume same coordinate space)
bbox = (
int(float(bboxes[best_idx][0])),
int(float(bboxes[best_idx][1])),
int(float(bboxes[best_idx][2])),
int(float(bboxes[best_idx][3])),
)
track_id = int(track_ids[best_idx]) if best_idx < len(track_ids) else best_idx
return mask, bbox, track_id