from datetime import datetime, timedelta import jax.numpy as jnp import numpy as np import pytest from hypothesis import given, settings, HealthCheck from hypothesis import strategies as st from app.camera import Camera, CameraParams from playground import ( Detection, Tracking, calculate_affinity_matrix, calculate_camera_affinity_matrix, ) # ---------------------------------------------------------------------------- # Helper functions to generate synthetic cameras / trackings / detections # ---------------------------------------------------------------------------- def _make_dummy_camera(cam_id: str, rng: np.random.Generator) -> Camera: K = jnp.eye(3) Rt = jnp.eye(4) dist = jnp.zeros(5) image_size = jnp.array([1000, 1000]) params = CameraParams(K=K, Rt=Rt, dist_coeffs=dist, image_size=image_size) return Camera(id=cam_id, params=params) def _random_keypoints_3d(rng: np.random.Generator, J: int): return jnp.asarray(rng.uniform(-1.0, 1.0, size=(J, 3)).astype(np.float32)) def _random_keypoints_2d(rng: np.random.Generator, J: int): return jnp.asarray(rng.uniform(0.0, 1000.0, size=(J, 2)).astype(np.float32)) def _make_trackings(rng: np.random.Generator, camera: Camera, T: int, J: int): now = datetime.now() trackings = [] for i in range(T): kps3d = _random_keypoints_3d(rng, J) trk = Tracking( id=i + 1, keypoints=kps3d, last_active_timestamp=now - timedelta(milliseconds=int(rng.integers(20, 50))), ) trackings.append(trk) return trackings def _make_detections(rng: np.random.Generator, camera: Camera, D: int, J: int): now = datetime.now() detections = [] for _ in range(D): kps2d = _random_keypoints_2d(rng, J) det = Detection( keypoints=kps2d, confidences=jnp.ones(J, dtype=jnp.float32), camera=camera, timestamp=now, ) detections.append(det) return detections # ---------------------------------------------------------------------------- # Property-based test: per-camera vs naive slice should match # ---------------------------------------------------------------------------- @settings(max_examples=3, deadline=None, suppress_health_check=[HealthCheck.too_slow]) @given( T=st.integers(min_value=1, max_value=4), D=st.integers(min_value=1, max_value=4), J=st.integers(min_value=5, max_value=15), seed=st.integers(min_value=0, max_value=10000), ) def test_per_camera_matches_naive(T, D, J, seed): rng = np.random.default_rng(seed) cam = _make_dummy_camera("C0", rng) trackings = _make_trackings(rng, cam, T, J) detections = _make_detections(rng, cam, D, J) # Compute per-camera affinity (fast) A_fast = calculate_camera_affinity_matrix( trackings, detections, w_2d=1.0, alpha_2d=1.0, w_3d=1.0, alpha_3d=1.0, lambda_a=0.1, ) # Compute naive multi-camera affinity and slice out this camera from collections import OrderedDict det_dict = OrderedDict({"C0": detections}) A_naive, _ = calculate_affinity_matrix( trackings, det_dict, w_2d=1.0, alpha_2d=1.0, w_3d=1.0, alpha_3d=1.0, lambda_a=0.1, ) # They should be close np.testing.assert_allclose(A_fast, np.asarray(A_naive), rtol=1e-5, atol=1e-5) if __name__ == "__main__" and pytest is not None: pytest.main([__file__])