feat: Add hypothesis testing and update dependencies

- Added property-based tests for camera affinity calculations using Hypothesis, enhancing test coverage and reliability.
- Updated `.gitignore` to include `.hypothesis` directory.
- Added `hypothesis` and `pytest` as dependencies in `pyproject.toml` and `uv.lock` for improved testing capabilities.
- Refactored imports in `playground.py` to include `Sequence` and `Mapping` from `beartype.typing`.
- Introduced a new test file `tests/test_affinity.py` for structured testing of affinity calculations.
This commit is contained in:
2025-04-28 18:21:08 +08:00
parent b3ed20296a
commit 487dd4e237
5 changed files with 185 additions and 2 deletions

1
.gitignore vendored
View File

@ -8,4 +8,5 @@ wheels/
# Virtual environments # Virtual environments
.venv .venv
.hypothesis
samples samples

View File

@ -24,9 +24,7 @@ from pathlib import Path
from typing import ( from typing import (
Any, Any,
Generator, Generator,
Mapping,
Optional, Optional,
Sequence,
TypeAlias, TypeAlias,
TypedDict, TypedDict,
TypeVar, TypeVar,
@ -40,6 +38,7 @@ import jax.numpy as jnp
import numpy as np import numpy as np
import orjson import orjson
from beartype import beartype from beartype import beartype
from beartype.typing import Sequence, Mapping
from cv2 import undistortPoints from cv2 import undistortPoints
from IPython.display import display from IPython.display import display
from jaxtyping import Array, Float, Num, jaxtyped from jaxtyping import Array, Float, Num, jaxtyped

View File

@ -9,6 +9,7 @@ dependencies = [
"awkward>=2.7.4", "awkward>=2.7.4",
"beartype>=0.20.0", "beartype>=0.20.0",
"cvxopt>=1.3.2", "cvxopt>=1.3.2",
"hypothesis>=6.131.9",
"jax[cuda12]>=0.5.1", "jax[cuda12]>=0.5.1",
"jaxtyping>=0.2.38", "jaxtyping>=0.2.38",
"jupytext>=1.17.0", "jupytext>=1.17.0",
@ -18,6 +19,7 @@ dependencies = [
"pandas>=2.2.3", "pandas>=2.2.3",
"plotly>=6.0.1", "plotly>=6.0.1",
"pyarrow>=19.0.1", "pyarrow>=19.0.1",
"pytest>=8.3.5",
"scipy>=1.15.2", "scipy>=1.15.2",
"torch>=2.6.0", "torch>=2.6.0",
"torchvision>=0.21.0", "torchvision>=0.21.0",

119
tests/test_affinity.py Normal file
View File

@ -0,0 +1,119 @@
from datetime import datetime, timedelta
import jax.numpy as jnp
import numpy as np
import pytest
from hypothesis import given, settings, HealthCheck
from hypothesis import strategies as st
from app.camera import Camera, CameraParams
from playground import (
Detection,
Tracking,
calculate_affinity_matrix,
calculate_camera_affinity_matrix,
)
# ----------------------------------------------------------------------------
# Helper functions to generate synthetic cameras / trackings / detections
# ----------------------------------------------------------------------------
def _make_dummy_camera(cam_id: str, rng: np.random.Generator) -> Camera:
K = jnp.eye(3)
Rt = jnp.eye(4)
dist = jnp.zeros(5)
image_size = jnp.array([1000, 1000])
params = CameraParams(K=K, Rt=Rt, dist_coeffs=dist, image_size=image_size)
return Camera(id=cam_id, params=params)
def _random_keypoints_3d(rng: np.random.Generator, J: int):
return jnp.asarray(rng.uniform(-1.0, 1.0, size=(J, 3)).astype(np.float32))
def _random_keypoints_2d(rng: np.random.Generator, J: int):
return jnp.asarray(rng.uniform(0.0, 1000.0, size=(J, 2)).astype(np.float32))
def _make_trackings(rng: np.random.Generator, camera: Camera, T: int, J: int):
now = datetime.now()
trackings = []
for i in range(T):
kps3d = _random_keypoints_3d(rng, J)
trk = Tracking(
id=i + 1,
keypoints=kps3d,
last_active_timestamp=now
- timedelta(milliseconds=int(rng.integers(20, 50))),
)
trackings.append(trk)
return trackings
def _make_detections(rng: np.random.Generator, camera: Camera, D: int, J: int):
now = datetime.now()
detections = []
for _ in range(D):
kps2d = _random_keypoints_2d(rng, J)
det = Detection(
keypoints=kps2d,
confidences=jnp.ones(J, dtype=jnp.float32),
camera=camera,
timestamp=now,
)
detections.append(det)
return detections
# ----------------------------------------------------------------------------
# Property-based test: per-camera vs naive slice should match
# ----------------------------------------------------------------------------
@settings(max_examples=3, deadline=None, suppress_health_check=[HealthCheck.too_slow])
@given(
T=st.integers(min_value=1, max_value=4),
D=st.integers(min_value=1, max_value=4),
J=st.integers(min_value=5, max_value=15),
seed=st.integers(min_value=0, max_value=10000),
)
def test_per_camera_matches_naive(T, D, J, seed):
rng = np.random.default_rng(seed)
cam = _make_dummy_camera("C0", rng)
trackings = _make_trackings(rng, cam, T, J)
detections = _make_detections(rng, cam, D, J)
# Compute per-camera affinity (fast)
A_fast = calculate_camera_affinity_matrix(
trackings,
detections,
w_2d=1.0,
alpha_2d=1.0,
w_3d=1.0,
alpha_3d=1.0,
lambda_a=0.1,
)
# Compute naive multi-camera affinity and slice out this camera
from collections import OrderedDict
det_dict = OrderedDict({"C0": detections})
A_naive, _ = calculate_affinity_matrix(
trackings,
det_dict,
w_2d=1.0,
alpha_2d=1.0,
w_3d=1.0,
alpha_3d=1.0,
lambda_a=0.1,
)
# They should be close
np.testing.assert_allclose(A_fast, np.asarray(A_naive), rtol=1e-5, atol=1e-5)
if __name__ == "__main__" and pytest is not None:
pytest.main([__file__])

62
uv.lock generated
View File

@ -448,6 +448,7 @@ dependencies = [
{ name = "awkward" }, { name = "awkward" },
{ name = "beartype" }, { name = "beartype" },
{ name = "cvxopt" }, { name = "cvxopt" },
{ name = "hypothesis" },
{ name = "jax", extra = ["cuda12"] }, { name = "jax", extra = ["cuda12"] },
{ name = "jaxtyping" }, { name = "jaxtyping" },
{ name = "jupytext" }, { name = "jupytext" },
@ -457,6 +458,7 @@ dependencies = [
{ name = "pandas" }, { name = "pandas" },
{ name = "plotly" }, { name = "plotly" },
{ name = "pyarrow" }, { name = "pyarrow" },
{ name = "pytest" },
{ name = "scipy" }, { name = "scipy" },
{ name = "torch" }, { name = "torch" },
{ name = "torchvision" }, { name = "torchvision" },
@ -474,6 +476,7 @@ requires-dist = [
{ name = "awkward", specifier = ">=2.7.4" }, { name = "awkward", specifier = ">=2.7.4" },
{ name = "beartype", specifier = ">=0.20.0" }, { name = "beartype", specifier = ">=0.20.0" },
{ name = "cvxopt", specifier = ">=1.3.2" }, { name = "cvxopt", specifier = ">=1.3.2" },
{ name = "hypothesis", specifier = ">=6.131.9" },
{ name = "jax", extras = ["cuda12"], specifier = ">=0.5.1" }, { name = "jax", extras = ["cuda12"], specifier = ">=0.5.1" },
{ name = "jaxtyping", specifier = ">=0.2.38" }, { name = "jaxtyping", specifier = ">=0.2.38" },
{ name = "jupytext", specifier = ">=1.17.0" }, { name = "jupytext", specifier = ">=1.17.0" },
@ -483,6 +486,7 @@ requires-dist = [
{ name = "pandas", specifier = ">=2.2.3" }, { name = "pandas", specifier = ">=2.2.3" },
{ name = "plotly", specifier = ">=6.0.1" }, { name = "plotly", specifier = ">=6.0.1" },
{ name = "pyarrow", specifier = ">=19.0.1" }, { name = "pyarrow", specifier = ">=19.0.1" },
{ name = "pytest", specifier = ">=8.3.5" },
{ name = "scipy", specifier = ">=1.15.2" }, { name = "scipy", specifier = ">=1.15.2" },
{ name = "torch", specifier = ">=2.6.0" }, { name = "torch", specifier = ">=2.6.0" },
{ name = "torchvision", specifier = ">=0.21.0" }, { name = "torchvision", specifier = ">=0.21.0" },
@ -711,6 +715,20 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517 }, { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517 },
] ]
[[package]]
name = "hypothesis"
version = "6.131.9"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "attrs" },
{ name = "exceptiongroup", marker = "python_full_version < '3.11'" },
{ name = "sortedcontainers" },
]
sdist = { url = "https://files.pythonhosted.org/packages/10/ff/217417d065aa8a4e6815ddc39acee1222f1b67bd0e4803b85de86a837873/hypothesis-6.131.9.tar.gz", hash = "sha256:ee9b0e1403e1121c91921dbdc79d7f509fdb96d457a0389222d2a68d6c8a8f8e", size = 435415 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/bd/e5/41a6733bfe11997795669dec3b3d785c28918e06568a2540dcc29f0d3fa7/hypothesis-6.131.9-py3-none-any.whl", hash = "sha256:7c2d9d6382e98e5337b27bd34e5b223bac23956787a827e1d087e00d893561d6", size = 499853 },
]
[[package]] [[package]]
name = "idna" name = "idna"
version = "3.10" version = "3.10"
@ -732,6 +750,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/79/9d/0fb148dc4d6fa4a7dd1d8378168d9b4cd8d4560a6fbf6f0121c5fc34eb68/importlib_metadata-8.6.1-py3-none-any.whl", hash = "sha256:02a89390c1e15fdfdc0d7c6b25cb3e62650d0494005c97d6f148bf5b9787525e", size = 26971 }, { url = "https://files.pythonhosted.org/packages/79/9d/0fb148dc4d6fa4a7dd1d8378168d9b4cd8d4560a6fbf6f0121c5fc34eb68/importlib_metadata-8.6.1-py3-none-any.whl", hash = "sha256:02a89390c1e15fdfdc0d7c6b25cb3e62650d0494005c97d6f148bf5b9787525e", size = 26971 },
] ]
[[package]]
name = "iniconfig"
version = "2.1.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050 },
]
[[package]] [[package]]
name = "ipykernel" name = "ipykernel"
version = "6.29.5" version = "6.29.5"
@ -2143,6 +2170,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/02/65/ad2bc85f7377f5cfba5d4466d5474423a3fb7f6a97fd807c06f92dd3e721/plotly-6.0.1-py3-none-any.whl", hash = "sha256:4714db20fea57a435692c548a4eb4fae454f7daddf15f8d8ba7e1045681d7768", size = 14805757 }, { url = "https://files.pythonhosted.org/packages/02/65/ad2bc85f7377f5cfba5d4466d5474423a3fb7f6a97fd807c06f92dd3e721/plotly-6.0.1-py3-none-any.whl", hash = "sha256:4714db20fea57a435692c548a4eb4fae454f7daddf15f8d8ba7e1045681d7768", size = 14805757 },
] ]
[[package]]
name = "pluggy"
version = "1.5.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/96/2d/02d4312c973c6050a18b314a5ad0b3210edb65a906f868e31c111dede4a6/pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1", size = 67955 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669", size = 20556 },
]
[[package]] [[package]]
name = "prometheus-client" name = "prometheus-client"
version = "0.21.1" version = "0.21.1"
@ -2266,6 +2302,23 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/05/e7/df2285f3d08fee213f2d041540fa4fc9ca6c2d44cf36d3a035bf2a8d2bcc/pyparsing-3.2.3-py3-none-any.whl", hash = "sha256:a749938e02d6fd0b59b356ca504a24982314bb090c383e3cf201c95ef7e2bfcf", size = 111120 }, { url = "https://files.pythonhosted.org/packages/05/e7/df2285f3d08fee213f2d041540fa4fc9ca6c2d44cf36d3a035bf2a8d2bcc/pyparsing-3.2.3-py3-none-any.whl", hash = "sha256:a749938e02d6fd0b59b356ca504a24982314bb090c383e3cf201c95ef7e2bfcf", size = 111120 },
] ]
[[package]]
name = "pytest"
version = "8.3.5"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "colorama", marker = "sys_platform == 'win32'" },
{ name = "exceptiongroup", marker = "python_full_version < '3.11'" },
{ name = "iniconfig" },
{ name = "packaging" },
{ name = "pluggy" },
{ name = "tomli", marker = "python_full_version < '3.11'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/ae/3c/c9d525a414d506893f0cd8a8d0de7706446213181570cdbd766691164e40/pytest-8.3.5.tar.gz", hash = "sha256:f4efe70cc14e511565ac476b57c279e12a855b11f48f212af1080ef2263d3845", size = 1450891 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/30/3d/64ad57c803f1fa1e963a7946b6e0fea4a70df53c1a7fed304586539c2bac/pytest-8.3.5-py3-none-any.whl", hash = "sha256:c69214aa47deac29fad6c2a4f590b9c4a9fdb16a403176fe154b79c0b4d4d820", size = 343634 },
]
[[package]] [[package]]
name = "python-dateutil" name = "python-dateutil"
version = "2.9.0.post0" version = "2.9.0.post0"
@ -2672,6 +2725,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235 }, { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235 },
] ]
[[package]]
name = "sortedcontainers"
version = "2.4.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/e8/c4/ba2f8066cceb6f23394729afe52f3bf7adec04bf9ed2c820b39e19299111/sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88", size = 30594 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl", hash = "sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0", size = 29575 },
]
[[package]] [[package]]
name = "soupsieve" name = "soupsieve"
version = "2.6" version = "2.6"