forked from HQU-gxy/CVTH3PE
refactor: Update affinity matrix calculation and dependencies
- Replaced the `linear_sum_assignment` import from `scipy.optimize` with `hungarian_algorithm` from `optax` to enhance performance in affinity matrix calculations. - Introduced a new `AffinityResult` class to encapsulate results of affinity computations, including trackings and detections, improving the structure of the affinity calculation process. - Removed deprecated functions and debug print statements to streamline the codebase. - Updated `pyproject.toml` and `uv.lock` to include `optax` as a dependency, ensuring compatibility with the new implementation. - Refactored imports and type hints for better organization and consistency across the codebase.
This commit is contained in:
@ -2,8 +2,10 @@ from dataclasses import dataclass
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
|
Callable,
|
||||||
Generator,
|
Generator,
|
||||||
Optional,
|
Optional,
|
||||||
|
Sequence,
|
||||||
TypeAlias,
|
TypeAlias,
|
||||||
TypedDict,
|
TypedDict,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
@ -14,7 +16,11 @@ from typing import (
|
|||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from beartype import beartype
|
from beartype import beartype
|
||||||
from jaxtyping import Array, Float, jaxtyped
|
from beartype.typing import Mapping, Sequence
|
||||||
|
from jax import Array
|
||||||
|
from jaxtyping import Array, Float, Int, jaxtyped
|
||||||
|
|
||||||
|
from app.camera import Detection
|
||||||
|
|
||||||
|
|
||||||
@jaxtyped(typechecker=beartype)
|
@jaxtyped(typechecker=beartype)
|
||||||
@ -67,3 +73,30 @@ class Tracking:
|
|||||||
# Step 2 – pure JAX math
|
# Step 2 – pure JAX math
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
return self.keypoints + velocity * delta_t_s
|
return self.keypoints + velocity * delta_t_s
|
||||||
|
|
||||||
|
|
||||||
|
@jaxtyped(typechecker=beartype)
|
||||||
|
@dataclass
|
||||||
|
class AffinityResult:
|
||||||
|
"""
|
||||||
|
Result of affinity computation between trackings and detections.
|
||||||
|
"""
|
||||||
|
|
||||||
|
matrix: Float[Array, "T D"]
|
||||||
|
trackings: Sequence[Tracking]
|
||||||
|
detections: Sequence[Detection]
|
||||||
|
indices_T: Int[Array, "T"] # pylint: disable=invalid-name
|
||||||
|
indices_D: Int[Array, "D"] # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
def tracking_detections(
|
||||||
|
self,
|
||||||
|
) -> Generator[tuple[float, Tracking, Detection], None, None]:
|
||||||
|
"""
|
||||||
|
iterate over the best matching trackings and detections
|
||||||
|
"""
|
||||||
|
for t, d in zip(self.indices_T, self.indices_D):
|
||||||
|
yield (
|
||||||
|
self.matrix[t, d].item(),
|
||||||
|
self.trackings[t],
|
||||||
|
self.detections[d],
|
||||||
|
)
|
||||||
|
|||||||
@ -1,37 +0,0 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
from typing import Sequence, Callable, Generator
|
|
||||||
|
|
||||||
from app.camera import Detection
|
|
||||||
from . import Tracking
|
|
||||||
from beartype.typing import Sequence, Mapping
|
|
||||||
from jaxtyping import jaxtyped, Float, Int
|
|
||||||
from jax import Array
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AffinityResult:
|
|
||||||
"""
|
|
||||||
Result of affinity computation between trackings and detections.
|
|
||||||
"""
|
|
||||||
|
|
||||||
matrix: Float[Array, "T D"]
|
|
||||||
"""
|
|
||||||
Affinity matrix between trackings and detections.
|
|
||||||
"""
|
|
||||||
|
|
||||||
trackings: Sequence[Tracking]
|
|
||||||
"""
|
|
||||||
Trackings used to compute the affinity matrix.
|
|
||||||
"""
|
|
||||||
|
|
||||||
detections: Sequence[Detection]
|
|
||||||
"""
|
|
||||||
Detections used to compute the affinity matrix.
|
|
||||||
"""
|
|
||||||
|
|
||||||
indices_T: Sequence[int]
|
|
||||||
indices_D: Sequence[int]
|
|
||||||
|
|
||||||
def tracking_detections(self) -> Generator[tuple[Tracking, Detection]]:
|
|
||||||
for t, d in zip(self.indices_T, self.indices_D):
|
|
||||||
yield (self.trackings[t], self.detections[d])
|
|
||||||
315
playground.py
315
playground.py
@ -45,7 +45,7 @@ from IPython.display import display
|
|||||||
from jaxtyping import Array, Float, Num, jaxtyped
|
from jaxtyping import Array, Float, Num, jaxtyped
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from numpy.typing import ArrayLike
|
from numpy.typing import ArrayLike
|
||||||
from scipy.optimize import linear_sum_assignment
|
from optax.assignment import hungarian_algorithm as linear_sum_assignment
|
||||||
from scipy.spatial.transform import Rotation as R
|
from scipy.spatial.transform import Rotation as R
|
||||||
from typing_extensions import deprecated
|
from typing_extensions import deprecated
|
||||||
|
|
||||||
@ -58,7 +58,7 @@ from app.camera import (
|
|||||||
classify_by_camera,
|
classify_by_camera,
|
||||||
)
|
)
|
||||||
from app.solver._old import GLPKSolver
|
from app.solver._old import GLPKSolver
|
||||||
from app.tracking import Tracking
|
from app.tracking import AffinityResult, Tracking
|
||||||
from app.visualize.whole_body import visualize_whole_body
|
from app.visualize.whole_body import visualize_whole_body
|
||||||
|
|
||||||
NDArray: TypeAlias = np.ndarray
|
NDArray: TypeAlias = np.ndarray
|
||||||
@ -69,12 +69,6 @@ AK_CAMERA_DATASET: ak.Array = ak.from_parquet(DATASET_PATH / "camera_params.parq
|
|||||||
DELTA_T_MIN = timedelta(milliseconds=10)
|
DELTA_T_MIN = timedelta(milliseconds=10)
|
||||||
display(AK_CAMERA_DATASET)
|
display(AK_CAMERA_DATASET)
|
||||||
|
|
||||||
_DEBUG_CURRENT_TRACKING: tuple[int, int] = (0, 0)
|
|
||||||
|
|
||||||
|
|
||||||
def _global_current_tracking_str():
|
|
||||||
return str(_DEBUG_CURRENT_TRACKING)
|
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
class Resolution(TypedDict):
|
class Resolution(TypedDict):
|
||||||
@ -594,23 +588,6 @@ def calculate_distance_2d(
|
|||||||
left_normalized = left / jnp.array([w, h])
|
left_normalized = left / jnp.array([w, h])
|
||||||
right_normalized = right / jnp.array([w, h])
|
right_normalized = right / jnp.array([w, h])
|
||||||
dist = jnp.linalg.norm(left_normalized - right_normalized, axis=-1)
|
dist = jnp.linalg.norm(left_normalized - right_normalized, axis=-1)
|
||||||
lt = left_normalized[:6]
|
|
||||||
rt = right_normalized[:6]
|
|
||||||
jax.debug.print(
|
|
||||||
"[REF]{} norm_trk first6 = {}",
|
|
||||||
_global_current_tracking_str(),
|
|
||||||
lt,
|
|
||||||
)
|
|
||||||
jax.debug.print(
|
|
||||||
"[REF]{} norm_det first6 = {}",
|
|
||||||
_global_current_tracking_str(),
|
|
||||||
rt,
|
|
||||||
)
|
|
||||||
jax.debug.print(
|
|
||||||
"[REF]{} dist2d first6 = {}",
|
|
||||||
_global_current_tracking_str(),
|
|
||||||
dist[:6],
|
|
||||||
)
|
|
||||||
return dist
|
return dist
|
||||||
|
|
||||||
|
|
||||||
@ -806,191 +783,12 @@ def calculate_tracking_detection_affinity(
|
|||||||
lambda_a=lambda_a,
|
lambda_a=lambda_a,
|
||||||
)
|
)
|
||||||
|
|
||||||
jax.debug.print(
|
|
||||||
"[REF] aff2d{} first6 = {}",
|
|
||||||
_global_current_tracking_str(),
|
|
||||||
affinity_2d[:6],
|
|
||||||
)
|
|
||||||
jax.debug.print(
|
|
||||||
"[REF] aff3d{} first6 = {}", _global_current_tracking_str(), affinity_3d[:6]
|
|
||||||
)
|
|
||||||
jax.debug.print(
|
|
||||||
"[REF] aff2d.shape={}; aff3d.shape={}",
|
|
||||||
affinity_2d.shape,
|
|
||||||
affinity_3d.shape,
|
|
||||||
)
|
|
||||||
# Combine affinities
|
# Combine affinities
|
||||||
total_affinity = affinity_2d + affinity_3d
|
total_affinity = affinity_2d + affinity_3d
|
||||||
return jnp.sum(total_affinity).item()
|
return jnp.sum(total_affinity).item()
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
@deprecated(
|
|
||||||
"Use `calculate_camera_affinity_matrix` instead. This implementation has the problem of under-utilizing views from different cameras."
|
|
||||||
)
|
|
||||||
@beartype
|
|
||||||
def calculate_affinity_matrix(
|
|
||||||
trackings: Sequence[Tracking],
|
|
||||||
detections: Sequence[Detection] | OrderedDict[CameraID, list[Detection]],
|
|
||||||
w_2d: float,
|
|
||||||
alpha_2d: float,
|
|
||||||
w_3d: float,
|
|
||||||
alpha_3d: float,
|
|
||||||
lambda_a: float,
|
|
||||||
) -> tuple[Float[Array, "T D"], OrderedDict[CameraID, list[Detection]]]:
|
|
||||||
"""
|
|
||||||
Calculate the affinity matrix between a set of trackings and detections.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
trackings: Sequence of tracking objects
|
|
||||||
detections: Sequence of detection objects
|
|
||||||
w_2d: Weight for 2D affinity
|
|
||||||
alpha_2d: Normalization factor for 2D distance
|
|
||||||
w_3d: Weight for 3D affinity
|
|
||||||
alpha_3d: Normalization factor for 3D distance
|
|
||||||
lambda_a: Decay rate for time difference
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- affinity matrix of shape (T, D) where T is number of trackings and D
|
|
||||||
is number of detections
|
|
||||||
- dictionary mapping camera IDs to lists of detections from that camera,
|
|
||||||
since it's a `OrderDict` you could flat it out to get the indices of
|
|
||||||
detections in the affinity matrix
|
|
||||||
|
|
||||||
Matrix Layout:
|
|
||||||
The affinity matrix has shape (T, D), where:
|
|
||||||
- T = number of trackings (rows)
|
|
||||||
- D = total number of detections across all cameras (columns)
|
|
||||||
|
|
||||||
The matrix is organized as follows:
|
|
||||||
|
|
||||||
```
|
|
||||||
| Camera 1 | Camera 2 | Camera c |
|
|
||||||
| d1 d2 ... | d1 d2 ... | d1 d2 ... |
|
|
||||||
---------+-------------+-------------+-------------+
|
|
||||||
Track 1 | a11 a12 ... | a11 a12 ... | a11 a12 ... |
|
|
||||||
Track 2 | a21 a22 ... | a21 a22 ... | a21 a22 ... |
|
|
||||||
... | ... | ... | ... |
|
|
||||||
Track t | at1 at2 ... | at1 at2 ... | at1 at2 ... |
|
|
||||||
```
|
|
||||||
|
|
||||||
Where:
|
|
||||||
- Rows are ordered by tracking ID
|
|
||||||
- Columns are ordered by camera, then by detection within each camera
|
|
||||||
- Each cell aij represents the affinity between tracking i and detection j
|
|
||||||
|
|
||||||
The detection ordering in columns follows the exact same order as iterating
|
|
||||||
through the detection_by_camera dictionary, which is returned alongside
|
|
||||||
the matrix to maintain this relationship.
|
|
||||||
"""
|
|
||||||
if isinstance(detections, OrderedDict):
|
|
||||||
D = flatten_values_len(detections)
|
|
||||||
affinity = jnp.zeros((len(trackings), D))
|
|
||||||
detection_by_camera = detections
|
|
||||||
else:
|
|
||||||
affinity = jnp.zeros((len(trackings), len(detections)))
|
|
||||||
detection_by_camera = classify_by_camera(detections)
|
|
||||||
|
|
||||||
for i, tracking in enumerate(trackings):
|
|
||||||
j = 0
|
|
||||||
for _, camera_detections in detection_by_camera.items():
|
|
||||||
for det in camera_detections:
|
|
||||||
affinity_value = calculate_tracking_detection_affinity(
|
|
||||||
tracking,
|
|
||||||
det,
|
|
||||||
w_2d=w_2d,
|
|
||||||
alpha_2d=alpha_2d,
|
|
||||||
w_3d=w_3d,
|
|
||||||
alpha_3d=alpha_3d,
|
|
||||||
lambda_a=lambda_a,
|
|
||||||
)
|
|
||||||
affinity = affinity.at[i, j].set(affinity_value)
|
|
||||||
j += 1
|
|
||||||
|
|
||||||
return affinity, detection_by_camera
|
|
||||||
|
|
||||||
|
|
||||||
@beartype
|
|
||||||
def calculate_camera_affinity_matrix(
|
|
||||||
trackings: Sequence[Tracking],
|
|
||||||
camera_detections: Sequence[Detection],
|
|
||||||
w_2d: float,
|
|
||||||
alpha_2d: float,
|
|
||||||
w_3d: float,
|
|
||||||
alpha_3d: float,
|
|
||||||
lambda_a: float,
|
|
||||||
) -> Float[Array, "T D"]:
|
|
||||||
"""
|
|
||||||
Calculate an affinity matrix between trackings and detections from a single camera.
|
|
||||||
|
|
||||||
This follows the iterative camera-by-camera approach from the paper
|
|
||||||
"Cross-View Tracking for Multi-Human 3D Pose Estimation at over 100 FPS".
|
|
||||||
Instead of creating one large matrix for all cameras, this creates
|
|
||||||
a separate matrix for each camera, which can be processed independently.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
trackings: Sequence of tracking objects
|
|
||||||
camera_detections: Sequence of detection objects, from the same camera
|
|
||||||
w_2d: Weight for 2D affinity
|
|
||||||
alpha_2d: Normalization factor for 2D distance
|
|
||||||
w_3d: Weight for 3D affinity
|
|
||||||
alpha_3d: Normalization factor for 3D distance
|
|
||||||
lambda_a: Decay rate for time difference
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Affinity matrix of shape (T, D) where:
|
|
||||||
- T = number of trackings (rows)
|
|
||||||
- D = number of detections from this specific camera (columns)
|
|
||||||
|
|
||||||
Matrix Layout:
|
|
||||||
The affinity matrix for a single camera has shape (T, D), where:
|
|
||||||
- T = number of trackings (rows)
|
|
||||||
- D = number of detections from this camera (columns)
|
|
||||||
|
|
||||||
The matrix is organized as follows:
|
|
||||||
|
|
||||||
```
|
|
||||||
| Detections from Camera c |
|
|
||||||
| d1 d2 d3 ... |
|
|
||||||
---------+------------------------+
|
|
||||||
Track 1 | a11 a12 a13 ... |
|
|
||||||
Track 2 | a21 a22 a23 ... |
|
|
||||||
... | ... ... ... ... |
|
|
||||||
Track t | at1 at2 at3 ... |
|
|
||||||
```
|
|
||||||
|
|
||||||
Each cell aij represents the affinity between tracking i and detection j,
|
|
||||||
computed using both 2D and 3D geometric correspondences.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def verify_all_detection_from_same_camera(detections: Sequence[Detection]):
|
|
||||||
if not detections:
|
|
||||||
return True
|
|
||||||
camera_id = next(iter(detections)).camera.id
|
|
||||||
return all(map(lambda d: d.camera.id == camera_id, detections))
|
|
||||||
|
|
||||||
if not verify_all_detection_from_same_camera(camera_detections):
|
|
||||||
raise ValueError("All detections must be from the same camera")
|
|
||||||
|
|
||||||
affinity = jnp.zeros((len(trackings), len(camera_detections)))
|
|
||||||
|
|
||||||
for i, tracking in enumerate(trackings):
|
|
||||||
for j, det in enumerate(camera_detections):
|
|
||||||
global _DEBUG_CURRENT_TRACKING
|
|
||||||
_DEBUG_CURRENT_TRACKING = (i, j)
|
|
||||||
affinity_value = calculate_tracking_detection_affinity(
|
|
||||||
tracking,
|
|
||||||
det,
|
|
||||||
w_2d=w_2d,
|
|
||||||
alpha_2d=alpha_2d,
|
|
||||||
w_3d=w_3d,
|
|
||||||
alpha_3d=alpha_3d,
|
|
||||||
lambda_a=lambda_a,
|
|
||||||
)
|
|
||||||
affinity = affinity.at[i, j].set(affinity_value)
|
|
||||||
return affinity
|
|
||||||
|
|
||||||
|
|
||||||
@beartype
|
@beartype
|
||||||
def calculate_camera_affinity_matrix_jax(
|
def calculate_camera_affinity_matrix_jax(
|
||||||
trackings: Sequence[Tracking],
|
trackings: Sequence[Tracking],
|
||||||
@ -1010,9 +808,6 @@ def calculate_camera_affinity_matrix_jax(
|
|||||||
(tracking, detection) pair. The mathematical definition of the affinity
|
(tracking, detection) pair. The mathematical definition of the affinity
|
||||||
is **unchanged**, so the result remains bit-identical to the reference
|
is **unchanged**, so the result remains bit-identical to the reference
|
||||||
implementation used in the tests.
|
implementation used in the tests.
|
||||||
|
|
||||||
TODO: It gives a wrong result (maybe it's my problem?) somehow,
|
|
||||||
and I need to find a way to debug this.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
@ -1052,8 +847,8 @@ def calculate_camera_affinity_matrix_jax(
|
|||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Compute Δt matrix – shape (T, D)
|
# Compute Δt matrix – shape (T, D)
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Epoch timestamps are ~1.7 × 10⁹; storing them in float32 wipes out
|
# Epoch timestamps are ~1.7 × 10⁹; storing them in float32 wipes out
|
||||||
# sub‑second detail (resolution ≈ 200 ms). Keep them in float64 until
|
# sub‑second detail (resolution ≈ 200 ms). Keep them in float64 until
|
||||||
# after subtraction so we preserve Δt‑on‑the‑order‑of‑milliseconds.
|
# after subtraction so we preserve Δt‑on‑the‑order‑of‑milliseconds.
|
||||||
# --- timestamps ----------
|
# --- timestamps ----------
|
||||||
t0 = min(
|
t0 = min(
|
||||||
@ -1093,12 +888,6 @@ def calculate_camera_affinity_matrix_jax(
|
|||||||
diff2d = norm_trk[:, None, :, :] - norm_det[None, :, :, :]
|
diff2d = norm_trk[:, None, :, :] - norm_det[None, :, :, :]
|
||||||
dist2d: Float[Array, "T D J"] = jnp.linalg.norm(diff2d, axis=-1)
|
dist2d: Float[Array, "T D J"] = jnp.linalg.norm(diff2d, axis=-1)
|
||||||
|
|
||||||
jax.debug.print(
|
|
||||||
"[JAX] norm_trk[0,0,:6] = {}", norm_trk[0, :, :6] # shape (J,2) 取前6
|
|
||||||
)
|
|
||||||
jax.debug.print("[JAX] norm_det[0,:6] = {}", norm_det[0, :6]) # shape (J,2)
|
|
||||||
jax.debug.print("[JAX] dist2d(T0,D0) first6 = {}", dist2d[0, 0, :6])
|
|
||||||
|
|
||||||
# Compute per-keypoint 2D affinity
|
# Compute per-keypoint 2D affinity
|
||||||
delta_t_broadcast = delta_t[:, :, None] # (T, D, 1)
|
delta_t_broadcast = delta_t[:, :, None] # (T, D, 1)
|
||||||
affinity_2d = (
|
affinity_2d = (
|
||||||
@ -1155,11 +944,6 @@ def calculate_camera_affinity_matrix_jax(
|
|||||||
w_3d * (1 - dist3d / alpha_3d) * jnp.exp(-lambda_a * delta_t_broadcast)
|
w_3d * (1 - dist3d / alpha_3d) * jnp.exp(-lambda_a * delta_t_broadcast)
|
||||||
)
|
)
|
||||||
|
|
||||||
jax.debug.print("[JAX] aff3d(T0,D0) first6 = {}", affinity_3d[0, 0, :6])
|
|
||||||
jax.debug.print("[JAX] aff2d(T0,D0) first6 = {}", affinity_2d[0, 0, :6])
|
|
||||||
jax.debug.print(
|
|
||||||
"[JAX] aff2d.shape={}; aff3d.shape={}", affinity_2d.shape, affinity_3d.shape
|
|
||||||
)
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Combine and reduce across keypoints → (T, D)
|
# Combine and reduce across keypoints → (T, D)
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
@ -1167,60 +951,57 @@ def calculate_camera_affinity_matrix_jax(
|
|||||||
return total_affinity # type: ignore[return-value]
|
return total_affinity # type: ignore[return-value]
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Debug helper – compare JAX vs reference implementation
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
@beartype
|
@beartype
|
||||||
def debug_compare_affinity_matrices(
|
def calculate_affinity_matrix(
|
||||||
trackings: Sequence[Tracking],
|
trackings: Sequence[Tracking],
|
||||||
camera_detections: Sequence[Detection],
|
detections: Sequence[Detection] | Mapping[CameraID, list[Detection]],
|
||||||
*,
|
|
||||||
w_2d: float,
|
w_2d: float,
|
||||||
alpha_2d: float,
|
alpha_2d: float,
|
||||||
w_3d: float,
|
w_3d: float,
|
||||||
alpha_3d: float,
|
alpha_3d: float,
|
||||||
lambda_a: float,
|
lambda_a: float,
|
||||||
atol: float = 1e-5,
|
) -> dict[CameraID, AffinityResult]:
|
||||||
rtol: float = 1e-3,
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
Compute both affinity matrices and print out the max absolute / relative
|
Calculate the affinity matrix between a set of trackings and detections.
|
||||||
difference. If any entry differs more than atol+rtol*|ref|, dump the
|
|
||||||
offending indices so you can inspect individual terms.
|
Args:
|
||||||
|
trackings: Sequence of tracking objects
|
||||||
|
detections: Sequence of detection objects or a group detections by ID
|
||||||
|
w_2d: Weight for 2D affinity
|
||||||
|
alpha_2d: Normalization factor for 2D distance
|
||||||
|
w_3d: Weight for 3D affinity
|
||||||
|
alpha_3d: Normalization factor for 3D distance
|
||||||
|
lambda_a: Decay rate for time difference
|
||||||
|
Returns:
|
||||||
|
A dictionary mapping camera IDs to affinity results.
|
||||||
"""
|
"""
|
||||||
aff_jax = calculate_camera_affinity_matrix_jax(
|
if isinstance(detections, Mapping):
|
||||||
trackings,
|
detection_by_camera = detections
|
||||||
camera_detections,
|
|
||||||
w_2d=w_2d,
|
|
||||||
alpha_2d=alpha_2d,
|
|
||||||
w_3d=w_3d,
|
|
||||||
alpha_3d=alpha_3d,
|
|
||||||
lambda_a=lambda_a,
|
|
||||||
)
|
|
||||||
aff_ref = calculate_camera_affinity_matrix(
|
|
||||||
trackings,
|
|
||||||
camera_detections,
|
|
||||||
w_2d=w_2d,
|
|
||||||
alpha_2d=alpha_2d,
|
|
||||||
w_3d=w_3d,
|
|
||||||
alpha_3d=alpha_3d,
|
|
||||||
lambda_a=lambda_a,
|
|
||||||
)
|
|
||||||
|
|
||||||
diff = jnp.abs(aff_jax - aff_ref)
|
|
||||||
max_abs = float(diff.max())
|
|
||||||
max_rel = float((diff / (jnp.abs(aff_ref) + 1e-12)).max())
|
|
||||||
jax.debug.print(f"[DEBUG] max abs diff {max_abs:.6g}, max rel diff {max_rel:.6g}")
|
|
||||||
|
|
||||||
bad = jnp.where(diff > atol + rtol * jnp.abs(aff_ref))
|
|
||||||
if bad[0].size > 0:
|
|
||||||
for t, d in zip(*[arr.tolist() for arr in bad]):
|
|
||||||
jax.debug.print(
|
|
||||||
f" ↳ mismatch at (T={t}, D={d}): "
|
|
||||||
f"jax={aff_jax[t,d]:.6g}, ref={aff_ref[t,d]:.6g}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
jax.debug.print("✅ matrices match within tolerance")
|
detection_by_camera = classify_by_camera(detections)
|
||||||
|
|
||||||
|
res: dict[CameraID, AffinityResult] = {}
|
||||||
|
for camera_id, camera_detections in detection_by_camera.items():
|
||||||
|
affinity_matrix = calculate_camera_affinity_matrix_jax(
|
||||||
|
trackings,
|
||||||
|
camera_detections,
|
||||||
|
w_2d,
|
||||||
|
alpha_2d,
|
||||||
|
w_3d,
|
||||||
|
alpha_3d,
|
||||||
|
lambda_a,
|
||||||
|
)
|
||||||
|
# row, col
|
||||||
|
indices_T, indices_D = linear_sum_assignment(affinity_matrix)
|
||||||
|
affinity_result = AffinityResult(
|
||||||
|
matrix=affinity_matrix,
|
||||||
|
trackings=trackings,
|
||||||
|
detections=camera_detections,
|
||||||
|
indices_T=indices_T,
|
||||||
|
indices_D=indices_D,
|
||||||
|
)
|
||||||
|
res[camera_id] = affinity_result
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
@ -1235,15 +1016,15 @@ trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id)
|
|||||||
unmatched_detections = shallow_copy(next_group)
|
unmatched_detections = shallow_copy(next_group)
|
||||||
camera_detections = classify_by_camera(unmatched_detections)
|
camera_detections = classify_by_camera(unmatched_detections)
|
||||||
|
|
||||||
camera_detections_next_batch = camera_detections["AE_08"]
|
affinities = calculate_affinity_matrix(
|
||||||
debug_compare_affinity_matrices(
|
|
||||||
trackings,
|
trackings,
|
||||||
camera_detections_next_batch,
|
unmatched_detections,
|
||||||
w_2d=W_2D,
|
w_2d=W_2D,
|
||||||
alpha_2d=ALPHA_2D,
|
alpha_2d=ALPHA_2D,
|
||||||
w_3d=W_3D,
|
w_3d=W_3D,
|
||||||
alpha_3d=ALPHA_3D,
|
alpha_3d=ALPHA_3D,
|
||||||
lambda_a=LAMBDA_A,
|
lambda_a=LAMBDA_A,
|
||||||
)
|
)
|
||||||
|
display(affinities)
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
|||||||
@ -15,6 +15,7 @@ dependencies = [
|
|||||||
"jupytext>=1.17.0",
|
"jupytext>=1.17.0",
|
||||||
"matplotlib>=3.10.1",
|
"matplotlib>=3.10.1",
|
||||||
"opencv-python-headless>=4.11.0.86",
|
"opencv-python-headless>=4.11.0.86",
|
||||||
|
"optax>=0.2.4",
|
||||||
"orjson>=3.10.15",
|
"orjson>=3.10.15",
|
||||||
"pandas>=2.2.3",
|
"pandas>=2.2.3",
|
||||||
"plotly>=6.0.1",
|
"plotly>=6.0.1",
|
||||||
|
|||||||
@ -1,224 +0,0 @@
|
|||||||
from datetime import datetime, timedelta
|
|
||||||
import time
|
|
||||||
|
|
||||||
import jax.numpy as jnp
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
from hypothesis import given, settings, HealthCheck
|
|
||||||
from hypothesis import strategies as st
|
|
||||||
|
|
||||||
from app.camera import Camera, CameraParams
|
|
||||||
from playground import (
|
|
||||||
Detection,
|
|
||||||
Tracking,
|
|
||||||
calculate_affinity_matrix,
|
|
||||||
calculate_camera_affinity_matrix,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------
|
|
||||||
# Helper functions to generate synthetic cameras / trackings / detections
|
|
||||||
# ----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _make_dummy_camera(cam_id: str, rng: np.random.Generator) -> Camera:
|
|
||||||
K = jnp.eye(3)
|
|
||||||
Rt = jnp.eye(4)
|
|
||||||
dist = jnp.zeros(5)
|
|
||||||
image_size = jnp.array([1000, 1000])
|
|
||||||
params = CameraParams(K=K, Rt=Rt, dist_coeffs=dist, image_size=image_size)
|
|
||||||
return Camera(id=cam_id, params=params)
|
|
||||||
|
|
||||||
|
|
||||||
def _random_keypoints_3d(rng: np.random.Generator, J: int):
|
|
||||||
return jnp.asarray(rng.uniform(-1.0, 1.0, size=(J, 3)).astype(np.float32))
|
|
||||||
|
|
||||||
|
|
||||||
def _random_keypoints_2d(rng: np.random.Generator, J: int):
|
|
||||||
return jnp.asarray(rng.uniform(0.0, 1000.0, size=(J, 2)).astype(np.float32))
|
|
||||||
|
|
||||||
|
|
||||||
def _make_trackings(rng: np.random.Generator, camera: Camera, T: int, J: int):
|
|
||||||
now = datetime.now()
|
|
||||||
trackings = []
|
|
||||||
for i in range(T):
|
|
||||||
kps3d = _random_keypoints_3d(rng, J)
|
|
||||||
trk = Tracking(
|
|
||||||
id=i + 1,
|
|
||||||
keypoints=kps3d,
|
|
||||||
last_active_timestamp=now
|
|
||||||
- timedelta(milliseconds=int(rng.integers(20, 50))),
|
|
||||||
)
|
|
||||||
trackings.append(trk)
|
|
||||||
return trackings
|
|
||||||
|
|
||||||
|
|
||||||
def _make_detections(rng: np.random.Generator, camera: Camera, D: int, J: int):
|
|
||||||
now = datetime.now()
|
|
||||||
detections = []
|
|
||||||
for _ in range(D):
|
|
||||||
kps2d = _random_keypoints_2d(rng, J)
|
|
||||||
det = Detection(
|
|
||||||
keypoints=kps2d,
|
|
||||||
confidences=jnp.ones(J, dtype=jnp.float32),
|
|
||||||
camera=camera,
|
|
||||||
timestamp=now,
|
|
||||||
)
|
|
||||||
detections.append(det)
|
|
||||||
return detections
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------
|
|
||||||
# Property-based test: per-camera vs naive slice should match
|
|
||||||
# ----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@settings(max_examples=3, deadline=None, suppress_health_check=[HealthCheck.too_slow])
|
|
||||||
@given(
|
|
||||||
T=st.integers(min_value=1, max_value=4),
|
|
||||||
D=st.integers(min_value=1, max_value=4),
|
|
||||||
J=st.integers(min_value=5, max_value=15),
|
|
||||||
seed=st.integers(min_value=0, max_value=10000),
|
|
||||||
)
|
|
||||||
def test_per_camera_matches_naive(T, D, J, seed):
|
|
||||||
rng = np.random.default_rng(seed)
|
|
||||||
|
|
||||||
cam = _make_dummy_camera("C0", rng)
|
|
||||||
|
|
||||||
trackings = _make_trackings(rng, cam, T, J)
|
|
||||||
detections = _make_detections(rng, cam, D, J)
|
|
||||||
|
|
||||||
# Parameters
|
|
||||||
W_2D = 1.0
|
|
||||||
ALPHA_2D = 1.0
|
|
||||||
LAMBDA_A = 0.1
|
|
||||||
W_3D = 1.0
|
|
||||||
ALPHA_3D = 1.0
|
|
||||||
|
|
||||||
# Compute per-camera affinity (fast)
|
|
||||||
A_fast = calculate_camera_affinity_matrix(
|
|
||||||
trackings,
|
|
||||||
detections,
|
|
||||||
w_2d=W_2D,
|
|
||||||
alpha_2d=ALPHA_2D,
|
|
||||||
w_3d=W_3D,
|
|
||||||
alpha_3d=ALPHA_3D,
|
|
||||||
lambda_a=LAMBDA_A,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Compute naive multi-camera affinity and slice out this camera
|
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
det_dict = OrderedDict({"C0": detections})
|
|
||||||
A_naive, _ = calculate_affinity_matrix(
|
|
||||||
trackings,
|
|
||||||
det_dict,
|
|
||||||
w_2d=W_2D,
|
|
||||||
alpha_2d=ALPHA_2D,
|
|
||||||
w_3d=W_3D,
|
|
||||||
alpha_3d=ALPHA_3D,
|
|
||||||
lambda_a=LAMBDA_A,
|
|
||||||
)
|
|
||||||
# both fast and naive implementation gives NaN
|
|
||||||
# we need to inject real-world data
|
|
||||||
|
|
||||||
# print("A_fast")
|
|
||||||
# print(A_fast)
|
|
||||||
# print("A_naive")
|
|
||||||
# print(A_naive)
|
|
||||||
|
|
||||||
# They should be close
|
|
||||||
np.testing.assert_allclose(A_fast, np.asarray(A_naive), rtol=1e-5, atol=1e-5)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("T,D,J", [(2, 3, 10), (4, 4, 15), (6, 8, 20)])
|
|
||||||
def test_benchmark_affinity_matrix(T, D, J):
|
|
||||||
"""Compare performance between naive and fast affinity matrix calculation."""
|
|
||||||
seed = 42
|
|
||||||
rng = np.random.default_rng(seed)
|
|
||||||
cam = _make_dummy_camera("C0", rng)
|
|
||||||
|
|
||||||
trackings = _make_trackings(rng, cam, T, J)
|
|
||||||
detections = _make_detections(rng, cam, D, J)
|
|
||||||
|
|
||||||
# Parameters
|
|
||||||
w_2d = 1.0
|
|
||||||
alpha_2d = 1.0
|
|
||||||
w_3d = 1.0
|
|
||||||
alpha_3d = 1.0
|
|
||||||
lambda_a = 0.1
|
|
||||||
|
|
||||||
# Setup for naive
|
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
det_dict = OrderedDict({"C0": detections})
|
|
||||||
|
|
||||||
# First run to compile
|
|
||||||
A_fast = calculate_camera_affinity_matrix(
|
|
||||||
trackings,
|
|
||||||
detections,
|
|
||||||
w_2d=w_2d,
|
|
||||||
alpha_2d=alpha_2d,
|
|
||||||
w_3d=w_3d,
|
|
||||||
alpha_3d=alpha_3d,
|
|
||||||
lambda_a=lambda_a,
|
|
||||||
)
|
|
||||||
A_naive, _ = calculate_affinity_matrix(
|
|
||||||
trackings,
|
|
||||||
det_dict,
|
|
||||||
w_2d=w_2d,
|
|
||||||
alpha_2d=alpha_2d,
|
|
||||||
w_3d=w_3d,
|
|
||||||
alpha_3d=alpha_3d,
|
|
||||||
lambda_a=lambda_a,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assert they match before timing
|
|
||||||
np.testing.assert_allclose(A_fast, np.asarray(A_naive), rtol=1e-5, atol=1e-5)
|
|
||||||
|
|
||||||
# Timing
|
|
||||||
num_runs = 3
|
|
||||||
|
|
||||||
# Time the vectorized version
|
|
||||||
start = time.perf_counter()
|
|
||||||
for _ in range(num_runs):
|
|
||||||
calculate_camera_affinity_matrix(
|
|
||||||
trackings,
|
|
||||||
detections,
|
|
||||||
w_2d=w_2d,
|
|
||||||
alpha_2d=alpha_2d,
|
|
||||||
w_3d=w_3d,
|
|
||||||
alpha_3d=alpha_3d,
|
|
||||||
lambda_a=lambda_a,
|
|
||||||
)
|
|
||||||
end = time.perf_counter()
|
|
||||||
vectorized_time = (end - start) / num_runs
|
|
||||||
|
|
||||||
# Time the naive version
|
|
||||||
start = time.perf_counter()
|
|
||||||
for _ in range(num_runs):
|
|
||||||
calculate_affinity_matrix(
|
|
||||||
trackings,
|
|
||||||
det_dict,
|
|
||||||
w_2d=w_2d,
|
|
||||||
alpha_2d=alpha_2d,
|
|
||||||
w_3d=w_3d,
|
|
||||||
alpha_3d=alpha_3d,
|
|
||||||
lambda_a=lambda_a,
|
|
||||||
)
|
|
||||||
end = time.perf_counter()
|
|
||||||
naive_time = (end - start) / num_runs
|
|
||||||
|
|
||||||
speedup = naive_time / vectorized_time
|
|
||||||
print(f"\nBenchmark T={T}, D={D}, J={J}:")
|
|
||||||
print(f" Vectorized: {vectorized_time*1000:.2f}ms per run")
|
|
||||||
print(f" Naive: {naive_time*1000:.2f}ms per run")
|
|
||||||
print(f" Speedup: {speedup:.2f}x")
|
|
||||||
|
|
||||||
# Sanity check - vectorized should be faster!
|
|
||||||
assert speedup > 1.0, "Vectorized implementation should be faster"
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__" and pytest is not None:
|
|
||||||
# python -m pytest -xvs -k test_benchmark
|
|
||||||
# pytest.main([__file__])
|
|
||||||
pytest.main(["-xvs", __file__, "-k", "test_benchmark"])
|
|
||||||
69
uv.lock
generated
69
uv.lock
generated
@ -16,6 +16,15 @@ resolution-markers = [
|
|||||||
"(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')",
|
"(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "absl-py"
|
||||||
|
version = "2.2.2"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/b5/f0/e6342091061ed3a46aadc116b13edd7bb5249c3ab1b3ef07f24b0c248fc3/absl_py-2.2.2.tar.gz", hash = "sha256:bf25b2c2eed013ca456918c453d687eab4e8309fba81ee2f4c1a6aa2494175eb", size = 119982 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/f6/d4/349f7f4bd5ea92dab34f5bb0fe31775ef6c311427a14d5a5b31ecb442341/absl_py-2.2.2-py3-none-any.whl", hash = "sha256:e5797bc6abe45f64fd95dc06394ca3f2bedf3b5d895e9da691c9ee3397d70092", size = 135565 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "anyio"
|
name = "anyio"
|
||||||
version = "4.8.0"
|
version = "4.8.0"
|
||||||
@ -354,6 +363,24 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/0e/f6/65ecc6878a89bb1c23a086ea335ad4bf21a588990c3f535a227b9eea9108/charset_normalizer-3.4.1-py3-none-any.whl", hash = "sha256:d98b1668f06378c6dbefec3b92299716b931cd4e6061f3c875a71ced1780ab85", size = 49767 },
|
{ url = "https://files.pythonhosted.org/packages/0e/f6/65ecc6878a89bb1c23a086ea335ad4bf21a588990c3f535a227b9eea9108/charset_normalizer-3.4.1-py3-none-any.whl", hash = "sha256:d98b1668f06378c6dbefec3b92299716b931cd4e6061f3c875a71ced1780ab85", size = 49767 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "chex"
|
||||||
|
version = "0.1.89"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "absl-py" },
|
||||||
|
{ name = "jax" },
|
||||||
|
{ name = "jaxlib" },
|
||||||
|
{ name = "numpy" },
|
||||||
|
{ name = "setuptools", marker = "python_full_version >= '3.12'" },
|
||||||
|
{ name = "toolz" },
|
||||||
|
{ name = "typing-extensions" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/ca/ac/504a8019f7ef372fc6cc3999ec9e3d0fbb38e6992f55d845d5b928010c11/chex-0.1.89.tar.gz", hash = "sha256:78f856e6a0a8459edfcbb402c2c044d2b8102eac4b633838cbdfdcdb09c6c8e0", size = 90676 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/5e/6c/309972937d931069816dc8b28193a650485bc35cca92c04c8c15c4bd181e/chex-0.1.89-py3-none-any.whl", hash = "sha256:145241c27d8944adb634fb7d472a460e1c1b643f561507d4031ad5156ef82dfa", size = 99908 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "colorama"
|
name = "colorama"
|
||||||
version = "0.4.6"
|
version = "0.4.6"
|
||||||
@ -454,6 +481,7 @@ dependencies = [
|
|||||||
{ name = "jupytext" },
|
{ name = "jupytext" },
|
||||||
{ name = "matplotlib" },
|
{ name = "matplotlib" },
|
||||||
{ name = "opencv-python-headless" },
|
{ name = "opencv-python-headless" },
|
||||||
|
{ name = "optax" },
|
||||||
{ name = "orjson" },
|
{ name = "orjson" },
|
||||||
{ name = "pandas" },
|
{ name = "pandas" },
|
||||||
{ name = "plotly" },
|
{ name = "plotly" },
|
||||||
@ -482,6 +510,7 @@ requires-dist = [
|
|||||||
{ name = "jupytext", specifier = ">=1.17.0" },
|
{ name = "jupytext", specifier = ">=1.17.0" },
|
||||||
{ name = "matplotlib", specifier = ">=3.10.1" },
|
{ name = "matplotlib", specifier = ">=3.10.1" },
|
||||||
{ name = "opencv-python-headless", specifier = ">=4.11.0.86" },
|
{ name = "opencv-python-headless", specifier = ">=4.11.0.86" },
|
||||||
|
{ name = "optax", specifier = ">=0.2.4" },
|
||||||
{ name = "orjson", specifier = ">=3.10.15" },
|
{ name = "orjson", specifier = ">=3.10.15" },
|
||||||
{ name = "pandas", specifier = ">=2.2.3" },
|
{ name = "pandas", specifier = ">=2.2.3" },
|
||||||
{ name = "plotly", specifier = ">=6.0.1" },
|
{ name = "plotly", specifier = ">=6.0.1" },
|
||||||
@ -583,6 +612,20 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/07/6c/aa3f2f849e01cb6a001cd8554a88d4c77c5c1a31c95bdf1cf9301e6d9ef4/defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61", size = 25604 },
|
{ url = "https://files.pythonhosted.org/packages/07/6c/aa3f2f849e01cb6a001cd8554a88d4c77c5c1a31c95bdf1cf9301e6d9ef4/defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61", size = 25604 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "etils"
|
||||||
|
version = "1.12.2"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/e4/12/1cc11e88a0201280ff389bc4076df7c3432e39d9f22cba8b71aa263f67b8/etils-1.12.2.tar.gz", hash = "sha256:c6b9e1f0ce66d1bbf54f99201b08a60ba396d3446d9eb18d4bc39b26a2e1a5ee", size = 104711 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/dd/71/40ee142e564b8a34a7ae9546e99e665e0001011a3254d5bbbe113d72ccba/etils-1.12.2-py3-none-any.whl", hash = "sha256:4600bec9de6cf5cb043a171e1856e38b5f273719cf3ecef90199f7091a6b3912", size = 167613 },
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.optional-dependencies]
|
||||||
|
epy = [
|
||||||
|
{ name = "typing-extensions" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "exceptiongroup"
|
name = "exceptiongroup"
|
||||||
version = "1.2.2"
|
version = "1.2.2"
|
||||||
@ -1925,6 +1968,23 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl", hash = "sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd", size = 71932 },
|
{ url = "https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl", hash = "sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd", size = 71932 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "optax"
|
||||||
|
version = "0.2.4"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "absl-py" },
|
||||||
|
{ name = "chex" },
|
||||||
|
{ name = "etils", extra = ["epy"] },
|
||||||
|
{ name = "jax" },
|
||||||
|
{ name = "jaxlib" },
|
||||||
|
{ name = "numpy" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/af/b5/f88a0d851547b2e6b2c7e7e6509ad66236b3e7019f1f095bb03dbaa61fa1/optax-0.2.4.tar.gz", hash = "sha256:4e05d3d5307e6dde4c319187ae36e6cd3a0c035d4ed25e9e992449a304f47336", size = 229717 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/5c/24/28d0bb21600a78e46754947333ec9a297044af884d360092eb8561575fe9/optax-0.2.4-py3-none-any.whl", hash = "sha256:db35c04e50b52596662efb002334de08c2a0a74971e4da33f467e84fac08886a", size = 319212 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "orjson"
|
name = "orjson"
|
||||||
version = "3.10.15"
|
version = "3.10.15"
|
||||||
@ -2834,6 +2894,15 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/6e/c2/61d3e0f47e2b74ef40a68b9e6ad5984f6241a942f7cd3bbfbdbd03861ea9/tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc", size = 14257 },
|
{ url = "https://files.pythonhosted.org/packages/6e/c2/61d3e0f47e2b74ef40a68b9e6ad5984f6241a942f7cd3bbfbdbd03861ea9/tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc", size = 14257 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "toolz"
|
||||||
|
version = "1.0.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/8a/0b/d80dfa675bf592f636d1ea0b835eab4ec8df6e9415d8cfd766df54456123/toolz-1.0.0.tar.gz", hash = "sha256:2c86e3d9a04798ac556793bced838816296a2f085017664e4995cb40a1047a02", size = 66790 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/03/98/eb27cc78ad3af8e302c9d8ff4977f5026676e130d28dd7578132a457170c/toolz-1.0.0-py3-none-any.whl", hash = "sha256:292c8f1c4e7516bf9086f8850935c799a874039c8bcf959d47b600e4c44a6236", size = 56383 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "torch"
|
name = "torch"
|
||||||
version = "2.6.0"
|
version = "2.6.0"
|
||||||
|
|||||||
Reference in New Issue
Block a user