feat(demo): add export and silhouette visualization outputs
Add preprocess-only silhouette export and configurable result exporters so demo runs can be persisted for offline analysis and reproducible evaluation. Include optional parquet support and CLI visualization dumps while updating tests and tracking notes for the verified pipeline/debug workflow.
This commit is contained in:
@@ -70,6 +70,14 @@ class ScoliosisPipeline:
|
||||
_classifier: ScoNetDemo
|
||||
_device: str
|
||||
_closed: bool
|
||||
_preprocess_only: bool
|
||||
_silhouette_export_path: Path | None
|
||||
_silhouette_export_format: str
|
||||
_silhouette_buffer: list[dict[str, object]]
|
||||
_silhouette_visualize_dir: Path | None
|
||||
_result_export_path: Path | None
|
||||
_result_export_format: str
|
||||
_result_buffer: list[dict[str, object]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -84,6 +92,12 @@ class ScoliosisPipeline:
|
||||
nats_url: str | None,
|
||||
nats_subject: str,
|
||||
max_frames: int | None,
|
||||
preprocess_only: bool = False,
|
||||
silhouette_export_path: str | None = None,
|
||||
silhouette_export_format: str = "pickle",
|
||||
silhouette_visualize_dir: str | None = None,
|
||||
result_export_path: str | None = None,
|
||||
result_export_format: str = "json",
|
||||
) -> None:
|
||||
self._detector = YOLO(yolo_model)
|
||||
self._source = create_source(source, max_frames=max_frames)
|
||||
@@ -96,6 +110,20 @@ class ScoliosisPipeline:
|
||||
)
|
||||
self._device = device
|
||||
self._closed = False
|
||||
self._preprocess_only = preprocess_only
|
||||
self._silhouette_export_path = (
|
||||
Path(silhouette_export_path) if silhouette_export_path else None
|
||||
)
|
||||
self._silhouette_export_format = silhouette_export_format
|
||||
self._silhouette_buffer = []
|
||||
self._silhouette_visualize_dir = (
|
||||
Path(silhouette_visualize_dir) if silhouette_visualize_dir else None
|
||||
)
|
||||
self._result_export_path = (
|
||||
Path(result_export_path) if result_export_path else None
|
||||
)
|
||||
self._result_export_format = result_export_format
|
||||
self._result_buffer = []
|
||||
|
||||
@staticmethod
|
||||
def _extract_int(meta: dict[str, object], key: str, fallback: int) -> int:
|
||||
@@ -185,6 +213,25 @@ class ScoliosisPipeline:
|
||||
return None
|
||||
|
||||
silhouette, track_id = selected
|
||||
|
||||
# Store silhouette for export if in preprocess-only mode or if export requested
|
||||
if self._silhouette_export_path is not None or self._preprocess_only:
|
||||
self._silhouette_buffer.append(
|
||||
{
|
||||
"frame": frame_idx,
|
||||
"track_id": track_id,
|
||||
"timestamp_ns": timestamp_ns,
|
||||
"silhouette": silhouette.copy(),
|
||||
}
|
||||
)
|
||||
|
||||
# Visualize silhouette if requested
|
||||
if self._silhouette_visualize_dir is not None:
|
||||
self._visualize_silhouette(silhouette, frame_idx, track_id)
|
||||
|
||||
if self._preprocess_only:
|
||||
return None
|
||||
|
||||
self._window.push(silhouette, frame_idx=frame_idx, track_id=track_id)
|
||||
|
||||
if not self._window.should_classify():
|
||||
@@ -206,6 +253,11 @@ class ScoliosisPipeline:
|
||||
window=(max(0, window_start), frame_idx),
|
||||
timestamp_ns=timestamp_ns,
|
||||
)
|
||||
|
||||
# Store result for export if export path specified
|
||||
if self._result_export_path is not None:
|
||||
self._result_buffer.append(result)
|
||||
|
||||
self._publisher.publish(result)
|
||||
return result
|
||||
|
||||
@@ -240,12 +292,190 @@ class ScoliosisPipeline:
|
||||
def close(self) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
# Export silhouettes if requested
|
||||
if self._silhouette_export_path is not None and self._silhouette_buffer:
|
||||
self._export_silhouettes()
|
||||
|
||||
# Export results if requested
|
||||
if self._result_export_path is not None and self._result_buffer:
|
||||
self._export_results()
|
||||
|
||||
close_fn = getattr(self._publisher, "close", None)
|
||||
if callable(close_fn):
|
||||
with suppress(Exception):
|
||||
_ = close_fn()
|
||||
self._closed = True
|
||||
|
||||
def _export_silhouettes(self) -> None:
|
||||
"""Export silhouettes to file in specified format."""
|
||||
if self._silhouette_export_path is None:
|
||||
return
|
||||
|
||||
self._silhouette_export_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if self._silhouette_export_format == "pickle":
|
||||
import pickle
|
||||
|
||||
with open(self._silhouette_export_path, "wb") as f:
|
||||
pickle.dump(self._silhouette_buffer, f)
|
||||
logger.info(
|
||||
"Exported %d silhouettes to %s",
|
||||
len(self._silhouette_buffer),
|
||||
self._silhouette_export_path,
|
||||
)
|
||||
elif self._silhouette_export_format == "parquet":
|
||||
self._export_parquet_silhouettes()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported silhouette export format: {self._silhouette_export_format}"
|
||||
)
|
||||
|
||||
def _visualize_silhouette(
|
||||
self,
|
||||
silhouette: Float[ndarray, "64 44"],
|
||||
frame_idx: int,
|
||||
track_id: int,
|
||||
) -> None:
|
||||
"""Save silhouette as PNG image."""
|
||||
if self._silhouette_visualize_dir is None:
|
||||
return
|
||||
|
||||
self._silhouette_visualize_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Convert float silhouette to uint8 (0-255)
|
||||
silhouette_u8 = (silhouette * 255).astype(np.uint8)
|
||||
|
||||
# Create deterministic filename
|
||||
filename = f"silhouette_frame{frame_idx:06d}_track{track_id:04d}.png"
|
||||
output_path = self._silhouette_visualize_dir / filename
|
||||
|
||||
# Save using PIL
|
||||
from PIL import Image
|
||||
|
||||
Image.fromarray(silhouette_u8).save(output_path)
|
||||
|
||||
def _export_parquet_silhouettes(self) -> None:
|
||||
"""Export silhouettes to parquet format."""
|
||||
import importlib
|
||||
|
||||
try:
|
||||
pa = importlib.import_module("pyarrow")
|
||||
pq = importlib.import_module("pyarrow.parquet")
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"Parquet export requires pyarrow. Install with: pip install pyarrow"
|
||||
) from e
|
||||
|
||||
# Convert silhouettes to columnar format
|
||||
frames = []
|
||||
track_ids = []
|
||||
timestamps = []
|
||||
silhouettes = []
|
||||
|
||||
for item in self._silhouette_buffer:
|
||||
frames.append(item["frame"])
|
||||
track_ids.append(item["track_id"])
|
||||
timestamps.append(item["timestamp_ns"])
|
||||
silhouette_array = cast(ndarray, item["silhouette"])
|
||||
silhouettes.append(silhouette_array.flatten().tolist())
|
||||
|
||||
table = pa.table(
|
||||
{
|
||||
"frame": pa.array(frames, type=pa.int64()),
|
||||
"track_id": pa.array(track_ids, type=pa.int64()),
|
||||
"timestamp_ns": pa.array(timestamps, type=pa.int64()),
|
||||
"silhouette": pa.array(silhouettes, type=pa.list_(pa.float64())),
|
||||
}
|
||||
)
|
||||
|
||||
pq.write_table(table, self._silhouette_export_path)
|
||||
logger.info(
|
||||
"Exported %d silhouettes to parquet: %s",
|
||||
len(self._silhouette_buffer),
|
||||
self._silhouette_export_path,
|
||||
)
|
||||
|
||||
def _export_results(self) -> None:
|
||||
"""Export results to file in specified format."""
|
||||
if self._result_export_path is None:
|
||||
return
|
||||
|
||||
self._result_export_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if self._result_export_format == "json":
|
||||
import json
|
||||
|
||||
with open(self._result_export_path, "w", encoding="utf-8") as f:
|
||||
for result in self._result_buffer:
|
||||
f.write(json.dumps(result, ensure_ascii=False, default=str) + "\n")
|
||||
logger.info(
|
||||
"Exported %d results to JSON: %s",
|
||||
len(self._result_buffer),
|
||||
self._result_export_path,
|
||||
)
|
||||
elif self._result_export_format == "pickle":
|
||||
import pickle
|
||||
|
||||
with open(self._result_export_path, "wb") as f:
|
||||
pickle.dump(self._result_buffer, f)
|
||||
logger.info(
|
||||
"Exported %d results to pickle: %s",
|
||||
len(self._result_buffer),
|
||||
self._result_export_path,
|
||||
)
|
||||
elif self._result_export_format == "parquet":
|
||||
self._export_parquet_results()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported result export format: {self._result_export_format}"
|
||||
)
|
||||
|
||||
def _export_parquet_results(self) -> None:
|
||||
"""Export results to parquet format."""
|
||||
import importlib
|
||||
|
||||
try:
|
||||
pa = importlib.import_module("pyarrow")
|
||||
pq = importlib.import_module("pyarrow.parquet")
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"Parquet export requires pyarrow. Install with: pip install pyarrow"
|
||||
) from e
|
||||
|
||||
frames = []
|
||||
track_ids = []
|
||||
labels = []
|
||||
confidences = []
|
||||
windows = []
|
||||
timestamps = []
|
||||
|
||||
for result in self._result_buffer:
|
||||
frames.append(result["frame"])
|
||||
track_ids.append(result["track_id"])
|
||||
labels.append(result["label"])
|
||||
confidences.append(result["confidence"])
|
||||
windows.append(result["window"])
|
||||
timestamps.append(result["timestamp_ns"])
|
||||
|
||||
table = pa.table(
|
||||
{
|
||||
"frame": pa.array(frames, type=pa.int64()),
|
||||
"track_id": pa.array(track_ids, type=pa.int64()),
|
||||
"label": pa.array(labels, type=pa.string()),
|
||||
"confidence": pa.array(confidences, type=pa.float64()),
|
||||
"window": pa.array(windows, type=pa.int64()),
|
||||
"timestamp_ns": pa.array(timestamps, type=pa.int64()),
|
||||
}
|
||||
)
|
||||
|
||||
pq.write_table(table, self._result_export_path)
|
||||
logger.info(
|
||||
"Exported %d results to parquet: %s",
|
||||
len(self._result_buffer),
|
||||
self._result_export_path,
|
||||
)
|
||||
|
||||
|
||||
def validate_runtime_inputs(source: str, checkpoint: str, config: str) -> None:
|
||||
if source.startswith("cvmmap://") or source.isdigit():
|
||||
@@ -285,6 +515,44 @@ def validate_runtime_inputs(source: str, checkpoint: str, config: str) -> None:
|
||||
show_default=True,
|
||||
)
|
||||
@click.option("--max-frames", type=click.IntRange(min=1), default=None)
|
||||
@click.option(
|
||||
"--preprocess-only",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Only preprocess silhouettes, skip classification.",
|
||||
)
|
||||
@click.option(
|
||||
"--silhouette-export-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to export silhouettes (required for preprocess-only mode).",
|
||||
)
|
||||
@click.option(
|
||||
"--silhouette-export-format",
|
||||
type=click.Choice(["pickle", "parquet"]),
|
||||
default="pickle",
|
||||
show_default=True,
|
||||
help="Format for silhouette export.",
|
||||
)
|
||||
@click.option(
|
||||
"--result-export-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to export inference results.",
|
||||
)
|
||||
@click.option(
|
||||
"--result-export-format",
|
||||
type=click.Choice(["json", "pickle", "parquet"]),
|
||||
default="json",
|
||||
show_default=True,
|
||||
help="Format for result export.",
|
||||
)
|
||||
@click.option(
|
||||
"--silhouette-visualize-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory to save silhouette PNG visualizations.",
|
||||
)
|
||||
def main(
|
||||
source: str,
|
||||
checkpoint: str,
|
||||
@@ -296,12 +564,24 @@ def main(
|
||||
nats_url: str | None,
|
||||
nats_subject: str,
|
||||
max_frames: int | None,
|
||||
preprocess_only: bool,
|
||||
silhouette_export_path: str | None,
|
||||
silhouette_export_format: str,
|
||||
result_export_path: str | None,
|
||||
result_export_format: str,
|
||||
silhouette_visualize_dir: str | None,
|
||||
) -> None:
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
)
|
||||
|
||||
# Validate preprocess-only mode requirements
|
||||
if preprocess_only and not silhouette_export_path:
|
||||
raise click.UsageError(
|
||||
"--silhouette-export-path is required when using --preprocess-only"
|
||||
)
|
||||
|
||||
try:
|
||||
validate_runtime_inputs(source=source, checkpoint=checkpoint, config=config)
|
||||
pipeline = ScoliosisPipeline(
|
||||
@@ -315,6 +595,12 @@ def main(
|
||||
nats_url=nats_url,
|
||||
nats_subject=nats_subject,
|
||||
max_frames=max_frames,
|
||||
preprocess_only=preprocess_only,
|
||||
silhouette_export_path=silhouette_export_path,
|
||||
silhouette_export_format=silhouette_export_format,
|
||||
silhouette_visualize_dir=silhouette_visualize_dir,
|
||||
result_export_path=result_export_path,
|
||||
result_export_format=result_export_format,
|
||||
)
|
||||
raise SystemExit(pipeline.run())
|
||||
except ValueError as err:
|
||||
|
||||
+63
-10
@@ -5,7 +5,7 @@ with track ID tracking and gap detection.
|
||||
"""
|
||||
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING, Protocol, final
|
||||
from typing import TYPE_CHECKING, Protocol, cast, final
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -20,6 +20,9 @@ if TYPE_CHECKING:
|
||||
SIL_HEIGHT: int = 64
|
||||
SIL_WIDTH: int = 44
|
||||
|
||||
# Type alias for array-like inputs
|
||||
type _ArrayLike = torch.Tensor | ndarray
|
||||
|
||||
|
||||
class _Boxes(Protocol):
|
||||
"""Protocol for boxes with xyxy and id attributes."""
|
||||
@@ -207,6 +210,33 @@ class SilhouetteWindow:
|
||||
return len(self._buffer) / self.window_size
|
||||
|
||||
|
||||
def _to_numpy(obj: _ArrayLike) -> ndarray:
|
||||
"""Safely convert array-like object to numpy array.
|
||||
|
||||
Handles torch tensors (CPU or CUDA) by detaching and moving to CPU first.
|
||||
Falls back to np.asarray for other array-like objects.
|
||||
|
||||
Args:
|
||||
obj: Array-like object (numpy array, torch tensor, or similar).
|
||||
|
||||
Returns:
|
||||
Numpy array representation of the input.
|
||||
"""
|
||||
# Handle torch tensors (including CUDA tensors)
|
||||
detach_fn = getattr(obj, "detach", None)
|
||||
if detach_fn is not None and callable(detach_fn):
|
||||
# It's a torch tensor
|
||||
tensor = detach_fn()
|
||||
cpu_fn = getattr(tensor, "cpu", None)
|
||||
if cpu_fn is not None and callable(cpu_fn):
|
||||
tensor = cpu_fn()
|
||||
numpy_fn = getattr(tensor, "numpy", None)
|
||||
if numpy_fn is not None and callable(numpy_fn):
|
||||
return cast(ndarray, numpy_fn())
|
||||
# Fall back to np.asarray for other array-like objects
|
||||
return cast(ndarray, np.asarray(obj))
|
||||
|
||||
|
||||
def select_person(
|
||||
results: _DetectionResults,
|
||||
) -> tuple[ndarray, tuple[int, int, int, int], int] | None:
|
||||
@@ -232,7 +262,7 @@ def select_person(
|
||||
if track_ids_obj is None:
|
||||
return None
|
||||
|
||||
track_ids: ndarray = np.asarray(track_ids_obj)
|
||||
track_ids: ndarray = _to_numpy(cast(ndarray, track_ids_obj))
|
||||
if track_ids.size == 0:
|
||||
return None
|
||||
|
||||
@@ -241,7 +271,7 @@ def select_person(
|
||||
if xyxy_obj is None:
|
||||
return None
|
||||
|
||||
bboxes: ndarray = np.asarray(xyxy_obj)
|
||||
bboxes: ndarray = _to_numpy(cast(ndarray, xyxy_obj))
|
||||
if bboxes.ndim == 1:
|
||||
bboxes = bboxes.reshape(1, -1)
|
||||
|
||||
@@ -257,7 +287,7 @@ def select_person(
|
||||
if masks_data is None:
|
||||
return None
|
||||
|
||||
masks: ndarray = np.asarray(masks_data)
|
||||
masks: ndarray = _to_numpy(cast(ndarray, masks_data))
|
||||
if masks.ndim == 2:
|
||||
masks = masks[np.newaxis, ...]
|
||||
|
||||
@@ -284,12 +314,35 @@ def select_person(
|
||||
|
||||
# Extract mask and bbox
|
||||
mask: "NDArray[np.float32]" = masks[best_idx]
|
||||
bbox = (
|
||||
int(float(bboxes[best_idx][0])),
|
||||
int(float(bboxes[best_idx][1])),
|
||||
int(float(bboxes[best_idx][2])),
|
||||
int(float(bboxes[best_idx][3])),
|
||||
)
|
||||
mask_shape = mask.shape
|
||||
mask_h, mask_w = int(mask_shape[0]), int(mask_shape[1])
|
||||
|
||||
# Get original image dimensions from results (YOLO provides this)
|
||||
orig_shape = getattr(results, "orig_shape", None)
|
||||
# Validate orig_shape is a sequence of at least 2 numeric values
|
||||
if (
|
||||
orig_shape is not None
|
||||
and isinstance(orig_shape, (tuple, list))
|
||||
and len(orig_shape) >= 2
|
||||
):
|
||||
frame_h, frame_w = int(orig_shape[0]), int(orig_shape[1])
|
||||
# Scale bbox from frame space to mask space
|
||||
scale_x = mask_w / frame_w if frame_w > 0 else 1.0
|
||||
scale_y = mask_h / frame_h if frame_h > 0 else 1.0
|
||||
bbox = (
|
||||
int(float(bboxes[best_idx][0]) * scale_x),
|
||||
int(float(bboxes[best_idx][1]) * scale_y),
|
||||
int(float(bboxes[best_idx][2]) * scale_x),
|
||||
int(float(bboxes[best_idx][3]) * scale_y),
|
||||
)
|
||||
else:
|
||||
# Fallback: use bbox as-is (assume same coordinate space)
|
||||
bbox = (
|
||||
int(float(bboxes[best_idx][0])),
|
||||
int(float(bboxes[best_idx][1])),
|
||||
int(float(bboxes[best_idx][2])),
|
||||
int(float(bboxes[best_idx][3])),
|
||||
)
|
||||
track_id = int(track_ids[best_idx]) if best_idx < len(track_ids) else best_idx
|
||||
|
||||
return mask, bbox, track_id
|
||||
|
||||
Reference in New Issue
Block a user