feat(demo): add export and silhouette visualization outputs
Add preprocess-only silhouette export and configurable result exporters so demo runs can be persisted for offline analysis and reproducible evaluation. Include optional parquet support and CLI visualization dumps while updating tests and tracking notes for the verified pipeline/debug workflow.
This commit is contained in:
@@ -70,6 +70,14 @@ class ScoliosisPipeline:
|
||||
_classifier: ScoNetDemo
|
||||
_device: str
|
||||
_closed: bool
|
||||
_preprocess_only: bool
|
||||
_silhouette_export_path: Path | None
|
||||
_silhouette_export_format: str
|
||||
_silhouette_buffer: list[dict[str, object]]
|
||||
_silhouette_visualize_dir: Path | None
|
||||
_result_export_path: Path | None
|
||||
_result_export_format: str
|
||||
_result_buffer: list[dict[str, object]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -84,6 +92,12 @@ class ScoliosisPipeline:
|
||||
nats_url: str | None,
|
||||
nats_subject: str,
|
||||
max_frames: int | None,
|
||||
preprocess_only: bool = False,
|
||||
silhouette_export_path: str | None = None,
|
||||
silhouette_export_format: str = "pickle",
|
||||
silhouette_visualize_dir: str | None = None,
|
||||
result_export_path: str | None = None,
|
||||
result_export_format: str = "json",
|
||||
) -> None:
|
||||
self._detector = YOLO(yolo_model)
|
||||
self._source = create_source(source, max_frames=max_frames)
|
||||
@@ -96,6 +110,20 @@ class ScoliosisPipeline:
|
||||
)
|
||||
self._device = device
|
||||
self._closed = False
|
||||
self._preprocess_only = preprocess_only
|
||||
self._silhouette_export_path = (
|
||||
Path(silhouette_export_path) if silhouette_export_path else None
|
||||
)
|
||||
self._silhouette_export_format = silhouette_export_format
|
||||
self._silhouette_buffer = []
|
||||
self._silhouette_visualize_dir = (
|
||||
Path(silhouette_visualize_dir) if silhouette_visualize_dir else None
|
||||
)
|
||||
self._result_export_path = (
|
||||
Path(result_export_path) if result_export_path else None
|
||||
)
|
||||
self._result_export_format = result_export_format
|
||||
self._result_buffer = []
|
||||
|
||||
@staticmethod
|
||||
def _extract_int(meta: dict[str, object], key: str, fallback: int) -> int:
|
||||
@@ -185,6 +213,25 @@ class ScoliosisPipeline:
|
||||
return None
|
||||
|
||||
silhouette, track_id = selected
|
||||
|
||||
# Store silhouette for export if in preprocess-only mode or if export requested
|
||||
if self._silhouette_export_path is not None or self._preprocess_only:
|
||||
self._silhouette_buffer.append(
|
||||
{
|
||||
"frame": frame_idx,
|
||||
"track_id": track_id,
|
||||
"timestamp_ns": timestamp_ns,
|
||||
"silhouette": silhouette.copy(),
|
||||
}
|
||||
)
|
||||
|
||||
# Visualize silhouette if requested
|
||||
if self._silhouette_visualize_dir is not None:
|
||||
self._visualize_silhouette(silhouette, frame_idx, track_id)
|
||||
|
||||
if self._preprocess_only:
|
||||
return None
|
||||
|
||||
self._window.push(silhouette, frame_idx=frame_idx, track_id=track_id)
|
||||
|
||||
if not self._window.should_classify():
|
||||
@@ -206,6 +253,11 @@ class ScoliosisPipeline:
|
||||
window=(max(0, window_start), frame_idx),
|
||||
timestamp_ns=timestamp_ns,
|
||||
)
|
||||
|
||||
# Store result for export if export path specified
|
||||
if self._result_export_path is not None:
|
||||
self._result_buffer.append(result)
|
||||
|
||||
self._publisher.publish(result)
|
||||
return result
|
||||
|
||||
@@ -240,12 +292,190 @@ class ScoliosisPipeline:
|
||||
def close(self) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
# Export silhouettes if requested
|
||||
if self._silhouette_export_path is not None and self._silhouette_buffer:
|
||||
self._export_silhouettes()
|
||||
|
||||
# Export results if requested
|
||||
if self._result_export_path is not None and self._result_buffer:
|
||||
self._export_results()
|
||||
|
||||
close_fn = getattr(self._publisher, "close", None)
|
||||
if callable(close_fn):
|
||||
with suppress(Exception):
|
||||
_ = close_fn()
|
||||
self._closed = True
|
||||
|
||||
def _export_silhouettes(self) -> None:
|
||||
"""Export silhouettes to file in specified format."""
|
||||
if self._silhouette_export_path is None:
|
||||
return
|
||||
|
||||
self._silhouette_export_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if self._silhouette_export_format == "pickle":
|
||||
import pickle
|
||||
|
||||
with open(self._silhouette_export_path, "wb") as f:
|
||||
pickle.dump(self._silhouette_buffer, f)
|
||||
logger.info(
|
||||
"Exported %d silhouettes to %s",
|
||||
len(self._silhouette_buffer),
|
||||
self._silhouette_export_path,
|
||||
)
|
||||
elif self._silhouette_export_format == "parquet":
|
||||
self._export_parquet_silhouettes()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported silhouette export format: {self._silhouette_export_format}"
|
||||
)
|
||||
|
||||
def _visualize_silhouette(
|
||||
self,
|
||||
silhouette: Float[ndarray, "64 44"],
|
||||
frame_idx: int,
|
||||
track_id: int,
|
||||
) -> None:
|
||||
"""Save silhouette as PNG image."""
|
||||
if self._silhouette_visualize_dir is None:
|
||||
return
|
||||
|
||||
self._silhouette_visualize_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Convert float silhouette to uint8 (0-255)
|
||||
silhouette_u8 = (silhouette * 255).astype(np.uint8)
|
||||
|
||||
# Create deterministic filename
|
||||
filename = f"silhouette_frame{frame_idx:06d}_track{track_id:04d}.png"
|
||||
output_path = self._silhouette_visualize_dir / filename
|
||||
|
||||
# Save using PIL
|
||||
from PIL import Image
|
||||
|
||||
Image.fromarray(silhouette_u8).save(output_path)
|
||||
|
||||
def _export_parquet_silhouettes(self) -> None:
|
||||
"""Export silhouettes to parquet format."""
|
||||
import importlib
|
||||
|
||||
try:
|
||||
pa = importlib.import_module("pyarrow")
|
||||
pq = importlib.import_module("pyarrow.parquet")
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"Parquet export requires pyarrow. Install with: pip install pyarrow"
|
||||
) from e
|
||||
|
||||
# Convert silhouettes to columnar format
|
||||
frames = []
|
||||
track_ids = []
|
||||
timestamps = []
|
||||
silhouettes = []
|
||||
|
||||
for item in self._silhouette_buffer:
|
||||
frames.append(item["frame"])
|
||||
track_ids.append(item["track_id"])
|
||||
timestamps.append(item["timestamp_ns"])
|
||||
silhouette_array = cast(ndarray, item["silhouette"])
|
||||
silhouettes.append(silhouette_array.flatten().tolist())
|
||||
|
||||
table = pa.table(
|
||||
{
|
||||
"frame": pa.array(frames, type=pa.int64()),
|
||||
"track_id": pa.array(track_ids, type=pa.int64()),
|
||||
"timestamp_ns": pa.array(timestamps, type=pa.int64()),
|
||||
"silhouette": pa.array(silhouettes, type=pa.list_(pa.float64())),
|
||||
}
|
||||
)
|
||||
|
||||
pq.write_table(table, self._silhouette_export_path)
|
||||
logger.info(
|
||||
"Exported %d silhouettes to parquet: %s",
|
||||
len(self._silhouette_buffer),
|
||||
self._silhouette_export_path,
|
||||
)
|
||||
|
||||
def _export_results(self) -> None:
|
||||
"""Export results to file in specified format."""
|
||||
if self._result_export_path is None:
|
||||
return
|
||||
|
||||
self._result_export_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if self._result_export_format == "json":
|
||||
import json
|
||||
|
||||
with open(self._result_export_path, "w", encoding="utf-8") as f:
|
||||
for result in self._result_buffer:
|
||||
f.write(json.dumps(result, ensure_ascii=False, default=str) + "\n")
|
||||
logger.info(
|
||||
"Exported %d results to JSON: %s",
|
||||
len(self._result_buffer),
|
||||
self._result_export_path,
|
||||
)
|
||||
elif self._result_export_format == "pickle":
|
||||
import pickle
|
||||
|
||||
with open(self._result_export_path, "wb") as f:
|
||||
pickle.dump(self._result_buffer, f)
|
||||
logger.info(
|
||||
"Exported %d results to pickle: %s",
|
||||
len(self._result_buffer),
|
||||
self._result_export_path,
|
||||
)
|
||||
elif self._result_export_format == "parquet":
|
||||
self._export_parquet_results()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported result export format: {self._result_export_format}"
|
||||
)
|
||||
|
||||
def _export_parquet_results(self) -> None:
|
||||
"""Export results to parquet format."""
|
||||
import importlib
|
||||
|
||||
try:
|
||||
pa = importlib.import_module("pyarrow")
|
||||
pq = importlib.import_module("pyarrow.parquet")
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"Parquet export requires pyarrow. Install with: pip install pyarrow"
|
||||
) from e
|
||||
|
||||
frames = []
|
||||
track_ids = []
|
||||
labels = []
|
||||
confidences = []
|
||||
windows = []
|
||||
timestamps = []
|
||||
|
||||
for result in self._result_buffer:
|
||||
frames.append(result["frame"])
|
||||
track_ids.append(result["track_id"])
|
||||
labels.append(result["label"])
|
||||
confidences.append(result["confidence"])
|
||||
windows.append(result["window"])
|
||||
timestamps.append(result["timestamp_ns"])
|
||||
|
||||
table = pa.table(
|
||||
{
|
||||
"frame": pa.array(frames, type=pa.int64()),
|
||||
"track_id": pa.array(track_ids, type=pa.int64()),
|
||||
"label": pa.array(labels, type=pa.string()),
|
||||
"confidence": pa.array(confidences, type=pa.float64()),
|
||||
"window": pa.array(windows, type=pa.int64()),
|
||||
"timestamp_ns": pa.array(timestamps, type=pa.int64()),
|
||||
}
|
||||
)
|
||||
|
||||
pq.write_table(table, self._result_export_path)
|
||||
logger.info(
|
||||
"Exported %d results to parquet: %s",
|
||||
len(self._result_buffer),
|
||||
self._result_export_path,
|
||||
)
|
||||
|
||||
|
||||
def validate_runtime_inputs(source: str, checkpoint: str, config: str) -> None:
|
||||
if source.startswith("cvmmap://") or source.isdigit():
|
||||
@@ -285,6 +515,44 @@ def validate_runtime_inputs(source: str, checkpoint: str, config: str) -> None:
|
||||
show_default=True,
|
||||
)
|
||||
@click.option("--max-frames", type=click.IntRange(min=1), default=None)
|
||||
@click.option(
|
||||
"--preprocess-only",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Only preprocess silhouettes, skip classification.",
|
||||
)
|
||||
@click.option(
|
||||
"--silhouette-export-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to export silhouettes (required for preprocess-only mode).",
|
||||
)
|
||||
@click.option(
|
||||
"--silhouette-export-format",
|
||||
type=click.Choice(["pickle", "parquet"]),
|
||||
default="pickle",
|
||||
show_default=True,
|
||||
help="Format for silhouette export.",
|
||||
)
|
||||
@click.option(
|
||||
"--result-export-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to export inference results.",
|
||||
)
|
||||
@click.option(
|
||||
"--result-export-format",
|
||||
type=click.Choice(["json", "pickle", "parquet"]),
|
||||
default="json",
|
||||
show_default=True,
|
||||
help="Format for result export.",
|
||||
)
|
||||
@click.option(
|
||||
"--silhouette-visualize-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory to save silhouette PNG visualizations.",
|
||||
)
|
||||
def main(
|
||||
source: str,
|
||||
checkpoint: str,
|
||||
@@ -296,12 +564,24 @@ def main(
|
||||
nats_url: str | None,
|
||||
nats_subject: str,
|
||||
max_frames: int | None,
|
||||
preprocess_only: bool,
|
||||
silhouette_export_path: str | None,
|
||||
silhouette_export_format: str,
|
||||
result_export_path: str | None,
|
||||
result_export_format: str,
|
||||
silhouette_visualize_dir: str | None,
|
||||
) -> None:
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
)
|
||||
|
||||
# Validate preprocess-only mode requirements
|
||||
if preprocess_only and not silhouette_export_path:
|
||||
raise click.UsageError(
|
||||
"--silhouette-export-path is required when using --preprocess-only"
|
||||
)
|
||||
|
||||
try:
|
||||
validate_runtime_inputs(source=source, checkpoint=checkpoint, config=config)
|
||||
pipeline = ScoliosisPipeline(
|
||||
@@ -315,6 +595,12 @@ def main(
|
||||
nats_url=nats_url,
|
||||
nats_subject=nats_subject,
|
||||
max_frames=max_frames,
|
||||
preprocess_only=preprocess_only,
|
||||
silhouette_export_path=silhouette_export_path,
|
||||
silhouette_export_format=silhouette_export_format,
|
||||
silhouette_visualize_dir=silhouette_visualize_dir,
|
||||
result_export_path=result_export_path,
|
||||
result_export_format=result_export_format,
|
||||
)
|
||||
raise SystemExit(pipeline.run())
|
||||
except ValueError as err:
|
||||
|
||||
Reference in New Issue
Block a user