Files
OpenGait/scripts/export_positive_batches.py

395 lines
12 KiB
Python

#!/usr/bin/env python3
"""
Export all positive labeled batches from Scoliosis1K dataset as time windows.
Creates grid visualizations similar to visualizer._prepare_segmentation_input_view()
for all positive class samples, arranged in sliding time windows.
Optimized UI with:
- Subject ID and batch info footer
- Dual frame counts (window-relative and sequence-relative)
- Clean layout with proper spacing
"""
from __future__ import annotations
import json
import pickle
from pathlib import Path
from typing import Final
import cv2
import numpy as np
from numpy.typing import NDArray
# Constants matching visualizer.py
DISPLAY_HEIGHT: Final = 256
DISPLAY_WIDTH: Final = 176
SIL_HEIGHT: Final = 64
SIL_WIDTH: Final = 44
# Footer settings
FOOTER_HEIGHT: Final = 80 # Height for metadata footer
FOOTER_BG_COLOR: Final = (40, 40, 40) # Dark gray background
TEXT_COLOR: Final = (255, 255, 255) # White text
ACCENT_COLOR: Final = (0, 255, 255) # Cyan for emphasis
FONT: Final = cv2.FONT_HERSHEY_SIMPLEX
FONT_SCALE: Final = 0.6
FONT_THICKNESS: Final = 2
def upscale_silhouette(
silhouette: NDArray[np.float32] | NDArray[np.uint8],
) -> NDArray[np.uint8]:
"""Upscale silhouette to display size."""
if silhouette.dtype == np.float32 or silhouette.dtype == np.float64:
sil_u8 = (silhouette * 255).astype(np.uint8)
else:
sil_u8 = silhouette.astype(np.uint8)
upscaled = cv2.resize(
sil_u8, (DISPLAY_WIDTH, DISPLAY_HEIGHT), interpolation=cv2.INTER_NEAREST
)
return upscaled
def create_optimized_visualization(
silhouettes: NDArray[np.float32],
subject_id: str,
view_name: str,
window_idx: int,
start_frame: int,
end_frame: int,
n_frames_total: int,
tile_height: int = DISPLAY_HEIGHT,
tile_width: int = DISPLAY_WIDTH,
) -> NDArray[np.uint8]:
"""
Create optimized visualization with grid and metadata footer.
Args:
silhouettes: Array of shape (n_frames, 64, 44) float32
subject_id: Subject identifier
view_name: View identifier (e.g., "000_180")
window_idx: Window index within sequence
start_frame: Starting frame index in sequence
end_frame: Ending frame index in sequence
n_frames_total: Total frames in the sequence
tile_height: Height of each tile in the grid
tile_width: Width of each tile in the grid
Returns:
Combined image with grid visualization and metadata footer
"""
n_frames = int(silhouettes.shape[0])
tiles_per_row = int(np.ceil(np.sqrt(n_frames)))
rows = int(np.ceil(n_frames / tiles_per_row))
# Create grid
grid = np.zeros((rows * tile_height, tiles_per_row * tile_width), dtype=np.uint8)
# Place each silhouette in the grid
for idx in range(n_frames):
sil = silhouettes[idx]
tile = upscale_silhouette(sil)
r = idx // tiles_per_row
c = idx % tiles_per_row
y0, y1 = r * tile_height, (r + 1) * tile_height
x0, x1 = c * tile_width, (c + 1) * tile_width
grid[y0:y1, x0:x1] = tile
# Convert to BGR
grid_bgr = cv2.cvtColor(grid, cv2.COLOR_GRAY2BGR)
# Add frame indices as text (both window-relative and sequence-relative)
for idx in range(n_frames):
r = idx // tiles_per_row
c = idx % tiles_per_row
y0 = r * tile_height
x0 = c * tile_width
# Window frame count (top-left)
cv2.putText(
grid_bgr,
f"{idx}", # Window-relative frame number
(x0 + 8, y0 + 22),
FONT,
FONT_SCALE,
ACCENT_COLOR,
FONT_THICKNESS,
cv2.LINE_AA,
)
# Sequence frame count (bottom-left of tile)
seq_frame = start_frame + idx
cv2.putText(
grid_bgr,
f"#{seq_frame}", # Sequence-relative frame number
(x0 + 8, y0 + tile_height - 10),
FONT,
0.45, # Slightly smaller font
(180, 180, 180), # Light gray
1,
cv2.LINE_AA,
)
# Create footer with metadata
grid_width = grid_bgr.shape[1]
footer = np.full((FOOTER_HEIGHT, grid_width, 3), FOOTER_BG_COLOR, dtype=np.uint8)
# Line 1: Subject ID and view
line1 = f"Subject: {subject_id} | View: {view_name}"
cv2.putText(
footer,
line1,
(15, 25),
FONT,
0.7,
TEXT_COLOR,
FONT_THICKNESS,
cv2.LINE_AA,
)
# Line 2: Window batch frame range
line2 = f"Window {window_idx}: frames [{start_frame:03d} - {end_frame - 1:03d}] ({n_frames} frames)"
cv2.putText(
footer,
line2,
(15, 50),
FONT,
0.7,
ACCENT_COLOR,
FONT_THICKNESS,
cv2.LINE_AA,
)
# Line 3: Progress within sequence
progress_pct = (end_frame / n_frames_total) * 100
line3 = f"Sequence: {n_frames_total} frames total | Progress: {progress_pct:.1f}%"
cv2.putText(
footer,
line3,
(15, 72),
FONT,
0.6,
(200, 200, 200),
1,
cv2.LINE_AA,
)
# Combine grid and footer
combined = np.vstack([grid_bgr, footer])
return combined
def load_pkl_sequence(pkl_path: Path) -> NDArray[np.float32]:
"""Load a .pkl file containing silhouette sequence."""
with open(pkl_path, "rb") as f:
data = pickle.load(f)
# Handle different possible structures
if isinstance(data, np.ndarray):
return data.astype(np.float32)
elif isinstance(data, list):
# List of frames
return np.stack([np.array(frame) for frame in data]).astype(np.float32)
else:
raise ValueError(f"Unexpected data type in {pkl_path}: {type(data)}")
def create_windows(
sequence: NDArray[np.float32],
window_size: int = 30,
stride: int = 30,
) -> list[NDArray[np.float32]]:
"""
Split a sequence into sliding windows.
Args:
sequence: Array of shape (N, 64, 44)
window_size: Number of frames per window
stride: Stride between consecutive windows
Returns:
List of window arrays, each of shape (window_size, 64, 44)
"""
n_frames = sequence.shape[0]
windows = []
for start_idx in range(0, n_frames - window_size + 1, stride):
end_idx = start_idx + window_size
window = sequence[start_idx:end_idx]
windows.append(window)
return windows
def export_positive_batches(
dataset_root: Path,
output_dir: Path,
window_size: int = 30,
stride: int = 30,
max_sequences: int | None = None,
) -> None:
"""
Export all positive labeled batches from Scoliosis1K dataset as time windows.
Args:
dataset_root: Path to Scoliosis1K-sil-pkl directory
output_dir: Output directory for visualizations
window_size: Number of frames per window (default 30)
stride: Stride between consecutive windows (default 30 = non-overlapping)
max_sequences: Maximum number of sequences to process (None = all)
"""
output_dir.mkdir(parents=True, exist_ok=True)
# Find all positive samples
positive_samples: list[
tuple[Path, str, str, str]
] = [] # (pkl_path, subject_id, view_name, pkl_name)
for subject_dir in sorted(dataset_root.iterdir()):
if not subject_dir.is_dir():
continue
subject_id = subject_dir.name
# Check for positive class directory (lowercase)
positive_dir = subject_dir / "positive"
if not positive_dir.exists():
continue
# Iterate through views
for view_dir in sorted(positive_dir.iterdir()):
if not view_dir.is_dir():
continue
view_name = view_dir.name
# Find .pkl files
for pkl_file in sorted(view_dir.glob("*.pkl")):
positive_samples.append(
(pkl_file, subject_id, view_name, pkl_file.stem)
)
print(f"Found {len(positive_samples)} positive labeled sequences")
if max_sequences:
positive_samples = positive_samples[:max_sequences]
print(f"Processing first {max_sequences} sequences")
total_windows = 0
# Export each sequence's windows
for seq_idx, (pkl_path, subject_id, view_name, pkl_name) in enumerate(
positive_samples, 1
):
print(
f"[{seq_idx}/{len(positive_samples)}] Processing {subject_id}/{view_name}/{pkl_name}..."
)
# Load sequence
try:
sequence = load_pkl_sequence(pkl_path)
except Exception as e:
print(f" Error loading {pkl_path}: {e}")
continue
# Ensure correct shape (N, 64, 44)
if len(sequence.shape) == 2:
# Single frame
sequence = sequence[np.newaxis, ...]
elif len(sequence.shape) == 3:
# (N, H, W) - expected
pass
else:
print(f" Unexpected shape {sequence.shape}, skipping")
continue
n_frames = sequence.shape[0]
print(f" Sequence has {n_frames} frames")
# Skip if sequence is shorter than window size
if n_frames < window_size:
print(f" Skipping: sequence too short (< {window_size} frames)")
continue
# Normalize if needed
if sequence.max() > 1.0:
sequence = sequence / 255.0
# Create windows
windows = create_windows(sequence, window_size=window_size, stride=stride)
print(f" Created {len(windows)} windows (size={window_size}, stride={stride})")
# Export each window
for window_idx, window in enumerate(windows):
start_frame = window_idx * stride
end_frame = start_frame + window_size
# Create visualization for this window with full metadata
vis_image = create_optimized_visualization(
silhouettes=window,
subject_id=subject_id,
view_name=view_name,
window_idx=window_idx,
start_frame=start_frame,
end_frame=end_frame,
n_frames_total=n_frames,
)
# Save with descriptive filename including window index
output_filename = (
f"{subject_id}_{view_name}_{pkl_name}_win{window_idx:03d}.png"
)
output_path = output_dir / output_filename
cv2.imwrite(str(output_path), vis_image)
# Save metadata for this window
meta = {
"subject_id": subject_id,
"view": view_name,
"pkl_name": pkl_name,
"window_index": window_idx,
"window_size": window_size,
"stride": stride,
"start_frame": start_frame,
"end_frame": end_frame,
"sequence_shape": sequence.shape,
"n_frames_total": n_frames,
"source_path": str(pkl_path),
}
meta_filename = (
f"{subject_id}_{view_name}_{pkl_name}_win{window_idx:03d}.json"
)
meta_path = output_dir / meta_filename
with open(meta_path, "w") as f:
json.dump(meta, f, indent=2)
total_windows += 1
print(f" Exported {len(windows)} windows")
print(f"\nExport complete! Saved {total_windows} windows to {output_dir}")
def main() -> None:
"""Main entry point."""
# Paths
dataset_root = Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-sil-pkl")
output_dir = Path("/home/crosstyan/Code/OpenGait/output/positive_batches")
if not dataset_root.exists():
print(f"Error: Dataset not found at {dataset_root}")
return
# Export all positive batches with windowing
export_positive_batches(
dataset_root,
output_dir,
window_size=30, # 30 frames per window
stride=30, # Non-overlapping windows
)
if __name__ == "__main__":
main()