f501119d43
Add preprocess-only silhouette export and configurable result exporters so demo runs can be persisted for offline analysis and reproducible evaluation. Include optional parquet support and CLI visualization dumps while updating tests and tracking notes for the verified pipeline/debug workflow.
612 lines
20 KiB
Python
612 lines
20 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import Callable
|
|
from contextlib import suppress
|
|
import logging
|
|
from pathlib import Path
|
|
import time
|
|
from typing import Protocol, cast
|
|
|
|
from beartype import beartype
|
|
import click
|
|
import jaxtyping
|
|
from jaxtyping import Float, UInt8
|
|
import numpy as np
|
|
from numpy import ndarray
|
|
from numpy.typing import NDArray
|
|
from ultralytics.models.yolo.model import YOLO
|
|
|
|
from .input import FrameStream, create_source
|
|
from .output import ResultPublisher, create_publisher, create_result
|
|
from .preprocess import frame_to_person_mask, mask_to_silhouette
|
|
from .sconet_demo import ScoNetDemo
|
|
from .window import SilhouetteWindow, select_person
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
JaxtypedDecorator = Callable[[Callable[..., object]], Callable[..., object]]
|
|
JaxtypedFactory = Callable[..., JaxtypedDecorator]
|
|
jaxtyped = cast(JaxtypedFactory, jaxtyping.jaxtyped)
|
|
|
|
|
|
class _BoxesLike(Protocol):
|
|
@property
|
|
def xyxy(self) -> NDArray[np.float32] | object: ...
|
|
|
|
@property
|
|
def id(self) -> NDArray[np.int64] | object | None: ...
|
|
|
|
|
|
class _MasksLike(Protocol):
|
|
@property
|
|
def data(self) -> NDArray[np.float32] | object: ...
|
|
|
|
|
|
class _DetectionResultsLike(Protocol):
|
|
@property
|
|
def boxes(self) -> _BoxesLike: ...
|
|
|
|
@property
|
|
def masks(self) -> _MasksLike: ...
|
|
|
|
|
|
class _TrackCallable(Protocol):
|
|
def __call__(
|
|
self,
|
|
source: object,
|
|
*,
|
|
persist: bool = True,
|
|
verbose: bool = False,
|
|
device: str | None = None,
|
|
classes: list[int] | None = None,
|
|
) -> object: ...
|
|
|
|
|
|
class ScoliosisPipeline:
|
|
_detector: object
|
|
_source: FrameStream
|
|
_window: SilhouetteWindow
|
|
_publisher: ResultPublisher
|
|
_classifier: ScoNetDemo
|
|
_device: str
|
|
_closed: bool
|
|
_preprocess_only: bool
|
|
_silhouette_export_path: Path | None
|
|
_silhouette_export_format: str
|
|
_silhouette_buffer: list[dict[str, object]]
|
|
_silhouette_visualize_dir: Path | None
|
|
_result_export_path: Path | None
|
|
_result_export_format: str
|
|
_result_buffer: list[dict[str, object]]
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
source: str,
|
|
checkpoint: str,
|
|
config: str,
|
|
device: str,
|
|
yolo_model: str,
|
|
window: int,
|
|
stride: int,
|
|
nats_url: str | None,
|
|
nats_subject: str,
|
|
max_frames: int | None,
|
|
preprocess_only: bool = False,
|
|
silhouette_export_path: str | None = None,
|
|
silhouette_export_format: str = "pickle",
|
|
silhouette_visualize_dir: str | None = None,
|
|
result_export_path: str | None = None,
|
|
result_export_format: str = "json",
|
|
) -> None:
|
|
self._detector = YOLO(yolo_model)
|
|
self._source = create_source(source, max_frames=max_frames)
|
|
self._window = SilhouetteWindow(window_size=window, stride=stride)
|
|
self._publisher = create_publisher(nats_url=nats_url, subject=nats_subject)
|
|
self._classifier = ScoNetDemo(
|
|
cfg_path=config,
|
|
checkpoint_path=checkpoint,
|
|
device=device,
|
|
)
|
|
self._device = device
|
|
self._closed = False
|
|
self._preprocess_only = preprocess_only
|
|
self._silhouette_export_path = (
|
|
Path(silhouette_export_path) if silhouette_export_path else None
|
|
)
|
|
self._silhouette_export_format = silhouette_export_format
|
|
self._silhouette_buffer = []
|
|
self._silhouette_visualize_dir = (
|
|
Path(silhouette_visualize_dir) if silhouette_visualize_dir else None
|
|
)
|
|
self._result_export_path = (
|
|
Path(result_export_path) if result_export_path else None
|
|
)
|
|
self._result_export_format = result_export_format
|
|
self._result_buffer = []
|
|
|
|
@staticmethod
|
|
def _extract_int(meta: dict[str, object], key: str, fallback: int) -> int:
|
|
value = meta.get(key)
|
|
if isinstance(value, int):
|
|
return value
|
|
return fallback
|
|
|
|
@staticmethod
|
|
def _extract_timestamp(meta: dict[str, object]) -> int:
|
|
value = meta.get("timestamp_ns")
|
|
if isinstance(value, int):
|
|
return value
|
|
return time.monotonic_ns()
|
|
|
|
@staticmethod
|
|
def _to_mask_u8(mask: ndarray) -> UInt8[ndarray, "h w"]:
|
|
binary = np.where(np.asarray(mask) > 0.5, np.uint8(255), np.uint8(0)).astype(
|
|
np.uint8
|
|
)
|
|
return cast(UInt8[ndarray, "h w"], binary)
|
|
|
|
def _first_result(self, detections: object) -> _DetectionResultsLike | None:
|
|
if isinstance(detections, list):
|
|
return cast(_DetectionResultsLike, detections[0]) if detections else None
|
|
if isinstance(detections, tuple):
|
|
return cast(_DetectionResultsLike, detections[0]) if detections else None
|
|
return cast(_DetectionResultsLike, detections)
|
|
|
|
def _select_silhouette(
|
|
self,
|
|
result: _DetectionResultsLike,
|
|
) -> tuple[Float[ndarray, "64 44"], int] | None:
|
|
selected = select_person(result)
|
|
if selected is not None:
|
|
mask_raw, bbox, track_id = selected
|
|
silhouette = cast(
|
|
Float[ndarray, "64 44"] | None,
|
|
mask_to_silhouette(self._to_mask_u8(mask_raw), bbox),
|
|
)
|
|
if silhouette is not None:
|
|
return silhouette, int(track_id)
|
|
|
|
fallback = cast(
|
|
tuple[UInt8[ndarray, "h w"], tuple[int, int, int, int]] | None,
|
|
frame_to_person_mask(result),
|
|
)
|
|
if fallback is None:
|
|
return None
|
|
|
|
mask_u8, bbox = fallback
|
|
silhouette = cast(
|
|
Float[ndarray, "64 44"] | None,
|
|
mask_to_silhouette(mask_u8, bbox),
|
|
)
|
|
if silhouette is None:
|
|
return None
|
|
return silhouette, 0
|
|
|
|
@jaxtyped(typechecker=beartype)
|
|
def process_frame(
|
|
self,
|
|
frame: UInt8[ndarray, "h w c"],
|
|
metadata: dict[str, object],
|
|
) -> dict[str, object] | None:
|
|
frame_idx = self._extract_int(metadata, "frame_count", fallback=0)
|
|
timestamp_ns = self._extract_timestamp(metadata)
|
|
|
|
track_fn_obj = getattr(self._detector, "track", None)
|
|
if not callable(track_fn_obj):
|
|
raise RuntimeError("YOLO detector does not expose a callable track()")
|
|
|
|
track_fn = cast(_TrackCallable, track_fn_obj)
|
|
detections = track_fn(
|
|
frame,
|
|
persist=True,
|
|
verbose=False,
|
|
device=self._device,
|
|
classes=[0],
|
|
)
|
|
first = self._first_result(detections)
|
|
if first is None:
|
|
return None
|
|
|
|
selected = self._select_silhouette(first)
|
|
if selected is None:
|
|
return None
|
|
|
|
silhouette, track_id = selected
|
|
|
|
# Store silhouette for export if in preprocess-only mode or if export requested
|
|
if self._silhouette_export_path is not None or self._preprocess_only:
|
|
self._silhouette_buffer.append(
|
|
{
|
|
"frame": frame_idx,
|
|
"track_id": track_id,
|
|
"timestamp_ns": timestamp_ns,
|
|
"silhouette": silhouette.copy(),
|
|
}
|
|
)
|
|
|
|
# Visualize silhouette if requested
|
|
if self._silhouette_visualize_dir is not None:
|
|
self._visualize_silhouette(silhouette, frame_idx, track_id)
|
|
|
|
if self._preprocess_only:
|
|
return None
|
|
|
|
self._window.push(silhouette, frame_idx=frame_idx, track_id=track_id)
|
|
|
|
if not self._window.should_classify():
|
|
return None
|
|
|
|
window_tensor = self._window.get_tensor(device=self._device)
|
|
label, confidence = cast(
|
|
tuple[str, float],
|
|
self._classifier.predict(window_tensor),
|
|
)
|
|
self._window.mark_classified()
|
|
|
|
window_start = frame_idx - self._window.window_size + 1
|
|
result = create_result(
|
|
frame=frame_idx,
|
|
track_id=track_id,
|
|
label=label,
|
|
confidence=float(confidence),
|
|
window=(max(0, window_start), frame_idx),
|
|
timestamp_ns=timestamp_ns,
|
|
)
|
|
|
|
# Store result for export if export path specified
|
|
if self._result_export_path is not None:
|
|
self._result_buffer.append(result)
|
|
|
|
self._publisher.publish(result)
|
|
return result
|
|
|
|
def run(self) -> int:
|
|
frame_count = 0
|
|
start_time = time.perf_counter()
|
|
try:
|
|
for item in self._source:
|
|
frame, metadata = item
|
|
frame_u8 = np.asarray(frame, dtype=np.uint8)
|
|
frame_idx = self._extract_int(metadata, "frame_count", fallback=0)
|
|
frame_count += 1
|
|
try:
|
|
_ = self.process_frame(frame_u8, metadata)
|
|
except Exception as frame_error:
|
|
logger.warning(
|
|
"Skipping frame %d due to processing error: %s",
|
|
frame_idx,
|
|
frame_error,
|
|
)
|
|
if frame_count % 100 == 0:
|
|
elapsed = time.perf_counter() - start_time
|
|
fps = frame_count / elapsed if elapsed > 0 else 0.0
|
|
logger.info("Processed %d frames (%.2f FPS)", frame_count, fps)
|
|
return 0
|
|
except KeyboardInterrupt:
|
|
logger.info("Interrupted by user, shutting down cleanly.")
|
|
return 130
|
|
finally:
|
|
self.close()
|
|
|
|
def close(self) -> None:
|
|
if self._closed:
|
|
return
|
|
|
|
# Export silhouettes if requested
|
|
if self._silhouette_export_path is not None and self._silhouette_buffer:
|
|
self._export_silhouettes()
|
|
|
|
# Export results if requested
|
|
if self._result_export_path is not None and self._result_buffer:
|
|
self._export_results()
|
|
|
|
close_fn = getattr(self._publisher, "close", None)
|
|
if callable(close_fn):
|
|
with suppress(Exception):
|
|
_ = close_fn()
|
|
self._closed = True
|
|
|
|
def _export_silhouettes(self) -> None:
|
|
"""Export silhouettes to file in specified format."""
|
|
if self._silhouette_export_path is None:
|
|
return
|
|
|
|
self._silhouette_export_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
if self._silhouette_export_format == "pickle":
|
|
import pickle
|
|
|
|
with open(self._silhouette_export_path, "wb") as f:
|
|
pickle.dump(self._silhouette_buffer, f)
|
|
logger.info(
|
|
"Exported %d silhouettes to %s",
|
|
len(self._silhouette_buffer),
|
|
self._silhouette_export_path,
|
|
)
|
|
elif self._silhouette_export_format == "parquet":
|
|
self._export_parquet_silhouettes()
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported silhouette export format: {self._silhouette_export_format}"
|
|
)
|
|
|
|
def _visualize_silhouette(
|
|
self,
|
|
silhouette: Float[ndarray, "64 44"],
|
|
frame_idx: int,
|
|
track_id: int,
|
|
) -> None:
|
|
"""Save silhouette as PNG image."""
|
|
if self._silhouette_visualize_dir is None:
|
|
return
|
|
|
|
self._silhouette_visualize_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Convert float silhouette to uint8 (0-255)
|
|
silhouette_u8 = (silhouette * 255).astype(np.uint8)
|
|
|
|
# Create deterministic filename
|
|
filename = f"silhouette_frame{frame_idx:06d}_track{track_id:04d}.png"
|
|
output_path = self._silhouette_visualize_dir / filename
|
|
|
|
# Save using PIL
|
|
from PIL import Image
|
|
|
|
Image.fromarray(silhouette_u8).save(output_path)
|
|
|
|
def _export_parquet_silhouettes(self) -> None:
|
|
"""Export silhouettes to parquet format."""
|
|
import importlib
|
|
|
|
try:
|
|
pa = importlib.import_module("pyarrow")
|
|
pq = importlib.import_module("pyarrow.parquet")
|
|
except ImportError as e:
|
|
raise RuntimeError(
|
|
"Parquet export requires pyarrow. Install with: pip install pyarrow"
|
|
) from e
|
|
|
|
# Convert silhouettes to columnar format
|
|
frames = []
|
|
track_ids = []
|
|
timestamps = []
|
|
silhouettes = []
|
|
|
|
for item in self._silhouette_buffer:
|
|
frames.append(item["frame"])
|
|
track_ids.append(item["track_id"])
|
|
timestamps.append(item["timestamp_ns"])
|
|
silhouette_array = cast(ndarray, item["silhouette"])
|
|
silhouettes.append(silhouette_array.flatten().tolist())
|
|
|
|
table = pa.table(
|
|
{
|
|
"frame": pa.array(frames, type=pa.int64()),
|
|
"track_id": pa.array(track_ids, type=pa.int64()),
|
|
"timestamp_ns": pa.array(timestamps, type=pa.int64()),
|
|
"silhouette": pa.array(silhouettes, type=pa.list_(pa.float64())),
|
|
}
|
|
)
|
|
|
|
pq.write_table(table, self._silhouette_export_path)
|
|
logger.info(
|
|
"Exported %d silhouettes to parquet: %s",
|
|
len(self._silhouette_buffer),
|
|
self._silhouette_export_path,
|
|
)
|
|
|
|
def _export_results(self) -> None:
|
|
"""Export results to file in specified format."""
|
|
if self._result_export_path is None:
|
|
return
|
|
|
|
self._result_export_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
if self._result_export_format == "json":
|
|
import json
|
|
|
|
with open(self._result_export_path, "w", encoding="utf-8") as f:
|
|
for result in self._result_buffer:
|
|
f.write(json.dumps(result, ensure_ascii=False, default=str) + "\n")
|
|
logger.info(
|
|
"Exported %d results to JSON: %s",
|
|
len(self._result_buffer),
|
|
self._result_export_path,
|
|
)
|
|
elif self._result_export_format == "pickle":
|
|
import pickle
|
|
|
|
with open(self._result_export_path, "wb") as f:
|
|
pickle.dump(self._result_buffer, f)
|
|
logger.info(
|
|
"Exported %d results to pickle: %s",
|
|
len(self._result_buffer),
|
|
self._result_export_path,
|
|
)
|
|
elif self._result_export_format == "parquet":
|
|
self._export_parquet_results()
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported result export format: {self._result_export_format}"
|
|
)
|
|
|
|
def _export_parquet_results(self) -> None:
|
|
"""Export results to parquet format."""
|
|
import importlib
|
|
|
|
try:
|
|
pa = importlib.import_module("pyarrow")
|
|
pq = importlib.import_module("pyarrow.parquet")
|
|
except ImportError as e:
|
|
raise RuntimeError(
|
|
"Parquet export requires pyarrow. Install with: pip install pyarrow"
|
|
) from e
|
|
|
|
frames = []
|
|
track_ids = []
|
|
labels = []
|
|
confidences = []
|
|
windows = []
|
|
timestamps = []
|
|
|
|
for result in self._result_buffer:
|
|
frames.append(result["frame"])
|
|
track_ids.append(result["track_id"])
|
|
labels.append(result["label"])
|
|
confidences.append(result["confidence"])
|
|
windows.append(result["window"])
|
|
timestamps.append(result["timestamp_ns"])
|
|
|
|
table = pa.table(
|
|
{
|
|
"frame": pa.array(frames, type=pa.int64()),
|
|
"track_id": pa.array(track_ids, type=pa.int64()),
|
|
"label": pa.array(labels, type=pa.string()),
|
|
"confidence": pa.array(confidences, type=pa.float64()),
|
|
"window": pa.array(windows, type=pa.int64()),
|
|
"timestamp_ns": pa.array(timestamps, type=pa.int64()),
|
|
}
|
|
)
|
|
|
|
pq.write_table(table, self._result_export_path)
|
|
logger.info(
|
|
"Exported %d results to parquet: %s",
|
|
len(self._result_buffer),
|
|
self._result_export_path,
|
|
)
|
|
|
|
|
|
def validate_runtime_inputs(source: str, checkpoint: str, config: str) -> None:
|
|
if source.startswith("cvmmap://") or source.isdigit():
|
|
pass
|
|
else:
|
|
source_path = Path(source)
|
|
if not source_path.is_file():
|
|
raise ValueError(f"Video source not found: {source}")
|
|
|
|
checkpoint_path = Path(checkpoint)
|
|
if not checkpoint_path.is_file():
|
|
raise ValueError(f"Checkpoint not found: {checkpoint}")
|
|
|
|
config_path = Path(config)
|
|
if not config_path.is_file():
|
|
raise ValueError(f"Config not found: {config}")
|
|
|
|
|
|
@click.command(context_settings={"help_option_names": ["-h", "--help"]})
|
|
@click.option("--source", type=str, required=True)
|
|
@click.option("--checkpoint", type=str, required=True)
|
|
@click.option(
|
|
"--config",
|
|
type=str,
|
|
default="configs/sconet/sconet_scoliosis1k.yaml",
|
|
show_default=True,
|
|
)
|
|
@click.option("--device", type=str, default="cuda:0", show_default=True)
|
|
@click.option("--yolo-model", type=str, default="yolo11n-seg.pt", show_default=True)
|
|
@click.option("--window", type=click.IntRange(min=1), default=30, show_default=True)
|
|
@click.option("--stride", type=click.IntRange(min=1), default=30, show_default=True)
|
|
@click.option("--nats-url", type=str, default=None)
|
|
@click.option(
|
|
"--nats-subject",
|
|
type=str,
|
|
default="scoliosis.result",
|
|
show_default=True,
|
|
)
|
|
@click.option("--max-frames", type=click.IntRange(min=1), default=None)
|
|
@click.option(
|
|
"--preprocess-only",
|
|
is_flag=True,
|
|
default=False,
|
|
help="Only preprocess silhouettes, skip classification.",
|
|
)
|
|
@click.option(
|
|
"--silhouette-export-path",
|
|
type=str,
|
|
default=None,
|
|
help="Path to export silhouettes (required for preprocess-only mode).",
|
|
)
|
|
@click.option(
|
|
"--silhouette-export-format",
|
|
type=click.Choice(["pickle", "parquet"]),
|
|
default="pickle",
|
|
show_default=True,
|
|
help="Format for silhouette export.",
|
|
)
|
|
@click.option(
|
|
"--result-export-path",
|
|
type=str,
|
|
default=None,
|
|
help="Path to export inference results.",
|
|
)
|
|
@click.option(
|
|
"--result-export-format",
|
|
type=click.Choice(["json", "pickle", "parquet"]),
|
|
default="json",
|
|
show_default=True,
|
|
help="Format for result export.",
|
|
)
|
|
@click.option(
|
|
"--silhouette-visualize-dir",
|
|
type=str,
|
|
default=None,
|
|
help="Directory to save silhouette PNG visualizations.",
|
|
)
|
|
def main(
|
|
source: str,
|
|
checkpoint: str,
|
|
config: str,
|
|
device: str,
|
|
yolo_model: str,
|
|
window: int,
|
|
stride: int,
|
|
nats_url: str | None,
|
|
nats_subject: str,
|
|
max_frames: int | None,
|
|
preprocess_only: bool,
|
|
silhouette_export_path: str | None,
|
|
silhouette_export_format: str,
|
|
result_export_path: str | None,
|
|
result_export_format: str,
|
|
silhouette_visualize_dir: str | None,
|
|
) -> None:
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
|
)
|
|
|
|
# Validate preprocess-only mode requirements
|
|
if preprocess_only and not silhouette_export_path:
|
|
raise click.UsageError(
|
|
"--silhouette-export-path is required when using --preprocess-only"
|
|
)
|
|
|
|
try:
|
|
validate_runtime_inputs(source=source, checkpoint=checkpoint, config=config)
|
|
pipeline = ScoliosisPipeline(
|
|
source=source,
|
|
checkpoint=checkpoint,
|
|
config=config,
|
|
device=device,
|
|
yolo_model=yolo_model,
|
|
window=window,
|
|
stride=stride,
|
|
nats_url=nats_url,
|
|
nats_subject=nats_subject,
|
|
max_frames=max_frames,
|
|
preprocess_only=preprocess_only,
|
|
silhouette_export_path=silhouette_export_path,
|
|
silhouette_export_format=silhouette_export_format,
|
|
silhouette_visualize_dir=silhouette_visualize_dir,
|
|
result_export_path=result_export_path,
|
|
result_export_format=result_export_format,
|
|
)
|
|
raise SystemExit(pipeline.run())
|
|
except ValueError as err:
|
|
click.echo(f"Error: {err}", err=True)
|
|
raise SystemExit(2) from err
|
|
except RuntimeError as err:
|
|
click.echo(f"Runtime error: {err}", err=True)
|
|
raise SystemExit(1) from err
|