diff --git a/src/pose_tracking_exp/detection/__init__.py b/src/pose_tracking_exp/detection/__init__.py index 4d17de1..9927445 100644 --- a/src/pose_tracking_exp/detection/__init__.py +++ b/src/pose_tracking_exp/detection/__init__.py @@ -25,6 +25,13 @@ from pose_tracking_exp.detection.sources import ( VideoFrameSource, parse_video_input_specs, ) +from pose_tracking_exp.detection.video_alignment import ( + AlignedFrameBundle, + VideoScanResult, + align_timestamp_sequences, + scan_video, + write_aligned_videos, +) from pose_tracking_exp.schema.detection import BoxDetections, CocoKeypointSchema, PoseBatchRequest, PoseDetections, SourceFrame from pose_tracking_exp.detection.yolo_rtmpose import ( WholeBodyPoseEstimator, @@ -39,6 +46,7 @@ __all__ = [ "CvmmapFrameSource", "DEFAULT_BACKEND", "DetectionRunnerConfig", + "AlignedFrameBundle", "IteratorFrameSource", "NatsPoseSink", "ParquetPoseSink", @@ -50,8 +58,10 @@ __all__ = [ "SimpleMovingAverage", "SourceFrame", "SourceSlot", + "VideoScanResult", "VideoFrameSource", "WholeBodyPoseEstimator", + "align_timestamp_sequences", "YoloRtmposeShim", "build_pose_shim", "build_yolo_rtmpose_shim", @@ -61,6 +71,8 @@ __all__ = [ "resolve_default_pose_config", "resolve_instances", "run_detection_runner", + "scan_video", "store_latest_frame", "take_pending_batch", + "write_aligned_videos", ] diff --git a/src/pose_tracking_exp/detection/video_alignment.py b/src/pose_tracking_exp/detection/video_alignment.py new file mode 100644 index 0000000..28b503b --- /dev/null +++ b/src/pose_tracking_exp/detection/video_alignment.py @@ -0,0 +1,189 @@ +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass +from pathlib import Path + +import cv2 +import numpy as np +from beartype import beartype + + +@dataclass(frozen=True, slots=True) +class VideoScanResult: + source_name: str + path: Path + fps: float + frame_size: tuple[int, int] + timestamps_unix_ns: tuple[int, ...] + + +@dataclass(frozen=True, slots=True) +class AlignedFrameBundle: + bundle_index: int + timestamp_unix_ns: int + frame_indices_by_source: dict[str, int] + + +def _timestamp_from_capture(position_msec: float, frame_index: int, fps: float) -> int: + if np.isfinite(position_msec) and (position_msec > 0.0 or frame_index == 0): + return int(round(position_msec * 1_000_000.0)) + return int(round((frame_index / fps) * 1_000_000_000.0)) + + +@beartype +def scan_video(path: Path, *, source_name: str, default_fps: float = 30.0) -> VideoScanResult: + capture = cv2.VideoCapture(str(path)) + if not capture.isOpened(): + capture.release() + raise RuntimeError(f"Could not open video input: {path}") + + fps = float(capture.get(cv2.CAP_PROP_FPS)) + if not np.isfinite(fps) or fps <= 0: + fps = default_fps + + width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) + timestamps: list[int] = [] + frame_index = 0 + try: + while True: + success, _frame = capture.read() + if not success: + break + timestamps.append( + _timestamp_from_capture( + float(capture.get(cv2.CAP_PROP_POS_MSEC)), + frame_index, + fps, + ) + ) + frame_index += 1 + finally: + capture.release() + + return VideoScanResult( + source_name=source_name, + path=path, + fps=fps, + frame_size=(width, height), + timestamps_unix_ns=tuple(timestamps), + ) + + +@beartype +def align_timestamp_sequences( + scans: Sequence[VideoScanResult], + *, + reference_name: str | None = None, + max_skew_ns: int, + min_views: int | None = None, +) -> tuple[AlignedFrameBundle, ...]: + if not scans: + return () + + ordered = tuple(scans) + by_name = {scan.source_name: scan for scan in ordered} + if reference_name is None: + reference_name = ordered[0].source_name + if reference_name not in by_name: + raise ValueError(f"Unknown reference source: {reference_name}") + + required_views = min_views if min_views is not None else len(ordered) + if required_views < 1 or required_views > len(ordered): + raise ValueError(f"min_views must be between 1 and {len(ordered)}, got {required_views}.") + + reference_timestamps = by_name[reference_name].timestamps_unix_ns + camera_names = tuple(scan.source_name for scan in ordered) + cursors = {camera_name: 0 for camera_name in camera_names} + bundles: list[AlignedFrameBundle] = [] + + for reference_index, reference_timestamp in enumerate(reference_timestamps): + matched = {reference_name: reference_index} + cursors[reference_name] = reference_index + 1 + for camera_name in camera_names: + if camera_name == reference_name: + continue + timestamps = by_name[camera_name].timestamps_unix_ns + cursor = cursors[camera_name] + best_index = -1 + best_skew = max_skew_ns + 1 + while cursor < len(timestamps): + skew = abs(timestamps[cursor] - reference_timestamp) + if skew <= best_skew: + best_skew = skew + best_index = cursor + if timestamps[cursor] > reference_timestamp and skew > best_skew: + break + cursor += 1 + if best_index >= 0 and best_skew <= max_skew_ns: + matched[camera_name] = best_index + cursors[camera_name] = best_index + 1 + if len(matched) >= required_views: + bundles.append( + AlignedFrameBundle( + bundle_index=len(bundles), + timestamp_unix_ns=reference_timestamp, + frame_indices_by_source=matched, + ) + ) + return tuple(bundles) + + +@beartype +def write_aligned_videos( + scans: Sequence[VideoScanResult], + bundles: Sequence[AlignedFrameBundle], + *, + output_dir: Path, + output_fps: float | None = None, + codec: str = "mp4v", +) -> dict[str, Path]: + if not scans: + return {} + + output_dir.mkdir(parents=True, exist_ok=True) + selected_indices_by_source: dict[str, tuple[int, ...]] = { + scan.source_name: tuple( + bundle.frame_indices_by_source[scan.source_name] + for bundle in bundles + if scan.source_name in bundle.frame_indices_by_source + ) + for scan in scans + } + writer_fps = output_fps if output_fps is not None else scans[0].fps + if writer_fps <= 0: + raise ValueError("output_fps must be positive.") + + outputs: dict[str, Path] = {} + fourcc = cv2.VideoWriter.fourcc(*codec) + for scan in scans: + selected = selected_indices_by_source[scan.source_name] + output_path = output_dir / f"{scan.source_name}.mp4" + outputs[scan.source_name] = output_path + capture = cv2.VideoCapture(str(scan.path)) + if not capture.isOpened(): + capture.release() + raise RuntimeError(f"Could not reopen video input: {scan.path}") + writer = cv2.VideoWriter(str(output_path), fourcc, writer_fps, scan.frame_size) + if not writer.isOpened(): + capture.release() + writer.release() + raise RuntimeError(f"Could not open output video writer: {output_path}") + try: + selected_cursor = 0 + frame_index = 0 + selected_count = len(selected) + while selected_cursor < selected_count: + success, frame = capture.read() + if not success or frame is None: + break + target_index = selected[selected_cursor] + if frame_index == target_index: + writer.write(frame) + selected_cursor += 1 + frame_index += 1 + finally: + writer.release() + capture.release() + return outputs diff --git a/tests/support/align_videos.py b/tests/support/align_videos.py new file mode 100644 index 0000000..5e27492 --- /dev/null +++ b/tests/support/align_videos.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +import json +from pathlib import Path + +import click +from loguru import logger + +from pose_tracking_exp.detection import ( + align_timestamp_sequences, + parse_video_input_specs, + scan_video, + write_aligned_videos, +) +from pose_tracking_exp.schema import TrackerConfig + + +@click.command() +@click.argument("inputs", nargs=-1, type=str, required=True) +@click.option("--output-dir", type=click.Path(path_type=Path, file_okay=False), required=True) +@click.option("--reference", "reference_name", type=str) +@click.option("--max-skew-ms", type=float, default=None, help="Max timestamp skew in milliseconds.") +@click.option("--min-views", type=click.IntRange(min=1), default=None) +@click.option("--codec", type=str, default="mp4v", show_default=True) +def main( + inputs: tuple[str, ...], + output_dir: Path, + reference_name: str | None, + max_skew_ms: float | None, + min_views: int | None, + codec: str, +) -> None: + logger.remove() + logger.add( + click.get_text_stream("stderr"), + level="INFO", + format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}", + ) + + parsed_inputs = parse_video_input_specs(inputs) + tracker_defaults = TrackerConfig() + scans = tuple( + scan_video(path, source_name=source_name) + for source_name, path in parsed_inputs + ) + if reference_name is None: + reference_name = scans[0].source_name + if min_views is None: + min_views = len(scans) + max_skew_ns = ( + int(round(max_skew_ms * 1_000_000.0)) + if max_skew_ms is not None + else tracker_defaults.max_sync_skew_ns + ) + + bundles = align_timestamp_sequences( + scans, + reference_name=reference_name, + max_skew_ns=max_skew_ns, + min_views=min_views, + ) + if not bundles: + raise click.ClickException("No aligned frame bundles were found.") + + outputs = write_aligned_videos( + scans, + bundles, + output_dir=output_dir, + output_fps=scans[0].fps, + codec=codec, + ) + metadata = { + "reference_name": reference_name, + "max_skew_ns": max_skew_ns, + "min_views": min_views, + "bundle_count": len(bundles), + "sources": { + scan.source_name: { + "input_path": str(scan.path), + "output_path": str(outputs[scan.source_name]), + "input_fps": scan.fps, + "input_frame_count": len(scan.timestamps_unix_ns), + "output_frame_count": sum( + 1 for bundle in bundles if scan.source_name in bundle.frame_indices_by_source + ), + } + for scan in scans + }, + "bundles": [ + { + "bundle_index": bundle.bundle_index, + "timestamp_unix_ns": bundle.timestamp_unix_ns, + "frame_indices_by_source": bundle.frame_indices_by_source, + } + for bundle in bundles + ], + } + (output_dir / "alignment.json").write_text(json.dumps(metadata, indent=2), encoding="utf-8") + logger.info( + "aligned {} bundles across {} sources into {}", + len(bundles), + len(scans), + output_dir, + ) + + +if __name__ == "__main__": + main() diff --git a/tests/test_video_alignment.py b/tests/test_video_alignment.py new file mode 100644 index 0000000..19c63fe --- /dev/null +++ b/tests/test_video_alignment.py @@ -0,0 +1,97 @@ +from pathlib import Path + +import cv2 +import numpy as np + +from pose_tracking_exp.detection.video_alignment import ( + align_timestamp_sequences, + write_aligned_videos, + VideoScanResult, +) + + +def test_align_timestamp_sequences_matches_full_common_window() -> None: + scans = ( + VideoScanResult( + source_name="cam0", + path=Path("/tmp/cam0.mp4"), + fps=30.0, + frame_size=(8, 6), + timestamps_unix_ns=(0, 33_000_000, 66_000_000, 99_000_000), + ), + VideoScanResult( + source_name="cam1", + path=Path("/tmp/cam1.mp4"), + fps=29.97, + frame_size=(8, 6), + timestamps_unix_ns=(1_000_000, 34_000_000, 67_000_000, 100_000_000), + ), + VideoScanResult( + source_name="cam2", + path=Path("/tmp/cam2.mp4"), + fps=29.5, + frame_size=(8, 6), + timestamps_unix_ns=(20_000_000, 90_000_000, 160_000_000), + ), + ) + + bundles = align_timestamp_sequences( + scans, + reference_name="cam0", + max_skew_ns=12_000_000, + min_views=2, + ) + + assert len(bundles) == 4 + assert bundles[0].frame_indices_by_source == {"cam0": 0, "cam1": 0} + assert bundles[-1].frame_indices_by_source == {"cam0": 3, "cam1": 3, "cam2": 1} + + +def _write_colored_video(path: Path, frame_values: list[int]) -> None: + writer = cv2.VideoWriter(str(path), cv2.VideoWriter.fourcc(*"mp4v"), 10.0, (8, 6)) + if not writer.isOpened(): + raise RuntimeError(f"Could not create {path}") + try: + for value in frame_values: + writer.write(np.full((6, 8, 3), value, dtype=np.uint8)) + finally: + writer.release() + + +def test_write_aligned_videos_selects_requested_frames(tmp_path: Path) -> None: + source0 = tmp_path / "cam0.mp4" + source1 = tmp_path / "cam1.mp4" + _write_colored_video(source0, [10, 20, 30, 40]) + _write_colored_video(source1, [11, 21, 31, 41]) + + scans = ( + VideoScanResult("cam0", source0, 10.0, (8, 6), (0, 100_000_000, 200_000_000, 300_000_000)), + VideoScanResult("cam1", source1, 10.0, (8, 6), (0, 100_000_000, 200_000_000, 300_000_000)), + ) + bundles = ( + # choose original frame indices 1 and 3 from both sources + *( + bundle + for bundle in ( + align_timestamp_sequences(scans, max_skew_ns=1_000_000, min_views=2) + ) + if bundle.bundle_index in {1, 3} + ), + ) + + outputs = write_aligned_videos(scans, bundles, output_dir=tmp_path / "aligned", output_fps=10.0) + + for source_name, expected_values in (("cam0", [20, 40]), ("cam1", [21, 41])): + capture = cv2.VideoCapture(str(outputs[source_name])) + frames: list[int] = [] + try: + while True: + success, frame = capture.read() + if not success or frame is None: + break + frames.append(int(round(float(frame.mean())))) + finally: + capture.release() + assert len(frames) == 2 + assert abs(frames[0] - expected_values[0]) <= 5 + assert abs(frames[1] - expected_values[1]) <= 5