refactor: Update affinity matrix calculation and dependencies
- Replaced the `linear_sum_assignment` import from `scipy.optimize` with `hungarian_algorithm` from `optax` to enhance performance in affinity matrix calculations. - Introduced a new `AffinityResult` class to encapsulate results of affinity computations, including trackings and detections, improving the structure of the affinity calculation process. - Removed deprecated functions and debug print statements to streamline the codebase. - Updated `pyproject.toml` and `uv.lock` to include `optax` as a dependency, ensuring compatibility with the new implementation. - Refactored imports and type hints for better organization and consistency across the codebase.
This commit is contained in:
315
playground.py
315
playground.py
@ -45,7 +45,7 @@ from IPython.display import display
|
||||
from jaxtyping import Array, Float, Num, jaxtyped
|
||||
from matplotlib import pyplot as plt
|
||||
from numpy.typing import ArrayLike
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
from optax.assignment import hungarian_algorithm as linear_sum_assignment
|
||||
from scipy.spatial.transform import Rotation as R
|
||||
from typing_extensions import deprecated
|
||||
|
||||
@ -58,7 +58,7 @@ from app.camera import (
|
||||
classify_by_camera,
|
||||
)
|
||||
from app.solver._old import GLPKSolver
|
||||
from app.tracking import Tracking
|
||||
from app.tracking import AffinityResult, Tracking
|
||||
from app.visualize.whole_body import visualize_whole_body
|
||||
|
||||
NDArray: TypeAlias = np.ndarray
|
||||
@ -69,12 +69,6 @@ AK_CAMERA_DATASET: ak.Array = ak.from_parquet(DATASET_PATH / "camera_params.parq
|
||||
DELTA_T_MIN = timedelta(milliseconds=10)
|
||||
display(AK_CAMERA_DATASET)
|
||||
|
||||
_DEBUG_CURRENT_TRACKING: tuple[int, int] = (0, 0)
|
||||
|
||||
|
||||
def _global_current_tracking_str():
|
||||
return str(_DEBUG_CURRENT_TRACKING)
|
||||
|
||||
|
||||
# %%
|
||||
class Resolution(TypedDict):
|
||||
@ -594,23 +588,6 @@ def calculate_distance_2d(
|
||||
left_normalized = left / jnp.array([w, h])
|
||||
right_normalized = right / jnp.array([w, h])
|
||||
dist = jnp.linalg.norm(left_normalized - right_normalized, axis=-1)
|
||||
lt = left_normalized[:6]
|
||||
rt = right_normalized[:6]
|
||||
jax.debug.print(
|
||||
"[REF]{} norm_trk first6 = {}",
|
||||
_global_current_tracking_str(),
|
||||
lt,
|
||||
)
|
||||
jax.debug.print(
|
||||
"[REF]{} norm_det first6 = {}",
|
||||
_global_current_tracking_str(),
|
||||
rt,
|
||||
)
|
||||
jax.debug.print(
|
||||
"[REF]{} dist2d first6 = {}",
|
||||
_global_current_tracking_str(),
|
||||
dist[:6],
|
||||
)
|
||||
return dist
|
||||
|
||||
|
||||
@ -806,191 +783,12 @@ def calculate_tracking_detection_affinity(
|
||||
lambda_a=lambda_a,
|
||||
)
|
||||
|
||||
jax.debug.print(
|
||||
"[REF] aff2d{} first6 = {}",
|
||||
_global_current_tracking_str(),
|
||||
affinity_2d[:6],
|
||||
)
|
||||
jax.debug.print(
|
||||
"[REF] aff3d{} first6 = {}", _global_current_tracking_str(), affinity_3d[:6]
|
||||
)
|
||||
jax.debug.print(
|
||||
"[REF] aff2d.shape={}; aff3d.shape={}",
|
||||
affinity_2d.shape,
|
||||
affinity_3d.shape,
|
||||
)
|
||||
# Combine affinities
|
||||
total_affinity = affinity_2d + affinity_3d
|
||||
return jnp.sum(total_affinity).item()
|
||||
|
||||
|
||||
# %%
|
||||
@deprecated(
|
||||
"Use `calculate_camera_affinity_matrix` instead. This implementation has the problem of under-utilizing views from different cameras."
|
||||
)
|
||||
@beartype
|
||||
def calculate_affinity_matrix(
|
||||
trackings: Sequence[Tracking],
|
||||
detections: Sequence[Detection] | OrderedDict[CameraID, list[Detection]],
|
||||
w_2d: float,
|
||||
alpha_2d: float,
|
||||
w_3d: float,
|
||||
alpha_3d: float,
|
||||
lambda_a: float,
|
||||
) -> tuple[Float[Array, "T D"], OrderedDict[CameraID, list[Detection]]]:
|
||||
"""
|
||||
Calculate the affinity matrix between a set of trackings and detections.
|
||||
|
||||
Args:
|
||||
trackings: Sequence of tracking objects
|
||||
detections: Sequence of detection objects
|
||||
w_2d: Weight for 2D affinity
|
||||
alpha_2d: Normalization factor for 2D distance
|
||||
w_3d: Weight for 3D affinity
|
||||
alpha_3d: Normalization factor for 3D distance
|
||||
lambda_a: Decay rate for time difference
|
||||
|
||||
Returns:
|
||||
- affinity matrix of shape (T, D) where T is number of trackings and D
|
||||
is number of detections
|
||||
- dictionary mapping camera IDs to lists of detections from that camera,
|
||||
since it's a `OrderDict` you could flat it out to get the indices of
|
||||
detections in the affinity matrix
|
||||
|
||||
Matrix Layout:
|
||||
The affinity matrix has shape (T, D), where:
|
||||
- T = number of trackings (rows)
|
||||
- D = total number of detections across all cameras (columns)
|
||||
|
||||
The matrix is organized as follows:
|
||||
|
||||
```
|
||||
| Camera 1 | Camera 2 | Camera c |
|
||||
| d1 d2 ... | d1 d2 ... | d1 d2 ... |
|
||||
---------+-------------+-------------+-------------+
|
||||
Track 1 | a11 a12 ... | a11 a12 ... | a11 a12 ... |
|
||||
Track 2 | a21 a22 ... | a21 a22 ... | a21 a22 ... |
|
||||
... | ... | ... | ... |
|
||||
Track t | at1 at2 ... | at1 at2 ... | at1 at2 ... |
|
||||
```
|
||||
|
||||
Where:
|
||||
- Rows are ordered by tracking ID
|
||||
- Columns are ordered by camera, then by detection within each camera
|
||||
- Each cell aij represents the affinity between tracking i and detection j
|
||||
|
||||
The detection ordering in columns follows the exact same order as iterating
|
||||
through the detection_by_camera dictionary, which is returned alongside
|
||||
the matrix to maintain this relationship.
|
||||
"""
|
||||
if isinstance(detections, OrderedDict):
|
||||
D = flatten_values_len(detections)
|
||||
affinity = jnp.zeros((len(trackings), D))
|
||||
detection_by_camera = detections
|
||||
else:
|
||||
affinity = jnp.zeros((len(trackings), len(detections)))
|
||||
detection_by_camera = classify_by_camera(detections)
|
||||
|
||||
for i, tracking in enumerate(trackings):
|
||||
j = 0
|
||||
for _, camera_detections in detection_by_camera.items():
|
||||
for det in camera_detections:
|
||||
affinity_value = calculate_tracking_detection_affinity(
|
||||
tracking,
|
||||
det,
|
||||
w_2d=w_2d,
|
||||
alpha_2d=alpha_2d,
|
||||
w_3d=w_3d,
|
||||
alpha_3d=alpha_3d,
|
||||
lambda_a=lambda_a,
|
||||
)
|
||||
affinity = affinity.at[i, j].set(affinity_value)
|
||||
j += 1
|
||||
|
||||
return affinity, detection_by_camera
|
||||
|
||||
|
||||
@beartype
|
||||
def calculate_camera_affinity_matrix(
|
||||
trackings: Sequence[Tracking],
|
||||
camera_detections: Sequence[Detection],
|
||||
w_2d: float,
|
||||
alpha_2d: float,
|
||||
w_3d: float,
|
||||
alpha_3d: float,
|
||||
lambda_a: float,
|
||||
) -> Float[Array, "T D"]:
|
||||
"""
|
||||
Calculate an affinity matrix between trackings and detections from a single camera.
|
||||
|
||||
This follows the iterative camera-by-camera approach from the paper
|
||||
"Cross-View Tracking for Multi-Human 3D Pose Estimation at over 100 FPS".
|
||||
Instead of creating one large matrix for all cameras, this creates
|
||||
a separate matrix for each camera, which can be processed independently.
|
||||
|
||||
Args:
|
||||
trackings: Sequence of tracking objects
|
||||
camera_detections: Sequence of detection objects, from the same camera
|
||||
w_2d: Weight for 2D affinity
|
||||
alpha_2d: Normalization factor for 2D distance
|
||||
w_3d: Weight for 3D affinity
|
||||
alpha_3d: Normalization factor for 3D distance
|
||||
lambda_a: Decay rate for time difference
|
||||
|
||||
Returns:
|
||||
Affinity matrix of shape (T, D) where:
|
||||
- T = number of trackings (rows)
|
||||
- D = number of detections from this specific camera (columns)
|
||||
|
||||
Matrix Layout:
|
||||
The affinity matrix for a single camera has shape (T, D), where:
|
||||
- T = number of trackings (rows)
|
||||
- D = number of detections from this camera (columns)
|
||||
|
||||
The matrix is organized as follows:
|
||||
|
||||
```
|
||||
| Detections from Camera c |
|
||||
| d1 d2 d3 ... |
|
||||
---------+------------------------+
|
||||
Track 1 | a11 a12 a13 ... |
|
||||
Track 2 | a21 a22 a23 ... |
|
||||
... | ... ... ... ... |
|
||||
Track t | at1 at2 at3 ... |
|
||||
```
|
||||
|
||||
Each cell aij represents the affinity between tracking i and detection j,
|
||||
computed using both 2D and 3D geometric correspondences.
|
||||
"""
|
||||
|
||||
def verify_all_detection_from_same_camera(detections: Sequence[Detection]):
|
||||
if not detections:
|
||||
return True
|
||||
camera_id = next(iter(detections)).camera.id
|
||||
return all(map(lambda d: d.camera.id == camera_id, detections))
|
||||
|
||||
if not verify_all_detection_from_same_camera(camera_detections):
|
||||
raise ValueError("All detections must be from the same camera")
|
||||
|
||||
affinity = jnp.zeros((len(trackings), len(camera_detections)))
|
||||
|
||||
for i, tracking in enumerate(trackings):
|
||||
for j, det in enumerate(camera_detections):
|
||||
global _DEBUG_CURRENT_TRACKING
|
||||
_DEBUG_CURRENT_TRACKING = (i, j)
|
||||
affinity_value = calculate_tracking_detection_affinity(
|
||||
tracking,
|
||||
det,
|
||||
w_2d=w_2d,
|
||||
alpha_2d=alpha_2d,
|
||||
w_3d=w_3d,
|
||||
alpha_3d=alpha_3d,
|
||||
lambda_a=lambda_a,
|
||||
)
|
||||
affinity = affinity.at[i, j].set(affinity_value)
|
||||
return affinity
|
||||
|
||||
|
||||
@beartype
|
||||
def calculate_camera_affinity_matrix_jax(
|
||||
trackings: Sequence[Tracking],
|
||||
@ -1010,9 +808,6 @@ def calculate_camera_affinity_matrix_jax(
|
||||
(tracking, detection) pair. The mathematical definition of the affinity
|
||||
is **unchanged**, so the result remains bit-identical to the reference
|
||||
implementation used in the tests.
|
||||
|
||||
TODO: It gives a wrong result (maybe it's my problem?) somehow,
|
||||
and I need to find a way to debug this.
|
||||
"""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@ -1052,8 +847,8 @@ def calculate_camera_affinity_matrix_jax(
|
||||
# ------------------------------------------------------------------
|
||||
# Compute Δt matrix – shape (T, D)
|
||||
# ------------------------------------------------------------------
|
||||
# Epoch timestamps are ~1.7 × 10⁹; storing them in float32 wipes out
|
||||
# sub‑second detail (resolution ≈ 200 ms). Keep them in float64 until
|
||||
# Epoch timestamps are ~1.7 × 10⁹; storing them in float32 wipes out
|
||||
# sub‑second detail (resolution ≈ 200 ms). Keep them in float64 until
|
||||
# after subtraction so we preserve Δt‑on‑the‑order‑of‑milliseconds.
|
||||
# --- timestamps ----------
|
||||
t0 = min(
|
||||
@ -1093,12 +888,6 @@ def calculate_camera_affinity_matrix_jax(
|
||||
diff2d = norm_trk[:, None, :, :] - norm_det[None, :, :, :]
|
||||
dist2d: Float[Array, "T D J"] = jnp.linalg.norm(diff2d, axis=-1)
|
||||
|
||||
jax.debug.print(
|
||||
"[JAX] norm_trk[0,0,:6] = {}", norm_trk[0, :, :6] # shape (J,2) 取前6
|
||||
)
|
||||
jax.debug.print("[JAX] norm_det[0,:6] = {}", norm_det[0, :6]) # shape (J,2)
|
||||
jax.debug.print("[JAX] dist2d(T0,D0) first6 = {}", dist2d[0, 0, :6])
|
||||
|
||||
# Compute per-keypoint 2D affinity
|
||||
delta_t_broadcast = delta_t[:, :, None] # (T, D, 1)
|
||||
affinity_2d = (
|
||||
@ -1155,11 +944,6 @@ def calculate_camera_affinity_matrix_jax(
|
||||
w_3d * (1 - dist3d / alpha_3d) * jnp.exp(-lambda_a * delta_t_broadcast)
|
||||
)
|
||||
|
||||
jax.debug.print("[JAX] aff3d(T0,D0) first6 = {}", affinity_3d[0, 0, :6])
|
||||
jax.debug.print("[JAX] aff2d(T0,D0) first6 = {}", affinity_2d[0, 0, :6])
|
||||
jax.debug.print(
|
||||
"[JAX] aff2d.shape={}; aff3d.shape={}", affinity_2d.shape, affinity_3d.shape
|
||||
)
|
||||
# ------------------------------------------------------------------
|
||||
# Combine and reduce across keypoints → (T, D)
|
||||
# ------------------------------------------------------------------
|
||||
@ -1167,60 +951,57 @@ def calculate_camera_affinity_matrix_jax(
|
||||
return total_affinity # type: ignore[return-value]
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Debug helper – compare JAX vs reference implementation
|
||||
# ------------------------------------------------------------------
|
||||
@beartype
|
||||
def debug_compare_affinity_matrices(
|
||||
def calculate_affinity_matrix(
|
||||
trackings: Sequence[Tracking],
|
||||
camera_detections: Sequence[Detection],
|
||||
*,
|
||||
detections: Sequence[Detection] | Mapping[CameraID, list[Detection]],
|
||||
w_2d: float,
|
||||
alpha_2d: float,
|
||||
w_3d: float,
|
||||
alpha_3d: float,
|
||||
lambda_a: float,
|
||||
atol: float = 1e-5,
|
||||
rtol: float = 1e-3,
|
||||
) -> None:
|
||||
) -> dict[CameraID, AffinityResult]:
|
||||
"""
|
||||
Compute both affinity matrices and print out the max absolute / relative
|
||||
difference. If any entry differs more than atol+rtol*|ref|, dump the
|
||||
offending indices so you can inspect individual terms.
|
||||
Calculate the affinity matrix between a set of trackings and detections.
|
||||
|
||||
Args:
|
||||
trackings: Sequence of tracking objects
|
||||
detections: Sequence of detection objects or a group detections by ID
|
||||
w_2d: Weight for 2D affinity
|
||||
alpha_2d: Normalization factor for 2D distance
|
||||
w_3d: Weight for 3D affinity
|
||||
alpha_3d: Normalization factor for 3D distance
|
||||
lambda_a: Decay rate for time difference
|
||||
Returns:
|
||||
A dictionary mapping camera IDs to affinity results.
|
||||
"""
|
||||
aff_jax = calculate_camera_affinity_matrix_jax(
|
||||
trackings,
|
||||
camera_detections,
|
||||
w_2d=w_2d,
|
||||
alpha_2d=alpha_2d,
|
||||
w_3d=w_3d,
|
||||
alpha_3d=alpha_3d,
|
||||
lambda_a=lambda_a,
|
||||
)
|
||||
aff_ref = calculate_camera_affinity_matrix(
|
||||
trackings,
|
||||
camera_detections,
|
||||
w_2d=w_2d,
|
||||
alpha_2d=alpha_2d,
|
||||
w_3d=w_3d,
|
||||
alpha_3d=alpha_3d,
|
||||
lambda_a=lambda_a,
|
||||
)
|
||||
|
||||
diff = jnp.abs(aff_jax - aff_ref)
|
||||
max_abs = float(diff.max())
|
||||
max_rel = float((diff / (jnp.abs(aff_ref) + 1e-12)).max())
|
||||
jax.debug.print(f"[DEBUG] max abs diff {max_abs:.6g}, max rel diff {max_rel:.6g}")
|
||||
|
||||
bad = jnp.where(diff > atol + rtol * jnp.abs(aff_ref))
|
||||
if bad[0].size > 0:
|
||||
for t, d in zip(*[arr.tolist() for arr in bad]):
|
||||
jax.debug.print(
|
||||
f" ↳ mismatch at (T={t}, D={d}): "
|
||||
f"jax={aff_jax[t,d]:.6g}, ref={aff_ref[t,d]:.6g}"
|
||||
)
|
||||
if isinstance(detections, Mapping):
|
||||
detection_by_camera = detections
|
||||
else:
|
||||
jax.debug.print("✅ matrices match within tolerance")
|
||||
detection_by_camera = classify_by_camera(detections)
|
||||
|
||||
res: dict[CameraID, AffinityResult] = {}
|
||||
for camera_id, camera_detections in detection_by_camera.items():
|
||||
affinity_matrix = calculate_camera_affinity_matrix_jax(
|
||||
trackings,
|
||||
camera_detections,
|
||||
w_2d,
|
||||
alpha_2d,
|
||||
w_3d,
|
||||
alpha_3d,
|
||||
lambda_a,
|
||||
)
|
||||
# row, col
|
||||
indices_T, indices_D = linear_sum_assignment(affinity_matrix)
|
||||
affinity_result = AffinityResult(
|
||||
matrix=affinity_matrix,
|
||||
trackings=trackings,
|
||||
detections=camera_detections,
|
||||
indices_T=indices_T,
|
||||
indices_D=indices_D,
|
||||
)
|
||||
res[camera_id] = affinity_result
|
||||
return res
|
||||
|
||||
|
||||
# %%
|
||||
@ -1235,15 +1016,15 @@ trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id)
|
||||
unmatched_detections = shallow_copy(next_group)
|
||||
camera_detections = classify_by_camera(unmatched_detections)
|
||||
|
||||
camera_detections_next_batch = camera_detections["AE_08"]
|
||||
debug_compare_affinity_matrices(
|
||||
affinities = calculate_affinity_matrix(
|
||||
trackings,
|
||||
camera_detections_next_batch,
|
||||
unmatched_detections,
|
||||
w_2d=W_2D,
|
||||
alpha_2d=ALPHA_2D,
|
||||
w_3d=W_3D,
|
||||
alpha_3d=ALPHA_3D,
|
||||
lambda_a=LAMBDA_A,
|
||||
)
|
||||
display(affinities)
|
||||
|
||||
# %%
|
||||
|
||||
Reference in New Issue
Block a user