feat: Add hypothesis testing and update dependencies
- Added property-based tests for camera affinity calculations using Hypothesis, enhancing test coverage and reliability. - Updated `.gitignore` to include `.hypothesis` directory. - Added `hypothesis` and `pytest` as dependencies in `pyproject.toml` and `uv.lock` for improved testing capabilities. - Refactored imports in `playground.py` to include `Sequence` and `Mapping` from `beartype.typing`. - Introduced a new test file `tests/test_affinity.py` for structured testing of affinity calculations.
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@ -8,4 +8,5 @@ wheels/
|
|||||||
|
|
||||||
# Virtual environments
|
# Virtual environments
|
||||||
.venv
|
.venv
|
||||||
|
.hypothesis
|
||||||
samples
|
samples
|
||||||
|
|||||||
@ -24,9 +24,7 @@ from pathlib import Path
|
|||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Generator,
|
Generator,
|
||||||
Mapping,
|
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
|
||||||
TypeAlias,
|
TypeAlias,
|
||||||
TypedDict,
|
TypedDict,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
@ -40,6 +38,7 @@ import jax.numpy as jnp
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import orjson
|
import orjson
|
||||||
from beartype import beartype
|
from beartype import beartype
|
||||||
|
from beartype.typing import Sequence, Mapping
|
||||||
from cv2 import undistortPoints
|
from cv2 import undistortPoints
|
||||||
from IPython.display import display
|
from IPython.display import display
|
||||||
from jaxtyping import Array, Float, Num, jaxtyped
|
from jaxtyping import Array, Float, Num, jaxtyped
|
||||||
|
|||||||
@ -9,6 +9,7 @@ dependencies = [
|
|||||||
"awkward>=2.7.4",
|
"awkward>=2.7.4",
|
||||||
"beartype>=0.20.0",
|
"beartype>=0.20.0",
|
||||||
"cvxopt>=1.3.2",
|
"cvxopt>=1.3.2",
|
||||||
|
"hypothesis>=6.131.9",
|
||||||
"jax[cuda12]>=0.5.1",
|
"jax[cuda12]>=0.5.1",
|
||||||
"jaxtyping>=0.2.38",
|
"jaxtyping>=0.2.38",
|
||||||
"jupytext>=1.17.0",
|
"jupytext>=1.17.0",
|
||||||
@ -18,6 +19,7 @@ dependencies = [
|
|||||||
"pandas>=2.2.3",
|
"pandas>=2.2.3",
|
||||||
"plotly>=6.0.1",
|
"plotly>=6.0.1",
|
||||||
"pyarrow>=19.0.1",
|
"pyarrow>=19.0.1",
|
||||||
|
"pytest>=8.3.5",
|
||||||
"scipy>=1.15.2",
|
"scipy>=1.15.2",
|
||||||
"torch>=2.6.0",
|
"torch>=2.6.0",
|
||||||
"torchvision>=0.21.0",
|
"torchvision>=0.21.0",
|
||||||
|
|||||||
119
tests/test_affinity.py
Normal file
119
tests/test_affinity.py
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from hypothesis import given, settings, HealthCheck
|
||||||
|
from hypothesis import strategies as st
|
||||||
|
|
||||||
|
from app.camera import Camera, CameraParams
|
||||||
|
from playground import (
|
||||||
|
Detection,
|
||||||
|
Tracking,
|
||||||
|
calculate_affinity_matrix,
|
||||||
|
calculate_camera_affinity_matrix,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
# Helper functions to generate synthetic cameras / trackings / detections
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_dummy_camera(cam_id: str, rng: np.random.Generator) -> Camera:
|
||||||
|
K = jnp.eye(3)
|
||||||
|
Rt = jnp.eye(4)
|
||||||
|
dist = jnp.zeros(5)
|
||||||
|
image_size = jnp.array([1000, 1000])
|
||||||
|
params = CameraParams(K=K, Rt=Rt, dist_coeffs=dist, image_size=image_size)
|
||||||
|
return Camera(id=cam_id, params=params)
|
||||||
|
|
||||||
|
|
||||||
|
def _random_keypoints_3d(rng: np.random.Generator, J: int):
|
||||||
|
return jnp.asarray(rng.uniform(-1.0, 1.0, size=(J, 3)).astype(np.float32))
|
||||||
|
|
||||||
|
|
||||||
|
def _random_keypoints_2d(rng: np.random.Generator, J: int):
|
||||||
|
return jnp.asarray(rng.uniform(0.0, 1000.0, size=(J, 2)).astype(np.float32))
|
||||||
|
|
||||||
|
|
||||||
|
def _make_trackings(rng: np.random.Generator, camera: Camera, T: int, J: int):
|
||||||
|
now = datetime.now()
|
||||||
|
trackings = []
|
||||||
|
for i in range(T):
|
||||||
|
kps3d = _random_keypoints_3d(rng, J)
|
||||||
|
trk = Tracking(
|
||||||
|
id=i + 1,
|
||||||
|
keypoints=kps3d,
|
||||||
|
last_active_timestamp=now
|
||||||
|
- timedelta(milliseconds=int(rng.integers(20, 50))),
|
||||||
|
)
|
||||||
|
trackings.append(trk)
|
||||||
|
return trackings
|
||||||
|
|
||||||
|
|
||||||
|
def _make_detections(rng: np.random.Generator, camera: Camera, D: int, J: int):
|
||||||
|
now = datetime.now()
|
||||||
|
detections = []
|
||||||
|
for _ in range(D):
|
||||||
|
kps2d = _random_keypoints_2d(rng, J)
|
||||||
|
det = Detection(
|
||||||
|
keypoints=kps2d,
|
||||||
|
confidences=jnp.ones(J, dtype=jnp.float32),
|
||||||
|
camera=camera,
|
||||||
|
timestamp=now,
|
||||||
|
)
|
||||||
|
detections.append(det)
|
||||||
|
return detections
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
# Property-based test: per-camera vs naive slice should match
|
||||||
|
# ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@settings(max_examples=3, deadline=None, suppress_health_check=[HealthCheck.too_slow])
|
||||||
|
@given(
|
||||||
|
T=st.integers(min_value=1, max_value=4),
|
||||||
|
D=st.integers(min_value=1, max_value=4),
|
||||||
|
J=st.integers(min_value=5, max_value=15),
|
||||||
|
seed=st.integers(min_value=0, max_value=10000),
|
||||||
|
)
|
||||||
|
def test_per_camera_matches_naive(T, D, J, seed):
|
||||||
|
rng = np.random.default_rng(seed)
|
||||||
|
|
||||||
|
cam = _make_dummy_camera("C0", rng)
|
||||||
|
|
||||||
|
trackings = _make_trackings(rng, cam, T, J)
|
||||||
|
detections = _make_detections(rng, cam, D, J)
|
||||||
|
|
||||||
|
# Compute per-camera affinity (fast)
|
||||||
|
A_fast = calculate_camera_affinity_matrix(
|
||||||
|
trackings,
|
||||||
|
detections,
|
||||||
|
w_2d=1.0,
|
||||||
|
alpha_2d=1.0,
|
||||||
|
w_3d=1.0,
|
||||||
|
alpha_3d=1.0,
|
||||||
|
lambda_a=0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute naive multi-camera affinity and slice out this camera
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
det_dict = OrderedDict({"C0": detections})
|
||||||
|
A_naive, _ = calculate_affinity_matrix(
|
||||||
|
trackings,
|
||||||
|
det_dict,
|
||||||
|
w_2d=1.0,
|
||||||
|
alpha_2d=1.0,
|
||||||
|
w_3d=1.0,
|
||||||
|
alpha_3d=1.0,
|
||||||
|
lambda_a=0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# They should be close
|
||||||
|
np.testing.assert_allclose(A_fast, np.asarray(A_naive), rtol=1e-5, atol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__" and pytest is not None:
|
||||||
|
pytest.main([__file__])
|
||||||
62
uv.lock
generated
62
uv.lock
generated
@ -448,6 +448,7 @@ dependencies = [
|
|||||||
{ name = "awkward" },
|
{ name = "awkward" },
|
||||||
{ name = "beartype" },
|
{ name = "beartype" },
|
||||||
{ name = "cvxopt" },
|
{ name = "cvxopt" },
|
||||||
|
{ name = "hypothesis" },
|
||||||
{ name = "jax", extra = ["cuda12"] },
|
{ name = "jax", extra = ["cuda12"] },
|
||||||
{ name = "jaxtyping" },
|
{ name = "jaxtyping" },
|
||||||
{ name = "jupytext" },
|
{ name = "jupytext" },
|
||||||
@ -457,6 +458,7 @@ dependencies = [
|
|||||||
{ name = "pandas" },
|
{ name = "pandas" },
|
||||||
{ name = "plotly" },
|
{ name = "plotly" },
|
||||||
{ name = "pyarrow" },
|
{ name = "pyarrow" },
|
||||||
|
{ name = "pytest" },
|
||||||
{ name = "scipy" },
|
{ name = "scipy" },
|
||||||
{ name = "torch" },
|
{ name = "torch" },
|
||||||
{ name = "torchvision" },
|
{ name = "torchvision" },
|
||||||
@ -474,6 +476,7 @@ requires-dist = [
|
|||||||
{ name = "awkward", specifier = ">=2.7.4" },
|
{ name = "awkward", specifier = ">=2.7.4" },
|
||||||
{ name = "beartype", specifier = ">=0.20.0" },
|
{ name = "beartype", specifier = ">=0.20.0" },
|
||||||
{ name = "cvxopt", specifier = ">=1.3.2" },
|
{ name = "cvxopt", specifier = ">=1.3.2" },
|
||||||
|
{ name = "hypothesis", specifier = ">=6.131.9" },
|
||||||
{ name = "jax", extras = ["cuda12"], specifier = ">=0.5.1" },
|
{ name = "jax", extras = ["cuda12"], specifier = ">=0.5.1" },
|
||||||
{ name = "jaxtyping", specifier = ">=0.2.38" },
|
{ name = "jaxtyping", specifier = ">=0.2.38" },
|
||||||
{ name = "jupytext", specifier = ">=1.17.0" },
|
{ name = "jupytext", specifier = ">=1.17.0" },
|
||||||
@ -483,6 +486,7 @@ requires-dist = [
|
|||||||
{ name = "pandas", specifier = ">=2.2.3" },
|
{ name = "pandas", specifier = ">=2.2.3" },
|
||||||
{ name = "plotly", specifier = ">=6.0.1" },
|
{ name = "plotly", specifier = ">=6.0.1" },
|
||||||
{ name = "pyarrow", specifier = ">=19.0.1" },
|
{ name = "pyarrow", specifier = ">=19.0.1" },
|
||||||
|
{ name = "pytest", specifier = ">=8.3.5" },
|
||||||
{ name = "scipy", specifier = ">=1.15.2" },
|
{ name = "scipy", specifier = ">=1.15.2" },
|
||||||
{ name = "torch", specifier = ">=2.6.0" },
|
{ name = "torch", specifier = ">=2.6.0" },
|
||||||
{ name = "torchvision", specifier = ">=0.21.0" },
|
{ name = "torchvision", specifier = ">=0.21.0" },
|
||||||
@ -711,6 +715,20 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517 },
|
{ url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "hypothesis"
|
||||||
|
version = "6.131.9"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "attrs" },
|
||||||
|
{ name = "exceptiongroup", marker = "python_full_version < '3.11'" },
|
||||||
|
{ name = "sortedcontainers" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/10/ff/217417d065aa8a4e6815ddc39acee1222f1b67bd0e4803b85de86a837873/hypothesis-6.131.9.tar.gz", hash = "sha256:ee9b0e1403e1121c91921dbdc79d7f509fdb96d457a0389222d2a68d6c8a8f8e", size = 435415 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/bd/e5/41a6733bfe11997795669dec3b3d785c28918e06568a2540dcc29f0d3fa7/hypothesis-6.131.9-py3-none-any.whl", hash = "sha256:7c2d9d6382e98e5337b27bd34e5b223bac23956787a827e1d087e00d893561d6", size = 499853 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "idna"
|
name = "idna"
|
||||||
version = "3.10"
|
version = "3.10"
|
||||||
@ -732,6 +750,15 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/79/9d/0fb148dc4d6fa4a7dd1d8378168d9b4cd8d4560a6fbf6f0121c5fc34eb68/importlib_metadata-8.6.1-py3-none-any.whl", hash = "sha256:02a89390c1e15fdfdc0d7c6b25cb3e62650d0494005c97d6f148bf5b9787525e", size = 26971 },
|
{ url = "https://files.pythonhosted.org/packages/79/9d/0fb148dc4d6fa4a7dd1d8378168d9b4cd8d4560a6fbf6f0121c5fc34eb68/importlib_metadata-8.6.1-py3-none-any.whl", hash = "sha256:02a89390c1e15fdfdc0d7c6b25cb3e62650d0494005c97d6f148bf5b9787525e", size = 26971 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "iniconfig"
|
||||||
|
version = "2.1.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ipykernel"
|
name = "ipykernel"
|
||||||
version = "6.29.5"
|
version = "6.29.5"
|
||||||
@ -2143,6 +2170,15 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/02/65/ad2bc85f7377f5cfba5d4466d5474423a3fb7f6a97fd807c06f92dd3e721/plotly-6.0.1-py3-none-any.whl", hash = "sha256:4714db20fea57a435692c548a4eb4fae454f7daddf15f8d8ba7e1045681d7768", size = 14805757 },
|
{ url = "https://files.pythonhosted.org/packages/02/65/ad2bc85f7377f5cfba5d4466d5474423a3fb7f6a97fd807c06f92dd3e721/plotly-6.0.1-py3-none-any.whl", hash = "sha256:4714db20fea57a435692c548a4eb4fae454f7daddf15f8d8ba7e1045681d7768", size = 14805757 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pluggy"
|
||||||
|
version = "1.5.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/96/2d/02d4312c973c6050a18b314a5ad0b3210edb65a906f868e31c111dede4a6/pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1", size = 67955 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669", size = 20556 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "prometheus-client"
|
name = "prometheus-client"
|
||||||
version = "0.21.1"
|
version = "0.21.1"
|
||||||
@ -2266,6 +2302,23 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/05/e7/df2285f3d08fee213f2d041540fa4fc9ca6c2d44cf36d3a035bf2a8d2bcc/pyparsing-3.2.3-py3-none-any.whl", hash = "sha256:a749938e02d6fd0b59b356ca504a24982314bb090c383e3cf201c95ef7e2bfcf", size = 111120 },
|
{ url = "https://files.pythonhosted.org/packages/05/e7/df2285f3d08fee213f2d041540fa4fc9ca6c2d44cf36d3a035bf2a8d2bcc/pyparsing-3.2.3-py3-none-any.whl", hash = "sha256:a749938e02d6fd0b59b356ca504a24982314bb090c383e3cf201c95ef7e2bfcf", size = 111120 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pytest"
|
||||||
|
version = "8.3.5"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||||
|
{ name = "exceptiongroup", marker = "python_full_version < '3.11'" },
|
||||||
|
{ name = "iniconfig" },
|
||||||
|
{ name = "packaging" },
|
||||||
|
{ name = "pluggy" },
|
||||||
|
{ name = "tomli", marker = "python_full_version < '3.11'" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/ae/3c/c9d525a414d506893f0cd8a8d0de7706446213181570cdbd766691164e40/pytest-8.3.5.tar.gz", hash = "sha256:f4efe70cc14e511565ac476b57c279e12a855b11f48f212af1080ef2263d3845", size = 1450891 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/30/3d/64ad57c803f1fa1e963a7946b6e0fea4a70df53c1a7fed304586539c2bac/pytest-8.3.5-py3-none-any.whl", hash = "sha256:c69214aa47deac29fad6c2a4f590b9c4a9fdb16a403176fe154b79c0b4d4d820", size = 343634 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "python-dateutil"
|
name = "python-dateutil"
|
||||||
version = "2.9.0.post0"
|
version = "2.9.0.post0"
|
||||||
@ -2672,6 +2725,15 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235 },
|
{ url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "sortedcontainers"
|
||||||
|
version = "2.4.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/e8/c4/ba2f8066cceb6f23394729afe52f3bf7adec04bf9ed2c820b39e19299111/sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88", size = 30594 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl", hash = "sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0", size = 29575 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "soupsieve"
|
name = "soupsieve"
|
||||||
version = "2.6"
|
version = "2.6"
|
||||||
|
|||||||
Reference in New Issue
Block a user