chore: update demo runtime, tests, and agent docs
This commit is contained in:
@@ -147,4 +147,5 @@ dmypy.json
|
|||||||
cython_debug/
|
cython_debug/
|
||||||
|
|
||||||
ckpt/
|
ckpt/
|
||||||
|
output/
|
||||||
assets/*
|
assets/*
|
||||||
@@ -1,84 +1,191 @@
|
|||||||
# PROJECT KNOWLEDGE BASE
|
# OpenGait Agent Guide
|
||||||
|
|
||||||
**Generated:** 2026-02-11T10:53:29Z
|
This file is for autonomous coding agents working in this repository.
|
||||||
**Commit:** f754f6f
|
Use it as the default playbook for commands, conventions, and safety checks.
|
||||||
**Branch:** master
|
|
||||||
|
|
||||||
## OVERVIEW
|
## Scope and Ground Truth
|
||||||
OpenGait is a research-grade, config-driven gait analysis framework centered on distributed PyTorch training/testing.
|
|
||||||
Core runtime lives in `opengait/`; `configs/` and `datasets/` are first-class operational surfaces, not just support folders.
|
|
||||||
|
|
||||||
## STRUCTURE
|
- Repository: `OpenGait`
|
||||||
```text
|
- Runtime package: `opengait/`
|
||||||
OpenGait/
|
- Primary entrypoint: `opengait/main.py`
|
||||||
├── opengait/ # runtime package (train/test, model/data/eval pipelines)
|
- Package/runtime tool: `uv`
|
||||||
├── configs/ # model- and dataset-specific experiment specs
|
|
||||||
├── datasets/ # preprocessing/rearrangement scripts + partitions
|
|
||||||
├── docs/ # user workflow docs
|
|
||||||
├── train.sh # launch patterns (DDP)
|
|
||||||
└── test.sh # eval launch patterns (DDP)
|
|
||||||
```
|
|
||||||
|
|
||||||
## WHERE TO LOOK
|
Critical source-of-truth rule:
|
||||||
| Task | Location | Notes |
|
- `opengait/demo` is an implementation layer and may contain project-specific behavior.
|
||||||
|------|----------|-------|
|
- When asked to “refer to the paper” or verify methodology, use the paper and official citations as ground truth.
|
||||||
| Train/test entry | `opengait/main.py` | DDP init + config load + model dispatch |
|
- Do not treat demo/runtime behavior as proof of paper method unless explicitly cited by the paper.
|
||||||
| Model registration | `opengait/modeling/models/__init__.py` | dynamic class import/registration |
|
|
||||||
| Backbone/loss registration | `opengait/modeling/backbones/__init__.py`, `opengait/modeling/losses/__init__.py` | same dynamic pattern |
|
|
||||||
| Config merge behavior | `opengait/utils/common.py::config_loader` | merges into `configs/default.yaml` |
|
|
||||||
| Data loading contract | `opengait/data/dataset.py`, `opengait/data/collate_fn.py` | `.pkl` only, sequence sampling modes |
|
|
||||||
| Evaluation dispatch | `opengait/evaluation/evaluator.py` | dataset-specific eval routines |
|
|
||||||
| Dataset preprocessing | `datasets/pretreatment.py` + dataset subdirs | many standalone CLI tools |
|
|
||||||
|
|
||||||
## CODE MAP
|
## Environment Setup
|
||||||
| Symbol / Module | Type | Location | Refs | Role |
|
|
||||||
|-----------------|------|----------|------|------|
|
|
||||||
| `config_loader` | function | `opengait/utils/common.py` | high | YAML merge + default overlay |
|
|
||||||
| `get_ddp_module` | function | `opengait/utils/common.py` | high | wraps modules with DDP passthrough |
|
|
||||||
| `BaseModel` | class | `opengait/modeling/base_model.py` | high | canonical train/test lifecycle |
|
|
||||||
| `LossAggregator` | class | `opengait/modeling/loss_aggregator.py` | medium | consumes `training_feat` contract |
|
|
||||||
| `DataSet` | class | `opengait/data/dataset.py` | high | dataset partition + sequence loading |
|
|
||||||
| `CollateFn` | class | `opengait/data/collate_fn.py` | high | fixed/unfixed/all sampling policy |
|
|
||||||
| `evaluate_*` funcs | functions | `opengait/evaluation/evaluator.py` | medium | metric/report orchestration |
|
|
||||||
| `models` package registry | dynamic module | `opengait/modeling/models/__init__.py` | high | config string → model class |
|
|
||||||
|
|
||||||
## CONVENTIONS
|
Install dependencies with uv:
|
||||||
- Launch pattern is DDP-first (`python -m torch.distributed.launch ... opengait/main.py --cfgs ... --phase ...`).
|
|
||||||
- DDP Constraints: `world_size` must equal number of visible GPUs; test `evaluator_cfg.sampler.batch_size` must equal `world_size`.
|
|
||||||
- Model/loss/backbone discoverability is filesystem-driven via package-level dynamic imports.
|
|
||||||
- Experiment config semantics: custom YAML overlays `configs/default.yaml` (local key precedence).
|
|
||||||
- Outputs are keyed by config identity: `output/${dataset_name}/${model}/${save_name}`.
|
|
||||||
|
|
||||||
## ANTI-PATTERNS (THIS PROJECT)
|
|
||||||
- Do not feed non-`.pkl` sequence files into runtime loaders (`opengait/data/dataset.py`).
|
|
||||||
- Do not violate sampler shape assumptions (`trainer_cfg.sampler.batch_size` is `[P, K]` for triplet regimes).
|
|
||||||
- Do not ignore DDP cleanup guidance; abnormal exits can leave zombie processes (`misc/clean_process.sh`).
|
|
||||||
- Do not add unregistered model/loss classes outside expected directories (`opengait/modeling/models`, `opengait/modeling/losses`).
|
|
||||||
|
|
||||||
## UNIQUE STYLES
|
|
||||||
- `datasets/` is intentionally script-heavy (rearrange/extract/pretreat), not a pure library package.
|
|
||||||
- Research model zoo is broad; many model files co-exist as first-class references.
|
|
||||||
- Recent repo trajectory includes scoliosis screening models (ScoNet lineage), not only person-ID gait benchmarks.
|
|
||||||
|
|
||||||
## COMMANDS
|
|
||||||
```bash
|
```bash
|
||||||
# install (uv)
|
|
||||||
uv sync --extra torch
|
uv sync --extra torch
|
||||||
|
|
||||||
# train (uv)
|
|
||||||
CUDA_VISIBLE_DEVICES=0,1 uv run python -m torch.distributed.launch --nproc_per_node=2 opengait/main.py --cfgs ./configs/baseline/baseline.yaml --phase train
|
|
||||||
|
|
||||||
# test (uv)
|
|
||||||
CUDA_VISIBLE_DEVICES=0,1 uv run python -m torch.distributed.launch --nproc_per_node=2 opengait/main.py --cfgs ./configs/baseline/baseline.yaml --phase test
|
|
||||||
|
|
||||||
# ScoNet 1-GPU eval
|
|
||||||
CUDA_VISIBLE_DEVICES=0 uv run python -m torch.distributed.launch --nproc_per_node=1 opengait/main.py --cfgs ./configs/sconet/sconet_scoliosis1k_local_eval_1gpu.yaml --phase test
|
|
||||||
|
|
||||||
# preprocess (generic)
|
|
||||||
python datasets/pretreatment.py --input_path <raw_or_rearranged> --output_path <pkl_root>
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## NOTES
|
Notes from `pyproject.toml`:
|
||||||
- LSP symbol map can be enabled via uv dev dependency `basedpyright`; `basedpyright` and `basedpyright-langserver` are available in `.venv` after `uv sync`.
|
- Python requirement: `>=3.10`
|
||||||
- `train.sh` / `test.sh` are canonical launch examples across datasets/models.
|
- Dev tooling includes `pytest` and `basedpyright`
|
||||||
- Academic-use-only restriction is stated in repository README.
|
- Optional extras include `torch` and `parquet`
|
||||||
|
|
||||||
|
## Build / Run Commands
|
||||||
|
|
||||||
|
Train (DDP):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1 uv run python -m torch.distributed.launch \
|
||||||
|
--nproc_per_node=2 opengait/main.py \
|
||||||
|
--cfgs ./configs/baseline/baseline.yaml --phase train
|
||||||
|
```
|
||||||
|
|
||||||
|
Test (DDP):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1 uv run python -m torch.distributed.launch \
|
||||||
|
--nproc_per_node=2 opengait/main.py \
|
||||||
|
--cfgs ./configs/baseline/baseline.yaml --phase test
|
||||||
|
```
|
||||||
|
|
||||||
|
Single-GPU eval example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0 uv run python -m torch.distributed.launch \
|
||||||
|
--nproc_per_node=1 opengait/main.py \
|
||||||
|
--cfgs ./configs/sconet/sconet_scoliosis1k_local_eval_1gpu.yaml --phase test
|
||||||
|
```
|
||||||
|
|
||||||
|
Demo CLI entry:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run python -m opengait.demo --help
|
||||||
|
```
|
||||||
|
|
||||||
|
## DDP Constraints (Important)
|
||||||
|
|
||||||
|
- `--nproc_per_node` must match visible GPU count in `CUDA_VISIBLE_DEVICES`.
|
||||||
|
- Test/evaluator sampling settings are strict and can fail if world size mismatches config.
|
||||||
|
- If interrupted DDP leaves stale processes:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sh misc/clean_process.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
## Test Commands (especially single test)
|
||||||
|
|
||||||
|
Run all tests:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run pytest tests
|
||||||
|
```
|
||||||
|
|
||||||
|
Run one file:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run pytest tests/demo/test_pipeline.py -v
|
||||||
|
```
|
||||||
|
|
||||||
|
Run one test function:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run pytest tests/demo/test_pipeline.py::test_resolve_stride_modes -v
|
||||||
|
```
|
||||||
|
|
||||||
|
Run by keyword:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run pytest tests/demo/test_window.py -k "stride" -v
|
||||||
|
```
|
||||||
|
|
||||||
|
## Lint / Typecheck
|
||||||
|
|
||||||
|
Typecheck with basedpyright:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run basedpyright opengait tests
|
||||||
|
```
|
||||||
|
|
||||||
|
Project currently has no enforced formatter config in root tool files.
|
||||||
|
Follow existing local formatting and keep edits minimal.
|
||||||
|
|
||||||
|
## High-Value Paths
|
||||||
|
|
||||||
|
- `opengait/main.py` — runtime bootstrap
|
||||||
|
- `opengait/modeling/base_model.py` — model lifecycle contract
|
||||||
|
- `opengait/modeling/models/` — model zoo implementations
|
||||||
|
- `opengait/data/dataset.py` — dataset loading rules
|
||||||
|
- `opengait/data/collate_fn.py` — frame sampling behavior
|
||||||
|
- `opengait/evaluation/evaluator.py` — evaluation dispatch
|
||||||
|
- `configs/` — experiment definitions
|
||||||
|
- `datasets/` — preprocessing and partitions
|
||||||
|
|
||||||
|
## Code Style Guidelines
|
||||||
|
|
||||||
|
### Imports
|
||||||
|
- Keep ordering consistent: stdlib, third-party, local.
|
||||||
|
- Prefer explicit imports; avoid wildcard imports.
|
||||||
|
- Avoid introducing heavy imports in hot paths unless needed.
|
||||||
|
|
||||||
|
### Formatting
|
||||||
|
- Match surrounding file style (spacing, wrapping, structure).
|
||||||
|
- Avoid unrelated formatting churn.
|
||||||
|
- Keep diffs surgical.
|
||||||
|
|
||||||
|
### Types
|
||||||
|
- Add type annotations for new public APIs and non-trivial helpers.
|
||||||
|
- Reuse established typing style: `typing`, `numpy.typing`, `jaxtyping` where already used.
|
||||||
|
- Do not suppress type safety with blanket casts; keep unavoidable casts narrow.
|
||||||
|
|
||||||
|
### Naming
|
||||||
|
- `snake_case` for functions/variables
|
||||||
|
- `PascalCase` for classes
|
||||||
|
- `UPPER_SNAKE_CASE` for constants
|
||||||
|
- Preserve existing config key names and schema conventions
|
||||||
|
|
||||||
|
### Error Handling
|
||||||
|
- Raise explicit, actionable errors on invalid inputs.
|
||||||
|
- Fail fast for missing files, bad args, invalid shapes, and runtime preconditions.
|
||||||
|
- Never swallow exceptions silently.
|
||||||
|
- Preserve CLI error semantics (clear messages, non-zero exits).
|
||||||
|
|
||||||
|
### Logging
|
||||||
|
- Use module-level logger pattern already in codebase.
|
||||||
|
- Keep logs concise and operational.
|
||||||
|
- Avoid excessive per-frame logging in realtime/demo loops.
|
||||||
|
|
||||||
|
## Model and Config Contracts
|
||||||
|
|
||||||
|
- New models should conform to `BaseModel` expectations.
|
||||||
|
- Respect forward output dictionary contract used by loss/evaluator pipeline.
|
||||||
|
- Keep model registration/discovery patterns consistent with current package layout.
|
||||||
|
- Respect sampler semantics from config (`fixed_unordered`, `all_ordered`, etc.).
|
||||||
|
|
||||||
|
## Data Contracts
|
||||||
|
|
||||||
|
- Runtime data expects preprocessed `.pkl` sequence files.
|
||||||
|
- Partition JSON files are required for train/test split behavior.
|
||||||
|
- Do not mix modalities accidentally (silhouette / pose / pointcloud) across pipelines.
|
||||||
|
|
||||||
|
## Research-Verification Policy
|
||||||
|
|
||||||
|
When answering methodology questions:
|
||||||
|
- Prefer primary sources (paper PDF, official project docs, official code tied to publication).
|
||||||
|
- Quote/cite paper statements when concluding method behavior.
|
||||||
|
- If local implementation differs from paper, state divergence explicitly.
|
||||||
|
- For this repo specifically, remember: `opengait/demo` may differ from paper intent.
|
||||||
|
|
||||||
|
## Cursor / Copilot Rules Check
|
||||||
|
|
||||||
|
Checked these paths:
|
||||||
|
- `.cursor/rules/`
|
||||||
|
- `.cursorrules`
|
||||||
|
- `.github/copilot-instructions.md`
|
||||||
|
|
||||||
|
Current status: no Cursor/Copilot instruction files found.
|
||||||
|
|
||||||
|
## Agent Checklist Before Finishing
|
||||||
|
|
||||||
|
- Commands executed with `uv run ...` where applicable
|
||||||
|
- Targeted tests for changed files pass
|
||||||
|
- Typecheck is clean for modified code
|
||||||
|
- Behavior/documentation updated together for user-facing changes
|
||||||
|
- Paper-vs-implementation claims clearly separated when relevant
|
||||||
|
|||||||
@@ -0,0 +1,101 @@
|
|||||||
|
data_cfg:
|
||||||
|
dataset_name: Scoliosis1K
|
||||||
|
dataset_root: /mnt/public/data/Scoliosis1K/Scoliosis1K-sil-pkl
|
||||||
|
dataset_partition: /mnt/public/data/Scoliosis1K/Scoliosis1K_1116.json
|
||||||
|
num_workers: 1
|
||||||
|
remove_no_gallery: false
|
||||||
|
test_dataset_name: Scoliosis1K
|
||||||
|
|
||||||
|
evaluator_cfg:
|
||||||
|
enable_float16: true
|
||||||
|
restore_ckpt_strict: true
|
||||||
|
restore_hint: ./ckpt/ScoNet-20000.pt
|
||||||
|
save_name: ScoNet
|
||||||
|
eval_func: evaluate_scoliosis
|
||||||
|
sampler:
|
||||||
|
batch_shuffle: false
|
||||||
|
batch_size: 1
|
||||||
|
sample_type: all_ordered
|
||||||
|
frames_all_limit: 720
|
||||||
|
metric: euc
|
||||||
|
transform:
|
||||||
|
- type: BaseSilCuttingTransform
|
||||||
|
|
||||||
|
loss_cfg:
|
||||||
|
- loss_term_weight: 1.0
|
||||||
|
margin: 0.2
|
||||||
|
type: TripletLoss
|
||||||
|
log_prefix: triplet
|
||||||
|
- loss_term_weight: 1.0
|
||||||
|
scale: 16
|
||||||
|
type: CrossEntropyLoss
|
||||||
|
log_prefix: softmax
|
||||||
|
log_accuracy: true
|
||||||
|
|
||||||
|
model_cfg:
|
||||||
|
model: ScoNet
|
||||||
|
backbone_cfg:
|
||||||
|
type: ResNet9
|
||||||
|
block: BasicBlock
|
||||||
|
channels:
|
||||||
|
- 64
|
||||||
|
- 128
|
||||||
|
- 256
|
||||||
|
- 512
|
||||||
|
layers:
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
strides:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 2
|
||||||
|
- 1
|
||||||
|
maxpool: false
|
||||||
|
SeparateFCs:
|
||||||
|
in_channels: 512
|
||||||
|
out_channels: 256
|
||||||
|
parts_num: 16
|
||||||
|
SeparateBNNecks:
|
||||||
|
class_num: 3
|
||||||
|
in_channels: 256
|
||||||
|
parts_num: 16
|
||||||
|
bin_num:
|
||||||
|
- 16
|
||||||
|
|
||||||
|
optimizer_cfg:
|
||||||
|
lr: 0.1
|
||||||
|
momentum: 0.9
|
||||||
|
solver: SGD
|
||||||
|
weight_decay: 0.0005
|
||||||
|
|
||||||
|
scheduler_cfg:
|
||||||
|
gamma: 0.1
|
||||||
|
milestones:
|
||||||
|
- 10000
|
||||||
|
- 14000
|
||||||
|
- 18000
|
||||||
|
scheduler: MultiStepLR
|
||||||
|
|
||||||
|
trainer_cfg:
|
||||||
|
enable_float16: true
|
||||||
|
fix_BN: false
|
||||||
|
with_test: false
|
||||||
|
log_iter: 100
|
||||||
|
restore_ckpt_strict: true
|
||||||
|
restore_hint: 0
|
||||||
|
save_iter: 20000
|
||||||
|
save_name: ScoNet
|
||||||
|
sync_BN: true
|
||||||
|
total_iter: 20000
|
||||||
|
sampler:
|
||||||
|
batch_shuffle: true
|
||||||
|
batch_size:
|
||||||
|
- 8
|
||||||
|
- 8
|
||||||
|
frames_num_fixed: 30
|
||||||
|
sample_type: fixed_unordered
|
||||||
|
type: TripletSampler
|
||||||
|
transform:
|
||||||
|
- type: BaseSilCuttingTransform
|
||||||
@@ -0,0 +1,144 @@
|
|||||||
|
# Demo Window, Stride, and Sequence Behavior (ScoNet)
|
||||||
|
|
||||||
|
This note explains how the `opengait/demo` runtime feeds silhouettes into the neural network, what `stride` means, and when the sliding window is reset.
|
||||||
|
|
||||||
|
## Why sequence input (not single frame)
|
||||||
|
|
||||||
|
ScoNet-style inference is sequence-based.
|
||||||
|
|
||||||
|
- ScoNet / ScoNet-MT paper: https://arxiv.org/html/2407.05726v3
|
||||||
|
- DRF follow-up paper: https://arxiv.org/html/2509.00872v1
|
||||||
|
|
||||||
|
Both works use temporal information across walking frames rather than a single independent image.
|
||||||
|
|
||||||
|
### Direct quotes from the papers
|
||||||
|
|
||||||
|
From ScoNet / ScoNet-MT (MICCAI 2024, `2407.05726v3`):
|
||||||
|
|
||||||
|
> "For experiments, **30 frames were selected from each gait sequence as input**."
|
||||||
|
> (Section 4.1, Implementation Details)
|
||||||
|
|
||||||
|
From the same paper's dataset description:
|
||||||
|
|
||||||
|
> "Each sequence, containing approximately **300 frames at 15 frames per second**..."
|
||||||
|
> (Section 2.2, Data Collection and Preprocessing)
|
||||||
|
|
||||||
|
From DRF (MICCAI 2025, `2509.00872v1`):
|
||||||
|
|
||||||
|
DRF follows ScoNet-MT's sequence-level setup/architecture in its implementation details, and its PAV branch also aggregates across frames:
|
||||||
|
|
||||||
|
> "Sequence-Level PAV Refinement ... (2) **Temporal Aggregation**: For each metric, the mean of valid measurements across **all frames** is computed..."
|
||||||
|
> (Section 3.1, PAV: Discrete Clinical Prior)
|
||||||
|
|
||||||
|
## What papers say (and do not say) about stride
|
||||||
|
|
||||||
|
The papers define sequence-based inputs and temporal aggregation, but they do **not** define a deployment/runtime `stride` knob for online inference windows.
|
||||||
|
|
||||||
|
In other words:
|
||||||
|
|
||||||
|
- Paper gives the sequence framing (e.g., 30-frame inputs in ScoNet experiments).
|
||||||
|
- Demo `stride` is an engineering control for how often to run inference in streaming mode.
|
||||||
|
|
||||||
|
## What the demo feeds into the network
|
||||||
|
|
||||||
|
In `opengait/demo`, each inference uses the current silhouette buffer from `SilhouetteWindow`:
|
||||||
|
|
||||||
|
- Per-frame silhouette shape: `64 x 44`
|
||||||
|
- Tensor shape for inference: `[1, 1, window_size, 64, 44]`
|
||||||
|
- Default `window_size`: `30`
|
||||||
|
|
||||||
|
So by default, one prediction uses **30 silhouettes**.
|
||||||
|
|
||||||
|
## What is stride?
|
||||||
|
|
||||||
|
`stride` means the minimum frame distance between two consecutive classifications **after** the window is already full.
|
||||||
|
|
||||||
|
In this demo, the window is a true sliding buffer. It is **not** cleared after each inference. After inference, the pipeline only records the last classified frame and continues buffering new silhouettes.
|
||||||
|
|
||||||
|
- If `stride = 1`: classify at every new frame once ready
|
||||||
|
- If `stride = 30` (default): classify every 30 frames once ready
|
||||||
|
|
||||||
|
## Window mode shortcut (`--window-mode`)
|
||||||
|
|
||||||
|
To make window scheduling explicit, the demo CLI supports:
|
||||||
|
|
||||||
|
- `--window-mode manual` (default): use the exact `--stride` value
|
||||||
|
- `--window-mode sliding`: force `stride = 1` (max overlap)
|
||||||
|
- `--window-mode chunked`: force `stride = window` (no overlap)
|
||||||
|
|
||||||
|
This is only a shortcut for runtime behavior. It does not change ScoNet weights or architecture.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
- Sliding windows: `--window 30 --window-mode sliding` -> windows like `0-29, 1-30, 2-31, ...`
|
||||||
|
- Chunked windows: `--window 30 --window-mode chunked` -> windows like `0-29, 30-59, 60-89, ...`
|
||||||
|
- Manual stride: `--window 30 --stride 10 --window-mode manual` -> windows every 10 frames
|
||||||
|
|
||||||
|
Time interval between predictions is approximately:
|
||||||
|
|
||||||
|
`prediction_interval_seconds ~= stride / fps`
|
||||||
|
|
||||||
|
If `--target-fps` is set, use the emitted (downsampled) fps in this formula.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
- `stride=30`, `fps=15` -> about `2.0s`
|
||||||
|
- `stride=15`, `fps=30` -> about `0.5s`
|
||||||
|
|
||||||
|
First prediction latency is approximately:
|
||||||
|
|
||||||
|
`first_prediction_latency_seconds ~= window_size / fps`
|
||||||
|
|
||||||
|
assuming detections are continuous.
|
||||||
|
|
||||||
|
## Does the window clear when tracking target switches?
|
||||||
|
|
||||||
|
Yes. The window is reset in either case:
|
||||||
|
|
||||||
|
1. **Track ID changed** (new tracking target)
|
||||||
|
2. **Frame gap too large** (`frame_idx - last_frame > gap_threshold`)
|
||||||
|
|
||||||
|
Default `gap_threshold` in demo is `15` frames.
|
||||||
|
|
||||||
|
This prevents silhouettes from different people or long interrupted segments from being mixed into one inference window.
|
||||||
|
|
||||||
|
To be explicit:
|
||||||
|
|
||||||
|
- **Inference finished** -> window stays (sliding continues)
|
||||||
|
- **Track ID changed** -> window reset
|
||||||
|
- **Frame gap > gap_threshold** -> window reset
|
||||||
|
|
||||||
|
## Practical note about real-time detections
|
||||||
|
|
||||||
|
The window fills only when a valid silhouette is produced (i.e., person detection/segmentation succeeds). If detections are intermittent, the real-world time covered by one `window_size` can be longer than `window_size / fps`.
|
||||||
|
|
||||||
|
## Online vs offline behavior (important)
|
||||||
|
|
||||||
|
ScoNet's neural network does not hard-code a fixed frame count in the model graph. In OpenGait, frame count is controlled by sampling/runtime policy:
|
||||||
|
|
||||||
|
- Training config typically uses `frames_num_fixed: 30` with random fixed-frame sampling.
|
||||||
|
- Offline evaluation often uses `all_ordered` sequences (with `frames_all_limit` as a memory guard).
|
||||||
|
- Online demo uses the runtime window/stride scheduler.
|
||||||
|
|
||||||
|
So this does not mean the method only works offline. It means online performance depends on the latency/robustness trade-off you choose:
|
||||||
|
|
||||||
|
- Smaller windows / larger stride -> lower latency, potentially less stable predictions
|
||||||
|
- Larger windows / overlap -> smoother predictions, higher compute/latency
|
||||||
|
|
||||||
|
If you want behavior closest to ScoNet training assumptions, start from `--window 30` and tune stride (or `--window-mode`) for your deployment latency budget.
|
||||||
|
|
||||||
|
## Temporal downsampling (`--target-fps`)
|
||||||
|
|
||||||
|
Use `--target-fps` to normalize incoming frame cadence before silhouettes are pushed into the classification window.
|
||||||
|
|
||||||
|
- Default (`--target-fps 15`): timestamp-based pacing emits frames at approximately 15 FPS into the window
|
||||||
|
- Optional override (`--no-target-fps`): disable temporal downsampling and use all frames
|
||||||
|
|
||||||
|
Current default is `--target-fps 15` to align runtime cadence with ScoNet training assumptions.
|
||||||
|
|
||||||
|
For offline video sources, pacing uses video-time timestamps (`CAP_PROP_POS_MSEC`) when available, with an FPS-based synthetic timestamp fallback. This avoids coupling downsampling to processing throughput.
|
||||||
|
|
||||||
|
This is useful when camera FPS differs from training cadence. For example, with a 24 FPS camera:
|
||||||
|
|
||||||
|
- `--target-fps 15 --window 30` keeps model input near ~2.0 seconds of gait context (close to paper setup)
|
||||||
|
- `--stride` is interpreted in emitted-frame units after pacing
|
||||||
+53
-22
@@ -3,8 +3,16 @@ from __future__ import annotations
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
from .pipeline import ScoliosisPipeline
|
from .pipeline import ScoliosisPipeline, WindowMode, resolve_stride
|
||||||
|
|
||||||
|
|
||||||
|
def _positive_float(value: str) -> float:
|
||||||
|
parsed = float(value)
|
||||||
|
if parsed <= 0:
|
||||||
|
raise argparse.ArgumentTypeError("target-fps must be positive")
|
||||||
|
return parsed
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@@ -29,6 +37,24 @@ if __name__ == "__main__":
|
|||||||
"--window", type=int, default=30, help="Window size for classification"
|
"--window", type=int, default=30, help="Window size for classification"
|
||||||
)
|
)
|
||||||
parser.add_argument("--stride", type=int, default=30, help="Stride for window")
|
parser.add_argument("--stride", type=int, default=30, help="Stride for window")
|
||||||
|
parser.add_argument(
|
||||||
|
"--target-fps",
|
||||||
|
type=_positive_float,
|
||||||
|
default=15.0,
|
||||||
|
help="Target FPS for temporal downsampling before windowing",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--window-mode",
|
||||||
|
type=str,
|
||||||
|
choices=["manual", "sliding", "chunked"],
|
||||||
|
default="manual",
|
||||||
|
help="Window scheduling mode: manual uses --stride; sliding uses stride=1; chunked uses stride=window",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-target-fps",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable temporal downsampling and use all frames",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--nats-url", type=str, default=None, help="NATS URL for result publishing"
|
"--nats-url", type=str, default=None, help="NATS URL for result publishing"
|
||||||
)
|
)
|
||||||
@@ -88,27 +114,32 @@ if __name__ == "__main__":
|
|||||||
source=args.source, checkpoint=args.checkpoint, config=args.config
|
source=args.source, checkpoint=args.checkpoint, config=args.config
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build kwargs based on what ScoliosisPipeline accepts
|
effective_stride = resolve_stride(
|
||||||
pipeline_kwargs = {
|
window=cast(int, args.window),
|
||||||
"source": args.source,
|
stride=cast(int, args.stride),
|
||||||
"checkpoint": args.checkpoint,
|
window_mode=cast(WindowMode, args.window_mode),
|
||||||
"config": args.config,
|
)
|
||||||
"device": args.device,
|
|
||||||
"yolo_model": args.yolo_model,
|
pipeline = ScoliosisPipeline(
|
||||||
"window": args.window,
|
source=cast(str, args.source),
|
||||||
"stride": args.stride,
|
checkpoint=cast(str, args.checkpoint),
|
||||||
"nats_url": args.nats_url,
|
config=cast(str, args.config),
|
||||||
"nats_subject": args.nats_subject,
|
device=cast(str, args.device),
|
||||||
"max_frames": args.max_frames,
|
yolo_model=cast(str, args.yolo_model),
|
||||||
"preprocess_only": args.preprocess_only,
|
window=cast(int, args.window),
|
||||||
"silhouette_export_path": args.silhouette_export_path,
|
stride=effective_stride,
|
||||||
"silhouette_export_format": args.silhouette_export_format,
|
target_fps=(None if args.no_target_fps else cast(float, args.target_fps)),
|
||||||
"silhouette_visualize_dir": args.silhouette_visualize_dir,
|
nats_url=cast(str | None, args.nats_url),
|
||||||
"result_export_path": args.result_export_path,
|
nats_subject=cast(str, args.nats_subject),
|
||||||
"result_export_format": args.result_export_format,
|
max_frames=cast(int | None, args.max_frames),
|
||||||
"visualize": args.visualize,
|
preprocess_only=cast(bool, args.preprocess_only),
|
||||||
}
|
silhouette_export_path=cast(str | None, args.silhouette_export_path),
|
||||||
pipeline = ScoliosisPipeline(**pipeline_kwargs)
|
silhouette_export_format=cast(str, args.silhouette_export_format),
|
||||||
|
silhouette_visualize_dir=cast(str | None, args.silhouette_visualize_dir),
|
||||||
|
result_export_path=cast(str | None, args.result_export_path),
|
||||||
|
result_export_format=cast(str, args.result_export_format),
|
||||||
|
visualize=cast(bool, args.visualize),
|
||||||
|
)
|
||||||
raise SystemExit(pipeline.run())
|
raise SystemExit(pipeline.run())
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
print(f"Error: {err}", file=sys.stderr)
|
print(f"Error: {err}", file=sys.stderr)
|
||||||
|
|||||||
+18
-3
@@ -18,6 +18,7 @@ logger = logging.getLogger(__name__)
|
|||||||
# Type alias for frame stream: (frame_array, metadata_dict)
|
# Type alias for frame stream: (frame_array, metadata_dict)
|
||||||
FrameStream = Iterable[tuple[np.ndarray, dict[str, object]]]
|
FrameStream = Iterable[tuple[np.ndarray, dict[str, object]]]
|
||||||
|
|
||||||
|
|
||||||
# Protocol for cv-mmap metadata (needed at runtime for nested function annotation)
|
# Protocol for cv-mmap metadata (needed at runtime for nested function annotation)
|
||||||
class _FrameMetadata(Protocol):
|
class _FrameMetadata(Protocol):
|
||||||
frame_count: int
|
frame_count: int
|
||||||
@@ -58,6 +59,13 @@ def opencv_source(
|
|||||||
if not cap.isOpened():
|
if not cap.isOpened():
|
||||||
raise RuntimeError(f"Failed to open video source: {path}")
|
raise RuntimeError(f"Failed to open video source: {path}")
|
||||||
|
|
||||||
|
is_file_source = isinstance(path, str)
|
||||||
|
source_fps = float(cap.get(cv2.CAP_PROP_FPS)) if is_file_source else 0.0
|
||||||
|
fps_valid = source_fps > 0.0 and np.isfinite(source_fps)
|
||||||
|
fallback_fps = source_fps if fps_valid else 30.0
|
||||||
|
fallback_interval_ns = int(1_000_000_000 / fallback_fps)
|
||||||
|
start_ns = time.monotonic_ns()
|
||||||
|
|
||||||
frame_idx = 0
|
frame_idx = 0
|
||||||
try:
|
try:
|
||||||
while max_frames is None or frame_idx < max_frames:
|
while max_frames is None or frame_idx < max_frames:
|
||||||
@@ -66,14 +74,22 @@ def opencv_source(
|
|||||||
# End of stream
|
# End of stream
|
||||||
break
|
break
|
||||||
|
|
||||||
# Get timestamp if available (some backends support this)
|
if is_file_source:
|
||||||
timestamp_ns = time.monotonic_ns()
|
pos_msec = float(cap.get(cv2.CAP_PROP_POS_MSEC))
|
||||||
|
if np.isfinite(pos_msec) and pos_msec > 0.0:
|
||||||
|
timestamp_ns = start_ns + int(pos_msec * 1_000_000)
|
||||||
|
else:
|
||||||
|
timestamp_ns = start_ns + frame_idx * fallback_interval_ns
|
||||||
|
else:
|
||||||
|
timestamp_ns = time.monotonic_ns()
|
||||||
|
|
||||||
metadata: dict[str, object] = {
|
metadata: dict[str, object] = {
|
||||||
"frame_count": frame_idx,
|
"frame_count": frame_idx,
|
||||||
"timestamp_ns": timestamp_ns,
|
"timestamp_ns": timestamp_ns,
|
||||||
"source": path,
|
"source": path,
|
||||||
}
|
}
|
||||||
|
if fps_valid:
|
||||||
|
metadata["source_fps"] = source_fps
|
||||||
|
|
||||||
yield frame, metadata
|
yield frame, metadata
|
||||||
frame_idx += 1
|
frame_idx += 1
|
||||||
@@ -118,7 +134,6 @@ def cvmmap_source(
|
|||||||
# Import cvmmap only when function is called
|
# Import cvmmap only when function is called
|
||||||
# Use try/except for runtime import check
|
# Use try/except for runtime import check
|
||||||
try:
|
try:
|
||||||
|
|
||||||
from cvmmap import CvMmapClient as _CvMmapClientReal # pyright: ignore[reportMissingTypeStubs]
|
from cvmmap import CvMmapClient as _CvMmapClientReal # pyright: ignore[reportMissingTypeStubs]
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
|
|||||||
+105
-5
@@ -5,7 +5,7 @@ from contextlib import suppress
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, Protocol, cast
|
from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, cast
|
||||||
|
|
||||||
from beartype import beartype
|
from beartype import beartype
|
||||||
import click
|
import click
|
||||||
@@ -31,6 +31,16 @@ JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
|
|||||||
JaxtypedFactory = Callable[..., JaxtypedDecorator]
|
JaxtypedFactory = Callable[..., JaxtypedDecorator]
|
||||||
jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
|
jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
|
||||||
|
|
||||||
|
WindowMode: TypeAlias = Literal["manual", "sliding", "chunked"]
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_stride(window: int, stride: int, window_mode: WindowMode) -> int:
|
||||||
|
if window_mode == "manual":
|
||||||
|
return stride
|
||||||
|
if window_mode == "sliding":
|
||||||
|
return 1
|
||||||
|
return window
|
||||||
|
|
||||||
|
|
||||||
class _BoxesLike(Protocol):
|
class _BoxesLike(Protocol):
|
||||||
@property
|
@property
|
||||||
@@ -65,6 +75,27 @@ class _TrackCallable(Protocol):
|
|||||||
) -> object: ...
|
) -> object: ...
|
||||||
|
|
||||||
|
|
||||||
|
class _FramePacer:
|
||||||
|
_interval_ns: int
|
||||||
|
_next_emit_ns: int | None
|
||||||
|
|
||||||
|
def __init__(self, target_fps: float) -> None:
|
||||||
|
if target_fps <= 0:
|
||||||
|
raise ValueError(f"target_fps must be positive, got {target_fps}")
|
||||||
|
self._interval_ns = int(1_000_000_000 / target_fps)
|
||||||
|
self._next_emit_ns = None
|
||||||
|
|
||||||
|
def should_emit(self, timestamp_ns: int) -> bool:
|
||||||
|
if self._next_emit_ns is None:
|
||||||
|
self._next_emit_ns = timestamp_ns + self._interval_ns
|
||||||
|
return True
|
||||||
|
if timestamp_ns >= self._next_emit_ns:
|
||||||
|
while self._next_emit_ns <= timestamp_ns:
|
||||||
|
self._next_emit_ns += self._interval_ns
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class ScoliosisPipeline:
|
class ScoliosisPipeline:
|
||||||
_detector: object
|
_detector: object
|
||||||
_source: FrameStream
|
_source: FrameStream
|
||||||
@@ -83,6 +114,7 @@ class ScoliosisPipeline:
|
|||||||
_result_buffer: list[DemoResult]
|
_result_buffer: list[DemoResult]
|
||||||
_visualizer: OpenCVVisualizer | None
|
_visualizer: OpenCVVisualizer | None
|
||||||
_last_viz_payload: dict[str, object] | None
|
_last_viz_payload: dict[str, object] | None
|
||||||
|
_frame_pacer: _FramePacer | None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -104,6 +136,7 @@ class ScoliosisPipeline:
|
|||||||
result_export_path: str | None = None,
|
result_export_path: str | None = None,
|
||||||
result_export_format: str = "json",
|
result_export_format: str = "json",
|
||||||
visualize: bool = False,
|
visualize: bool = False,
|
||||||
|
target_fps: float | None = 15.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._detector = YOLO(yolo_model)
|
self._detector = YOLO(yolo_model)
|
||||||
self._source = create_source(source, max_frames=max_frames)
|
self._source = create_source(source, max_frames=max_frames)
|
||||||
@@ -140,6 +173,7 @@ class ScoliosisPipeline:
|
|||||||
else:
|
else:
|
||||||
self._visualizer = None
|
self._visualizer = None
|
||||||
self._last_viz_payload = None
|
self._last_viz_payload = None
|
||||||
|
self._frame_pacer = _FramePacer(target_fps) if target_fps is not None else None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_int(meta: dict[str, object], key: str, fallback: int) -> int:
|
def _extract_int(meta: dict[str, object], key: str, fallback: int) -> int:
|
||||||
@@ -177,6 +211,7 @@ class ScoliosisPipeline:
|
|||||||
Float[ndarray, "64 44"],
|
Float[ndarray, "64 44"],
|
||||||
UInt8[ndarray, "h w"],
|
UInt8[ndarray, "h w"],
|
||||||
BBoxXYXY,
|
BBoxXYXY,
|
||||||
|
BBoxXYXY,
|
||||||
int,
|
int,
|
||||||
]
|
]
|
||||||
| None
|
| None
|
||||||
@@ -189,7 +224,7 @@ class ScoliosisPipeline:
|
|||||||
mask_to_silhouette(self._to_mask_u8(mask_raw), bbox_mask),
|
mask_to_silhouette(self._to_mask_u8(mask_raw), bbox_mask),
|
||||||
)
|
)
|
||||||
if silhouette is not None:
|
if silhouette is not None:
|
||||||
return silhouette, mask_raw, bbox_frame, int(track_id)
|
return silhouette, mask_raw, bbox_frame, bbox_mask, int(track_id)
|
||||||
|
|
||||||
fallback = cast(
|
fallback = cast(
|
||||||
tuple[UInt8[ndarray, "h w"], BBoxXYXY] | None,
|
tuple[UInt8[ndarray, "h w"], BBoxXYXY] | None,
|
||||||
@@ -231,7 +266,7 @@ class ScoliosisPipeline:
|
|||||||
# Fallback: use mask-space bbox if orig_shape unavailable
|
# Fallback: use mask-space bbox if orig_shape unavailable
|
||||||
bbox_frame = bbox_mask
|
bbox_frame = bbox_mask
|
||||||
# For fallback case, mask_raw is the same as mask_u8
|
# For fallback case, mask_raw is the same as mask_u8
|
||||||
return silhouette, mask_u8, bbox_frame, 0
|
return silhouette, mask_u8, bbox_frame, bbox_mask, 0
|
||||||
|
|
||||||
@jaxtyped(typechecker=beartype)
|
@jaxtyped(typechecker=beartype)
|
||||||
def process_frame(
|
def process_frame(
|
||||||
@@ -262,7 +297,7 @@ class ScoliosisPipeline:
|
|||||||
if selected is None:
|
if selected is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
silhouette, mask_raw, bbox, track_id = selected
|
silhouette, mask_raw, bbox, bbox_mask, track_id = selected
|
||||||
|
|
||||||
# Store silhouette for export if in preprocess-only mode or if export requested
|
# Store silhouette for export if in preprocess-only mode or if export requested
|
||||||
if self._silhouette_export_path is not None or self._preprocess_only:
|
if self._silhouette_export_path is not None or self._preprocess_only:
|
||||||
@@ -284,20 +319,39 @@ class ScoliosisPipeline:
|
|||||||
return {
|
return {
|
||||||
"mask_raw": mask_raw,
|
"mask_raw": mask_raw,
|
||||||
"bbox": bbox,
|
"bbox": bbox,
|
||||||
|
"bbox_mask": bbox_mask,
|
||||||
"silhouette": silhouette,
|
"silhouette": silhouette,
|
||||||
|
"segmentation_input": None,
|
||||||
|
"track_id": track_id,
|
||||||
|
"label": None,
|
||||||
|
"confidence": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self._frame_pacer is not None and not self._frame_pacer.should_emit(
|
||||||
|
timestamp_ns
|
||||||
|
):
|
||||||
|
return {
|
||||||
|
"mask_raw": mask_raw,
|
||||||
|
"bbox": bbox,
|
||||||
|
"bbox_mask": bbox_mask,
|
||||||
|
"silhouette": silhouette,
|
||||||
|
"segmentation_input": self._window.buffered_silhouettes,
|
||||||
"track_id": track_id,
|
"track_id": track_id,
|
||||||
"label": None,
|
"label": None,
|
||||||
"confidence": None,
|
"confidence": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
self._window.push(silhouette, frame_idx=frame_idx, track_id=track_id)
|
self._window.push(silhouette, frame_idx=frame_idx, track_id=track_id)
|
||||||
|
segmentation_input = self._window.buffered_silhouettes
|
||||||
|
|
||||||
if not self._window.should_classify():
|
if not self._window.should_classify():
|
||||||
# Return visualization payload even when not classifying yet
|
# Return visualization payload even when not classifying yet
|
||||||
return {
|
return {
|
||||||
"mask_raw": mask_raw,
|
"mask_raw": mask_raw,
|
||||||
"bbox": bbox,
|
"bbox": bbox,
|
||||||
|
"bbox_mask": bbox_mask,
|
||||||
"silhouette": silhouette,
|
"silhouette": silhouette,
|
||||||
|
"segmentation_input": segmentation_input,
|
||||||
"track_id": track_id,
|
"track_id": track_id,
|
||||||
"label": None,
|
"label": None,
|
||||||
"confidence": None,
|
"confidence": None,
|
||||||
@@ -330,7 +384,9 @@ class ScoliosisPipeline:
|
|||||||
"result": result,
|
"result": result,
|
||||||
"mask_raw": mask_raw,
|
"mask_raw": mask_raw,
|
||||||
"bbox": bbox,
|
"bbox": bbox,
|
||||||
|
"bbox_mask": bbox_mask,
|
||||||
"silhouette": silhouette,
|
"silhouette": silhouette,
|
||||||
|
"segmentation_input": segmentation_input,
|
||||||
"track_id": track_id,
|
"track_id": track_id,
|
||||||
"label": label,
|
"label": label,
|
||||||
"confidence": confidence,
|
"confidence": confidence,
|
||||||
@@ -400,7 +456,9 @@ class ScoliosisPipeline:
|
|||||||
viz_dict = cast(dict[str, object], viz_data)
|
viz_dict = cast(dict[str, object], viz_data)
|
||||||
mask_raw_obj = viz_dict.get("mask_raw")
|
mask_raw_obj = viz_dict.get("mask_raw")
|
||||||
bbox_obj = viz_dict.get("bbox")
|
bbox_obj = viz_dict.get("bbox")
|
||||||
|
bbox_mask_obj = viz_dict.get("bbox_mask")
|
||||||
silhouette_obj = viz_dict.get("silhouette")
|
silhouette_obj = viz_dict.get("silhouette")
|
||||||
|
segmentation_input_obj = viz_dict.get("segmentation_input")
|
||||||
track_id_val = viz_dict.get("track_id", 0)
|
track_id_val = viz_dict.get("track_id", 0)
|
||||||
track_id = track_id_val if isinstance(track_id_val, int) else 0
|
track_id = track_id_val if isinstance(track_id_val, int) else 0
|
||||||
label_obj = viz_dict.get("label")
|
label_obj = viz_dict.get("label")
|
||||||
@@ -409,24 +467,33 @@ class ScoliosisPipeline:
|
|||||||
# Cast extracted values to expected types
|
# Cast extracted values to expected types
|
||||||
mask_raw = cast(NDArray[np.uint8] | None, mask_raw_obj)
|
mask_raw = cast(NDArray[np.uint8] | None, mask_raw_obj)
|
||||||
bbox = cast(BBoxXYXY | None, bbox_obj)
|
bbox = cast(BBoxXYXY | None, bbox_obj)
|
||||||
|
bbox_mask = cast(BBoxXYXY | None, bbox_mask_obj)
|
||||||
silhouette = cast(NDArray[np.float32] | None, silhouette_obj)
|
silhouette = cast(NDArray[np.float32] | None, silhouette_obj)
|
||||||
|
segmentation_input = cast(
|
||||||
|
NDArray[np.float32] | None,
|
||||||
|
segmentation_input_obj,
|
||||||
|
)
|
||||||
label = cast(str | None, label_obj)
|
label = cast(str | None, label_obj)
|
||||||
confidence = cast(float | None, confidence_obj)
|
confidence = cast(float | None, confidence_obj)
|
||||||
else:
|
else:
|
||||||
# No detection and no cache - use default values
|
# No detection and no cache - use default values
|
||||||
mask_raw = None
|
mask_raw = None
|
||||||
bbox = None
|
bbox = None
|
||||||
|
bbox_mask = None
|
||||||
track_id = 0
|
track_id = 0
|
||||||
silhouette = None
|
silhouette = None
|
||||||
|
segmentation_input = None
|
||||||
label = None
|
label = None
|
||||||
confidence = None
|
confidence = None
|
||||||
|
|
||||||
keep_running = self._visualizer.update(
|
keep_running = self._visualizer.update(
|
||||||
frame_u8,
|
frame_u8,
|
||||||
bbox,
|
bbox,
|
||||||
|
bbox_mask,
|
||||||
track_id,
|
track_id,
|
||||||
mask_raw,
|
mask_raw,
|
||||||
silhouette,
|
silhouette,
|
||||||
|
segmentation_input,
|
||||||
label,
|
label,
|
||||||
confidence,
|
confidence,
|
||||||
ema_fps,
|
ema_fps,
|
||||||
@@ -671,6 +738,23 @@ def validate_runtime_inputs(source: str, checkpoint: str, config: str) -> None:
|
|||||||
)
|
)
|
||||||
@click.option("--window", type=click.IntRange(min=1), default=30, show_default=True)
|
@click.option("--window", type=click.IntRange(min=1), default=30, show_default=True)
|
||||||
@click.option("--stride", type=click.IntRange(min=1), default=30, show_default=True)
|
@click.option("--stride", type=click.IntRange(min=1), default=30, show_default=True)
|
||||||
|
@click.option(
|
||||||
|
"--window-mode",
|
||||||
|
type=click.Choice(["manual", "sliding", "chunked"], case_sensitive=False),
|
||||||
|
default="manual",
|
||||||
|
show_default=True,
|
||||||
|
help=(
|
||||||
|
"Window scheduling mode: manual uses --stride; "
|
||||||
|
"sliding forces stride=1; chunked forces stride=window"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--target-fps",
|
||||||
|
type=click.FloatRange(min=0.1),
|
||||||
|
default=15.0,
|
||||||
|
show_default=True,
|
||||||
|
)
|
||||||
|
@click.option("--no-target-fps", is_flag=True, default=False)
|
||||||
@click.option("--nats-url", type=str, default=None)
|
@click.option("--nats-url", type=str, default=None)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--nats-subject",
|
"--nats-subject",
|
||||||
@@ -725,6 +809,9 @@ def main(
|
|||||||
yolo_model: str,
|
yolo_model: str,
|
||||||
window: int,
|
window: int,
|
||||||
stride: int,
|
stride: int,
|
||||||
|
window_mode: str,
|
||||||
|
target_fps: float | None,
|
||||||
|
no_target_fps: bool,
|
||||||
nats_url: str | None,
|
nats_url: str | None,
|
||||||
nats_subject: str,
|
nats_subject: str,
|
||||||
max_frames: int | None,
|
max_frames: int | None,
|
||||||
@@ -748,6 +835,18 @@ def main(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
validate_runtime_inputs(source=source, checkpoint=checkpoint, config=config)
|
validate_runtime_inputs(source=source, checkpoint=checkpoint, config=config)
|
||||||
|
effective_stride = resolve_stride(
|
||||||
|
window=window,
|
||||||
|
stride=stride,
|
||||||
|
window_mode=cast(WindowMode, window_mode.lower()),
|
||||||
|
)
|
||||||
|
if effective_stride != stride:
|
||||||
|
logger.info(
|
||||||
|
"window_mode=%s overrides stride=%d -> effective_stride=%d",
|
||||||
|
window_mode,
|
||||||
|
stride,
|
||||||
|
effective_stride,
|
||||||
|
)
|
||||||
pipeline = ScoliosisPipeline(
|
pipeline = ScoliosisPipeline(
|
||||||
source=source,
|
source=source,
|
||||||
checkpoint=checkpoint,
|
checkpoint=checkpoint,
|
||||||
@@ -755,7 +854,8 @@ def main(
|
|||||||
device=device,
|
device=device,
|
||||||
yolo_model=yolo_model,
|
yolo_model=yolo_model,
|
||||||
window=window,
|
window=window,
|
||||||
stride=stride,
|
stride=effective_stride,
|
||||||
|
target_fps=None if no_target_fps else target_fps,
|
||||||
nats_url=nats_url,
|
nats_url=nats_url,
|
||||||
nats_subject=nats_subject,
|
nats_subject=nats_subject,
|
||||||
max_frames=max_frames,
|
max_frames=max_frames,
|
||||||
|
|||||||
@@ -23,8 +23,10 @@ jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
|
|||||||
UInt8Array = NDArray[np.uint8]
|
UInt8Array = NDArray[np.uint8]
|
||||||
Float32Array = NDArray[np.float32]
|
Float32Array = NDArray[np.float32]
|
||||||
|
|
||||||
#: Bounding box in XYXY format: (x1, y1, x2, y2) where (x1,y1) is top-left and (x2,y2) is bottom-right.
|
|
||||||
BBoxXYXY = tuple[int, int, int, int]
|
BBoxXYXY = tuple[int, int, int, int]
|
||||||
|
"""
|
||||||
|
Bounding box in XYXY format: (x1, y1, x2, y2) where (x1,y1) is top-left and (x2,y2) is bottom-right.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def _read_attr(container: object, key: str) -> object | None:
|
def _read_attr(container: object, key: str) -> object | None:
|
||||||
|
|||||||
+246
-122
@@ -20,7 +20,9 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
# Window names
|
# Window names
|
||||||
MAIN_WINDOW = "Scoliosis Detection"
|
MAIN_WINDOW = "Scoliosis Detection"
|
||||||
SEG_WINDOW = "Segmentation"
|
SEG_WINDOW = "Normalized Silhouette"
|
||||||
|
RAW_WINDOW = "Raw Mask"
|
||||||
|
WINDOW_SEG_INPUT = "Segmentation Input"
|
||||||
|
|
||||||
# Silhouette dimensions (from preprocess.py)
|
# Silhouette dimensions (from preprocess.py)
|
||||||
SIL_HEIGHT = 64
|
SIL_HEIGHT = 64
|
||||||
@@ -29,43 +31,45 @@ SIL_WIDTH = 44
|
|||||||
# Display dimensions for upscaled silhouette
|
# Display dimensions for upscaled silhouette
|
||||||
DISPLAY_HEIGHT = 256
|
DISPLAY_HEIGHT = 256
|
||||||
DISPLAY_WIDTH = 176
|
DISPLAY_WIDTH = 176
|
||||||
|
RAW_STATS_PAD = 54
|
||||||
|
MODE_LABEL_PAD = 26
|
||||||
|
|
||||||
# Colors (BGR)
|
# Colors (BGR)
|
||||||
COLOR_GREEN = (0, 255, 0)
|
COLOR_GREEN = (0, 255, 0)
|
||||||
COLOR_WHITE = (255, 255, 255)
|
COLOR_WHITE = (255, 255, 255)
|
||||||
COLOR_BLACK = (0, 0, 0)
|
COLOR_BLACK = (0, 0, 0)
|
||||||
|
COLOR_DARK_GRAY = (56, 56, 56)
|
||||||
COLOR_RED = (0, 0, 255)
|
COLOR_RED = (0, 0, 255)
|
||||||
COLOR_YELLOW = (0, 255, 255)
|
COLOR_YELLOW = (0, 255, 255)
|
||||||
|
|
||||||
# Mode labels
|
|
||||||
MODE_LABELS = ["Both", "Raw Mask", "Normalized"]
|
|
||||||
|
|
||||||
# Type alias for image arrays (NDArray or cv2.Mat)
|
# Type alias for image arrays (NDArray or cv2.Mat)
|
||||||
ImageArray = NDArray[np.uint8]
|
ImageArray = NDArray[np.uint8]
|
||||||
|
|
||||||
|
|
||||||
class OpenCVVisualizer:
|
class OpenCVVisualizer:
|
||||||
"""Real-time visualizer for gait analysis demo.
|
|
||||||
|
|
||||||
Displays two windows:
|
|
||||||
- Main stream: Original frame with bounding box and metadata overlay
|
|
||||||
- Segmentation: Raw mask, normalized silhouette, or side-by-side view
|
|
||||||
|
|
||||||
Supports interactive mode switching via keyboard.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
"""Initialize visualizer with default mask mode."""
|
self.show_raw_window: bool = False
|
||||||
self.mask_mode: int = 0 # 0: Both, 1: Raw, 2: Normalized
|
self.show_raw_debug: bool = False
|
||||||
self._windows_created: bool = False
|
self._windows_created: bool = False
|
||||||
|
self._raw_window_created: bool = False
|
||||||
|
|
||||||
def _ensure_windows(self) -> None:
|
def _ensure_windows(self) -> None:
|
||||||
"""Create OpenCV windows if not already created."""
|
|
||||||
if not self._windows_created:
|
if not self._windows_created:
|
||||||
cv2.namedWindow(MAIN_WINDOW, cv2.WINDOW_NORMAL)
|
cv2.namedWindow(MAIN_WINDOW, cv2.WINDOW_NORMAL)
|
||||||
cv2.namedWindow(SEG_WINDOW, cv2.WINDOW_NORMAL)
|
cv2.namedWindow(SEG_WINDOW, cv2.WINDOW_NORMAL)
|
||||||
|
cv2.namedWindow(WINDOW_SEG_INPUT, cv2.WINDOW_NORMAL)
|
||||||
self._windows_created = True
|
self._windows_created = True
|
||||||
|
|
||||||
|
def _ensure_raw_window(self) -> None:
|
||||||
|
if not self._raw_window_created:
|
||||||
|
cv2.namedWindow(RAW_WINDOW, cv2.WINDOW_NORMAL)
|
||||||
|
self._raw_window_created = True
|
||||||
|
|
||||||
|
def _hide_raw_window(self) -> None:
|
||||||
|
if self._raw_window_created:
|
||||||
|
cv2.destroyWindow(RAW_WINDOW)
|
||||||
|
self._raw_window_created = False
|
||||||
|
|
||||||
def _draw_bbox(
|
def _draw_bbox(
|
||||||
self,
|
self,
|
||||||
frame: ImageArray,
|
frame: ImageArray,
|
||||||
@@ -215,33 +219,181 @@ class OpenCVVisualizer:
|
|||||||
|
|
||||||
return upscaled
|
return upscaled
|
||||||
|
|
||||||
|
def _normalize_mask_for_display(self, mask: NDArray[np.generic]) -> ImageArray:
|
||||||
|
mask_array = np.asarray(mask)
|
||||||
|
if mask_array.dtype == np.bool_:
|
||||||
|
bool_scaled = np.where(mask_array, np.uint8(255), np.uint8(0)).astype(
|
||||||
|
np.uint8
|
||||||
|
)
|
||||||
|
return cast(ImageArray, bool_scaled)
|
||||||
|
|
||||||
|
if mask_array.dtype == np.uint8:
|
||||||
|
mask_array = cast(ImageArray, mask_array)
|
||||||
|
max_u8 = int(np.max(mask_array)) if mask_array.size > 0 else 0
|
||||||
|
if max_u8 <= 1:
|
||||||
|
scaled_u8 = np.where(mask_array > 0, np.uint8(255), np.uint8(0)).astype(
|
||||||
|
np.uint8
|
||||||
|
)
|
||||||
|
return cast(ImageArray, scaled_u8)
|
||||||
|
return cast(ImageArray, mask_array)
|
||||||
|
|
||||||
|
if np.issubdtype(mask_array.dtype, np.integer):
|
||||||
|
max_int = float(np.max(mask_array)) if mask_array.size > 0 else 0.0
|
||||||
|
if max_int <= 1.0:
|
||||||
|
return cast(
|
||||||
|
ImageArray, (mask_array.astype(np.float32) * 255.0).astype(np.uint8)
|
||||||
|
)
|
||||||
|
clipped = np.clip(mask_array, 0, 255).astype(np.uint8)
|
||||||
|
return cast(ImageArray, clipped)
|
||||||
|
|
||||||
|
mask_float = np.asarray(mask_array, dtype=np.float32)
|
||||||
|
max_val = float(np.max(mask_float)) if mask_float.size > 0 else 0.0
|
||||||
|
if max_val <= 0.0:
|
||||||
|
return np.zeros(mask_float.shape, dtype=np.uint8)
|
||||||
|
|
||||||
|
normalized = np.clip((mask_float / max_val) * 255.0, 0.0, 255.0).astype(
|
||||||
|
np.uint8
|
||||||
|
)
|
||||||
|
return cast(ImageArray, normalized)
|
||||||
|
|
||||||
|
def _draw_raw_stats(self, image: ImageArray, mask_raw: ImageArray | None) -> None:
|
||||||
|
if mask_raw is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
mask = np.asarray(mask_raw)
|
||||||
|
if mask.size == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
stats = [
|
||||||
|
f"raw: {mask.dtype}",
|
||||||
|
f"min/max: {float(mask.min()):.3f}/{float(mask.max()):.3f}",
|
||||||
|
f"nnz: {int(np.count_nonzero(mask))}",
|
||||||
|
]
|
||||||
|
|
||||||
|
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||||
|
font_scale = 0.45
|
||||||
|
thickness = 1
|
||||||
|
line_h = 18
|
||||||
|
x0 = 8
|
||||||
|
y0 = 20
|
||||||
|
|
||||||
|
for i, txt in enumerate(stats):
|
||||||
|
y = y0 + i * line_h
|
||||||
|
(tw, th), _ = cv2.getTextSize(txt, font, font_scale, thickness)
|
||||||
|
_ = cv2.rectangle(
|
||||||
|
image, (x0 - 4, y - th - 4), (x0 + tw + 4, y + 4), COLOR_BLACK, -1
|
||||||
|
)
|
||||||
|
_ = cv2.putText(
|
||||||
|
image, txt, (x0, y), font, font_scale, COLOR_YELLOW, thickness
|
||||||
|
)
|
||||||
|
|
||||||
def _prepare_segmentation_view(
|
def _prepare_segmentation_view(
|
||||||
self,
|
self,
|
||||||
mask_raw: ImageArray | None,
|
mask_raw: ImageArray | None,
|
||||||
silhouette: NDArray[np.float32] | None,
|
silhouette: NDArray[np.float32] | None,
|
||||||
|
bbox: BBoxXYXY | None,
|
||||||
) -> ImageArray:
|
) -> ImageArray:
|
||||||
"""Prepare segmentation window content based on current mode.
|
_ = mask_raw
|
||||||
|
_ = bbox
|
||||||
|
return self._prepare_normalized_view(silhouette)
|
||||||
|
|
||||||
Args:
|
def _fit_gray_to_display(
|
||||||
mask_raw: Raw binary mask (H, W) uint8 or None
|
self,
|
||||||
silhouette: Normalized silhouette (64, 44) float32 or None
|
gray: ImageArray,
|
||||||
|
out_h: int = DISPLAY_HEIGHT,
|
||||||
|
out_w: int = DISPLAY_WIDTH,
|
||||||
|
) -> ImageArray:
|
||||||
|
src_h, src_w = gray.shape[:2]
|
||||||
|
if src_h <= 0 or src_w <= 0:
|
||||||
|
return np.zeros((out_h, out_w), dtype=np.uint8)
|
||||||
|
|
||||||
Returns:
|
scale = min(out_w / src_w, out_h / src_h)
|
||||||
Displayable image (H, W, 3) uint8
|
new_w = max(1, int(round(src_w * scale)))
|
||||||
"""
|
new_h = max(1, int(round(src_h * scale)))
|
||||||
if self.mask_mode == 0:
|
|
||||||
# Mode 0: Both (side by side)
|
resized = cast(
|
||||||
return self._prepare_both_view(mask_raw, silhouette)
|
ImageArray,
|
||||||
elif self.mask_mode == 1:
|
cv2.resize(gray, (new_w, new_h), interpolation=cv2.INTER_NEAREST),
|
||||||
# Mode 1: Raw mask only
|
)
|
||||||
return self._prepare_raw_view(mask_raw)
|
canvas = np.zeros((out_h, out_w), dtype=np.uint8)
|
||||||
else:
|
x0 = (out_w - new_w) // 2
|
||||||
# Mode 2: Normalized silhouette only
|
y0 = (out_h - new_h) // 2
|
||||||
return self._prepare_normalized_view(silhouette)
|
canvas[y0 : y0 + new_h, x0 : x0 + new_w] = resized
|
||||||
|
return cast(ImageArray, canvas)
|
||||||
|
|
||||||
|
def _crop_mask_to_bbox(
|
||||||
|
self,
|
||||||
|
mask_gray: ImageArray,
|
||||||
|
bbox: BBoxXYXY | None,
|
||||||
|
) -> ImageArray:
|
||||||
|
if bbox is None:
|
||||||
|
return mask_gray
|
||||||
|
|
||||||
|
h, w = mask_gray.shape[:2]
|
||||||
|
x1, y1, x2, y2 = bbox
|
||||||
|
x1c = max(0, min(w, int(x1)))
|
||||||
|
x2c = max(0, min(w, int(x2)))
|
||||||
|
y1c = max(0, min(h, int(y1)))
|
||||||
|
y2c = max(0, min(h, int(y2)))
|
||||||
|
|
||||||
|
if x2c <= x1c or y2c <= y1c:
|
||||||
|
return mask_gray
|
||||||
|
|
||||||
|
cropped = mask_gray[y1c:y2c, x1c:x2c]
|
||||||
|
if cropped.size == 0:
|
||||||
|
return mask_gray
|
||||||
|
return cast(ImageArray, cropped)
|
||||||
|
|
||||||
|
def _prepare_segmentation_input_view(
|
||||||
|
self,
|
||||||
|
silhouettes: NDArray[np.float32] | None,
|
||||||
|
) -> ImageArray:
|
||||||
|
if silhouettes is None or silhouettes.size == 0:
|
||||||
|
placeholder = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH, 3), dtype=np.uint8)
|
||||||
|
self._draw_mode_indicator(placeholder, "Input Silhouettes (No Data)")
|
||||||
|
return placeholder
|
||||||
|
|
||||||
|
n_frames = int(silhouettes.shape[0])
|
||||||
|
tiles_per_row = int(np.ceil(np.sqrt(n_frames)))
|
||||||
|
rows = int(np.ceil(n_frames / tiles_per_row))
|
||||||
|
|
||||||
|
tile_h = DISPLAY_HEIGHT
|
||||||
|
tile_w = DISPLAY_WIDTH
|
||||||
|
grid = np.zeros((rows * tile_h, tiles_per_row * tile_w), dtype=np.uint8)
|
||||||
|
|
||||||
|
for idx in range(n_frames):
|
||||||
|
sil = silhouettes[idx]
|
||||||
|
tile = self._upscale_silhouette(sil)
|
||||||
|
r = idx // tiles_per_row
|
||||||
|
c = idx % tiles_per_row
|
||||||
|
y0, y1 = r * tile_h, (r + 1) * tile_h
|
||||||
|
x0, x1 = c * tile_w, (c + 1) * tile_w
|
||||||
|
grid[y0:y1, x0:x1] = tile
|
||||||
|
|
||||||
|
grid_bgr = cast(ImageArray, cv2.cvtColor(grid, cv2.COLOR_GRAY2BGR))
|
||||||
|
|
||||||
|
for idx in range(n_frames):
|
||||||
|
r = idx // tiles_per_row
|
||||||
|
c = idx % tiles_per_row
|
||||||
|
y0 = r * tile_h
|
||||||
|
x0 = c * tile_w
|
||||||
|
cv2.putText(
|
||||||
|
grid_bgr,
|
||||||
|
str(idx),
|
||||||
|
(x0 + 8, y0 + 22),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX,
|
||||||
|
0.6,
|
||||||
|
(0, 255, 255),
|
||||||
|
2,
|
||||||
|
cv2.LINE_AA,
|
||||||
|
)
|
||||||
|
|
||||||
|
return grid_bgr
|
||||||
|
|
||||||
def _prepare_raw_view(
|
def _prepare_raw_view(
|
||||||
self,
|
self,
|
||||||
mask_raw: ImageArray | None,
|
mask_raw: ImageArray | None,
|
||||||
|
bbox: BBoxXYXY | None = None,
|
||||||
) -> ImageArray:
|
) -> ImageArray:
|
||||||
"""Prepare raw mask view.
|
"""Prepare raw mask view.
|
||||||
|
|
||||||
@@ -261,20 +413,23 @@ class OpenCVVisualizer:
|
|||||||
if len(mask_raw.shape) == 3:
|
if len(mask_raw.shape) == 3:
|
||||||
mask_gray = cast(ImageArray, cv2.cvtColor(mask_raw, cv2.COLOR_BGR2GRAY))
|
mask_gray = cast(ImageArray, cv2.cvtColor(mask_raw, cv2.COLOR_BGR2GRAY))
|
||||||
else:
|
else:
|
||||||
mask_gray = mask_raw
|
mask_gray = cast(ImageArray, mask_raw)
|
||||||
|
|
||||||
# Resize to display size
|
mask_gray = self._normalize_mask_for_display(mask_gray)
|
||||||
mask_resized = cast(
|
mask_gray = self._crop_mask_to_bbox(mask_gray, bbox)
|
||||||
ImageArray,
|
|
||||||
cv2.resize(
|
debug_pad = RAW_STATS_PAD if self.show_raw_debug else 0
|
||||||
mask_gray,
|
content_h = max(1, DISPLAY_HEIGHT - debug_pad - MODE_LABEL_PAD)
|
||||||
(DISPLAY_WIDTH, DISPLAY_HEIGHT),
|
mask_resized = self._fit_gray_to_display(
|
||||||
interpolation=cv2.INTER_NEAREST,
|
mask_gray, out_h=content_h, out_w=DISPLAY_WIDTH
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
full_mask = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH), dtype=np.uint8)
|
||||||
|
full_mask[debug_pad : debug_pad + content_h, :] = mask_resized
|
||||||
|
|
||||||
# Convert to BGR for display
|
# Convert to BGR for display
|
||||||
mask_bgr = cast(ImageArray, cv2.cvtColor(mask_resized, cv2.COLOR_GRAY2BGR))
|
mask_bgr = cast(ImageArray, cv2.cvtColor(full_mask, cv2.COLOR_GRAY2BGR))
|
||||||
|
if self.show_raw_debug:
|
||||||
|
self._draw_raw_stats(mask_bgr, mask_raw)
|
||||||
self._draw_mode_indicator(mask_bgr, "Raw Mask")
|
self._draw_mode_indicator(mask_bgr, "Raw Mask")
|
||||||
|
|
||||||
return mask_bgr
|
return mask_bgr
|
||||||
@@ -299,80 +454,21 @@ class OpenCVVisualizer:
|
|||||||
|
|
||||||
# Upscale and convert
|
# Upscale and convert
|
||||||
upscaled = self._upscale_silhouette(silhouette)
|
upscaled = self._upscale_silhouette(silhouette)
|
||||||
sil_bgr = cast(ImageArray, cv2.cvtColor(upscaled, cv2.COLOR_GRAY2BGR))
|
content_h = max(1, DISPLAY_HEIGHT - MODE_LABEL_PAD)
|
||||||
|
sil_compact = self._fit_gray_to_display(
|
||||||
|
upscaled, out_h=content_h, out_w=DISPLAY_WIDTH
|
||||||
|
)
|
||||||
|
sil_canvas = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH), dtype=np.uint8)
|
||||||
|
sil_canvas[:content_h, :] = sil_compact
|
||||||
|
sil_bgr = cast(ImageArray, cv2.cvtColor(sil_canvas, cv2.COLOR_GRAY2BGR))
|
||||||
self._draw_mode_indicator(sil_bgr, "Normalized")
|
self._draw_mode_indicator(sil_bgr, "Normalized")
|
||||||
|
|
||||||
return sil_bgr
|
return sil_bgr
|
||||||
|
|
||||||
def _prepare_both_view(
|
def _draw_mode_indicator(self, image: ImageArray, label: str) -> None:
|
||||||
self,
|
|
||||||
mask_raw: ImageArray | None,
|
|
||||||
silhouette: NDArray[np.float32] | None,
|
|
||||||
) -> ImageArray:
|
|
||||||
"""Prepare side-by-side view of both masks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mask_raw: Raw binary mask or None
|
|
||||||
silhouette: Normalized silhouette or None
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Displayable side-by-side image
|
|
||||||
"""
|
|
||||||
# Prepare individual views without mode indicators (will be drawn on combined)
|
|
||||||
# Raw view preparation (without indicator)
|
|
||||||
if mask_raw is None:
|
|
||||||
raw_gray = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH), dtype=np.uint8)
|
|
||||||
else:
|
|
||||||
if len(mask_raw.shape) == 3:
|
|
||||||
mask_gray = cast(ImageArray, cv2.cvtColor(mask_raw, cv2.COLOR_BGR2GRAY))
|
|
||||||
else:
|
|
||||||
mask_gray = mask_raw
|
|
||||||
# Normalize to uint8 [0,255] for display (handles both float [0,1] and uint8 inputs)
|
|
||||||
if mask_gray.dtype == np.float32 or mask_gray.dtype == np.float64:
|
|
||||||
mask_gray = (mask_gray * 255).astype(np.uint8)
|
|
||||||
raw_gray = cast(
|
|
||||||
ImageArray,
|
|
||||||
cv2.resize(
|
|
||||||
mask_gray,
|
|
||||||
(DISPLAY_WIDTH, DISPLAY_HEIGHT),
|
|
||||||
interpolation=cv2.INTER_NEAREST,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Normalized view preparation (without indicator)
|
|
||||||
if silhouette is None:
|
|
||||||
norm_gray = np.zeros((DISPLAY_HEIGHT, DISPLAY_WIDTH), dtype=np.uint8)
|
|
||||||
else:
|
|
||||||
upscaled = self._upscale_silhouette(silhouette)
|
|
||||||
norm_gray = upscaled
|
|
||||||
|
|
||||||
# Stack horizontally
|
|
||||||
combined = np.hstack([raw_gray, norm_gray])
|
|
||||||
|
|
||||||
# Convert back to BGR
|
|
||||||
combined_bgr = cast(ImageArray, cv2.cvtColor(combined, cv2.COLOR_GRAY2BGR))
|
|
||||||
|
|
||||||
# Add mode indicator
|
|
||||||
self._draw_mode_indicator(combined_bgr, "Both: Raw | Normalized")
|
|
||||||
|
|
||||||
return combined_bgr
|
|
||||||
|
|
||||||
def _draw_mode_indicator(
|
|
||||||
self,
|
|
||||||
image: ImageArray,
|
|
||||||
label: str,
|
|
||||||
) -> None:
|
|
||||||
"""Draw mode indicator text on image.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image: Image to draw on (modified in place)
|
|
||||||
label: Mode label text
|
|
||||||
"""
|
|
||||||
h, w = image.shape[:2]
|
h, w = image.shape[:2]
|
||||||
|
|
||||||
# Mode text at bottom
|
mode_text = label
|
||||||
mode_text = f"Mode: {MODE_LABELS[self.mask_mode]} ({self.mask_mode}) - {label}"
|
|
||||||
|
|
||||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||||
font_scale = 0.5
|
font_scale = 0.5
|
||||||
@@ -383,15 +479,22 @@ class OpenCVVisualizer:
|
|||||||
mode_text, font, font_scale, thickness
|
mode_text, font, font_scale, thickness
|
||||||
)
|
)
|
||||||
|
|
||||||
# Draw background at bottom center
|
x_pos = 14
|
||||||
x_pos = (w - text_width) // 2
|
y_pos = h - 8
|
||||||
y_pos = h - 10
|
y_top = max(0, h - MODE_LABEL_PAD)
|
||||||
|
|
||||||
_ = cv2.rectangle(
|
_ = cv2.rectangle(
|
||||||
image,
|
image,
|
||||||
(x_pos - 5, y_pos - text_height - 5),
|
(0, y_top),
|
||||||
(x_pos + text_width + 5, y_pos + 5),
|
(w, h),
|
||||||
COLOR_BLACK,
|
COLOR_DARK_GRAY,
|
||||||
|
-1,
|
||||||
|
)
|
||||||
|
_ = cv2.rectangle(
|
||||||
|
image,
|
||||||
|
(x_pos - 6, y_pos - text_height - 6),
|
||||||
|
(x_pos + text_width + 8, y_pos + 6),
|
||||||
|
COLOR_DARK_GRAY,
|
||||||
-1,
|
-1,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -410,9 +513,11 @@ class OpenCVVisualizer:
|
|||||||
self,
|
self,
|
||||||
frame: ImageArray,
|
frame: ImageArray,
|
||||||
bbox: BBoxXYXY | None,
|
bbox: BBoxXYXY | None,
|
||||||
|
bbox_mask: BBoxXYXY | None,
|
||||||
track_id: int,
|
track_id: int,
|
||||||
mask_raw: ImageArray | None,
|
mask_raw: ImageArray | None,
|
||||||
silhouette: NDArray[np.float32] | None,
|
silhouette: NDArray[np.float32] | None,
|
||||||
|
segmentation_input: NDArray[np.float32] | None,
|
||||||
label: str | None,
|
label: str | None,
|
||||||
confidence: float | None,
|
confidence: float | None,
|
||||||
fps: float,
|
fps: float,
|
||||||
@@ -441,23 +546,42 @@ class OpenCVVisualizer:
|
|||||||
cv2.imshow(MAIN_WINDOW, main_display)
|
cv2.imshow(MAIN_WINDOW, main_display)
|
||||||
|
|
||||||
# Prepare and show segmentation window
|
# Prepare and show segmentation window
|
||||||
seg_display = self._prepare_segmentation_view(mask_raw, silhouette)
|
seg_display = self._prepare_segmentation_view(mask_raw, silhouette, bbox)
|
||||||
cv2.imshow(SEG_WINDOW, seg_display)
|
cv2.imshow(SEG_WINDOW, seg_display)
|
||||||
|
|
||||||
|
if self.show_raw_window:
|
||||||
|
self._ensure_raw_window()
|
||||||
|
raw_display = self._prepare_raw_view(mask_raw, bbox_mask)
|
||||||
|
cv2.imshow(RAW_WINDOW, raw_display)
|
||||||
|
|
||||||
|
seg_input_display = self._prepare_segmentation_input_view(segmentation_input)
|
||||||
|
cv2.imshow(WINDOW_SEG_INPUT, seg_input_display)
|
||||||
|
|
||||||
# Handle keyboard input
|
# Handle keyboard input
|
||||||
key = cv2.waitKey(1) & 0xFF
|
key = cv2.waitKey(1) & 0xFF
|
||||||
|
|
||||||
if key == ord("q"):
|
if key == ord("q"):
|
||||||
return False
|
return False
|
||||||
elif key == ord("m"):
|
elif key == ord("r"):
|
||||||
# Cycle through modes: 0 -> 1 -> 2 -> 0
|
self.show_raw_window = not self.show_raw_window
|
||||||
self.mask_mode = (self.mask_mode + 1) % 3
|
if self.show_raw_window:
|
||||||
logger.debug("Switched to mask mode: %s", MODE_LABELS[self.mask_mode])
|
self._ensure_raw_window()
|
||||||
|
logger.debug("Raw mask window enabled")
|
||||||
|
else:
|
||||||
|
self._hide_raw_window()
|
||||||
|
logger.debug("Raw mask window disabled")
|
||||||
|
elif key == ord("d"):
|
||||||
|
self.show_raw_debug = not self.show_raw_debug
|
||||||
|
logger.debug(
|
||||||
|
"Raw mask debug overlay %s",
|
||||||
|
"enabled" if self.show_raw_debug else "disabled",
|
||||||
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
"""Close all OpenCV windows and cleanup."""
|
|
||||||
if self._windows_created:
|
if self._windows_created:
|
||||||
|
self._hide_raw_window()
|
||||||
cv2.destroyAllWindows()
|
cv2.destroyAllWindows()
|
||||||
self._windows_created = False
|
self._windows_created = False
|
||||||
|
self._raw_window_created = False
|
||||||
|
|||||||
@@ -216,6 +216,15 @@ class SilhouetteWindow:
|
|||||||
raise ValueError("Window is empty")
|
raise ValueError("Window is empty")
|
||||||
return int(self._frame_indices[0])
|
return int(self._frame_indices[0])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def buffered_silhouettes(self) -> Float[ndarray, "n 64 44"]:
|
||||||
|
if not self._buffer:
|
||||||
|
return np.empty((0, SIL_HEIGHT, SIL_WIDTH), dtype=np.float32)
|
||||||
|
return cast(
|
||||||
|
Float[ndarray, "n 64 44"],
|
||||||
|
np.stack(list(self._buffer), axis=0).astype(np.float32, copy=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _to_numpy(obj: _ArrayLike) -> ndarray:
|
def _to_numpy(obj: _ArrayLike) -> ndarray:
|
||||||
"""Safely convert array-like object to numpy array.
|
"""Safely convert array-like object to numpy array.
|
||||||
|
|||||||
@@ -0,0 +1,394 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Export all positive labeled batches from Scoliosis1K dataset as time windows.
|
||||||
|
|
||||||
|
Creates grid visualizations similar to visualizer._prepare_segmentation_input_view()
|
||||||
|
for all positive class samples, arranged in sliding time windows.
|
||||||
|
|
||||||
|
Optimized UI with:
|
||||||
|
- Subject ID and batch info footer
|
||||||
|
- Dual frame counts (window-relative and sequence-relative)
|
||||||
|
- Clean layout with proper spacing
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import pickle
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Final
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
|
# Constants matching visualizer.py
|
||||||
|
DISPLAY_HEIGHT: Final = 256
|
||||||
|
DISPLAY_WIDTH: Final = 176
|
||||||
|
SIL_HEIGHT: Final = 64
|
||||||
|
SIL_WIDTH: Final = 44
|
||||||
|
|
||||||
|
# Footer settings
|
||||||
|
FOOTER_HEIGHT: Final = 80 # Height for metadata footer
|
||||||
|
FOOTER_BG_COLOR: Final = (40, 40, 40) # Dark gray background
|
||||||
|
TEXT_COLOR: Final = (255, 255, 255) # White text
|
||||||
|
ACCENT_COLOR: Final = (0, 255, 255) # Cyan for emphasis
|
||||||
|
FONT: Final = cv2.FONT_HERSHEY_SIMPLEX
|
||||||
|
FONT_SCALE: Final = 0.6
|
||||||
|
FONT_THICKNESS: Final = 2
|
||||||
|
|
||||||
|
|
||||||
|
def upscale_silhouette(
|
||||||
|
silhouette: NDArray[np.float32] | NDArray[np.uint8],
|
||||||
|
) -> NDArray[np.uint8]:
|
||||||
|
"""Upscale silhouette to display size."""
|
||||||
|
if silhouette.dtype == np.float32 or silhouette.dtype == np.float64:
|
||||||
|
sil_u8 = (silhouette * 255).astype(np.uint8)
|
||||||
|
else:
|
||||||
|
sil_u8 = silhouette.astype(np.uint8)
|
||||||
|
upscaled = cv2.resize(
|
||||||
|
sil_u8, (DISPLAY_WIDTH, DISPLAY_HEIGHT), interpolation=cv2.INTER_NEAREST
|
||||||
|
)
|
||||||
|
return upscaled
|
||||||
|
|
||||||
|
|
||||||
|
def create_optimized_visualization(
|
||||||
|
silhouettes: NDArray[np.float32],
|
||||||
|
subject_id: str,
|
||||||
|
view_name: str,
|
||||||
|
window_idx: int,
|
||||||
|
start_frame: int,
|
||||||
|
end_frame: int,
|
||||||
|
n_frames_total: int,
|
||||||
|
tile_height: int = DISPLAY_HEIGHT,
|
||||||
|
tile_width: int = DISPLAY_WIDTH,
|
||||||
|
) -> NDArray[np.uint8]:
|
||||||
|
"""
|
||||||
|
Create optimized visualization with grid and metadata footer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
silhouettes: Array of shape (n_frames, 64, 44) float32
|
||||||
|
subject_id: Subject identifier
|
||||||
|
view_name: View identifier (e.g., "000_180")
|
||||||
|
window_idx: Window index within sequence
|
||||||
|
start_frame: Starting frame index in sequence
|
||||||
|
end_frame: Ending frame index in sequence
|
||||||
|
n_frames_total: Total frames in the sequence
|
||||||
|
tile_height: Height of each tile in the grid
|
||||||
|
tile_width: Width of each tile in the grid
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Combined image with grid visualization and metadata footer
|
||||||
|
"""
|
||||||
|
n_frames = int(silhouettes.shape[0])
|
||||||
|
tiles_per_row = int(np.ceil(np.sqrt(n_frames)))
|
||||||
|
rows = int(np.ceil(n_frames / tiles_per_row))
|
||||||
|
|
||||||
|
# Create grid
|
||||||
|
grid = np.zeros((rows * tile_height, tiles_per_row * tile_width), dtype=np.uint8)
|
||||||
|
|
||||||
|
# Place each silhouette in the grid
|
||||||
|
for idx in range(n_frames):
|
||||||
|
sil = silhouettes[idx]
|
||||||
|
tile = upscale_silhouette(sil)
|
||||||
|
r = idx // tiles_per_row
|
||||||
|
c = idx % tiles_per_row
|
||||||
|
y0, y1 = r * tile_height, (r + 1) * tile_height
|
||||||
|
x0, x1 = c * tile_width, (c + 1) * tile_width
|
||||||
|
grid[y0:y1, x0:x1] = tile
|
||||||
|
|
||||||
|
# Convert to BGR
|
||||||
|
grid_bgr = cv2.cvtColor(grid, cv2.COLOR_GRAY2BGR)
|
||||||
|
|
||||||
|
# Add frame indices as text (both window-relative and sequence-relative)
|
||||||
|
for idx in range(n_frames):
|
||||||
|
r = idx // tiles_per_row
|
||||||
|
c = idx % tiles_per_row
|
||||||
|
y0 = r * tile_height
|
||||||
|
x0 = c * tile_width
|
||||||
|
|
||||||
|
# Window frame count (top-left)
|
||||||
|
cv2.putText(
|
||||||
|
grid_bgr,
|
||||||
|
f"{idx}", # Window-relative frame number
|
||||||
|
(x0 + 8, y0 + 22),
|
||||||
|
FONT,
|
||||||
|
FONT_SCALE,
|
||||||
|
ACCENT_COLOR,
|
||||||
|
FONT_THICKNESS,
|
||||||
|
cv2.LINE_AA,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sequence frame count (bottom-left of tile)
|
||||||
|
seq_frame = start_frame + idx
|
||||||
|
cv2.putText(
|
||||||
|
grid_bgr,
|
||||||
|
f"#{seq_frame}", # Sequence-relative frame number
|
||||||
|
(x0 + 8, y0 + tile_height - 10),
|
||||||
|
FONT,
|
||||||
|
0.45, # Slightly smaller font
|
||||||
|
(180, 180, 180), # Light gray
|
||||||
|
1,
|
||||||
|
cv2.LINE_AA,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create footer with metadata
|
||||||
|
grid_width = grid_bgr.shape[1]
|
||||||
|
footer = np.full((FOOTER_HEIGHT, grid_width, 3), FOOTER_BG_COLOR, dtype=np.uint8)
|
||||||
|
|
||||||
|
# Line 1: Subject ID and view
|
||||||
|
line1 = f"Subject: {subject_id} | View: {view_name}"
|
||||||
|
cv2.putText(
|
||||||
|
footer,
|
||||||
|
line1,
|
||||||
|
(15, 25),
|
||||||
|
FONT,
|
||||||
|
0.7,
|
||||||
|
TEXT_COLOR,
|
||||||
|
FONT_THICKNESS,
|
||||||
|
cv2.LINE_AA,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Line 2: Window batch frame range
|
||||||
|
line2 = f"Window {window_idx}: frames [{start_frame:03d} - {end_frame - 1:03d}] ({n_frames} frames)"
|
||||||
|
cv2.putText(
|
||||||
|
footer,
|
||||||
|
line2,
|
||||||
|
(15, 50),
|
||||||
|
FONT,
|
||||||
|
0.7,
|
||||||
|
ACCENT_COLOR,
|
||||||
|
FONT_THICKNESS,
|
||||||
|
cv2.LINE_AA,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Line 3: Progress within sequence
|
||||||
|
progress_pct = (end_frame / n_frames_total) * 100
|
||||||
|
line3 = f"Sequence: {n_frames_total} frames total | Progress: {progress_pct:.1f}%"
|
||||||
|
cv2.putText(
|
||||||
|
footer,
|
||||||
|
line3,
|
||||||
|
(15, 72),
|
||||||
|
FONT,
|
||||||
|
0.6,
|
||||||
|
(200, 200, 200),
|
||||||
|
1,
|
||||||
|
cv2.LINE_AA,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Combine grid and footer
|
||||||
|
combined = np.vstack([grid_bgr, footer])
|
||||||
|
|
||||||
|
return combined
|
||||||
|
|
||||||
|
|
||||||
|
def load_pkl_sequence(pkl_path: Path) -> NDArray[np.float32]:
|
||||||
|
"""Load a .pkl file containing silhouette sequence."""
|
||||||
|
with open(pkl_path, "rb") as f:
|
||||||
|
data = pickle.load(f)
|
||||||
|
|
||||||
|
# Handle different possible structures
|
||||||
|
if isinstance(data, np.ndarray):
|
||||||
|
return data.astype(np.float32)
|
||||||
|
elif isinstance(data, list):
|
||||||
|
# List of frames
|
||||||
|
return np.stack([np.array(frame) for frame in data]).astype(np.float32)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected data type in {pkl_path}: {type(data)}")
|
||||||
|
|
||||||
|
|
||||||
|
def create_windows(
|
||||||
|
sequence: NDArray[np.float32],
|
||||||
|
window_size: int = 30,
|
||||||
|
stride: int = 30,
|
||||||
|
) -> list[NDArray[np.float32]]:
|
||||||
|
"""
|
||||||
|
Split a sequence into sliding windows.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sequence: Array of shape (N, 64, 44)
|
||||||
|
window_size: Number of frames per window
|
||||||
|
stride: Stride between consecutive windows
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of window arrays, each of shape (window_size, 64, 44)
|
||||||
|
"""
|
||||||
|
n_frames = sequence.shape[0]
|
||||||
|
windows = []
|
||||||
|
|
||||||
|
for start_idx in range(0, n_frames - window_size + 1, stride):
|
||||||
|
end_idx = start_idx + window_size
|
||||||
|
window = sequence[start_idx:end_idx]
|
||||||
|
windows.append(window)
|
||||||
|
|
||||||
|
return windows
|
||||||
|
|
||||||
|
|
||||||
|
def export_positive_batches(
|
||||||
|
dataset_root: Path,
|
||||||
|
output_dir: Path,
|
||||||
|
window_size: int = 30,
|
||||||
|
stride: int = 30,
|
||||||
|
max_sequences: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Export all positive labeled batches from Scoliosis1K dataset as time windows.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_root: Path to Scoliosis1K-sil-pkl directory
|
||||||
|
output_dir: Output directory for visualizations
|
||||||
|
window_size: Number of frames per window (default 30)
|
||||||
|
stride: Stride between consecutive windows (default 30 = non-overlapping)
|
||||||
|
max_sequences: Maximum number of sequences to process (None = all)
|
||||||
|
"""
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Find all positive samples
|
||||||
|
positive_samples: list[
|
||||||
|
tuple[Path, str, str, str]
|
||||||
|
] = [] # (pkl_path, subject_id, view_name, pkl_name)
|
||||||
|
|
||||||
|
for subject_dir in sorted(dataset_root.iterdir()):
|
||||||
|
if not subject_dir.is_dir():
|
||||||
|
continue
|
||||||
|
subject_id = subject_dir.name
|
||||||
|
|
||||||
|
# Check for positive class directory (lowercase)
|
||||||
|
positive_dir = subject_dir / "positive"
|
||||||
|
if not positive_dir.exists():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Iterate through views
|
||||||
|
for view_dir in sorted(positive_dir.iterdir()):
|
||||||
|
if not view_dir.is_dir():
|
||||||
|
continue
|
||||||
|
view_name = view_dir.name
|
||||||
|
|
||||||
|
# Find .pkl files
|
||||||
|
for pkl_file in sorted(view_dir.glob("*.pkl")):
|
||||||
|
positive_samples.append(
|
||||||
|
(pkl_file, subject_id, view_name, pkl_file.stem)
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Found {len(positive_samples)} positive labeled sequences")
|
||||||
|
|
||||||
|
if max_sequences:
|
||||||
|
positive_samples = positive_samples[:max_sequences]
|
||||||
|
print(f"Processing first {max_sequences} sequences")
|
||||||
|
|
||||||
|
total_windows = 0
|
||||||
|
|
||||||
|
# Export each sequence's windows
|
||||||
|
for seq_idx, (pkl_path, subject_id, view_name, pkl_name) in enumerate(
|
||||||
|
positive_samples, 1
|
||||||
|
):
|
||||||
|
print(
|
||||||
|
f"[{seq_idx}/{len(positive_samples)}] Processing {subject_id}/{view_name}/{pkl_name}..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load sequence
|
||||||
|
try:
|
||||||
|
sequence = load_pkl_sequence(pkl_path)
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Error loading {pkl_path}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Ensure correct shape (N, 64, 44)
|
||||||
|
if len(sequence.shape) == 2:
|
||||||
|
# Single frame
|
||||||
|
sequence = sequence[np.newaxis, ...]
|
||||||
|
elif len(sequence.shape) == 3:
|
||||||
|
# (N, H, W) - expected
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
print(f" Unexpected shape {sequence.shape}, skipping")
|
||||||
|
continue
|
||||||
|
|
||||||
|
n_frames = sequence.shape[0]
|
||||||
|
print(f" Sequence has {n_frames} frames")
|
||||||
|
|
||||||
|
# Skip if sequence is shorter than window size
|
||||||
|
if n_frames < window_size:
|
||||||
|
print(f" Skipping: sequence too short (< {window_size} frames)")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Normalize if needed
|
||||||
|
if sequence.max() > 1.0:
|
||||||
|
sequence = sequence / 255.0
|
||||||
|
|
||||||
|
# Create windows
|
||||||
|
windows = create_windows(sequence, window_size=window_size, stride=stride)
|
||||||
|
print(f" Created {len(windows)} windows (size={window_size}, stride={stride})")
|
||||||
|
|
||||||
|
# Export each window
|
||||||
|
for window_idx, window in enumerate(windows):
|
||||||
|
start_frame = window_idx * stride
|
||||||
|
end_frame = start_frame + window_size
|
||||||
|
|
||||||
|
# Create visualization for this window with full metadata
|
||||||
|
vis_image = create_optimized_visualization(
|
||||||
|
silhouettes=window,
|
||||||
|
subject_id=subject_id,
|
||||||
|
view_name=view_name,
|
||||||
|
window_idx=window_idx,
|
||||||
|
start_frame=start_frame,
|
||||||
|
end_frame=end_frame,
|
||||||
|
n_frames_total=n_frames,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save with descriptive filename including window index
|
||||||
|
output_filename = (
|
||||||
|
f"{subject_id}_{view_name}_{pkl_name}_win{window_idx:03d}.png"
|
||||||
|
)
|
||||||
|
output_path = output_dir / output_filename
|
||||||
|
cv2.imwrite(str(output_path), vis_image)
|
||||||
|
|
||||||
|
# Save metadata for this window
|
||||||
|
meta = {
|
||||||
|
"subject_id": subject_id,
|
||||||
|
"view": view_name,
|
||||||
|
"pkl_name": pkl_name,
|
||||||
|
"window_index": window_idx,
|
||||||
|
"window_size": window_size,
|
||||||
|
"stride": stride,
|
||||||
|
"start_frame": start_frame,
|
||||||
|
"end_frame": end_frame,
|
||||||
|
"sequence_shape": sequence.shape,
|
||||||
|
"n_frames_total": n_frames,
|
||||||
|
"source_path": str(pkl_path),
|
||||||
|
}
|
||||||
|
meta_filename = (
|
||||||
|
f"{subject_id}_{view_name}_{pkl_name}_win{window_idx:03d}.json"
|
||||||
|
)
|
||||||
|
meta_path = output_dir / meta_filename
|
||||||
|
with open(meta_path, "w") as f:
|
||||||
|
json.dump(meta, f, indent=2)
|
||||||
|
|
||||||
|
total_windows += 1
|
||||||
|
|
||||||
|
print(f" Exported {len(windows)} windows")
|
||||||
|
|
||||||
|
print(f"\nExport complete! Saved {total_windows} windows to {output_dir}")
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
"""Main entry point."""
|
||||||
|
# Paths
|
||||||
|
dataset_root = Path("/mnt/public/data/Scoliosis1K/Scoliosis1K-sil-pkl")
|
||||||
|
output_dir = Path("/home/crosstyan/Code/OpenGait/output/positive_batches")
|
||||||
|
|
||||||
|
if not dataset_root.exists():
|
||||||
|
print(f"Error: Dataset not found at {dataset_root}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Export all positive batches with windowing
|
||||||
|
export_positive_batches(
|
||||||
|
dataset_root,
|
||||||
|
output_dir,
|
||||||
|
window_size=30, # 30 frames per window
|
||||||
|
stride=30, # Non-overlapping windows
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
+66
-10
@@ -7,7 +7,7 @@ from pathlib import Path
|
|||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from typing import Final, cast
|
from typing import Final, Literal, cast
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -693,9 +693,11 @@ class MockVisualizer:
|
|||||||
self,
|
self,
|
||||||
frame: NDArray[np.uint8],
|
frame: NDArray[np.uint8],
|
||||||
bbox: tuple[int, int, int, int] | None,
|
bbox: tuple[int, int, int, int] | None,
|
||||||
|
bbox_mask: tuple[int, int, int, int] | None,
|
||||||
track_id: int,
|
track_id: int,
|
||||||
mask_raw: NDArray[np.uint8] | None,
|
mask_raw: NDArray[np.uint8] | None,
|
||||||
silhouette: NDArray[np.float32] | None,
|
silhouette: NDArray[np.float32] | None,
|
||||||
|
segmentation_input: NDArray[np.float32] | None,
|
||||||
label: str | None,
|
label: str | None,
|
||||||
confidence: float | None,
|
confidence: float | None,
|
||||||
fps: float,
|
fps: float,
|
||||||
@@ -704,9 +706,11 @@ class MockVisualizer:
|
|||||||
{
|
{
|
||||||
"frame": frame,
|
"frame": frame,
|
||||||
"bbox": bbox,
|
"bbox": bbox,
|
||||||
|
"bbox_mask": bbox_mask,
|
||||||
"track_id": track_id,
|
"track_id": track_id,
|
||||||
"mask_raw": mask_raw,
|
"mask_raw": mask_raw,
|
||||||
"silhouette": silhouette,
|
"silhouette": silhouette,
|
||||||
|
"segmentation_input": segmentation_input,
|
||||||
"label": label,
|
"label": label,
|
||||||
"confidence": confidence,
|
"confidence": confidence,
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
@@ -761,9 +765,8 @@ def test_pipeline_visualizer_updates_on_no_detection() -> None:
|
|||||||
visualize=True,
|
visualize=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Replace the visualizer with our mock
|
|
||||||
mock_viz = MockVisualizer()
|
mock_viz = MockVisualizer()
|
||||||
pipeline._visualizer = mock_viz # type: ignore[assignment]
|
setattr(pipeline, "_visualizer", mock_viz)
|
||||||
|
|
||||||
# Run pipeline
|
# Run pipeline
|
||||||
_ = pipeline.run()
|
_ = pipeline.run()
|
||||||
@@ -779,13 +782,14 @@ def test_pipeline_visualizer_updates_on_no_detection() -> None:
|
|||||||
for call in mock_viz.update_calls:
|
for call in mock_viz.update_calls:
|
||||||
assert call["track_id"] == 0 # Default track_id when no detection
|
assert call["track_id"] == 0 # Default track_id when no detection
|
||||||
assert call["bbox"] is None # No bbox when no detection
|
assert call["bbox"] is None # No bbox when no detection
|
||||||
|
assert call["bbox_mask"] is None
|
||||||
assert call["mask_raw"] is None # No mask when no detection
|
assert call["mask_raw"] is None # No mask when no detection
|
||||||
assert call["silhouette"] is None # No silhouette when no detection
|
assert call["silhouette"] is None # No silhouette when no detection
|
||||||
|
assert call["segmentation_input"] is None
|
||||||
assert call["label"] is None # No label when no detection
|
assert call["label"] is None # No label when no detection
|
||||||
assert call["confidence"] is None # No confidence when no detection
|
assert call["confidence"] is None # No confidence when no detection
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_pipeline_visualizer_uses_cached_detection_on_no_detection() -> None:
|
def test_pipeline_visualizer_uses_cached_detection_on_no_detection() -> None:
|
||||||
"""Test that visualizer reuses last valid detection when current frame has no detection.
|
"""Test that visualizer reuses last valid detection when current frame has no detection.
|
||||||
|
|
||||||
@@ -818,8 +822,8 @@ def test_pipeline_visualizer_uses_cached_detection_on_no_detection() -> None:
|
|||||||
mock_detector.track.side_effect = [
|
mock_detector.track.side_effect = [
|
||||||
[mock_result], # Frame 0: valid detection
|
[mock_result], # Frame 0: valid detection
|
||||||
[mock_result], # Frame 1: valid detection
|
[mock_result], # Frame 1: valid detection
|
||||||
[], # Frame 2: no detection
|
[], # Frame 2: no detection
|
||||||
[], # Frame 3: no detection
|
[], # Frame 3: no detection
|
||||||
]
|
]
|
||||||
mock_yolo.return_value = mock_detector
|
mock_yolo.return_value = mock_detector
|
||||||
|
|
||||||
@@ -835,7 +839,12 @@ def test_pipeline_visualizer_uses_cached_detection_on_no_detection() -> None:
|
|||||||
dummy_mask = np.random.randint(0, 256, (480, 640), dtype=np.uint8)
|
dummy_mask = np.random.randint(0, 256, (480, 640), dtype=np.uint8)
|
||||||
dummy_bbox_mask = (100, 100, 200, 300)
|
dummy_bbox_mask = (100, 100, 200, 300)
|
||||||
dummy_bbox_frame = (100, 100, 200, 300)
|
dummy_bbox_frame = (100, 100, 200, 300)
|
||||||
mock_select_person.return_value = (dummy_mask, dummy_bbox_mask, dummy_bbox_frame, 1)
|
mock_select_person.return_value = (
|
||||||
|
dummy_mask,
|
||||||
|
dummy_bbox_mask,
|
||||||
|
dummy_bbox_frame,
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
# Setup mock mask_to_silhouette to return valid silhouette
|
# Setup mock mask_to_silhouette to return valid silhouette
|
||||||
dummy_silhouette = np.random.rand(64, 44).astype(np.float32)
|
dummy_silhouette = np.random.rand(64, 44).astype(np.float32)
|
||||||
@@ -856,9 +865,8 @@ def test_pipeline_visualizer_uses_cached_detection_on_no_detection() -> None:
|
|||||||
visualize=True,
|
visualize=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Replace the visualizer with our mock
|
|
||||||
mock_viz = MockVisualizer()
|
mock_viz = MockVisualizer()
|
||||||
pipeline._visualizer = mock_viz # type: ignore[assignment]
|
setattr(pipeline, "_visualizer", mock_viz)
|
||||||
|
|
||||||
# Run pipeline
|
# Run pipeline
|
||||||
_ = pipeline.run()
|
_ = pipeline.run()
|
||||||
@@ -886,9 +894,57 @@ def test_pipeline_visualizer_uses_cached_detection_on_no_detection() -> None:
|
|||||||
"not None/blank"
|
"not None/blank"
|
||||||
)
|
)
|
||||||
|
|
||||||
# The cached masks should be copies (different objects) to prevent mutation issues
|
segmentation_inputs = [
|
||||||
|
call["segmentation_input"] for call in mock_viz.update_calls
|
||||||
|
]
|
||||||
|
bbox_mask_calls = [call["bbox_mask"] for call in mock_viz.update_calls]
|
||||||
|
assert segmentation_inputs[0] is not None
|
||||||
|
assert segmentation_inputs[1] is not None
|
||||||
|
assert segmentation_inputs[2] is not None
|
||||||
|
assert segmentation_inputs[3] is not None
|
||||||
|
assert bbox_mask_calls[0] == dummy_bbox_mask
|
||||||
|
assert bbox_mask_calls[1] == dummy_bbox_mask
|
||||||
|
assert bbox_mask_calls[2] == dummy_bbox_mask
|
||||||
|
assert bbox_mask_calls[3] == dummy_bbox_mask
|
||||||
|
|
||||||
if mask_raw_calls[1] is not None and mask_raw_calls[2] is not None:
|
if mask_raw_calls[1] is not None and mask_raw_calls[2] is not None:
|
||||||
assert mask_raw_calls[1] is not mask_raw_calls[2], (
|
assert mask_raw_calls[1] is not mask_raw_calls[2], (
|
||||||
"Cached mask should be a copy, not the same object reference"
|
"Cached mask should be a copy, not the same object reference"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_frame_pacer_emission_count_24_to_15() -> None:
|
||||||
|
from opengait.demo.pipeline import _FramePacer
|
||||||
|
|
||||||
|
pacer = _FramePacer(15.0)
|
||||||
|
interval_ns = int(1_000_000_000 / 24)
|
||||||
|
emitted = sum(pacer.should_emit(i * interval_ns) for i in range(100))
|
||||||
|
assert 60 <= emitted <= 65
|
||||||
|
|
||||||
|
|
||||||
|
def test_frame_pacer_requires_positive_target_fps() -> None:
|
||||||
|
from opengait.demo.pipeline import _FramePacer
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="target_fps must be positive"):
|
||||||
|
_FramePacer(0.0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("window", "stride", "mode", "expected"),
|
||||||
|
[
|
||||||
|
(30, 30, "manual", 30),
|
||||||
|
(30, 7, "manual", 7),
|
||||||
|
(30, 30, "sliding", 1),
|
||||||
|
(30, 1, "chunked", 30),
|
||||||
|
(15, 3, "chunked", 15),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_resolve_stride_modes(
|
||||||
|
window: int,
|
||||||
|
stride: int,
|
||||||
|
mode: Literal["manual", "sliding", "chunked"],
|
||||||
|
expected: int,
|
||||||
|
) -> None:
|
||||||
|
from opengait.demo.pipeline import resolve_stride
|
||||||
|
|
||||||
|
assert resolve_stride(window, stride, mode) == expected
|
||||||
|
|||||||
@@ -0,0 +1,171 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import cast
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from opengait.demo.input import create_source
|
||||||
|
from opengait.demo.visualizer import (
|
||||||
|
DISPLAY_HEIGHT,
|
||||||
|
DISPLAY_WIDTH,
|
||||||
|
ImageArray,
|
||||||
|
OpenCVVisualizer,
|
||||||
|
)
|
||||||
|
from opengait.demo.window import select_person
|
||||||
|
|
||||||
|
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||||
|
SAMPLE_VIDEO_PATH = REPO_ROOT / "assets" / "sample.mp4"
|
||||||
|
YOLO_MODEL_PATH = REPO_ROOT / "ckpt" / "yolo11n-seg.pt"
|
||||||
|
|
||||||
|
|
||||||
|
def test_prepare_raw_view_float_mask_has_visible_signal() -> None:
|
||||||
|
viz = OpenCVVisualizer()
|
||||||
|
|
||||||
|
mask_float = np.zeros((64, 64), dtype=np.float32)
|
||||||
|
mask_float[16:48, 16:48] = 1.0
|
||||||
|
|
||||||
|
rendered = viz._prepare_raw_view(cast(ImageArray, mask_float))
|
||||||
|
assert rendered.dtype == np.uint8
|
||||||
|
assert rendered.shape == (256, 176, 3)
|
||||||
|
|
||||||
|
mask_zero = np.zeros((64, 64), dtype=np.float32)
|
||||||
|
rendered_zero = viz._prepare_raw_view(cast(ImageArray, mask_zero))
|
||||||
|
|
||||||
|
roi = slice(0, DISPLAY_HEIGHT - 40)
|
||||||
|
diff = np.abs(rendered[roi].astype(np.int16) - rendered_zero[roi].astype(np.int16))
|
||||||
|
assert int(np.count_nonzero(diff)) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_prepare_raw_view_handles_values_slightly_above_one() -> None:
|
||||||
|
viz = OpenCVVisualizer()
|
||||||
|
|
||||||
|
mask = np.zeros((64, 64), dtype=np.float32)
|
||||||
|
mask[20:40, 20:40] = 1.0001
|
||||||
|
rendered = viz._prepare_raw_view(cast(ImageArray, mask))
|
||||||
|
|
||||||
|
roi = rendered[: DISPLAY_HEIGHT - 40, :, 0]
|
||||||
|
assert int(np.count_nonzero(roi)) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_segmentation_view_is_normalized_only_shape() -> None:
|
||||||
|
viz = OpenCVVisualizer()
|
||||||
|
mask = np.zeros((480, 640), dtype=np.uint8)
|
||||||
|
sil = np.random.rand(64, 44).astype(np.float32)
|
||||||
|
|
||||||
|
seg = viz._prepare_segmentation_view(cast(ImageArray, mask), sil, (0, 0, 100, 100))
|
||||||
|
assert seg.shape == (DISPLAY_HEIGHT, DISPLAY_WIDTH, 3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_toggles_raw_window_with_r_key() -> None:
|
||||||
|
viz = OpenCVVisualizer()
|
||||||
|
frame = np.zeros((240, 320, 3), dtype=np.uint8)
|
||||||
|
mask = np.zeros((240, 320), dtype=np.uint8)
|
||||||
|
mask[20:100, 30:120] = 255
|
||||||
|
sil = np.random.rand(64, 44).astype(np.float32)
|
||||||
|
seg_input = np.random.rand(4, 64, 44).astype(np.float32)
|
||||||
|
|
||||||
|
with (
|
||||||
|
mock.patch("cv2.namedWindow") as named_window,
|
||||||
|
mock.patch("cv2.imshow"),
|
||||||
|
mock.patch("cv2.destroyWindow") as destroy_window,
|
||||||
|
mock.patch("cv2.waitKey", side_effect=[ord("r"), ord("r"), ord("q")]),
|
||||||
|
):
|
||||||
|
assert viz.update(
|
||||||
|
frame,
|
||||||
|
(10, 10, 120, 150),
|
||||||
|
(10, 10, 120, 150),
|
||||||
|
1,
|
||||||
|
cast(ImageArray, mask),
|
||||||
|
sil,
|
||||||
|
seg_input,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
15.0,
|
||||||
|
)
|
||||||
|
assert viz.show_raw_window is True
|
||||||
|
assert viz._raw_window_created is True
|
||||||
|
|
||||||
|
assert viz.update(
|
||||||
|
frame,
|
||||||
|
(10, 10, 120, 150),
|
||||||
|
(10, 10, 120, 150),
|
||||||
|
1,
|
||||||
|
cast(ImageArray, mask),
|
||||||
|
sil,
|
||||||
|
seg_input,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
15.0,
|
||||||
|
)
|
||||||
|
assert viz.show_raw_window is False
|
||||||
|
assert viz._raw_window_created is False
|
||||||
|
assert destroy_window.called
|
||||||
|
|
||||||
|
assert (
|
||||||
|
viz.update(
|
||||||
|
frame,
|
||||||
|
(10, 10, 120, 150),
|
||||||
|
(10, 10, 120, 150),
|
||||||
|
1,
|
||||||
|
cast(ImageArray, mask),
|
||||||
|
sil,
|
||||||
|
seg_input,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
15.0,
|
||||||
|
)
|
||||||
|
is False
|
||||||
|
)
|
||||||
|
assert named_window.called
|
||||||
|
|
||||||
|
|
||||||
|
def test_sample_video_raw_mask_shape_range_and_render_signal() -> None:
|
||||||
|
if not SAMPLE_VIDEO_PATH.is_file():
|
||||||
|
pytest.skip(f"Missing sample video: {SAMPLE_VIDEO_PATH}")
|
||||||
|
if not YOLO_MODEL_PATH.is_file():
|
||||||
|
pytest.skip(f"Missing YOLO model file: {YOLO_MODEL_PATH}")
|
||||||
|
|
||||||
|
ultralytics = pytest.importorskip("ultralytics")
|
||||||
|
yolo_cls = getattr(ultralytics, "YOLO")
|
||||||
|
|
||||||
|
viz = OpenCVVisualizer()
|
||||||
|
detector = yolo_cls(str(YOLO_MODEL_PATH))
|
||||||
|
|
||||||
|
masks_seen = 0
|
||||||
|
rendered_nonzero: list[int] = []
|
||||||
|
|
||||||
|
for frame, _meta in create_source(str(SAMPLE_VIDEO_PATH), max_frames=30):
|
||||||
|
detections = detector.track(
|
||||||
|
frame,
|
||||||
|
persist=True,
|
||||||
|
verbose=False,
|
||||||
|
classes=[0],
|
||||||
|
device="cpu",
|
||||||
|
)
|
||||||
|
if not isinstance(detections, list) or not detections:
|
||||||
|
continue
|
||||||
|
|
||||||
|
selected = select_person(detections[0])
|
||||||
|
if selected is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
mask_raw, _, _, _ = selected
|
||||||
|
masks_seen += 1
|
||||||
|
|
||||||
|
arr = np.asarray(mask_raw)
|
||||||
|
assert arr.ndim == 2
|
||||||
|
assert arr.shape[0] > 0 and arr.shape[1] > 0
|
||||||
|
assert np.issubdtype(arr.dtype, np.number)
|
||||||
|
assert float(arr.min()) >= 0.0
|
||||||
|
assert float(arr.max()) <= 255.0
|
||||||
|
assert int(np.count_nonzero(arr)) > 0
|
||||||
|
|
||||||
|
rendered = viz._prepare_raw_view(arr)
|
||||||
|
roi = rendered[: DISPLAY_HEIGHT - 40, :, 0]
|
||||||
|
rendered_nonzero.append(int(np.count_nonzero(roi)))
|
||||||
|
|
||||||
|
assert masks_seen > 0
|
||||||
|
assert min(rendered_nonzero) > 0
|
||||||
Reference in New Issue
Block a user