diff --git a/app/tracking/__init__.py b/app/tracking/__init__.py index e3b7916..5035d34 100644 --- a/app/tracking/__init__.py +++ b/app/tracking/__init__.py @@ -2,8 +2,10 @@ from dataclasses import dataclass from datetime import datetime from typing import ( Any, + Callable, Generator, Optional, + Sequence, TypeAlias, TypedDict, TypeVar, @@ -14,7 +16,11 @@ from typing import ( import jax import jax.numpy as jnp from beartype import beartype -from jaxtyping import Array, Float, jaxtyped +from beartype.typing import Mapping, Sequence +from jax import Array +from jaxtyping import Array, Float, Int, jaxtyped + +from app.camera import Detection @jaxtyped(typechecker=beartype) @@ -67,3 +73,30 @@ class Tracking: # Step 2 – pure JAX math # ------------------------------------------------------------------ return self.keypoints + velocity * delta_t_s + + +@jaxtyped(typechecker=beartype) +@dataclass +class AffinityResult: + """ + Result of affinity computation between trackings and detections. + """ + + matrix: Float[Array, "T D"] + trackings: Sequence[Tracking] + detections: Sequence[Detection] + indices_T: Int[Array, "T"] # pylint: disable=invalid-name + indices_D: Int[Array, "D"] # pylint: disable=invalid-name + + def tracking_detections( + self, + ) -> Generator[tuple[float, Tracking, Detection], None, None]: + """ + iterate over the best matching trackings and detections + """ + for t, d in zip(self.indices_T, self.indices_D): + yield ( + self.matrix[t, d].item(), + self.trackings[t], + self.detections[d], + ) diff --git a/app/tracking/affinity_result.py b/app/tracking/affinity_result.py deleted file mode 100644 index 78d325f..0000000 --- a/app/tracking/affinity_result.py +++ /dev/null @@ -1,37 +0,0 @@ -from dataclasses import dataclass -from typing import Sequence, Callable, Generator - -from app.camera import Detection -from . import Tracking -from beartype.typing import Sequence, Mapping -from jaxtyping import jaxtyped, Float, Int -from jax import Array - - -@dataclass -class AffinityResult: - """ - Result of affinity computation between trackings and detections. - """ - - matrix: Float[Array, "T D"] - """ - Affinity matrix between trackings and detections. - """ - - trackings: Sequence[Tracking] - """ - Trackings used to compute the affinity matrix. - """ - - detections: Sequence[Detection] - """ - Detections used to compute the affinity matrix. - """ - - indices_T: Sequence[int] - indices_D: Sequence[int] - - def tracking_detections(self) -> Generator[tuple[Tracking, Detection]]: - for t, d in zip(self.indices_T, self.indices_D): - yield (self.trackings[t], self.detections[d]) diff --git a/playground.py b/playground.py index 10431c2..39cdb6a 100644 --- a/playground.py +++ b/playground.py @@ -45,7 +45,7 @@ from IPython.display import display from jaxtyping import Array, Float, Num, jaxtyped from matplotlib import pyplot as plt from numpy.typing import ArrayLike -from scipy.optimize import linear_sum_assignment +from optax.assignment import hungarian_algorithm as linear_sum_assignment from scipy.spatial.transform import Rotation as R from typing_extensions import deprecated @@ -58,7 +58,7 @@ from app.camera import ( classify_by_camera, ) from app.solver._old import GLPKSolver -from app.tracking import Tracking +from app.tracking import AffinityResult, Tracking from app.visualize.whole_body import visualize_whole_body NDArray: TypeAlias = np.ndarray @@ -69,12 +69,6 @@ AK_CAMERA_DATASET: ak.Array = ak.from_parquet(DATASET_PATH / "camera_params.parq DELTA_T_MIN = timedelta(milliseconds=10) display(AK_CAMERA_DATASET) -_DEBUG_CURRENT_TRACKING: tuple[int, int] = (0, 0) - - -def _global_current_tracking_str(): - return str(_DEBUG_CURRENT_TRACKING) - # %% class Resolution(TypedDict): @@ -594,23 +588,6 @@ def calculate_distance_2d( left_normalized = left / jnp.array([w, h]) right_normalized = right / jnp.array([w, h]) dist = jnp.linalg.norm(left_normalized - right_normalized, axis=-1) - lt = left_normalized[:6] - rt = right_normalized[:6] - jax.debug.print( - "[REF]{} norm_trk first6 = {}", - _global_current_tracking_str(), - lt, - ) - jax.debug.print( - "[REF]{} norm_det first6 = {}", - _global_current_tracking_str(), - rt, - ) - jax.debug.print( - "[REF]{} dist2d first6 = {}", - _global_current_tracking_str(), - dist[:6], - ) return dist @@ -806,191 +783,12 @@ def calculate_tracking_detection_affinity( lambda_a=lambda_a, ) - jax.debug.print( - "[REF] aff2d{} first6 = {}", - _global_current_tracking_str(), - affinity_2d[:6], - ) - jax.debug.print( - "[REF] aff3d{} first6 = {}", _global_current_tracking_str(), affinity_3d[:6] - ) - jax.debug.print( - "[REF] aff2d.shape={}; aff3d.shape={}", - affinity_2d.shape, - affinity_3d.shape, - ) # Combine affinities total_affinity = affinity_2d + affinity_3d return jnp.sum(total_affinity).item() # %% -@deprecated( - "Use `calculate_camera_affinity_matrix` instead. This implementation has the problem of under-utilizing views from different cameras." -) -@beartype -def calculate_affinity_matrix( - trackings: Sequence[Tracking], - detections: Sequence[Detection] | OrderedDict[CameraID, list[Detection]], - w_2d: float, - alpha_2d: float, - w_3d: float, - alpha_3d: float, - lambda_a: float, -) -> tuple[Float[Array, "T D"], OrderedDict[CameraID, list[Detection]]]: - """ - Calculate the affinity matrix between a set of trackings and detections. - - Args: - trackings: Sequence of tracking objects - detections: Sequence of detection objects - w_2d: Weight for 2D affinity - alpha_2d: Normalization factor for 2D distance - w_3d: Weight for 3D affinity - alpha_3d: Normalization factor for 3D distance - lambda_a: Decay rate for time difference - - Returns: - - affinity matrix of shape (T, D) where T is number of trackings and D - is number of detections - - dictionary mapping camera IDs to lists of detections from that camera, - since it's a `OrderDict` you could flat it out to get the indices of - detections in the affinity matrix - - Matrix Layout: - The affinity matrix has shape (T, D), where: - - T = number of trackings (rows) - - D = total number of detections across all cameras (columns) - - The matrix is organized as follows: - - ``` - | Camera 1 | Camera 2 | Camera c | - | d1 d2 ... | d1 d2 ... | d1 d2 ... | - ---------+-------------+-------------+-------------+ - Track 1 | a11 a12 ... | a11 a12 ... | a11 a12 ... | - Track 2 | a21 a22 ... | a21 a22 ... | a21 a22 ... | - ... | ... | ... | ... | - Track t | at1 at2 ... | at1 at2 ... | at1 at2 ... | - ``` - - Where: - - Rows are ordered by tracking ID - - Columns are ordered by camera, then by detection within each camera - - Each cell aij represents the affinity between tracking i and detection j - - The detection ordering in columns follows the exact same order as iterating - through the detection_by_camera dictionary, which is returned alongside - the matrix to maintain this relationship. - """ - if isinstance(detections, OrderedDict): - D = flatten_values_len(detections) - affinity = jnp.zeros((len(trackings), D)) - detection_by_camera = detections - else: - affinity = jnp.zeros((len(trackings), len(detections))) - detection_by_camera = classify_by_camera(detections) - - for i, tracking in enumerate(trackings): - j = 0 - for _, camera_detections in detection_by_camera.items(): - for det in camera_detections: - affinity_value = calculate_tracking_detection_affinity( - tracking, - det, - w_2d=w_2d, - alpha_2d=alpha_2d, - w_3d=w_3d, - alpha_3d=alpha_3d, - lambda_a=lambda_a, - ) - affinity = affinity.at[i, j].set(affinity_value) - j += 1 - - return affinity, detection_by_camera - - -@beartype -def calculate_camera_affinity_matrix( - trackings: Sequence[Tracking], - camera_detections: Sequence[Detection], - w_2d: float, - alpha_2d: float, - w_3d: float, - alpha_3d: float, - lambda_a: float, -) -> Float[Array, "T D"]: - """ - Calculate an affinity matrix between trackings and detections from a single camera. - - This follows the iterative camera-by-camera approach from the paper - "Cross-View Tracking for Multi-Human 3D Pose Estimation at over 100 FPS". - Instead of creating one large matrix for all cameras, this creates - a separate matrix for each camera, which can be processed independently. - - Args: - trackings: Sequence of tracking objects - camera_detections: Sequence of detection objects, from the same camera - w_2d: Weight for 2D affinity - alpha_2d: Normalization factor for 2D distance - w_3d: Weight for 3D affinity - alpha_3d: Normalization factor for 3D distance - lambda_a: Decay rate for time difference - - Returns: - Affinity matrix of shape (T, D) where: - - T = number of trackings (rows) - - D = number of detections from this specific camera (columns) - - Matrix Layout: - The affinity matrix for a single camera has shape (T, D), where: - - T = number of trackings (rows) - - D = number of detections from this camera (columns) - - The matrix is organized as follows: - - ``` - | Detections from Camera c | - | d1 d2 d3 ... | - ---------+------------------------+ - Track 1 | a11 a12 a13 ... | - Track 2 | a21 a22 a23 ... | - ... | ... ... ... ... | - Track t | at1 at2 at3 ... | - ``` - - Each cell aij represents the affinity between tracking i and detection j, - computed using both 2D and 3D geometric correspondences. - """ - - def verify_all_detection_from_same_camera(detections: Sequence[Detection]): - if not detections: - return True - camera_id = next(iter(detections)).camera.id - return all(map(lambda d: d.camera.id == camera_id, detections)) - - if not verify_all_detection_from_same_camera(camera_detections): - raise ValueError("All detections must be from the same camera") - - affinity = jnp.zeros((len(trackings), len(camera_detections))) - - for i, tracking in enumerate(trackings): - for j, det in enumerate(camera_detections): - global _DEBUG_CURRENT_TRACKING - _DEBUG_CURRENT_TRACKING = (i, j) - affinity_value = calculate_tracking_detection_affinity( - tracking, - det, - w_2d=w_2d, - alpha_2d=alpha_2d, - w_3d=w_3d, - alpha_3d=alpha_3d, - lambda_a=lambda_a, - ) - affinity = affinity.at[i, j].set(affinity_value) - return affinity - - @beartype def calculate_camera_affinity_matrix_jax( trackings: Sequence[Tracking], @@ -1010,9 +808,6 @@ def calculate_camera_affinity_matrix_jax( (tracking, detection) pair. The mathematical definition of the affinity is **unchanged**, so the result remains bit-identical to the reference implementation used in the tests. - - TODO: It gives a wrong result (maybe it's my problem?) somehow, - and I need to find a way to debug this. """ # ------------------------------------------------------------------ @@ -1052,8 +847,8 @@ def calculate_camera_affinity_matrix_jax( # ------------------------------------------------------------------ # Compute Δt matrix – shape (T, D) # ------------------------------------------------------------------ - # Epoch timestamps are ~1.7 × 10⁹; storing them in float32 wipes out - # sub‑second detail (resolution ≈ 200 ms). Keep them in float64 until + # Epoch timestamps are ~1.7 × 10⁹; storing them in float32 wipes out + # sub‑second detail (resolution ≈ 200 ms). Keep them in float64 until # after subtraction so we preserve Δt‑on‑the‑order‑of‑milliseconds. # --- timestamps ---------- t0 = min( @@ -1093,12 +888,6 @@ def calculate_camera_affinity_matrix_jax( diff2d = norm_trk[:, None, :, :] - norm_det[None, :, :, :] dist2d: Float[Array, "T D J"] = jnp.linalg.norm(diff2d, axis=-1) - jax.debug.print( - "[JAX] norm_trk[0,0,:6] = {}", norm_trk[0, :, :6] # shape (J,2) 取前6 - ) - jax.debug.print("[JAX] norm_det[0,:6] = {}", norm_det[0, :6]) # shape (J,2) - jax.debug.print("[JAX] dist2d(T0,D0) first6 = {}", dist2d[0, 0, :6]) - # Compute per-keypoint 2D affinity delta_t_broadcast = delta_t[:, :, None] # (T, D, 1) affinity_2d = ( @@ -1155,11 +944,6 @@ def calculate_camera_affinity_matrix_jax( w_3d * (1 - dist3d / alpha_3d) * jnp.exp(-lambda_a * delta_t_broadcast) ) - jax.debug.print("[JAX] aff3d(T0,D0) first6 = {}", affinity_3d[0, 0, :6]) - jax.debug.print("[JAX] aff2d(T0,D0) first6 = {}", affinity_2d[0, 0, :6]) - jax.debug.print( - "[JAX] aff2d.shape={}; aff3d.shape={}", affinity_2d.shape, affinity_3d.shape - ) # ------------------------------------------------------------------ # Combine and reduce across keypoints → (T, D) # ------------------------------------------------------------------ @@ -1167,60 +951,57 @@ def calculate_camera_affinity_matrix_jax( return total_affinity # type: ignore[return-value] -# ------------------------------------------------------------------ -# Debug helper – compare JAX vs reference implementation -# ------------------------------------------------------------------ @beartype -def debug_compare_affinity_matrices( +def calculate_affinity_matrix( trackings: Sequence[Tracking], - camera_detections: Sequence[Detection], - *, + detections: Sequence[Detection] | Mapping[CameraID, list[Detection]], w_2d: float, alpha_2d: float, w_3d: float, alpha_3d: float, lambda_a: float, - atol: float = 1e-5, - rtol: float = 1e-3, -) -> None: +) -> dict[CameraID, AffinityResult]: """ - Compute both affinity matrices and print out the max absolute / relative - difference. If any entry differs more than atol+rtol*|ref|, dump the - offending indices so you can inspect individual terms. + Calculate the affinity matrix between a set of trackings and detections. + + Args: + trackings: Sequence of tracking objects + detections: Sequence of detection objects or a group detections by ID + w_2d: Weight for 2D affinity + alpha_2d: Normalization factor for 2D distance + w_3d: Weight for 3D affinity + alpha_3d: Normalization factor for 3D distance + lambda_a: Decay rate for time difference + Returns: + A dictionary mapping camera IDs to affinity results. """ - aff_jax = calculate_camera_affinity_matrix_jax( - trackings, - camera_detections, - w_2d=w_2d, - alpha_2d=alpha_2d, - w_3d=w_3d, - alpha_3d=alpha_3d, - lambda_a=lambda_a, - ) - aff_ref = calculate_camera_affinity_matrix( - trackings, - camera_detections, - w_2d=w_2d, - alpha_2d=alpha_2d, - w_3d=w_3d, - alpha_3d=alpha_3d, - lambda_a=lambda_a, - ) - - diff = jnp.abs(aff_jax - aff_ref) - max_abs = float(diff.max()) - max_rel = float((diff / (jnp.abs(aff_ref) + 1e-12)).max()) - jax.debug.print(f"[DEBUG] max abs diff {max_abs:.6g}, max rel diff {max_rel:.6g}") - - bad = jnp.where(diff > atol + rtol * jnp.abs(aff_ref)) - if bad[0].size > 0: - for t, d in zip(*[arr.tolist() for arr in bad]): - jax.debug.print( - f" ↳ mismatch at (T={t}, D={d}): " - f"jax={aff_jax[t,d]:.6g}, ref={aff_ref[t,d]:.6g}" - ) + if isinstance(detections, Mapping): + detection_by_camera = detections else: - jax.debug.print("✅ matrices match within tolerance") + detection_by_camera = classify_by_camera(detections) + + res: dict[CameraID, AffinityResult] = {} + for camera_id, camera_detections in detection_by_camera.items(): + affinity_matrix = calculate_camera_affinity_matrix_jax( + trackings, + camera_detections, + w_2d, + alpha_2d, + w_3d, + alpha_3d, + lambda_a, + ) + # row, col + indices_T, indices_D = linear_sum_assignment(affinity_matrix) + affinity_result = AffinityResult( + matrix=affinity_matrix, + trackings=trackings, + detections=camera_detections, + indices_T=indices_T, + indices_D=indices_D, + ) + res[camera_id] = affinity_result + return res # %% @@ -1235,15 +1016,15 @@ trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id) unmatched_detections = shallow_copy(next_group) camera_detections = classify_by_camera(unmatched_detections) -camera_detections_next_batch = camera_detections["AE_08"] -debug_compare_affinity_matrices( +affinities = calculate_affinity_matrix( trackings, - camera_detections_next_batch, + unmatched_detections, w_2d=W_2D, alpha_2d=ALPHA_2D, w_3d=W_3D, alpha_3d=ALPHA_3D, lambda_a=LAMBDA_A, ) +display(affinities) # %% diff --git a/pyproject.toml b/pyproject.toml index e8b261d..4dd80b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "jupytext>=1.17.0", "matplotlib>=3.10.1", "opencv-python-headless>=4.11.0.86", + "optax>=0.2.4", "orjson>=3.10.15", "pandas>=2.2.3", "plotly>=6.0.1", diff --git a/tests/test_affinity.py b/tests/test_affinity.py deleted file mode 100644 index 88bd091..0000000 --- a/tests/test_affinity.py +++ /dev/null @@ -1,224 +0,0 @@ -from datetime import datetime, timedelta -import time - -import jax.numpy as jnp -import numpy as np -import pytest -from hypothesis import given, settings, HealthCheck -from hypothesis import strategies as st - -from app.camera import Camera, CameraParams -from playground import ( - Detection, - Tracking, - calculate_affinity_matrix, - calculate_camera_affinity_matrix, -) - -# ---------------------------------------------------------------------------- -# Helper functions to generate synthetic cameras / trackings / detections -# ---------------------------------------------------------------------------- - - -def _make_dummy_camera(cam_id: str, rng: np.random.Generator) -> Camera: - K = jnp.eye(3) - Rt = jnp.eye(4) - dist = jnp.zeros(5) - image_size = jnp.array([1000, 1000]) - params = CameraParams(K=K, Rt=Rt, dist_coeffs=dist, image_size=image_size) - return Camera(id=cam_id, params=params) - - -def _random_keypoints_3d(rng: np.random.Generator, J: int): - return jnp.asarray(rng.uniform(-1.0, 1.0, size=(J, 3)).astype(np.float32)) - - -def _random_keypoints_2d(rng: np.random.Generator, J: int): - return jnp.asarray(rng.uniform(0.0, 1000.0, size=(J, 2)).astype(np.float32)) - - -def _make_trackings(rng: np.random.Generator, camera: Camera, T: int, J: int): - now = datetime.now() - trackings = [] - for i in range(T): - kps3d = _random_keypoints_3d(rng, J) - trk = Tracking( - id=i + 1, - keypoints=kps3d, - last_active_timestamp=now - - timedelta(milliseconds=int(rng.integers(20, 50))), - ) - trackings.append(trk) - return trackings - - -def _make_detections(rng: np.random.Generator, camera: Camera, D: int, J: int): - now = datetime.now() - detections = [] - for _ in range(D): - kps2d = _random_keypoints_2d(rng, J) - det = Detection( - keypoints=kps2d, - confidences=jnp.ones(J, dtype=jnp.float32), - camera=camera, - timestamp=now, - ) - detections.append(det) - return detections - - -# ---------------------------------------------------------------------------- -# Property-based test: per-camera vs naive slice should match -# ---------------------------------------------------------------------------- - - -@settings(max_examples=3, deadline=None, suppress_health_check=[HealthCheck.too_slow]) -@given( - T=st.integers(min_value=1, max_value=4), - D=st.integers(min_value=1, max_value=4), - J=st.integers(min_value=5, max_value=15), - seed=st.integers(min_value=0, max_value=10000), -) -def test_per_camera_matches_naive(T, D, J, seed): - rng = np.random.default_rng(seed) - - cam = _make_dummy_camera("C0", rng) - - trackings = _make_trackings(rng, cam, T, J) - detections = _make_detections(rng, cam, D, J) - - # Parameters - W_2D = 1.0 - ALPHA_2D = 1.0 - LAMBDA_A = 0.1 - W_3D = 1.0 - ALPHA_3D = 1.0 - - # Compute per-camera affinity (fast) - A_fast = calculate_camera_affinity_matrix( - trackings, - detections, - w_2d=W_2D, - alpha_2d=ALPHA_2D, - w_3d=W_3D, - alpha_3d=ALPHA_3D, - lambda_a=LAMBDA_A, - ) - - # Compute naive multi-camera affinity and slice out this camera - from collections import OrderedDict - - det_dict = OrderedDict({"C0": detections}) - A_naive, _ = calculate_affinity_matrix( - trackings, - det_dict, - w_2d=W_2D, - alpha_2d=ALPHA_2D, - w_3d=W_3D, - alpha_3d=ALPHA_3D, - lambda_a=LAMBDA_A, - ) - # both fast and naive implementation gives NaN - # we need to inject real-world data - - # print("A_fast") - # print(A_fast) - # print("A_naive") - # print(A_naive) - - # They should be close - np.testing.assert_allclose(A_fast, np.asarray(A_naive), rtol=1e-5, atol=1e-5) - - -@pytest.mark.parametrize("T,D,J", [(2, 3, 10), (4, 4, 15), (6, 8, 20)]) -def test_benchmark_affinity_matrix(T, D, J): - """Compare performance between naive and fast affinity matrix calculation.""" - seed = 42 - rng = np.random.default_rng(seed) - cam = _make_dummy_camera("C0", rng) - - trackings = _make_trackings(rng, cam, T, J) - detections = _make_detections(rng, cam, D, J) - - # Parameters - w_2d = 1.0 - alpha_2d = 1.0 - w_3d = 1.0 - alpha_3d = 1.0 - lambda_a = 0.1 - - # Setup for naive - from collections import OrderedDict - - det_dict = OrderedDict({"C0": detections}) - - # First run to compile - A_fast = calculate_camera_affinity_matrix( - trackings, - detections, - w_2d=w_2d, - alpha_2d=alpha_2d, - w_3d=w_3d, - alpha_3d=alpha_3d, - lambda_a=lambda_a, - ) - A_naive, _ = calculate_affinity_matrix( - trackings, - det_dict, - w_2d=w_2d, - alpha_2d=alpha_2d, - w_3d=w_3d, - alpha_3d=alpha_3d, - lambda_a=lambda_a, - ) - - # Assert they match before timing - np.testing.assert_allclose(A_fast, np.asarray(A_naive), rtol=1e-5, atol=1e-5) - - # Timing - num_runs = 3 - - # Time the vectorized version - start = time.perf_counter() - for _ in range(num_runs): - calculate_camera_affinity_matrix( - trackings, - detections, - w_2d=w_2d, - alpha_2d=alpha_2d, - w_3d=w_3d, - alpha_3d=alpha_3d, - lambda_a=lambda_a, - ) - end = time.perf_counter() - vectorized_time = (end - start) / num_runs - - # Time the naive version - start = time.perf_counter() - for _ in range(num_runs): - calculate_affinity_matrix( - trackings, - det_dict, - w_2d=w_2d, - alpha_2d=alpha_2d, - w_3d=w_3d, - alpha_3d=alpha_3d, - lambda_a=lambda_a, - ) - end = time.perf_counter() - naive_time = (end - start) / num_runs - - speedup = naive_time / vectorized_time - print(f"\nBenchmark T={T}, D={D}, J={J}:") - print(f" Vectorized: {vectorized_time*1000:.2f}ms per run") - print(f" Naive: {naive_time*1000:.2f}ms per run") - print(f" Speedup: {speedup:.2f}x") - - # Sanity check - vectorized should be faster! - assert speedup > 1.0, "Vectorized implementation should be faster" - - -if __name__ == "__main__" and pytest is not None: - # python -m pytest -xvs -k test_benchmark - # pytest.main([__file__]) - pytest.main(["-xvs", __file__, "-k", "test_benchmark"]) diff --git a/uv.lock b/uv.lock index 255be76..ddba9a0 100644 --- a/uv.lock +++ b/uv.lock @@ -16,6 +16,15 @@ resolution-markers = [ "(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')", ] +[[package]] +name = "absl-py" +version = "2.2.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b5/f0/e6342091061ed3a46aadc116b13edd7bb5249c3ab1b3ef07f24b0c248fc3/absl_py-2.2.2.tar.gz", hash = "sha256:bf25b2c2eed013ca456918c453d687eab4e8309fba81ee2f4c1a6aa2494175eb", size = 119982 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f6/d4/349f7f4bd5ea92dab34f5bb0fe31775ef6c311427a14d5a5b31ecb442341/absl_py-2.2.2-py3-none-any.whl", hash = "sha256:e5797bc6abe45f64fd95dc06394ca3f2bedf3b5d895e9da691c9ee3397d70092", size = 135565 }, +] + [[package]] name = "anyio" version = "4.8.0" @@ -354,6 +363,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0e/f6/65ecc6878a89bb1c23a086ea335ad4bf21a588990c3f535a227b9eea9108/charset_normalizer-3.4.1-py3-none-any.whl", hash = "sha256:d98b1668f06378c6dbefec3b92299716b931cd4e6061f3c875a71ced1780ab85", size = 49767 }, ] +[[package]] +name = "chex" +version = "0.1.89" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "absl-py" }, + { name = "jax" }, + { name = "jaxlib" }, + { name = "numpy" }, + { name = "setuptools", marker = "python_full_version >= '3.12'" }, + { name = "toolz" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ca/ac/504a8019f7ef372fc6cc3999ec9e3d0fbb38e6992f55d845d5b928010c11/chex-0.1.89.tar.gz", hash = "sha256:78f856e6a0a8459edfcbb402c2c044d2b8102eac4b633838cbdfdcdb09c6c8e0", size = 90676 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5e/6c/309972937d931069816dc8b28193a650485bc35cca92c04c8c15c4bd181e/chex-0.1.89-py3-none-any.whl", hash = "sha256:145241c27d8944adb634fb7d472a460e1c1b643f561507d4031ad5156ef82dfa", size = 99908 }, +] + [[package]] name = "colorama" version = "0.4.6" @@ -454,6 +481,7 @@ dependencies = [ { name = "jupytext" }, { name = "matplotlib" }, { name = "opencv-python-headless" }, + { name = "optax" }, { name = "orjson" }, { name = "pandas" }, { name = "plotly" }, @@ -482,6 +510,7 @@ requires-dist = [ { name = "jupytext", specifier = ">=1.17.0" }, { name = "matplotlib", specifier = ">=3.10.1" }, { name = "opencv-python-headless", specifier = ">=4.11.0.86" }, + { name = "optax", specifier = ">=0.2.4" }, { name = "orjson", specifier = ">=3.10.15" }, { name = "pandas", specifier = ">=2.2.3" }, { name = "plotly", specifier = ">=6.0.1" }, @@ -583,6 +612,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/07/6c/aa3f2f849e01cb6a001cd8554a88d4c77c5c1a31c95bdf1cf9301e6d9ef4/defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61", size = 25604 }, ] +[[package]] +name = "etils" +version = "1.12.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e4/12/1cc11e88a0201280ff389bc4076df7c3432e39d9f22cba8b71aa263f67b8/etils-1.12.2.tar.gz", hash = "sha256:c6b9e1f0ce66d1bbf54f99201b08a60ba396d3446d9eb18d4bc39b26a2e1a5ee", size = 104711 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/dd/71/40ee142e564b8a34a7ae9546e99e665e0001011a3254d5bbbe113d72ccba/etils-1.12.2-py3-none-any.whl", hash = "sha256:4600bec9de6cf5cb043a171e1856e38b5f273719cf3ecef90199f7091a6b3912", size = 167613 }, +] + +[package.optional-dependencies] +epy = [ + { name = "typing-extensions" }, +] + [[package]] name = "exceptiongroup" version = "1.2.2" @@ -1925,6 +1968,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl", hash = "sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd", size = 71932 }, ] +[[package]] +name = "optax" +version = "0.2.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "absl-py" }, + { name = "chex" }, + { name = "etils", extra = ["epy"] }, + { name = "jax" }, + { name = "jaxlib" }, + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/af/b5/f88a0d851547b2e6b2c7e7e6509ad66236b3e7019f1f095bb03dbaa61fa1/optax-0.2.4.tar.gz", hash = "sha256:4e05d3d5307e6dde4c319187ae36e6cd3a0c035d4ed25e9e992449a304f47336", size = 229717 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5c/24/28d0bb21600a78e46754947333ec9a297044af884d360092eb8561575fe9/optax-0.2.4-py3-none-any.whl", hash = "sha256:db35c04e50b52596662efb002334de08c2a0a74971e4da33f467e84fac08886a", size = 319212 }, +] + [[package]] name = "orjson" version = "3.10.15" @@ -2834,6 +2894,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6e/c2/61d3e0f47e2b74ef40a68b9e6ad5984f6241a942f7cd3bbfbdbd03861ea9/tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc", size = 14257 }, ] +[[package]] +name = "toolz" +version = "1.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8a/0b/d80dfa675bf592f636d1ea0b835eab4ec8df6e9415d8cfd766df54456123/toolz-1.0.0.tar.gz", hash = "sha256:2c86e3d9a04798ac556793bced838816296a2f085017664e4995cb40a1047a02", size = 66790 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/03/98/eb27cc78ad3af8e302c9d8ff4977f5026676e130d28dd7578132a457170c/toolz-1.0.0-py3-none-any.whl", hash = "sha256:292c8f1c4e7516bf9086f8850935c799a874039c8bcf959d47b600e4c44a6236", size = 56383 }, +] + [[package]] name = "torch" version = "2.6.0"