chore: update demo runtime, tests, and agent docs

This commit is contained in:
2026-03-02 12:33:17 +08:00
parent 1f8f959ad7
commit cbb3284c13
14 changed files with 1491 additions and 236 deletions
+1
View File
@@ -147,4 +147,5 @@ dmypy.json
cython_debug/ cython_debug/
ckpt/ ckpt/
output/
assets/* assets/*
+180 -73
View File
@@ -1,84 +1,191 @@
# PROJECT KNOWLEDGE BASE # OpenGait Agent Guide
**Generated:** 2026-02-11T10:53:29Z This file is for autonomous coding agents working in this repository.
**Commit:** f754f6f Use it as the default playbook for commands, conventions, and safety checks.
**Branch:** master
## OVERVIEW ## Scope and Ground Truth
OpenGait is a research-grade, config-driven gait analysis framework centered on distributed PyTorch training/testing.
Core runtime lives in `opengait/`; `configs/` and `datasets/` are first-class operational surfaces, not just support folders.
## STRUCTURE - Repository: `OpenGait`
```text - Runtime package: `opengait/`
OpenGait/ - Primary entrypoint: `opengait/main.py`
├── opengait/ # runtime package (train/test, model/data/eval pipelines) - Package/runtime tool: `uv`
├── configs/ # model- and dataset-specific experiment specs
├── datasets/ # preprocessing/rearrangement scripts + partitions
├── docs/ # user workflow docs
├── train.sh # launch patterns (DDP)
└── test.sh # eval launch patterns (DDP)
```
## WHERE TO LOOK Critical source-of-truth rule:
| Task | Location | Notes | - `opengait/demo` is an implementation layer and may contain project-specific behavior.
|------|----------|-------| - When asked to “refer to the paper” or verify methodology, use the paper and official citations as ground truth.
| Train/test entry | `opengait/main.py` | DDP init + config load + model dispatch | - Do not treat demo/runtime behavior as proof of paper method unless explicitly cited by the paper.
| Model registration | `opengait/modeling/models/__init__.py` | dynamic class import/registration |
| Backbone/loss registration | `opengait/modeling/backbones/__init__.py`, `opengait/modeling/losses/__init__.py` | same dynamic pattern |
| Config merge behavior | `opengait/utils/common.py::config_loader` | merges into `configs/default.yaml` |
| Data loading contract | `opengait/data/dataset.py`, `opengait/data/collate_fn.py` | `.pkl` only, sequence sampling modes |
| Evaluation dispatch | `opengait/evaluation/evaluator.py` | dataset-specific eval routines |
| Dataset preprocessing | `datasets/pretreatment.py` + dataset subdirs | many standalone CLI tools |
## CODE MAP ## Environment Setup
| Symbol / Module | Type | Location | Refs | Role |
|-----------------|------|----------|------|------|
| `config_loader` | function | `opengait/utils/common.py` | high | YAML merge + default overlay |
| `get_ddp_module` | function | `opengait/utils/common.py` | high | wraps modules with DDP passthrough |
| `BaseModel` | class | `opengait/modeling/base_model.py` | high | canonical train/test lifecycle |
| `LossAggregator` | class | `opengait/modeling/loss_aggregator.py` | medium | consumes `training_feat` contract |
| `DataSet` | class | `opengait/data/dataset.py` | high | dataset partition + sequence loading |
| `CollateFn` | class | `opengait/data/collate_fn.py` | high | fixed/unfixed/all sampling policy |
| `evaluate_*` funcs | functions | `opengait/evaluation/evaluator.py` | medium | metric/report orchestration |
| `models` package registry | dynamic module | `opengait/modeling/models/__init__.py` | high | config string → model class |
## CONVENTIONS Install dependencies with uv:
- Launch pattern is DDP-first (`python -m torch.distributed.launch ... opengait/main.py --cfgs ... --phase ...`).
- DDP Constraints: `world_size` must equal number of visible GPUs; test `evaluator_cfg.sampler.batch_size` must equal `world_size`.
- Model/loss/backbone discoverability is filesystem-driven via package-level dynamic imports.
- Experiment config semantics: custom YAML overlays `configs/default.yaml` (local key precedence).
- Outputs are keyed by config identity: `output/${dataset_name}/${model}/${save_name}`.
## ANTI-PATTERNS (THIS PROJECT)
- Do not feed non-`.pkl` sequence files into runtime loaders (`opengait/data/dataset.py`).
- Do not violate sampler shape assumptions (`trainer_cfg.sampler.batch_size` is `[P, K]` for triplet regimes).
- Do not ignore DDP cleanup guidance; abnormal exits can leave zombie processes (`misc/clean_process.sh`).
- Do not add unregistered model/loss classes outside expected directories (`opengait/modeling/models`, `opengait/modeling/losses`).
## UNIQUE STYLES
- `datasets/` is intentionally script-heavy (rearrange/extract/pretreat), not a pure library package.
- Research model zoo is broad; many model files co-exist as first-class references.
- Recent repo trajectory includes scoliosis screening models (ScoNet lineage), not only person-ID gait benchmarks.
## COMMANDS
```bash ```bash
# install (uv)
uv sync --extra torch uv sync --extra torch
# train (uv)
CUDA_VISIBLE_DEVICES=0,1 uv run python -m torch.distributed.launch --nproc_per_node=2 opengait/main.py --cfgs ./configs/baseline/baseline.yaml --phase train
# test (uv)
CUDA_VISIBLE_DEVICES=0,1 uv run python -m torch.distributed.launch --nproc_per_node=2 opengait/main.py --cfgs ./configs/baseline/baseline.yaml --phase test
# ScoNet 1-GPU eval
CUDA_VISIBLE_DEVICES=0 uv run python -m torch.distributed.launch --nproc_per_node=1 opengait/main.py --cfgs ./configs/sconet/sconet_scoliosis1k_local_eval_1gpu.yaml --phase test
# preprocess (generic)
python datasets/pretreatment.py --input_path <raw_or_rearranged> --output_path <pkl_root>
``` ```
## NOTES Notes from `pyproject.toml`:
- LSP symbol map can be enabled via uv dev dependency `basedpyright`; `basedpyright` and `basedpyright-langserver` are available in `.venv` after `uv sync`. - Python requirement: `>=3.10`
- `train.sh` / `test.sh` are canonical launch examples across datasets/models. - Dev tooling includes `pytest` and `basedpyright`
- Academic-use-only restriction is stated in repository README. - Optional extras include `torch` and `parquet`
## Build / Run Commands
Train (DDP):
```bash
CUDA_VISIBLE_DEVICES=0,1 uv run python -m torch.distributed.launch \
--nproc_per_node=2 opengait/main.py \
--cfgs ./configs/baseline/baseline.yaml --phase train
```
Test (DDP):
```bash
CUDA_VISIBLE_DEVICES=0,1 uv run python -m torch.distributed.launch \
--nproc_per_node=2 opengait/main.py \
--cfgs ./configs/baseline/baseline.yaml --phase test
```
Single-GPU eval example:
```bash
CUDA_VISIBLE_DEVICES=0 uv run python -m torch.distributed.launch \
--nproc_per_node=1 opengait/main.py \
--cfgs ./configs/sconet/sconet_scoliosis1k_local_eval_1gpu.yaml --phase test
```
Demo CLI entry:
```bash
uv run python -m opengait.demo --help
```
## DDP Constraints (Important)
- `--nproc_per_node` must match visible GPU count in `CUDA_VISIBLE_DEVICES`.
- Test/evaluator sampling settings are strict and can fail if world size mismatches config.
- If interrupted DDP leaves stale processes:
```bash
sh misc/clean_process.sh
```
## Test Commands (especially single test)
Run all tests:
```bash
uv run pytest tests
```
Run one file:
```bash
uv run pytest tests/demo/test_pipeline.py -v
```
Run one test function:
```bash
uv run pytest tests/demo/test_pipeline.py::test_resolve_stride_modes -v
```
Run by keyword:
```bash
uv run pytest tests/demo/test_window.py -k "stride" -v
```
## Lint / Typecheck
Typecheck with basedpyright:
```bash
uv run basedpyright opengait tests
```
Project currently has no enforced formatter config in root tool files.
Follow existing local formatting and keep edits minimal.
## High-Value Paths
- `opengait/main.py` — runtime bootstrap
- `opengait/modeling/base_model.py` — model lifecycle contract
- `opengait/modeling/models/` — model zoo implementations
- `opengait/data/dataset.py` — dataset loading rules
- `opengait/data/collate_fn.py` — frame sampling behavior
- `opengait/evaluation/evaluator.py` — evaluation dispatch
- `configs/` — experiment definitions
- `datasets/` — preprocessing and partitions
## Code Style Guidelines
### Imports
- Keep ordering consistent: stdlib, third-party, local.
- Prefer explicit imports; avoid wildcard imports.
- Avoid introducing heavy imports in hot paths unless needed.
### Formatting
- Match surrounding file style (spacing, wrapping, structure).
- Avoid unrelated formatting churn.
- Keep diffs surgical.
### Types
- Add type annotations for new public APIs and non-trivial helpers.
- Reuse established typing style: `typing`, `numpy.typing`, `jaxtyping` where already used.
- Do not suppress type safety with blanket casts; keep unavoidable casts narrow.
### Naming
- `snake_case` for functions/variables
- `PascalCase` for classes
- `UPPER_SNAKE_CASE` for constants
- Preserve existing config key names and schema conventions
### Error Handling
- Raise explicit, actionable errors on invalid inputs.
- Fail fast for missing files, bad args, invalid shapes, and runtime preconditions.
- Never swallow exceptions silently.
- Preserve CLI error semantics (clear messages, non-zero exits).
### Logging
- Use module-level logger pattern already in codebase.
- Keep logs concise and operational.
- Avoid excessive per-frame logging in realtime/demo loops.
## Model and Config Contracts
- New models should conform to `BaseModel` expectations.
- Respect forward output dictionary contract used by loss/evaluator pipeline.
- Keep model registration/discovery patterns consistent with current package layout.
- Respect sampler semantics from config (`fixed_unordered`, `all_ordered`, etc.).
## Data Contracts
- Runtime data expects preprocessed `.pkl` sequence files.
- Partition JSON files are required for train/test split behavior.
- Do not mix modalities accidentally (silhouette / pose / pointcloud) across pipelines.
## Research-Verification Policy
When answering methodology questions:
- Prefer primary sources (paper PDF, official project docs, official code tied to publication).
- Quote/cite paper statements when concluding method behavior.
- If local implementation differs from paper, state divergence explicitly.
- For this repo specifically, remember: `opengait/demo` may differ from paper intent.
## Cursor / Copilot Rules Check
Checked these paths:
- `.cursor/rules/`
- `.cursorrules`
- `.github/copilot-instructions.md`
Current status: no Cursor/Copilot instruction files found.
## Agent Checklist Before Finishing
- Commands executed with `uv run ...` where applicable
- Targeted tests for changed files pass
- Typecheck is clean for modified code
- Behavior/documentation updated together for user-facing changes
- Paper-vs-implementation claims clearly separated when relevant
@@ -0,0 +1,101 @@
data_cfg:
dataset_name: Scoliosis1K
dataset_root: /mnt/public/data/Scoliosis1K/Scoliosis1K-sil-pkl
dataset_partition: /mnt/public/data/Scoliosis1K/Scoliosis1K_1116.json
num_workers: 1
remove_no_gallery: false
test_dataset_name: Scoliosis1K
evaluator_cfg:
enable_float16: true
restore_ckpt_strict: true
restore_hint: ./ckpt/ScoNet-20000.pt
save_name: ScoNet
eval_func: evaluate_scoliosis
sampler:
batch_shuffle: false
batch_size: 1
sample_type: all_ordered
frames_all_limit: 720
metric: euc
transform:
- type: BaseSilCuttingTransform
loss_cfg:
- loss_term_weight: 1.0
margin: 0.2
type: TripletLoss
log_prefix: triplet
- loss_term_weight: 1.0
scale: 16
type: CrossEntropyLoss
log_prefix: softmax
log_accuracy: true
model_cfg:
model: ScoNet
backbone_cfg:
type: ResNet9
block: BasicBlock
channels:
- 64
- 128
- 256
- 512
layers:
- 1
- 1
- 1
- 1
strides:
- 1
- 2
- 2
- 1
maxpool: false
SeparateFCs:
in_channels: 512
out_channels: 256
parts_num: 16
SeparateBNNecks:
class_num: 3
in_channels: 256
parts_num: 16
bin_num:
- 16
optimizer_cfg:
lr: 0.1
momentum: 0.9
solver: SGD
weight_decay: 0.0005
scheduler_cfg:
gamma: 0.1
milestones:
- 10000
- 14000
- 18000
scheduler: MultiStepLR
trainer_cfg:
enable_float16: true
fix_BN: false
with_test: false
log_iter: 100
restore_ckpt_strict: true
restore_hint: 0
save_iter: 20000
save_name: ScoNet
sync_BN: true
total_iter: 20000
sampler:
batch_shuffle: true
batch_size:
- 8
- 8
frames_num_fixed: 30
sample_type: fixed_unordered
type: TripletSampler
transform:
- type: BaseSilCuttingTransform
+144
View File
@@ -0,0 +1,144 @@
# Demo Window, Stride, and Sequence Behavior (ScoNet)
This note explains how the `opengait/demo` runtime feeds silhouettes into the neural network, what `stride` means, and when the sliding window is reset.
## Why sequence input (not single frame)
ScoNet-style inference is sequence-based.
- ScoNet / ScoNet-MT paper: https://arxiv.org/html/2407.05726v3
- DRF follow-up paper: https://arxiv.org/html/2509.00872v1
Both works use temporal information across walking frames rather than a single independent image.
### Direct quotes from the papers
From ScoNet / ScoNet-MT (MICCAI 2024, `2407.05726v3`):
> "For experiments, **30 frames were selected from each gait sequence as input**."
> (Section 4.1, Implementation Details)
From the same paper's dataset description:
> "Each sequence, containing approximately **300 frames at 15 frames per second**..."
> (Section 2.2, Data Collection and Preprocessing)
From DRF (MICCAI 2025, `2509.00872v1`):
DRF follows ScoNet-MT's sequence-level setup/architecture in its implementation details, and its PAV branch also aggregates across frames:
> "Sequence-Level PAV Refinement ... (2) **Temporal Aggregation**: For each metric, the mean of valid measurements across **all frames** is computed..."
> (Section 3.1, PAV: Discrete Clinical Prior)
## What papers say (and do not say) about stride
The papers define sequence-based inputs and temporal aggregation, but they do **not** define a deployment/runtime `stride` knob for online inference windows.
In other words:
- Paper gives the sequence framing (e.g., 30-frame inputs in ScoNet experiments).
- Demo `stride` is an engineering control for how often to run inference in streaming mode.
## What the demo feeds into the network
In `opengait/demo`, each inference uses the current silhouette buffer from `SilhouetteWindow`:
- Per-frame silhouette shape: `64 x 44`
- Tensor shape for inference: `[1, 1, window_size, 64, 44]`
- Default `window_size`: `30`
So by default, one prediction uses **30 silhouettes**.
## What is stride?
`stride` means the minimum frame distance between two consecutive classifications **after** the window is already full.
In this demo, the window is a true sliding buffer. It is **not** cleared after each inference. After inference, the pipeline only records the last classified frame and continues buffering new silhouettes.
- If `stride = 1`: classify at every new frame once ready
- If `stride = 30` (default): classify every 30 frames once ready
## Window mode shortcut (`--window-mode`)
To make window scheduling explicit, the demo CLI supports:
- `--window-mode manual` (default): use the exact `--stride` value
- `--window-mode sliding`: force `stride = 1` (max overlap)
- `--window-mode chunked`: force `stride = window` (no overlap)
This is only a shortcut for runtime behavior. It does not change ScoNet weights or architecture.
Examples:
- Sliding windows: `--window 30 --window-mode sliding` -> windows like `0-29, 1-30, 2-31, ...`
- Chunked windows: `--window 30 --window-mode chunked` -> windows like `0-29, 30-59, 60-89, ...`
- Manual stride: `--window 30 --stride 10 --window-mode manual` -> windows every 10 frames
Time interval between predictions is approximately:
`prediction_interval_seconds ~= stride / fps`
If `--target-fps` is set, use the emitted (downsampled) fps in this formula.
Examples:
- `stride=30`, `fps=15` -> about `2.0s`
- `stride=15`, `fps=30` -> about `0.5s`
First prediction latency is approximately:
`first_prediction_latency_seconds ~= window_size / fps`
assuming detections are continuous.
## Does the window clear when tracking target switches?
Yes. The window is reset in either case:
1. **Track ID changed** (new tracking target)
2. **Frame gap too large** (`frame_idx - last_frame > gap_threshold`)
Default `gap_threshold` in demo is `15` frames.
This prevents silhouettes from different people or long interrupted segments from being mixed into one inference window.
To be explicit:
- **Inference finished** -> window stays (sliding continues)
- **Track ID changed** -> window reset
- **Frame gap > gap_threshold** -> window reset
## Practical note about real-time detections
The window fills only when a valid silhouette is produced (i.e., person detection/segmentation succeeds). If detections are intermittent, the real-world time covered by one `window_size` can be longer than `window_size / fps`.
## Online vs offline behavior (important)
ScoNet's neural network does not hard-code a fixed frame count in the model graph. In OpenGait, frame count is controlled by sampling/runtime policy:
- Training config typically uses `frames_num_fixed: 30` with random fixed-frame sampling.
- Offline evaluation often uses `all_ordered` sequences (with `frames_all_limit` as a memory guard).
- Online demo uses the runtime window/stride scheduler.
So this does not mean the method only works offline. It means online performance depends on the latency/robustness trade-off you choose:
- Smaller windows / larger stride -> lower latency, potentially less stable predictions
- Larger windows / overlap -> smoother predictions, higher compute/latency
If you want behavior closest to ScoNet training assumptions, start from `--window 30` and tune stride (or `--window-mode`) for your deployment latency budget.
## Temporal downsampling (`--target-fps`)
Use `--target-fps` to normalize incoming frame cadence before silhouettes are pushed into the classification window.
- Default (`--target-fps 15`): timestamp-based pacing emits frames at approximately 15 FPS into the window
- Optional override (`--no-target-fps`): disable temporal downsampling and use all frames
Current default is `--target-fps 15` to align runtime cadence with ScoNet training assumptions.
For offline video sources, pacing uses video-time timestamps (`CAP_PROP_POS_MSEC`) when available, with an FPS-based synthetic timestamp fallback. This avoids coupling downsampling to processing throughput.
This is useful when camera FPS differs from training cadence. For example, with a 24 FPS camera:
- `--target-fps 15 --window 30` keeps model input near ~2.0 seconds of gait context (close to paper setup)
- `--stride` is interpreted in emitted-frame units after pacing
+53 -22
View File
@@ -3,8 +3,16 @@ from __future__ import annotations
import argparse import argparse
import logging import logging
import sys import sys
from typing import cast
from .pipeline import ScoliosisPipeline from .pipeline import ScoliosisPipeline, WindowMode, resolve_stride
def _positive_float(value: str) -> float:
parsed = float(value)
if parsed <= 0:
raise argparse.ArgumentTypeError("target-fps must be positive")
return parsed
if __name__ == "__main__": if __name__ == "__main__":
@@ -29,6 +37,24 @@ if __name__ == "__main__":
"--window", type=int, default=30, help="Window size for classification" "--window", type=int, default=30, help="Window size for classification"
) )
parser.add_argument("--stride", type=int, default=30, help="Stride for window") parser.add_argument("--stride", type=int, default=30, help="Stride for window")
parser.add_argument(
"--target-fps",
type=_positive_float,
default=15.0,
help="Target FPS for temporal downsampling before windowing",
)
parser.add_argument(
"--window-mode",
type=str,
choices=["manual", "sliding", "chunked"],
default="manual",
help="Window scheduling mode: manual uses --stride; sliding uses stride=1; chunked uses stride=window",
)
parser.add_argument(
"--no-target-fps",
action="store_true",
help="Disable temporal downsampling and use all frames",
)
parser.add_argument( parser.add_argument(
"--nats-url", type=str, default=None, help="NATS URL for result publishing" "--nats-url", type=str, default=None, help="NATS URL for result publishing"
) )
@@ -88,27 +114,32 @@ if __name__ == "__main__":
source=args.source, checkpoint=args.checkpoint, config=args.config source=args.source, checkpoint=args.checkpoint, config=args.config
) )
# Build kwargs based on what ScoliosisPipeline accepts effective_stride = resolve_stride(
pipeline_kwargs = { window=cast(int, args.window),
"source": args.source, stride=cast(int, args.stride),
"checkpoint": args.checkpoint, window_mode=cast(WindowMode, args.window_mode),
"config": args.config, )
"device": args.device,
"yolo_model": args.yolo_model, pipeline = ScoliosisPipeline(
"window": args.window, source=cast(str, args.source),
"stride": args.stride, checkpoint=cast(str, args.checkpoint),
"nats_url": args.nats_url, config=cast(str, args.config),
"nats_subject": args.nats_subject, device=cast(str, args.device),
"max_frames": args.max_frames, yolo_model=cast(str, args.yolo_model),
"preprocess_only": args.preprocess_only, window=cast(int, args.window),
"silhouette_export_path": args.silhouette_export_path, stride=effective_stride,
"silhouette_export_format": args.silhouette_export_format, target_fps=(None if args.no_target_fps else cast(float, args.target_fps)),
"silhouette_visualize_dir": args.silhouette_visualize_dir, nats_url=cast(str | None, args.nats_url),
"result_export_path": args.result_export_path, nats_subject=cast(str, args.nats_subject),
"result_export_format": args.result_export_format, max_frames=cast(int | None, args.max_frames),
"visualize": args.visualize, preprocess_only=cast(bool, args.preprocess_only),
} silhouette_export_path=cast(str | None, args.silhouette_export_path),
pipeline = ScoliosisPipeline(**pipeline_kwargs) silhouette_export_format=cast(str, args.silhouette_export_format),
silhouette_visualize_dir=cast(str | None, args.silhouette_visualize_dir),
result_export_path=cast(str | None, args.result_export_path),
result_export_format=cast(str, args.result_export_format),
visualize=cast(bool, args.visualize),
)
raise SystemExit(pipeline.run()) raise SystemExit(pipeline.run())
except ValueError as err: except ValueError as err:
print(f"Error: {err}", file=sys.stderr) print(f"Error: {err}", file=sys.stderr)
+18 -3
View File
@@ -18,6 +18,7 @@ logger = logging.getLogger(__name__)
# Type alias for frame stream: (frame_array, metadata_dict) # Type alias for frame stream: (frame_array, metadata_dict)
FrameStream = Iterable[tuple[np.ndarray, dict[str, object]]] FrameStream = Iterable[tuple[np.ndarray, dict[str, object]]]
# Protocol for cv-mmap metadata (needed at runtime for nested function annotation) # Protocol for cv-mmap metadata (needed at runtime for nested function annotation)
class _FrameMetadata(Protocol): class _FrameMetadata(Protocol):
frame_count: int frame_count: int
@@ -58,6 +59,13 @@ def opencv_source(
if not cap.isOpened(): if not cap.isOpened():
raise RuntimeError(f"Failed to open video source: {path}") raise RuntimeError(f"Failed to open video source: {path}")
is_file_source = isinstance(path, str)
source_fps = float(cap.get(cv2.CAP_PROP_FPS)) if is_file_source else 0.0
fps_valid = source_fps > 0.0 and np.isfinite(source_fps)
fallback_fps = source_fps if fps_valid else 30.0
fallback_interval_ns = int(1_000_000_000 / fallback_fps)
start_ns = time.monotonic_ns()
frame_idx = 0 frame_idx = 0
try: try:
while max_frames is None or frame_idx < max_frames: while max_frames is None or frame_idx < max_frames:
@@ -66,14 +74,22 @@ def opencv_source(
# End of stream # End of stream
break break
# Get timestamp if available (some backends support this) if is_file_source:
timestamp_ns = time.monotonic_ns() pos_msec = float(cap.get(cv2.CAP_PROP_POS_MSEC))
if np.isfinite(pos_msec) and pos_msec > 0.0:
timestamp_ns = start_ns + int(pos_msec * 1_000_000)
else:
timestamp_ns = start_ns + frame_idx * fallback_interval_ns
else:
timestamp_ns = time.monotonic_ns()
metadata: dict[str, object] = { metadata: dict[str, object] = {
"frame_count": frame_idx, "frame_count": frame_idx,
"timestamp_ns": timestamp_ns, "timestamp_ns": timestamp_ns,
"source": path, "source": path,
} }
if fps_valid:
metadata["source_fps"] = source_fps
yield frame, metadata yield frame, metadata
frame_idx += 1 frame_idx += 1
@@ -118,7 +134,6 @@ def cvmmap_source(
# Import cvmmap only when function is called # Import cvmmap only when function is called
# Use try/except for runtime import check # Use try/except for runtime import check
try: try:
from cvmmap import CvMmapClient as _CvMmapClientReal # pyright: ignore[reportMissingTypeStubs] from cvmmap import CvMmapClient as _CvMmapClientReal # pyright: ignore[reportMissingTypeStubs]
except ImportError as e: except ImportError as e:
raise ImportError( raise ImportError(
+105 -5
View File
@@ -5,7 +5,7 @@ from contextlib import suppress
import logging import logging
from pathlib import Path from pathlib import Path
import time import time
from typing import TYPE_CHECKING, Protocol, cast from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, cast
from beartype import beartype from beartype import beartype
import click import click
@@ -31,6 +31,16 @@ JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
JaxtypedFactory = Callable[..., JaxtypedDecorator] JaxtypedFactory = Callable[..., JaxtypedDecorator]
jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped) jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
WindowMode: TypeAlias = Literal["manual", "sliding", "chunked"]
def resolve_stride(window: int, stride: int, window_mode: WindowMode) -> int:
if window_mode == "manual":
return stride
if window_mode == "sliding":
return 1
return window
class _BoxesLike(Protocol): class _BoxesLike(Protocol):
@property @property
@@ -65,6 +75,27 @@ class _TrackCallable(Protocol):
) -> object: ... ) -> object: ...
class _FramePacer:
_interval_ns: int
_next_emit_ns: int | None
def __init__(self, target_fps: float) -> None:
if target_fps <= 0:
raise ValueError(f"target_fps must be positive, got {target_fps}")
self._interval_ns = int(1_000_000_000 / target_fps)
self._next_emit_ns = None
def should_emit(self, timestamp_ns: int) -> bool:
if self._next_emit_ns is None:
self._next_emit_ns = timestamp_ns + self._interval_ns
return True
if timestamp_ns >= self._next_emit_ns:
while self._next_emit_ns <= timestamp_ns:
self._next_emit_ns += self._interval_ns
return True
return False
class ScoliosisPipeline: class ScoliosisPipeline:
_detector: object _detector: object
_source: FrameStream _source: FrameStream
@@ -83,6 +114,7 @@ class ScoliosisPipeline:
_result_buffer: list[DemoResult] _result_buffer: list[DemoResult]
_visualizer: OpenCVVisualizer | None _visualizer: OpenCVVisualizer | None
_last_viz_payload: dict[str, object] | None _last_viz_payload: dict[str, object] | None
_frame_pacer: _FramePacer | None
def __init__( def __init__(
self, self,
@@ -104,6 +136,7 @@ class ScoliosisPipeline:
result_export_path: str | None = None, result_export_path: str | None = None,
result_export_format: str = "json", result_export_format: str = "json",
visualize: bool = False, visualize: bool = False,
target_fps: float | None = 15.0,
) -> None: ) -> None:
self._detector = YOLO(yolo_model) self._detector = YOLO(yolo_model)
self._source = create_source(source, max_frames=max_frames) self._source = create_source(source, max_frames=max_frames)
@@ -140,6 +173,7 @@ class ScoliosisPipeline:
else: else:
self._visualizer = None self._visualizer = None
self._last_viz_payload = None self._last_viz_payload = None
self._frame_pacer = _FramePacer(target_fps) if target_fps is not None else None
@staticmethod @staticmethod
def _extract_int(meta: dict[str, object], key: str, fallback: int) -> int: def _extract_int(meta: dict[str, object], key: str, fallback: int) -> int:
@@ -177,6 +211,7 @@ class ScoliosisPipeline:
Float[ndarray, "64 44"], Float[ndarray, "64 44"],
UInt8[ndarray, "h w"], UInt8[ndarray, "h w"],
BBoxXYXY, BBoxXYXY,
BBoxXYXY,
int, int,
] ]
| None | None
@@ -189,7 +224,7 @@ class ScoliosisPipeline:
mask_to_silhouette(self._to_mask_u8(mask_raw), bbox_mask), mask_to_silhouette(self._to_mask_u8(mask_raw), bbox_mask),
) )
if silhouette is not None: if silhouette is not None:
return silhouette, mask_raw, bbox_frame, int(track_id) return silhouette, mask_raw, bbox_frame, bbox_mask, int(track_id)
fallback = cast( fallback = cast(
tuple[UInt8[ndarray, "h w"], BBoxXYXY] | None, tuple[UInt8[ndarray, "h w"], BBoxXYXY] | None,
@@ -231,7 +266,7 @@ class ScoliosisPipeline:
# Fallback: use mask-space bbox if orig_shape unavailable # Fallback: use mask-space bbox if orig_shape unavailable
bbox_frame = bbox_mask bbox_frame = bbox_mask
# For fallback case, mask_raw is the same as mask_u8 # For fallback case, mask_raw is the same as mask_u8
return silhouette, mask_u8, bbox_frame, 0 return silhouette, mask_u8, bbox_frame, bbox_mask, 0
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
def process_frame( def process_frame(
@@ -262,7 +297,7 @@ class ScoliosisPipeline:
if selected is None: if selected is None:
return None return None
silhouette, mask_raw, bbox, track_id = selected silhouette, mask_raw, bbox, bbox_mask, track_id = selected
# Store silhouette for export if in preprocess-only mode or if export requested # Store silhouette for export if in preprocess-only mode or if export requested
if self._silhouette_export_path is not None or self._preprocess_only: if self._silhouette_export_path is not None or self._preprocess_only:
@@ -284,20 +319,39 @@ class ScoliosisPipeline:
return { return {
"mask_raw": mask_raw, "mask_raw": mask_raw,
"bbox": bbox, "bbox": bbox,
"bbox_mask": bbox_mask,
"silhouette": silhouette, "silhouette": silhouette,
"segmentation_input": None,
"track_id": track_id,
"label": None,
"confidence": None,
}
if self._frame_pacer is not None and not self._frame_pacer.should_emit(
timestamp_ns
):
return {
"mask_raw": mask_raw,
"bbox": bbox,
"bbox_mask": bbox_mask,
"silhouette": silhouette,
"segmentation_input": self._window.buffered_silhouettes,
"track_id": track_id, "track_id": track_id,
"label": None, "label": None,
"confidence": None, "confidence": None,
} }
self._window.push(silhouette, frame_idx=frame_idx, track_id=track_id) self._window.push(silhouette, frame_idx=frame_idx, track_id=track_id)
segmentation_input = self._window.buffered_silhouettes
if not self._window.should_classify(): if not self._window.should_classify():
# Return visualization payload even when not classifying yet # Return visualization payload even when not classifying yet
return { return {
"mask_raw": mask_raw, "mask_raw": mask_raw,
"bbox": bbox, "bbox": bbox,
"bbox_mask": bbox_mask,
"silhouette": silhouette, "silhouette": silhouette,
"segmentation_input": segmentation_input,
"track_id": track_id, "track_id": track_id,
"label": None, "label": None,
"confidence": None, "confidence": None,
@@ -330,7 +384,9 @@ class ScoliosisPipeline:
"result": result, "result": result,
"mask_raw": mask_raw, "mask_raw": mask_raw,
"bbox": bbox, "bbox": bbox,
"bbox_mask": bbox_mask,
"silhouette": silhouette, "silhouette": silhouette,
"segmentation_input": segmentation_input,
"track_id": track_id, "track_id": track_id,
"label": label, "label": label,
"confidence": confidence, "confidence": confidence,
@@ -400,7 +456,9 @@ class ScoliosisPipeline:
viz_dict = cast(dict[str, object], viz_data) viz_dict = cast(dict[str, object], viz_data)
mask_raw_obj = viz_dict.get("mask_raw") mask_raw_obj = viz_dict.get("mask_raw")
bbox_obj = viz_dict.get("bbox") bbox_obj = viz_dict.get("bbox")
bbox_mask_obj = viz_dict.get("bbox_mask")
silhouette_obj = viz_dict.get("silhouette") silhouette_obj = viz_dict.get("silhouette")
segmentation_input_obj = viz_dict.get("segmentation_input")
track_id_val = viz_dict.get("track_id", 0) track_id_val = viz_dict.get("track_id", 0)
track_id = track_id_val if isinstance(track_id_val, int) else 0 track_id = track_id_val if isinstance(track_id_val, int) else 0
label_obj = viz_dict.get("label") label_obj = viz_dict.get("label")
@@ -409,24 +467,33 @@ class ScoliosisPipeline:
# Cast extracted values to expected types # Cast extracted values to expected types
mask_raw = cast(NDArray[np.uint8] | None, mask_raw_obj) mask_raw = cast(NDArray[np.uint8] | None, mask_raw_obj)
bbox = cast(BBoxXYXY | None, bbox_obj) bbox = cast(BBoxXYXY | None, bbox_obj)
bbox_mask = cast(BBoxXYXY | None, bbox_mask_obj)
silhouette = cast(NDArray[np.float32] | None, silhouette_obj) silhouette = cast(NDArray[np.float32] | None, silhouette_obj)
segmentation_input = cast(
NDArray[np.float32] | None,
segmentation_input_obj,
)
label = cast(str | None, label_obj) label = cast(str | None, label_obj)
confidence = cast(float | None, confidence_obj) confidence = cast(float | None, confidence_obj)
else: else:
# No detection and no cache - use default values # No detection and no cache - use default values
mask_raw = None mask_raw = None
bbox = None bbox = None
bbox_mask = None
track_id = 0 track_id = 0
silhouette = None silhouette = None
segmentation_input = None
label = None label = None
confidence = None confidence = None
keep_running = self._visualizer.update( keep_running = self._visualizer.update(
frame_u8, frame_u8,
bbox, bbox,
bbox_mask,
track_id, track_id,
mask_raw, mask_raw,
silhouette, silhouette,
segmentation_input,
label, label,
confidence, confidence,
ema_fps, ema_fps,
@@ -671,6 +738,23 @@ def validate_runtime_inputs(source: str, checkpoint: str, config: str) -> None:
) )
@click.option("--window", type=click.IntRange(min=1), default=30, show_default=True) @click.option("--window", type=click.IntRange(min=1), default=30, show_default=True)
@click.option("--stride", type=click.IntRange(min=1), default=30, show_default=True) @click.option("--stride", type=click.IntRange(min=1), default=30, show_default=True)
@click.option(
"--window-mode",
type=click.Choice(["manual", "sliding", "chunked"], case_sensitive=False),
default="manual",
show_default=True,
help=(
"Window scheduling mode: manual uses --stride; "
"sliding forces stride=1; chunked forces stride=window"
),
)
@click.option(
"--target-fps",
type=click.FloatRange(min=0.1),
default=15.0,
show_default=True,
)
@click.option("--no-target-fps", is_flag=True, default=False)
@click.option("--nats-url", type=str, default=None) @click.option("--nats-url", type=str, default=None)
@click.option( @click.option(
"--nats-subject", "--nats-subject",
@@ -725,6 +809,9 @@ def main(
yolo_model: str, yolo_model: str,
window: int, window: int,
stride: int, stride: int,
window_mode: str,
target_fps: float | None,
no_target_fps: bool,
nats_url: str | None, nats_url: str | None,
nats_subject: str, nats_subject: str,
max_frames: int | None, max_frames: int | None,
@@ -748,6 +835,18 @@ def main(
try: try:
validate_runtime_inputs(source=source, checkpoint=checkpoint, config=config) validate_runtime_inputs(source=source, checkpoint=checkpoint, config=config)
effective_stride = resolve_stride(
window=window,
stride=stride,
window_mode=cast(WindowMode, window_mode.lower()),
)
if effective_stride != stride:
logger.info(
"window_mode=%s overrides stride=%d -> effective_stride=%d",
window_mode,
stride,
effective_stride,
)
pipeline = ScoliosisPipeline( pipeline = ScoliosisPipeline(
source=source, source=source,
checkpoint=checkpoint, checkpoint=checkpoint,
@@ -755,7 +854,8 @@ def main(
device=device, device=device,
yolo_model=yolo_model, yolo_model=yolo_model,
window=window, window=window,
stride=stride, stride=effective_stride,
target_fps=None if no_target_fps else target_fps,
nats_url=nats_url, nats_url=nats_url,
nats_subject=nats_subject, nats_subject=nats_subject,
max_frames=max_frames, max_frames=max_frames,
+3 -1
View File
@@ -23,8 +23,10 @@ jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
UInt8Array = NDArray[np.uint8] UInt8Array = NDArray[np.uint8]
Float32Array = NDArray[np.float32] Float32Array = NDArray[np.float32]
#: Bounding box in XYXY format: (x1, y1, x2, y2) where (x1,y1) is top-left and (x2,y2) is bottom-right.
BBoxXYXY = tuple[int, int, int, int] BBoxXYXY = tuple[int, int, int, int]
"""
Bounding box in XYXY format: (x1, y1, x2, y2) where (x1,y1) is top-left and (x2,y2) is bottom-right.
"""
def _read_attr(container: object, key: str) -> object | None: def _read_attr(container: object, key: str) -> object | None:
+246 -122
View File
@@ -20,7 +20,9 @@ logger = logging.getLogger(__name__)
# Window names # Window names
MAIN_WINDOW = "Scoliosis Detection" MAIN_WINDOW = "Scoliosis Detection"
SEG_WINDOW = "Segmentation" SEG_WINDOW = "Normalized Silhouette"
RAW_WINDOW = "Raw Mask"
WINDOW_SEG_INPUT = "Segmentation Input"
# Silhouette dimensions (from preprocess.py) # Silhouette dimensions (from preprocess.py)
SIL_HEIGHT = 64 SIL_HEIGHT = 64
@@ -29,43 +31,45 @@ SIL_WIDTH = 44
# Display dimensions for upscaled silhouette # Display dimensions for upscaled silhouette
DISPLAY_HEIGHT = 256 DISPLAY_HEIGHT = 256
DISPLAY_WIDTH = 176 DISPLAY_WIDTH = 176
RAW_STATS_PAD = 54
MODE_LABEL_PAD = 26
# Colors (BGR) # Colors (BGR)
COLOR_GREEN = (0, 255, 0) COLOR_GREEN = (0, 255, 0)
COLOR_WHITE = (255, 255, 255) COLOR_WHITE = (255, 255, 255)
COLOR_BLACK = (0, 0, 0) COLOR_BLACK = (0, 0, 0)
COLOR_DARK_GRAY = (56, 56, 56)
COLOR_RED = (0, 0, 255) COLOR_RED = (0, 0, 255)
COLOR_YELLOW = (0, 255, 255) COLOR_YELLOW = (0, 255, 255)
# Mode labels
MODE_LABELS = ["Both", "Raw Mask", "Normalized"]
# Type alias for image arrays (NDArray or cv2.Mat) # Type alias for image arrays (NDArray or cv2.Mat)
ImageArray = NDArray[np.uint8] ImageArray = NDArray[np.uint8]
class OpenCVVisualizer: class OpenCVVisualizer:
"""Real-time visualizer for gait analysis demo.
Displays two windows:
- Main stream: Original frame with bounding box and metadata overlay
- Segmentation: Raw mask, normalized silhouette, or side-by-side view
Supports interactive mode switching via keyboard.
"""
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize visualizer with default mask mode.""" self.show_raw_window: bool = False
self.mask_mode: int = 0 # 0: Both, 1: Raw, 2: Normalized self.show_raw_debug: bool = False
self._windows_created: bool = False self._windows_created: bool = False
self._raw_window_created: bool = False
def _ensure_windows(self) -> None: def _ensure_windows(self) -> None:
"""Create OpenCV windows if not already created."""
if not self._windows_created: if not self._windows_created:
cv2.namedWindow(MAIN_WINDOW, cv2.WINDOW_NORMAL) cv2.namedWindow(MAIN_WINDOW, cv2.WINDOW_NORMAL)
cv2.namedWindow(SEG_WINDOW, cv2.WINDOW_NORMAL) cv2.namedWindow(SEG_WINDOW, cv2.WINDOW_NORMAL)
cv2.namedWindow(WINDOW_SEG_INPUT, cv2.WINDOW_NORMAL)
self._windows_created = True self._windows_created = True
def _ensure_raw_window(self) -> None:
if not self._raw_window_created:
cv2.namedWindow(RAW_WINDOW, cv2.WINDOW_NORMAL)
self._raw_window_created = True
def _hide_raw_window(self) -> None:
if self._raw_window_created:
cv2.destroyWindow(RAW_WINDOW)
self._raw_window_created = False
def _draw_bbox( def _draw_bbox(
self, self,
frame: ImageArray, frame: ImageArray,
@@ -215,33 +219,181 @@ class OpenCVVisualizer:
return upscaled return upscaled
def _normalize_mask_for_display(self, mask: NDArray[np.generic]) -> ImageArray:
mask_array = np.asarray(mask)
if mask_array.dtype == np.bool_:
bool_scaled = np.where(mask_array, np.uint8(255), np.uint8(0)).astype(
np.uint8
)
return cast(ImageArray, bool_scaled)
if mask_array.dtype == np.uint8:
mask_array = cast(ImageArray, mask_array)
max_u8 = int(np.max(mask_array)) if mask_array.size > 0 else 0
if max_u8 <= 1:
scaled_u8 = np.where(mask_array > 0, np.uint8(255), np.uint8(0)).astype(
np.uint8
)
return cast(ImageArray, scaled_u8)
return cast(ImageArray, mask_array)
if np.issubdtype(mask_array.dtype, np.integer):
max_int = float(np.max(mask_array)) if mask_array.size > 0 else 0.0
if max_int <= 1.0:
return cast(
ImageArray, (mask_array.astype(np.float32) * 255.0).astype(np.uint8)
)
clipped = np.clip(mask_array, 0, 255).astype(np.uint8)
return cast(ImageArray, clipped)
mask_float = np.asarray(mask_array, dtype=np.float32)
max_val = float(np.max(mask_float)) if mask_float.size > 0 else 0.0
if max_val <= 0.0:
return np.zeros(mask_float.shape, dtype=np.uint8)
normalized = np.clip((mask_float / max_val) * 255.0, 0.0, 255.0).astype(
np.uint8
)
return cast(ImageArray, normalized)
def _draw_raw_stats(self, image: ImageArray, mask_raw: ImageArray | None) -> None:
if mask_raw is None:
return
mask = np.asarray(mask_raw)
if mask.size == 0:
return
stats = [
f"raw: {mask.dtype}",
f"min/max: {float(mask.min()):.3f}/{float(mask.max()):.3f}",
f"nnz: {int(np.count_nonzero(mask))}",
]
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.45
thickness = 1
line_h = 18
x0 = 8
y0 = 20
for i, txt in enumerate(stats):
y = y0 + i * line_h
(tw, th), _ = cv2.getTextSize(txt, font, font_scale, thickness)
_ = cv2.rectangle(
image, (x0 - 4, y - th - 4), (x0 + tw + 4, y + 4), COLOR_BLACK, -1
)
_ = cv2.putText(
image, txt, (x0, y), font, font_scale, COLOR_YELLOW, thickness
)
def _prepare_segmentation_view( def _prepare_segmentation_view(
self, self,
mask_raw: ImageArray | None, mask_raw: ImageArray | None,
silhouette: NDArray[np.float32] | None, silhouette: NDArray[np.float32] | None,
bbox: BBoxXYXY | None,
) -> ImageArray: ) -> ImageArray:
"""Prepare segmentation window content based on current mode. _ = mask_raw
_ = bbox
return self._prepare_normalized_view(silhouette)
Args: def _fit_gray_to_display(
mask_raw: Raw binary mask (H, W) uint8 or None self,
silhouette: Normalized silhouette (64, 44) float32 or None gray: ImageArray,
out_h: int = DISPLAY_HEIGHT,
out_w: int = DISPLAY_WIDTH,
) -> ImageArray:
src_h, src_w = gray.shape[:2]
if src_h <= 0 or src_w <= 0:
return np.zeros((out_h, out_w), dtype=np.uint8)
Returns: scale = min(out_w / src_w, out_h / src_h)
Displayable image (H, W, 3) uint8 new_w = max(1, int(round(src_w * scale)))
""" new_h = max(1, int(round(src_h * scale)))
if self.mask_mode == 0:
# Mode 0: Both (side by side) resized = cast(
return self._prepare_both_view(mask_raw, silhouette) ImageArray,
elif self.mask_mode == 1: cv2.resize(gray, (new_w, new_h), interpolation=cv2.INTER_NEAREST),
# Mode 1: Raw mask only )
return self._prepare_raw_view(mask_raw) canvas = np.zeros((out_h, out_w), dtype=np.uint8)
else: x0 = (out_w - new_w) // 2
# Mode 2: Normalized silhouette only y0 = (out_h - new_h) // 2
return self._prepare_normalized_view(silhouette) canvas[y0 : y0 + new_h, x0 : x0 + new_w] = resized
return cast(ImageArray, canvas)
def _crop_mask_to_bbox(
self,
mask_gray: ImageArray,
bbox: BBoxXYXY | None,
) -> ImageArray:
if bbox is None:
return mask_gray
h, w = mask_gray.shape[:2]
x1, y1, x2, y2 = bbox
x1c = max(0, min(w, int(x1)))
x2c = max(0, min(w, int(x2)))
y1c = max(0, min(h, int(y1)))
y2c = max(0, min(h, int(y2)))
if x2c <= x1c or y2c <= y1c:
return mask_gray
cropped = mask_gray[y1c:y2c, x1c:x2c]
if cropped.size == 0:
return mask_gray
return cast(ImageArray, cropped)
def _prepare_segmentation_input_view(
self,
silhouettes: NDArray[np.float32] | None,
) -> ImageArray:
if silhouettes is None or silhouettes.size == 0:
placeholder = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH, 3), dtype=np.uint8)
self._draw_mode_indicator(placeholder, "Input Silhouettes (No Data)")
return placeholder
n_frames = int(silhouettes.shape[0])
tiles_per_row = int(np.ceil(np.sqrt(n_frames)))
rows = int(np.ceil(n_frames / tiles_per_row))
tile_h = DISPLAY_HEIGHT
tile_w = DISPLAY_WIDTH
grid = np.zeros((rows * tile_h, tiles_per_row * tile_w), dtype=np.uint8)
for idx in range(n_frames):
sil = silhouettes[idx]
tile = self._upscale_silhouette(sil)
r = idx // tiles_per_row
c = idx % tiles_per_row
y0, y1 = r * tile_h, (r + 1) * tile_h
x0, x1 = c * tile_w, (c + 1) * tile_w
grid[y0:y1, x0:x1] = tile
grid_bgr = cast(ImageArray, cv2.cvtColor(grid, cv2.COLOR_GRAY2BGR))
for idx in range(n_frames):
r = idx // tiles_per_row
c = idx % tiles_per_row
y0 = r * tile_h
x0 = c * tile_w
cv2.putText(
grid_bgr,
str(idx),
(x0 + 8, y0 + 22),
cv2.FONT_HERSHEY_SIMPLEX,
0.6,
(0, 255, 255),
2,
cv2.LINE_AA,
)
return grid_bgr
def _prepare_raw_view( def _prepare_raw_view(
self, self,
mask_raw: ImageArray | None, mask_raw: ImageArray | None,
bbox: BBoxXYXY | None = None,
) -> ImageArray: ) -> ImageArray:
"""Prepare raw mask view. """Prepare raw mask view.
@@ -261,20 +413,23 @@ class OpenCVVisualizer:
if len(mask_raw.shape) == 3: if len(mask_raw.shape) == 3:
mask_gray = cast(ImageArray, cv2.cvtColor(mask_raw, cv2.COLOR_BGR2GRAY)) mask_gray = cast(ImageArray, cv2.cvtColor(mask_raw, cv2.COLOR_BGR2GRAY))
else: else:
mask_gray = mask_raw mask_gray = cast(ImageArray, mask_raw)
# Resize to display size mask_gray = self._normalize_mask_for_display(mask_gray)
mask_resized = cast( mask_gray = self._crop_mask_to_bbox(mask_gray, bbox)
ImageArray,
cv2.resize( debug_pad = RAW_STATS_PAD if self.show_raw_debug else 0
mask_gray, content_h = max(1, DISPLAY_HEIGHT - debug_pad - MODE_LABEL_PAD)
(DISPLAY_WIDTH, DISPLAY_HEIGHT), mask_resized = self._fit_gray_to_display(
interpolation=cv2.INTER_NEAREST, mask_gray, out_h=content_h, out_w=DISPLAY_WIDTH
),
) )
full_mask = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH), dtype=np.uint8)
full_mask[debug_pad : debug_pad + content_h, :] = mask_resized
# Convert to BGR for display # Convert to BGR for display
mask_bgr = cast(ImageArray, cv2.cvtColor(mask_resized, cv2.COLOR_GRAY2BGR)) mask_bgr = cast(ImageArray, cv2.cvtColor(full_mask, cv2.COLOR_GRAY2BGR))
if self.show_raw_debug:
self._draw_raw_stats(mask_bgr, mask_raw)
self._draw_mode_indicator(mask_bgr, "Raw Mask") self._draw_mode_indicator(mask_bgr, "Raw Mask")
return mask_bgr return mask_bgr
@@ -299,80 +454,21 @@ class OpenCVVisualizer:
# Upscale and convert # Upscale and convert
upscaled = self._upscale_silhouette(silhouette) upscaled = self._upscale_silhouette(silhouette)
sil_bgr = cast(ImageArray, cv2.cvtColor(upscaled, cv2.COLOR_GRAY2BGR)) content_h = max(1, DISPLAY_HEIGHT - MODE_LABEL_PAD)
sil_compact = self._fit_gray_to_display(
upscaled, out_h=content_h, out_w=DISPLAY_WIDTH
)
sil_canvas = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH), dtype=np.uint8)
sil_canvas[:content_h, :] = sil_compact
sil_bgr = cast(ImageArray, cv2.cvtColor(sil_canvas, cv2.COLOR_GRAY2BGR))
self._draw_mode_indicator(sil_bgr, "Normalized") self._draw_mode_indicator(sil_bgr, "Normalized")
return sil_bgr return sil_bgr
def _prepare_both_view( def _draw_mode_indicator(self, image: ImageArray, label: str) -> None:
self,
mask_raw: ImageArray | None,
silhouette: NDArray[np.float32] | None,
) -> ImageArray:
"""Prepare side-by-side view of both masks.
Args:
mask_raw: Raw binary mask or None
silhouette: Normalized silhouette or None
Returns:
Displayable side-by-side image
"""
# Prepare individual views without mode indicators (will be drawn on combined)
# Raw view preparation (without indicator)
if mask_raw is None:
raw_gray = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH), dtype=np.uint8)
else:
if len(mask_raw.shape) == 3:
mask_gray = cast(ImageArray, cv2.cvtColor(mask_raw, cv2.COLOR_BGR2GRAY))
else:
mask_gray = mask_raw
# Normalize to uint8 [0,255] for display (handles both float [0,1] and uint8 inputs)
if mask_gray.dtype == np.float32 or mask_gray.dtype == np.float64:
mask_gray = (mask_gray * 255).astype(np.uint8)
raw_gray = cast(
ImageArray,
cv2.resize(
mask_gray,
(DISPLAY_WIDTH, DISPLAY_HEIGHT),
interpolation=cv2.INTER_NEAREST,
),
)
# Normalized view preparation (without indicator)
if silhouette is None:
norm_gray = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH), dtype=np.uint8)
else:
upscaled = self._upscale_silhouette(silhouette)
norm_gray = upscaled
# Stack horizontally
combined = np.hstack([raw_gray, norm_gray])
# Convert back to BGR
combined_bgr = cast(ImageArray, cv2.cvtColor(combined, cv2.COLOR_GRAY2BGR))
# Add mode indicator
self._draw_mode_indicator(combined_bgr, "Both: Raw | Normalized")
return combined_bgr
def _draw_mode_indicator(
self,
image: ImageArray,
label: str,
) -> None:
"""Draw mode indicator text on image.
Args:
image: Image to draw on (modified in place)
label: Mode label text
"""
h, w = image.shape[:2] h, w = image.shape[:2]
# Mode text at bottom mode_text = label
mode_text = f"Mode: {MODE_LABELS[self.mask_mode]} ({self.mask_mode}) - {label}"
font = cv2.FONT_HERSHEY_SIMPLEX font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.5 font_scale = 0.5
@@ -383,15 +479,22 @@ class OpenCVVisualizer:
mode_text, font, font_scale, thickness mode_text, font, font_scale, thickness
) )
# Draw background at bottom center x_pos = 14
x_pos = (w - text_width) // 2 y_pos = h - 8
y_pos = h - 10 y_top = max(0, h - MODE_LABEL_PAD)
_ = cv2.rectangle( _ = cv2.rectangle(
image, image,
(x_pos - 5, y_pos - text_height - 5), (0, y_top),
(x_pos + text_width + 5, y_pos + 5), (w, h),
COLOR_BLACK, COLOR_DARK_GRAY,
-1,
)
_ = cv2.rectangle(
image,
(x_pos - 6, y_pos - text_height - 6),
(x_pos + text_width + 8, y_pos + 6),
COLOR_DARK_GRAY,
-1, -1,
) )
@@ -410,9 +513,11 @@ class OpenCVVisualizer:
self, self,
frame: ImageArray, frame: ImageArray,
bbox: BBoxXYXY | None, bbox: BBoxXYXY | None,
bbox_mask: BBoxXYXY | None,
track_id: int, track_id: int,
mask_raw: ImageArray | None, mask_raw: ImageArray | None,
silhouette: NDArray[np.float32] | None, silhouette: NDArray[np.float32] | None,
segmentation_input: NDArray[np.float32] | None,
label: str | None, label: str | None,
confidence: float | None, confidence: float | None,
fps: float, fps: float,
@@ -441,23 +546,42 @@ class OpenCVVisualizer:
cv2.imshow(MAIN_WINDOW, main_display) cv2.imshow(MAIN_WINDOW, main_display)
# Prepare and show segmentation window # Prepare and show segmentation window
seg_display = self._prepare_segmentation_view(mask_raw, silhouette) seg_display = self._prepare_segmentation_view(mask_raw, silhouette, bbox)
cv2.imshow(SEG_WINDOW, seg_display) cv2.imshow(SEG_WINDOW, seg_display)
if self.show_raw_window:
self._ensure_raw_window()
raw_display = self._prepare_raw_view(mask_raw, bbox_mask)
cv2.imshow(RAW_WINDOW, raw_display)
seg_input_display = self._prepare_segmentation_input_view(segmentation_input)
cv2.imshow(WINDOW_SEG_INPUT, seg_input_display)
# Handle keyboard input # Handle keyboard input
key = cv2.waitKey(1) & 0xFF key = cv2.waitKey(1) & 0xFF
if key == ord("q"): if key == ord("q"):
return False return False
elif key == ord("m"): elif key == ord("r"):
# Cycle through modes: 0 -> 1 -> 2 -> 0 self.show_raw_window = not self.show_raw_window
self.mask_mode = (self.mask_mode + 1) % 3 if self.show_raw_window:
logger.debug("Switched to mask mode: %s", MODE_LABELS[self.mask_mode]) self._ensure_raw_window()
logger.debug("Raw mask window enabled")
else:
self._hide_raw_window()
logger.debug("Raw mask window disabled")
elif key == ord("d"):
self.show_raw_debug = not self.show_raw_debug
logger.debug(
"Raw mask debug overlay %s",
"enabled" if self.show_raw_debug else "disabled",
)
return True return True
def close(self) -> None: def close(self) -> None:
"""Close all OpenCV windows and cleanup."""
if self._windows_created: if self._windows_created:
self._hide_raw_window()
cv2.destroyAllWindows() cv2.destroyAllWindows()
self._windows_created = False self._windows_created = False
self._raw_window_created = False
+9
View File
@@ -216,6 +216,15 @@ class SilhouetteWindow:
raise ValueError("Window is empty") raise ValueError("Window is empty")
return int(self._frame_indices[0]) return int(self._frame_indices[0])
@property
def buffered_silhouettes(self) -> Float[ndarray, "n 64 44"]:
if not self._buffer:
return np.empty((0, SIL_HEIGHT, SIL_WIDTH), dtype=np.float32)
return cast(
Float[ndarray, "n 64 44"],
np.stack(list(self._buffer), axis=0).astype(np.float32, copy=True),
)
def _to_numpy(obj: _ArrayLike) -> ndarray: def _to_numpy(obj: _ArrayLike) -> ndarray:
"""Safely convert array-like object to numpy array. """Safely convert array-like object to numpy array.
View File
+394
View File
@@ -0,0 +1,394 @@
#!/usr/bin/env python3
"""
Export all positive labeled batches from Scoliosis1K dataset as time windows.
Creates grid visualizations similar to visualizer._prepare_segmentation_input_view()
for all positive class samples, arranged in sliding time windows.
Optimized UI with:
- Subject ID and batch info footer
- Dual frame counts (window-relative and sequence-relative)
- Clean layout with proper spacing
"""
from __future__ import annotations
import json
import pickle
from pathlib import Path
from typing import Final
import cv2
import numpy as np
from numpy.typing import NDArray
# Constants matching visualizer.py
DISPLAY_HEIGHT: Final = 256
DISPLAY_WIDTH: Final = 176
SIL_HEIGHT: Final = 64
SIL_WIDTH: Final = 44
# Footer settings
FOOTER_HEIGHT: Final = 80 # Height for metadata footer
FOOTER_BG_COLOR: Final = (40, 40, 40) # Dark gray background
TEXT_COLOR: Final = (255, 255, 255) # White text
ACCENT_COLOR: Final = (0, 255, 255) # Cyan for emphasis
FONT: Final = cv2.FONT_HERSHEY_SIMPLEX
FONT_SCALE: Final = 0.6
FONT_THICKNESS: Final = 2
def upscale_silhouette(
silhouette: NDArray[np.float32] | NDArray[np.uint8],
) -> NDArray[np.uint8]:
"""Upscale silhouette to display size."""
if silhouette.dtype == np.float32 or silhouette.dtype == np.float64:
sil_u8 = (silhouette * 255).astype(np.uint8)
else:
sil_u8 = silhouette.astype(np.uint8)
upscaled = cv2.resize(
sil_u8, (DISPLAY_WIDTH, DISPLAY_HEIGHT), interpolation=cv2.INTER_NEAREST
)
return upscaled
def create_optimized_visualization(
silhouettes: NDArray[np.float32],
subject_id: str,
view_name: str,
window_idx: int,
start_frame: int,
end_frame: int,
n_frames_total: int,
tile_height: int = DISPLAY_HEIGHT,
tile_width: int = DISPLAY_WIDTH,
) -> NDArray[np.uint8]:
"""
Create optimized visualization with grid and metadata footer.
Args:
silhouettes: Array of shape (n_frames, 64, 44) float32
subject_id: Subject identifier
view_name: View identifier (e.g., "000_180")
window_idx: Window index within sequence
start_frame: Starting frame index in sequence
end_frame: Ending frame index in sequence
n_frames_total: Total frames in the sequence
tile_height: Height of each tile in the grid
tile_width: Width of each tile in the grid
Returns:
Combined image with grid visualization and metadata footer
"""
n_frames = int(silhouettes.shape[0])
tiles_per_row = int(np.ceil(np.sqrt(n_frames)))
rows = int(np.ceil(n_frames / tiles_per_row))
# Create grid
grid = np.zeros((rows * tile_height, tiles_per_row * tile_width), dtype=np.uint8)
# Place each silhouette in the grid
for idx in range(n_frames):
sil = silhouettes[idx]
tile = upscale_silhouette(sil)
r = idx // tiles_per_row
c = idx % tiles_per_row
y0, y1 = r * tile_height, (r + 1) * tile_height
x0, x1 = c * tile_width, (c + 1) * tile_width
grid[y0:y1, x0:x1] = tile
# Convert to BGR
grid_bgr = cv2.cvtColor(grid, cv2.COLOR_GRAY2BGR)
# Add frame indices as text (both window-relative and sequence-relative)
for idx in range(n_frames):
r = idx // tiles_per_row
c = idx % tiles_per_row
y0 = r * tile_height
x0 = c * tile_width
# Window frame count (top-left)
cv2.putText(
grid_bgr,
f"{idx}", # Window-relative frame number
(x0 + 8, y0 + 22),
FONT,
FONT_SCALE,
ACCENT_COLOR,
FONT_THICKNESS,
cv2.LINE_AA,
)
# Sequence frame count (bottom-left of tile)
seq_frame = start_frame + idx
cv2.putText(
grid_bgr,
f"#{seq_frame}", # Sequence-relative frame number
(x0 + 8, y0 + tile_height - 10),
FONT,
0.45, # Slightly smaller font
(180, 180, 180), # Light gray
1,
cv2.LINE_AA,
)
# Create footer with metadata
grid_width = grid_bgr.shape[1]
footer = np.full((FOOTER_HEIGHT, grid_width, 3), FOOTER_BG_COLOR, dtype=np.uint8)
# Line 1: Subject ID and view
line1 = f"Subject: {subject_id} | View: {view_name}"
cv2.putText(
footer,
line1,
(15, 25),
FONT,
0.7,
TEXT_COLOR,
FONT_THICKNESS,
cv2.LINE_AA,
)
# Line 2: Window batch frame range
line2 = f"Window {window_idx}: frames [{start_frame:03d} - {end_frame - 1:03d}] ({n_frames} frames)"
cv2.putText(
footer,
line2,
(15, 50),
FONT,
0.7,
ACCENT_COLOR,
FONT_THICKNESS,
cv2.LINE_AA,
)
# Line 3: Progress within sequence
progress_pct = (end_frame / n_frames_total) * 100
line3 = f"Sequence: {n_frames_total} frames total | Progress: {progress_pct:.1f}%"
cv2.putText(
footer,
line3,
(15, 72),
FONT,
0.6,
(200, 200, 200),
1,
cv2.LINE_AA,
)
# Combine grid and footer
combined = np.vstack([grid_bgr, footer])
return combined
def load_pkl_sequence(pkl_path: Path) -> NDArray[np.float32]:
"""Load a .pkl file containing silhouette sequence."""
with open(pkl_path, "rb") as f:
data = pickle.load(f)
# Handle different possible structures
if isinstance(data, np.ndarray):
return data.astype(np.float32)
elif isinstance(data, list):
# List of frames
return np.stack([np.array(frame) for frame in data]).astype(np.float32)
else:
raise ValueError(f"Unexpected data type in {pkl_path}: {type(data)}")
def create_windows(
sequence: NDArray[np.float32],
window_size: int = 30,
stride: int = 30,
) -> list[NDArray[np.float32]]:
"""
Split a sequence into sliding windows.
Args:
sequence: Array of shape (N, 64, 44)
window_size: Number of frames per window
stride: Stride between consecutive windows
Returns:
List of window arrays, each of shape (window_size, 64, 44)
"""
n_frames = sequence.shape[0]
windows = []
for start_idx in range(0, n_frames - window_size + 1, stride):
end_idx = start_idx + window_size
window = sequence[start_idx:end_idx]
windows.append(window)
return windows
def export_positive_batches(
dataset_root: Path,
output_dir: Path,
window_size: int = 30,
stride: int = 30,
max_sequences: int | None = None,
) -> None:
"""
Export all positive labeled batches from Scoliosis1K dataset as time windows.
Args:
dataset_root: Path to Scoliosis1K-sil-pkl directory
output_dir: Output directory for visualizations
window_size: Number of frames per window (default 30)
stride: Stride between consecutive windows (default 30 = non-overlapping)
max_sequences: Maximum number of sequences to process (None = all)
"""
output_dir.mkdir(parents=True, exist_ok=True)
# Find all positive samples
positive_samples: list[
tuple[Path, str, str, str]
] = [] # (pkl_path, subject_id, view_name, pkl_name)
for subject_dir in sorted(dataset_root.iterdir()):
if not subject_dir.is_dir():
continue
subject_id = subject_dir.name
# Check for positive class directory (lowercase)
positive_dir = subject_dir / "positive"
if not positive_dir.exists():
continue
# Iterate through views
for view_dir in sorted(positive_dir.iterdir()):
if not view_dir.is_dir():
continue
view_name = view_dir.name
# Find .pkl files
for pkl_file in sorted(view_dir.glob("*.pkl")):
positive_samples.append(
(pkl_file, subject_id, view_name, pkl_file.stem)
)
print(f"Found {len(positive_samples)} positive labeled sequences")
if max_sequences:
positive_samples = positive_samples[:max_sequences]
print(f"Processing first {max_sequences} sequences")
total_windows = 0
# Export each sequence's windows
for seq_idx, (pkl_path, subject_id, view_name, pkl_name) in enumerate(
positive_samples, 1
):
print(
f"[{seq_idx}/{len(positive_samples)}] Processing {subject_id}/{view_name}/{pkl_name}..."
)
# Load sequence
try:
sequence = load_pkl_sequence(pkl_path)
except Exception as e:
print(f" Error loading {pkl_path}: {e}")
continue
# Ensure correct shape (N, 64, 44)
if len(sequence.shape) == 2:
# Single frame
sequence = sequence[np.newaxis, ...]
elif len(sequence.shape) == 3:
# (N, H, W) - expected
pass
else:
print(f" Unexpected shape {sequence.shape}, skipping")
continue
n_frames = sequence.shape[0]
print(f" Sequence has {n_frames} frames")
# Skip if sequence is shorter than window size
if n_frames < window_size:
print(f" Skipping: sequence too short (< {window_size} frames)")
continue
# Normalize if needed
if sequence.max() > 1.0:
sequence = sequence / 255.0
# Create windows
windows = create_windows(sequence, window_size=window_size, stride=stride)
print(f" Created {len(windows)} windows (size={window_size}, stride={stride})")
# Export each window
for window_idx, window in enumerate(windows):
start_frame = window_idx * stride
end_frame = start_frame + window_size
# Create visualization for this window with full metadata
vis_image = create_optimized_visualization(
silhouettes=window,
subject_id=subject_id,
view_name=view_name,
window_idx=window_idx,
start_frame=start_frame,
end_frame=end_frame,
n_frames_total=n_frames,
)
# Save with descriptive filename including window index
output_filename = (
f"{subject_id}_{view_name}_{pkl_name}_win{window_idx:03d}.png"
)
output_path = output_dir / output_filename
cv2.imwrite(str(output_path), vis_image)
# Save metadata for this window
meta = {
"subject_id": subject_id,
"view": view_name,
"pkl_name": pkl_name,
"window_index": window_idx,
"window_size": window_size,
"stride": stride,
"start_frame": start_frame,
"end_frame": end_frame,
"sequence_shape": sequence.shape,
"n_frames_total": n_frames,
"source_path": str(pkl_path),
}
meta_filename = (
f"{subject_id}_{view_name}_{pkl_name}_win{window_idx:03d}.json"
)
meta_path = output_dir / meta_filename
with open(meta_path, "w") as f:
json.dump(meta, f, indent=2)
total_windows += 1
print(f" Exported {len(windows)} windows")
print(f"\nExport complete! Saved {total_windows} windows to {output_dir}")
def main() -> None:
"""Main entry point."""
# Paths
dataset_root = Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-sil-pkl")
output_dir = Path("/home/crosstyan/Code/OpenGait/output/positive_batches")
if not dataset_root.exists():
print(f"Error: Dataset not found at {dataset_root}")
return
# Export all positive batches with windowing
export_positive_batches(
dataset_root,
output_dir,
window_size=30, # 30 frames per window
stride=30, # Non-overlapping windows
)
if __name__ == "__main__":
main()
+66 -10
View File
@@ -7,7 +7,7 @@ from pathlib import Path
import subprocess import subprocess
import sys import sys
import time import time
from typing import Final, cast from typing import Final, Literal, cast
from unittest import mock from unittest import mock
import numpy as np import numpy as np
@@ -693,9 +693,11 @@ class MockVisualizer:
self, self,
frame: NDArray[np.uint8], frame: NDArray[np.uint8],
bbox: tuple[int, int, int, int] | None, bbox: tuple[int, int, int, int] | None,
bbox_mask: tuple[int, int, int, int] | None,
track_id: int, track_id: int,
mask_raw: NDArray[np.uint8] | None, mask_raw: NDArray[np.uint8] | None,
silhouette: NDArray[np.float32] | None, silhouette: NDArray[np.float32] | None,
segmentation_input: NDArray[np.float32] | None,
label: str | None, label: str | None,
confidence: float | None, confidence: float | None,
fps: float, fps: float,
@@ -704,9 +706,11 @@ class MockVisualizer:
{ {
"frame": frame, "frame": frame,
"bbox": bbox, "bbox": bbox,
"bbox_mask": bbox_mask,
"track_id": track_id, "track_id": track_id,
"mask_raw": mask_raw, "mask_raw": mask_raw,
"silhouette": silhouette, "silhouette": silhouette,
"segmentation_input": segmentation_input,
"label": label, "label": label,
"confidence": confidence, "confidence": confidence,
"fps": fps, "fps": fps,
@@ -761,9 +765,8 @@ def test_pipeline_visualizer_updates_on_no_detection() -> None:
visualize=True, visualize=True,
) )
# Replace the visualizer with our mock
mock_viz = MockVisualizer() mock_viz = MockVisualizer()
pipeline._visualizer = mock_viz # type: ignore[assignment] setattr(pipeline, "_visualizer", mock_viz)
# Run pipeline # Run pipeline
_ = pipeline.run() _ = pipeline.run()
@@ -779,13 +782,14 @@ def test_pipeline_visualizer_updates_on_no_detection() -> None:
for call in mock_viz.update_calls: for call in mock_viz.update_calls:
assert call["track_id"] == 0 # Default track_id when no detection assert call["track_id"] == 0 # Default track_id when no detection
assert call["bbox"] is None # No bbox when no detection assert call["bbox"] is None # No bbox when no detection
assert call["bbox_mask"] is None
assert call["mask_raw"] is None # No mask when no detection assert call["mask_raw"] is None # No mask when no detection
assert call["silhouette"] is None # No silhouette when no detection assert call["silhouette"] is None # No silhouette when no detection
assert call["segmentation_input"] is None
assert call["label"] is None # No label when no detection assert call["label"] is None # No label when no detection
assert call["confidence"] is None # No confidence when no detection assert call["confidence"] is None # No confidence when no detection
def test_pipeline_visualizer_uses_cached_detection_on_no_detection() -> None: def test_pipeline_visualizer_uses_cached_detection_on_no_detection() -> None:
"""Test that visualizer reuses last valid detection when current frame has no detection. """Test that visualizer reuses last valid detection when current frame has no detection.
@@ -818,8 +822,8 @@ def test_pipeline_visualizer_uses_cached_detection_on_no_detection() -> None:
mock_detector.track.side_effect = [ mock_detector.track.side_effect = [
[mock_result], # Frame 0: valid detection [mock_result], # Frame 0: valid detection
[mock_result], # Frame 1: valid detection [mock_result], # Frame 1: valid detection
[], # Frame 2: no detection [], # Frame 2: no detection
[], # Frame 3: no detection [], # Frame 3: no detection
] ]
mock_yolo.return_value = mock_detector mock_yolo.return_value = mock_detector
@@ -835,7 +839,12 @@ def test_pipeline_visualizer_uses_cached_detection_on_no_detection() -> None:
dummy_mask = np.random.randint(0, 256, (480, 640), dtype=np.uint8) dummy_mask = np.random.randint(0, 256, (480, 640), dtype=np.uint8)
dummy_bbox_mask = (100, 100, 200, 300) dummy_bbox_mask = (100, 100, 200, 300)
dummy_bbox_frame = (100, 100, 200, 300) dummy_bbox_frame = (100, 100, 200, 300)
mock_select_person.return_value = (dummy_mask, dummy_bbox_mask, dummy_bbox_frame, 1) mock_select_person.return_value = (
dummy_mask,
dummy_bbox_mask,
dummy_bbox_frame,
1,
)
# Setup mock mask_to_silhouette to return valid silhouette # Setup mock mask_to_silhouette to return valid silhouette
dummy_silhouette = np.random.rand(64, 44).astype(np.float32) dummy_silhouette = np.random.rand(64, 44).astype(np.float32)
@@ -856,9 +865,8 @@ def test_pipeline_visualizer_uses_cached_detection_on_no_detection() -> None:
visualize=True, visualize=True,
) )
# Replace the visualizer with our mock
mock_viz = MockVisualizer() mock_viz = MockVisualizer()
pipeline._visualizer = mock_viz # type: ignore[assignment] setattr(pipeline, "_visualizer", mock_viz)
# Run pipeline # Run pipeline
_ = pipeline.run() _ = pipeline.run()
@@ -886,9 +894,57 @@ def test_pipeline_visualizer_uses_cached_detection_on_no_detection() -> None:
"not None/blank" "not None/blank"
) )
# The cached masks should be copies (different objects) to prevent mutation issues segmentation_inputs = [
call["segmentation_input"] for call in mock_viz.update_calls
]
bbox_mask_calls = [call["bbox_mask"] for call in mock_viz.update_calls]
assert segmentation_inputs[0] is not None
assert segmentation_inputs[1] is not None
assert segmentation_inputs[2] is not None
assert segmentation_inputs[3] is not None
assert bbox_mask_calls[0] == dummy_bbox_mask
assert bbox_mask_calls[1] == dummy_bbox_mask
assert bbox_mask_calls[2] == dummy_bbox_mask
assert bbox_mask_calls[3] == dummy_bbox_mask
if mask_raw_calls[1] is not None and mask_raw_calls[2] is not None: if mask_raw_calls[1] is not None and mask_raw_calls[2] is not None:
assert mask_raw_calls[1] is not mask_raw_calls[2], ( assert mask_raw_calls[1] is not mask_raw_calls[2], (
"Cached mask should be a copy, not the same object reference" "Cached mask should be a copy, not the same object reference"
) )
def test_frame_pacer_emission_count_24_to_15() -> None:
from opengait.demo.pipeline import _FramePacer
pacer = _FramePacer(15.0)
interval_ns = int(1_000_000_000 / 24)
emitted = sum(pacer.should_emit(i * interval_ns) for i in range(100))
assert 60 <= emitted <= 65
def test_frame_pacer_requires_positive_target_fps() -> None:
from opengait.demo.pipeline import _FramePacer
with pytest.raises(ValueError, match="target_fps must be positive"):
_FramePacer(0.0)
@pytest.mark.parametrize(
("window", "stride", "mode", "expected"),
[
(30, 30, "manual", 30),
(30, 7, "manual", 7),
(30, 30, "sliding", 1),
(30, 1, "chunked", 30),
(15, 3, "chunked", 15),
],
)
def test_resolve_stride_modes(
window: int,
stride: int,
mode: Literal["manual", "sliding", "chunked"],
expected: int,
) -> None:
from opengait.demo.pipeline import resolve_stride
assert resolve_stride(window, stride, mode) == expected
+171
View File
@@ -0,0 +1,171 @@
from __future__ import annotations
from pathlib import Path
from typing import cast
from unittest import mock
import numpy as np
import pytest
from opengait.demo.input import create_source
from opengait.demo.visualizer import (
DISPLAY_HEIGHT,
DISPLAY_WIDTH,
ImageArray,
OpenCVVisualizer,
)
from opengait.demo.window import select_person
REPO_ROOT = Path(__file__).resolve().parents[2]
SAMPLE_VIDEO_PATH = REPO_ROOT / "assets" / "sample.mp4"
YOLO_MODEL_PATH = REPO_ROOT / "ckpt" / "yolo11n-seg.pt"
def test_prepare_raw_view_float_mask_has_visible_signal() -> None:
viz = OpenCVVisualizer()
mask_float = np.zeros((64, 64), dtype=np.float32)
mask_float[16:48, 16:48] = 1.0
rendered = viz._prepare_raw_view(cast(ImageArray, mask_float))
assert rendered.dtype == np.uint8
assert rendered.shape == (256, 176, 3)
mask_zero = np.zeros((64, 64), dtype=np.float32)
rendered_zero = viz._prepare_raw_view(cast(ImageArray, mask_zero))
roi = slice(0, DISPLAY_HEIGHT - 40)
diff = np.abs(rendered[roi].astype(np.int16) - rendered_zero[roi].astype(np.int16))
assert int(np.count_nonzero(diff)) > 0
def test_prepare_raw_view_handles_values_slightly_above_one() -> None:
viz = OpenCVVisualizer()
mask = np.zeros((64, 64), dtype=np.float32)
mask[20:40, 20:40] = 1.0001
rendered = viz._prepare_raw_view(cast(ImageArray, mask))
roi = rendered[: DISPLAY_HEIGHT - 40, :, 0]
assert int(np.count_nonzero(roi)) > 0
def test_segmentation_view_is_normalized_only_shape() -> None:
viz = OpenCVVisualizer()
mask = np.zeros((480, 640), dtype=np.uint8)
sil = np.random.rand(64, 44).astype(np.float32)
seg = viz._prepare_segmentation_view(cast(ImageArray, mask), sil, (0, 0, 100, 100))
assert seg.shape == (DISPLAY_HEIGHT, DISPLAY_WIDTH, 3)
def test_update_toggles_raw_window_with_r_key() -> None:
viz = OpenCVVisualizer()
frame = np.zeros((240, 320, 3), dtype=np.uint8)
mask = np.zeros((240, 320), dtype=np.uint8)
mask[20:100, 30:120] = 255
sil = np.random.rand(64, 44).astype(np.float32)
seg_input = np.random.rand(4, 64, 44).astype(np.float32)
with (
mock.patch("cv2.namedWindow") as named_window,
mock.patch("cv2.imshow"),
mock.patch("cv2.destroyWindow") as destroy_window,
mock.patch("cv2.waitKey", side_effect=[ord("r"), ord("r"), ord("q")]),
):
assert viz.update(
frame,
(10, 10, 120, 150),
(10, 10, 120, 150),
1,
cast(ImageArray, mask),
sil,
seg_input,
None,
None,
15.0,
)
assert viz.show_raw_window is True
assert viz._raw_window_created is True
assert viz.update(
frame,
(10, 10, 120, 150),
(10, 10, 120, 150),
1,
cast(ImageArray, mask),
sil,
seg_input,
None,
None,
15.0,
)
assert viz.show_raw_window is False
assert viz._raw_window_created is False
assert destroy_window.called
assert (
viz.update(
frame,
(10, 10, 120, 150),
(10, 10, 120, 150),
1,
cast(ImageArray, mask),
sil,
seg_input,
None,
None,
15.0,
)
is False
)
assert named_window.called
def test_sample_video_raw_mask_shape_range_and_render_signal() -> None:
if not SAMPLE_VIDEO_PATH.is_file():
pytest.skip(f"Missing sample video: {SAMPLE_VIDEO_PATH}")
if not YOLO_MODEL_PATH.is_file():
pytest.skip(f"Missing YOLO model file: {YOLO_MODEL_PATH}")
ultralytics = pytest.importorskip("ultralytics")
yolo_cls = getattr(ultralytics, "YOLO")
viz = OpenCVVisualizer()
detector = yolo_cls(str(YOLO_MODEL_PATH))
masks_seen = 0
rendered_nonzero: list[int] = []
for frame, _meta in create_source(str(SAMPLE_VIDEO_PATH), max_frames=30):
detections = detector.track(
frame,
persist=True,
verbose=False,
classes=[0],
device="cpu",
)
if not isinstance(detections, list) or not detections:
continue
selected = select_person(detections[0])
if selected is None:
continue
mask_raw, _, _, _ = selected
masks_seen += 1
arr = np.asarray(mask_raw)
assert arr.ndim == 2
assert arr.shape[0] > 0 and arr.shape[1] > 0
assert np.issubdtype(arr.dtype, np.number)
assert float(arr.min()) >= 0.0
assert float(arr.max()) <= 255.0
assert int(np.count_nonzero(arr)) > 0
rendered = viz._prepare_raw_view(arr)
roi = rendered[: DISPLAY_HEIGHT - 40, :, 0]
rendered_nonzero.append(int(np.count_nonzero(roi)))
assert masks_seen > 0
assert min(rendered_nonzero) > 0