From 487dd4e237e8c2880b90dc7e3a2e235c3b13dd85 Mon Sep 17 00:00:00 2001 From: crosstyan Date: Mon, 28 Apr 2025 18:21:08 +0800 Subject: [PATCH] feat: Add hypothesis testing and update dependencies - Added property-based tests for camera affinity calculations using Hypothesis, enhancing test coverage and reliability. - Updated `.gitignore` to include `.hypothesis` directory. - Added `hypothesis` and `pytest` as dependencies in `pyproject.toml` and `uv.lock` for improved testing capabilities. - Refactored imports in `playground.py` to include `Sequence` and `Mapping` from `beartype.typing`. - Introduced a new test file `tests/test_affinity.py` for structured testing of affinity calculations. --- .gitignore | 1 + playground.py | 3 +- pyproject.toml | 2 + tests/test_affinity.py | 119 +++++++++++++++++++++++++++++++++++++++++ uv.lock | 62 +++++++++++++++++++++ 5 files changed, 185 insertions(+), 2 deletions(-) create mode 100644 tests/test_affinity.py diff --git a/.gitignore b/.gitignore index d05d914..1977b52 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,5 @@ wheels/ # Virtual environments .venv +.hypothesis samples diff --git a/playground.py b/playground.py index e2f33f1..8e549bd 100644 --- a/playground.py +++ b/playground.py @@ -24,9 +24,7 @@ from pathlib import Path from typing import ( Any, Generator, - Mapping, Optional, - Sequence, TypeAlias, TypedDict, TypeVar, @@ -40,6 +38,7 @@ import jax.numpy as jnp import numpy as np import orjson from beartype import beartype +from beartype.typing import Sequence, Mapping from cv2 import undistortPoints from IPython.display import display from jaxtyping import Array, Float, Num, jaxtyped diff --git a/pyproject.toml b/pyproject.toml index ab4e5a4..e8b261d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ dependencies = [ "awkward>=2.7.4", "beartype>=0.20.0", "cvxopt>=1.3.2", + "hypothesis>=6.131.9", "jax[cuda12]>=0.5.1", "jaxtyping>=0.2.38", "jupytext>=1.17.0", @@ -18,6 +19,7 @@ dependencies = [ "pandas>=2.2.3", "plotly>=6.0.1", "pyarrow>=19.0.1", + "pytest>=8.3.5", "scipy>=1.15.2", "torch>=2.6.0", "torchvision>=0.21.0", diff --git a/tests/test_affinity.py b/tests/test_affinity.py new file mode 100644 index 0000000..4037fd6 --- /dev/null +++ b/tests/test_affinity.py @@ -0,0 +1,119 @@ +from datetime import datetime, timedelta + +import jax.numpy as jnp +import numpy as np +import pytest +from hypothesis import given, settings, HealthCheck +from hypothesis import strategies as st + +from app.camera import Camera, CameraParams +from playground import ( + Detection, + Tracking, + calculate_affinity_matrix, + calculate_camera_affinity_matrix, +) + +# ---------------------------------------------------------------------------- +# Helper functions to generate synthetic cameras / trackings / detections +# ---------------------------------------------------------------------------- + + +def _make_dummy_camera(cam_id: str, rng: np.random.Generator) -> Camera: + K = jnp.eye(3) + Rt = jnp.eye(4) + dist = jnp.zeros(5) + image_size = jnp.array([1000, 1000]) + params = CameraParams(K=K, Rt=Rt, dist_coeffs=dist, image_size=image_size) + return Camera(id=cam_id, params=params) + + +def _random_keypoints_3d(rng: np.random.Generator, J: int): + return jnp.asarray(rng.uniform(-1.0, 1.0, size=(J, 3)).astype(np.float32)) + + +def _random_keypoints_2d(rng: np.random.Generator, J: int): + return jnp.asarray(rng.uniform(0.0, 1000.0, size=(J, 2)).astype(np.float32)) + + +def _make_trackings(rng: np.random.Generator, camera: Camera, T: int, J: int): + now = datetime.now() + trackings = [] + for i in range(T): + kps3d = _random_keypoints_3d(rng, J) + trk = Tracking( + id=i + 1, + keypoints=kps3d, + last_active_timestamp=now + - timedelta(milliseconds=int(rng.integers(20, 50))), + ) + trackings.append(trk) + return trackings + + +def _make_detections(rng: np.random.Generator, camera: Camera, D: int, J: int): + now = datetime.now() + detections = [] + for _ in range(D): + kps2d = _random_keypoints_2d(rng, J) + det = Detection( + keypoints=kps2d, + confidences=jnp.ones(J, dtype=jnp.float32), + camera=camera, + timestamp=now, + ) + detections.append(det) + return detections + + +# ---------------------------------------------------------------------------- +# Property-based test: per-camera vs naive slice should match +# ---------------------------------------------------------------------------- + + +@settings(max_examples=3, deadline=None, suppress_health_check=[HealthCheck.too_slow]) +@given( + T=st.integers(min_value=1, max_value=4), + D=st.integers(min_value=1, max_value=4), + J=st.integers(min_value=5, max_value=15), + seed=st.integers(min_value=0, max_value=10000), +) +def test_per_camera_matches_naive(T, D, J, seed): + rng = np.random.default_rng(seed) + + cam = _make_dummy_camera("C0", rng) + + trackings = _make_trackings(rng, cam, T, J) + detections = _make_detections(rng, cam, D, J) + + # Compute per-camera affinity (fast) + A_fast = calculate_camera_affinity_matrix( + trackings, + detections, + w_2d=1.0, + alpha_2d=1.0, + w_3d=1.0, + alpha_3d=1.0, + lambda_a=0.1, + ) + + # Compute naive multi-camera affinity and slice out this camera + from collections import OrderedDict + + det_dict = OrderedDict({"C0": detections}) + A_naive, _ = calculate_affinity_matrix( + trackings, + det_dict, + w_2d=1.0, + alpha_2d=1.0, + w_3d=1.0, + alpha_3d=1.0, + lambda_a=0.1, + ) + + # They should be close + np.testing.assert_allclose(A_fast, np.asarray(A_naive), rtol=1e-5, atol=1e-5) + + +if __name__ == "__main__" and pytest is not None: + pytest.main([__file__]) diff --git a/uv.lock b/uv.lock index 25b98f8..255be76 100644 --- a/uv.lock +++ b/uv.lock @@ -448,6 +448,7 @@ dependencies = [ { name = "awkward" }, { name = "beartype" }, { name = "cvxopt" }, + { name = "hypothesis" }, { name = "jax", extra = ["cuda12"] }, { name = "jaxtyping" }, { name = "jupytext" }, @@ -457,6 +458,7 @@ dependencies = [ { name = "pandas" }, { name = "plotly" }, { name = "pyarrow" }, + { name = "pytest" }, { name = "scipy" }, { name = "torch" }, { name = "torchvision" }, @@ -474,6 +476,7 @@ requires-dist = [ { name = "awkward", specifier = ">=2.7.4" }, { name = "beartype", specifier = ">=0.20.0" }, { name = "cvxopt", specifier = ">=1.3.2" }, + { name = "hypothesis", specifier = ">=6.131.9" }, { name = "jax", extras = ["cuda12"], specifier = ">=0.5.1" }, { name = "jaxtyping", specifier = ">=0.2.38" }, { name = "jupytext", specifier = ">=1.17.0" }, @@ -483,6 +486,7 @@ requires-dist = [ { name = "pandas", specifier = ">=2.2.3" }, { name = "plotly", specifier = ">=6.0.1" }, { name = "pyarrow", specifier = ">=19.0.1" }, + { name = "pytest", specifier = ">=8.3.5" }, { name = "scipy", specifier = ">=1.15.2" }, { name = "torch", specifier = ">=2.6.0" }, { name = "torchvision", specifier = ">=0.21.0" }, @@ -711,6 +715,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517 }, ] +[[package]] +name = "hypothesis" +version = "6.131.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, + { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, + { name = "sortedcontainers" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/10/ff/217417d065aa8a4e6815ddc39acee1222f1b67bd0e4803b85de86a837873/hypothesis-6.131.9.tar.gz", hash = "sha256:ee9b0e1403e1121c91921dbdc79d7f509fdb96d457a0389222d2a68d6c8a8f8e", size = 435415 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bd/e5/41a6733bfe11997795669dec3b3d785c28918e06568a2540dcc29f0d3fa7/hypothesis-6.131.9-py3-none-any.whl", hash = "sha256:7c2d9d6382e98e5337b27bd34e5b223bac23956787a827e1d087e00d893561d6", size = 499853 }, +] + [[package]] name = "idna" version = "3.10" @@ -732,6 +750,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/79/9d/0fb148dc4d6fa4a7dd1d8378168d9b4cd8d4560a6fbf6f0121c5fc34eb68/importlib_metadata-8.6.1-py3-none-any.whl", hash = "sha256:02a89390c1e15fdfdc0d7c6b25cb3e62650d0494005c97d6f148bf5b9787525e", size = 26971 }, ] +[[package]] +name = "iniconfig" +version = "2.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050 }, +] + [[package]] name = "ipykernel" version = "6.29.5" @@ -2143,6 +2170,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/02/65/ad2bc85f7377f5cfba5d4466d5474423a3fb7f6a97fd807c06f92dd3e721/plotly-6.0.1-py3-none-any.whl", hash = "sha256:4714db20fea57a435692c548a4eb4fae454f7daddf15f8d8ba7e1045681d7768", size = 14805757 }, ] +[[package]] +name = "pluggy" +version = "1.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/96/2d/02d4312c973c6050a18b314a5ad0b3210edb65a906f868e31c111dede4a6/pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1", size = 67955 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669", size = 20556 }, +] + [[package]] name = "prometheus-client" version = "0.21.1" @@ -2266,6 +2302,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/05/e7/df2285f3d08fee213f2d041540fa4fc9ca6c2d44cf36d3a035bf2a8d2bcc/pyparsing-3.2.3-py3-none-any.whl", hash = "sha256:a749938e02d6fd0b59b356ca504a24982314bb090c383e3cf201c95ef7e2bfcf", size = 111120 }, ] +[[package]] +name = "pytest" +version = "8.3.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ae/3c/c9d525a414d506893f0cd8a8d0de7706446213181570cdbd766691164e40/pytest-8.3.5.tar.gz", hash = "sha256:f4efe70cc14e511565ac476b57c279e12a855b11f48f212af1080ef2263d3845", size = 1450891 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/30/3d/64ad57c803f1fa1e963a7946b6e0fea4a70df53c1a7fed304586539c2bac/pytest-8.3.5-py3-none-any.whl", hash = "sha256:c69214aa47deac29fad6c2a4f590b9c4a9fdb16a403176fe154b79c0b4d4d820", size = 343634 }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -2672,6 +2725,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235 }, ] +[[package]] +name = "sortedcontainers" +version = "2.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e8/c4/ba2f8066cceb6f23394729afe52f3bf7adec04bf9ed2c820b39e19299111/sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88", size = 30594 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl", hash = "sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0", size = 29575 }, +] + [[package]] name = "soupsieve" version = "2.6"