Files
zed-playground/py_workspace/aruco/svo_sync.py
T

260 lines
9.0 KiB
Python

import pyzed.sl as sl
import numpy as np
from dataclasses import dataclass
from typing import Any
import os
from loguru import logger
@dataclass
class FrameData:
"""Data structure for a single frame from an SVO."""
image: np.ndarray
timestamp_ns: int
frame_index: int
serial_number: int
depth_map: np.ndarray | None = None
confidence_map: np.ndarray | None = None
class SVOReader:
"""Handles synchronized playback of multiple SVO files."""
svo_paths: list[str]
runtime_params: sl.RuntimeParameters
_depth_mode: sl.DEPTH_MODE
def __init__(
self, svo_paths: list[str], depth_mode: sl.DEPTH_MODE = sl.DEPTH_MODE.NONE
):
self.svo_paths = svo_paths
self.cameras: list[sl.Camera] = []
self.camera_info: list[dict[str, Any]] = []
self.runtime_params = sl.RuntimeParameters()
self._depth_mode = depth_mode
for path in svo_paths:
if not os.path.exists(path):
print(f"Warning: SVO file not found: {path}")
continue
init_params = sl.InitParameters()
init_params.set_from_svo_file(path)
init_params.svo_real_time_mode = False
init_params.depth_mode = depth_mode
init_params.coordinate_units = sl.UNIT.METER
cam = sl.Camera()
status = cam.open(init_params)
if status != sl.ERROR_CODE.SUCCESS:
print(f"Error: Could not open {path}: {status}")
continue
info = cam.get_camera_information()
self.cameras.append(cam)
self.camera_info.append(
{
"serial": info.serial_number,
"fps": info.camera_configuration.fps,
"total_frames": cam.get_svo_number_of_frames(),
}
)
def sync_to_latest_start(self):
"""Aligns all SVOs to the timestamp of the latest starting SVO."""
if not self.cameras:
return
start_timestamps = []
for cam in self.cameras:
err = cam.grab(self.runtime_params)
if err == sl.ERROR_CODE.SUCCESS:
ts = cam.get_timestamp(sl.TIME_REFERENCE.IMAGE).get_nanoseconds()
start_timestamps.append(ts)
else:
start_timestamps.append(0)
if not start_timestamps:
return
max_start_ts = max(start_timestamps)
for i, cam in enumerate(self.cameras):
ts = start_timestamps[i]
if ts < max_start_ts:
diff_ns = max_start_ts - ts
fps = self.camera_info[i]["fps"]
frames_to_skip = int((diff_ns / 1_000_000_000) * fps)
cam.set_svo_position(frames_to_skip)
else:
cam.set_svo_position(0)
def grab_all(self) -> list[FrameData | None]:
"""Grabs a frame from all cameras without strict synchronization."""
frames: list[FrameData | None] = []
for i, cam in enumerate(self.cameras):
err = cam.grab(self.runtime_params)
if err == sl.ERROR_CODE.SUCCESS:
mat = sl.Mat()
cam.retrieve_image(mat, sl.VIEW.LEFT)
depth_map = self._retrieve_depth(cam)
confidence_map = self._retrieve_confidence(cam)
frames.append(
FrameData(
image=mat.get_data().copy(),
timestamp_ns=cam.get_timestamp(
sl.TIME_REFERENCE.IMAGE
).get_nanoseconds(),
frame_index=cam.get_svo_position(),
serial_number=self.camera_info[i]["serial"],
depth_map=depth_map,
confidence_map=confidence_map,
)
)
elif err == sl.ERROR_CODE.END_OF_SVOFILE_REACHED:
cam.set_svo_position(0)
frames.append(None)
else:
frames.append(None)
return frames
def grab_synced(self, tolerance_ms: int = 33) -> list[FrameData | None]:
"""
Grabs frames from all cameras, attempting to keep them within tolerance_ms.
If a camera falls behind, it skips frames.
"""
frames = self.grab_all()
if not any(frames):
return frames
# Find latest timestamp
valid_timestamps = [f.timestamp_ns for f in frames if f is not None]
if not valid_timestamps:
return frames
max_ts = max(valid_timestamps)
tolerance_ns = tolerance_ms * 1_000_000
for i, frame in enumerate(frames):
if frame is None:
continue
if max_ts - frame.timestamp_ns > tolerance_ns:
cam = self.cameras[i]
fps = self.camera_info[i]["fps"]
diff_ns = max_ts - frame.timestamp_ns
frames_to_skip = int((diff_ns / 1_000_000_000) * fps)
if frames_to_skip > 0:
current_pos = cam.get_svo_position()
cam.set_svo_position(current_pos + frames_to_skip)
err = cam.grab(self.runtime_params)
if err == sl.ERROR_CODE.SUCCESS:
mat = sl.Mat()
cam.retrieve_image(mat, sl.VIEW.LEFT)
depth_map = self._retrieve_depth(cam)
confidence_map = self._retrieve_confidence(cam)
frames[i] = FrameData(
image=mat.get_data().copy(),
timestamp_ns=cam.get_timestamp(
sl.TIME_REFERENCE.IMAGE
).get_nanoseconds(),
frame_index=cam.get_svo_position(),
serial_number=self.camera_info[i]["serial"],
depth_map=depth_map,
confidence_map=confidence_map,
)
elif err == sl.ERROR_CODE.END_OF_SVOFILE_REACHED:
cam.set_svo_position(0)
frames[i] = None
else:
frames[i] = None
return frames
@property
def enable_depth(self) -> bool:
return self._depth_mode != sl.DEPTH_MODE.NONE
def _retrieve_depth(self, cam: sl.Camera) -> np.ndarray | None:
if not self.enable_depth:
return None
depth_mat = sl.Mat()
cam.retrieve_measure(depth_mat, sl.MEASURE.DEPTH)
depth_data = depth_mat.get_data().copy()
# Check if units are already in meters to avoid double scaling.
# SDK coordinate_units is set to METER in __init__.
units = cam.get_init_parameters().coordinate_units
if units == sl.UNIT.METER:
depth = depth_data
else:
# Fallback for safety, though coordinate_units should be METER.
depth = depth_data / 1000.0
# Sanity check and debug logging
valid_mask = np.isfinite(depth) & (depth > 0)
if np.any(valid_mask):
valid_depths = depth[valid_mask]
logger.debug(
f"Depth stats (m) - Min: {np.min(valid_depths):.3f}, "
f"Median: {np.median(valid_depths):.3f}, "
f"Max: {np.max(valid_depths):.3f}, "
f"P95: {np.percentile(valid_depths, 95):.3f}"
)
else:
logger.warning("No valid depth values retrieved")
return depth
def _retrieve_confidence(self, cam: sl.Camera) -> np.ndarray | None:
if not self.enable_depth:
return None
conf_mat = sl.Mat()
cam.retrieve_measure(conf_mat, sl.MEASURE.CONFIDENCE)
return conf_mat.get_data().copy()
def get_depth_at(self, frame: FrameData, x: int, y: int) -> float | None:
if frame.depth_map is None:
return None
h, w = frame.depth_map.shape[:2]
if x < 0 or x >= w or y < 0 or y >= h:
return None
depth = frame.depth_map[y, x]
if not np.isfinite(depth) or depth <= 0:
return None
return float(depth)
def get_depth_window_median(
self, frame: FrameData, x: int, y: int, size: int = 5
) -> float | None:
if frame.depth_map is None:
return None
if size % 2 == 0:
size += 1
h, w = frame.depth_map.shape[:2]
half = size // 2
x_min = max(0, x - half)
x_max = min(w, x + half + 1)
y_min = max(0, y - half)
y_max = min(h, y + half + 1)
window = frame.depth_map[y_min:y_max, x_min:x_max]
valid_depths = window[np.isfinite(window) & (window > 0)]
if len(valid_depths) == 0:
return None
return float(np.median(valid_depths))
def close(self):
"""Closes all cameras."""
for cam in self.cameras:
cam.close()
self.cameras = []