1
0
forked from HQU-gxy/CVTH3PE

refactor: Update affinity matrix calculation and dependencies

- Replaced the `linear_sum_assignment` import from `scipy.optimize` with `hungarian_algorithm` from `optax` to enhance performance in affinity matrix calculations.
- Introduced a new `AffinityResult` class to encapsulate results of affinity computations, including trackings and detections, improving the structure of the affinity calculation process.
- Removed deprecated functions and debug print statements to streamline the codebase.
- Updated `pyproject.toml` and `uv.lock` to include `optax` as a dependency, ensuring compatibility with the new implementation.
- Refactored imports and type hints for better organization and consistency across the codebase.
This commit is contained in:
2025-04-29 15:45:24 +08:00
parent ce1d5f3cf7
commit 29ca66ad47
6 changed files with 152 additions and 529 deletions

View File

@ -45,7 +45,7 @@ from IPython.display import display
from jaxtyping import Array, Float, Num, jaxtyped
from matplotlib import pyplot as plt
from numpy.typing import ArrayLike
from scipy.optimize import linear_sum_assignment
from optax.assignment import hungarian_algorithm as linear_sum_assignment
from scipy.spatial.transform import Rotation as R
from typing_extensions import deprecated
@ -58,7 +58,7 @@ from app.camera import (
classify_by_camera,
)
from app.solver._old import GLPKSolver
from app.tracking import Tracking
from app.tracking import AffinityResult, Tracking
from app.visualize.whole_body import visualize_whole_body
NDArray: TypeAlias = np.ndarray
@ -69,12 +69,6 @@ AK_CAMERA_DATASET: ak.Array = ak.from_parquet(DATASET_PATH / "camera_params.parq
DELTA_T_MIN = timedelta(milliseconds=10)
display(AK_CAMERA_DATASET)
_DEBUG_CURRENT_TRACKING: tuple[int, int] = (0, 0)
def _global_current_tracking_str():
return str(_DEBUG_CURRENT_TRACKING)
# %%
class Resolution(TypedDict):
@ -594,23 +588,6 @@ def calculate_distance_2d(
left_normalized = left / jnp.array([w, h])
right_normalized = right / jnp.array([w, h])
dist = jnp.linalg.norm(left_normalized - right_normalized, axis=-1)
lt = left_normalized[:6]
rt = right_normalized[:6]
jax.debug.print(
"[REF]{} norm_trk first6 = {}",
_global_current_tracking_str(),
lt,
)
jax.debug.print(
"[REF]{} norm_det first6 = {}",
_global_current_tracking_str(),
rt,
)
jax.debug.print(
"[REF]{} dist2d first6 = {}",
_global_current_tracking_str(),
dist[:6],
)
return dist
@ -806,191 +783,12 @@ def calculate_tracking_detection_affinity(
lambda_a=lambda_a,
)
jax.debug.print(
"[REF] aff2d{} first6 = {}",
_global_current_tracking_str(),
affinity_2d[:6],
)
jax.debug.print(
"[REF] aff3d{} first6 = {}", _global_current_tracking_str(), affinity_3d[:6]
)
jax.debug.print(
"[REF] aff2d.shape={}; aff3d.shape={}",
affinity_2d.shape,
affinity_3d.shape,
)
# Combine affinities
total_affinity = affinity_2d + affinity_3d
return jnp.sum(total_affinity).item()
# %%
@deprecated(
"Use `calculate_camera_affinity_matrix` instead. This implementation has the problem of under-utilizing views from different cameras."
)
@beartype
def calculate_affinity_matrix(
trackings: Sequence[Tracking],
detections: Sequence[Detection] | OrderedDict[CameraID, list[Detection]],
w_2d: float,
alpha_2d: float,
w_3d: float,
alpha_3d: float,
lambda_a: float,
) -> tuple[Float[Array, "T D"], OrderedDict[CameraID, list[Detection]]]:
"""
Calculate the affinity matrix between a set of trackings and detections.
Args:
trackings: Sequence of tracking objects
detections: Sequence of detection objects
w_2d: Weight for 2D affinity
alpha_2d: Normalization factor for 2D distance
w_3d: Weight for 3D affinity
alpha_3d: Normalization factor for 3D distance
lambda_a: Decay rate for time difference
Returns:
- affinity matrix of shape (T, D) where T is number of trackings and D
is number of detections
- dictionary mapping camera IDs to lists of detections from that camera,
since it's a `OrderDict` you could flat it out to get the indices of
detections in the affinity matrix
Matrix Layout:
The affinity matrix has shape (T, D), where:
- T = number of trackings (rows)
- D = total number of detections across all cameras (columns)
The matrix is organized as follows:
```
| Camera 1 | Camera 2 | Camera c |
| d1 d2 ... | d1 d2 ... | d1 d2 ... |
---------+-------------+-------------+-------------+
Track 1 | a11 a12 ... | a11 a12 ... | a11 a12 ... |
Track 2 | a21 a22 ... | a21 a22 ... | a21 a22 ... |
... | ... | ... | ... |
Track t | at1 at2 ... | at1 at2 ... | at1 at2 ... |
```
Where:
- Rows are ordered by tracking ID
- Columns are ordered by camera, then by detection within each camera
- Each cell aij represents the affinity between tracking i and detection j
The detection ordering in columns follows the exact same order as iterating
through the detection_by_camera dictionary, which is returned alongside
the matrix to maintain this relationship.
"""
if isinstance(detections, OrderedDict):
D = flatten_values_len(detections)
affinity = jnp.zeros((len(trackings), D))
detection_by_camera = detections
else:
affinity = jnp.zeros((len(trackings), len(detections)))
detection_by_camera = classify_by_camera(detections)
for i, tracking in enumerate(trackings):
j = 0
for _, camera_detections in detection_by_camera.items():
for det in camera_detections:
affinity_value = calculate_tracking_detection_affinity(
tracking,
det,
w_2d=w_2d,
alpha_2d=alpha_2d,
w_3d=w_3d,
alpha_3d=alpha_3d,
lambda_a=lambda_a,
)
affinity = affinity.at[i, j].set(affinity_value)
j += 1
return affinity, detection_by_camera
@beartype
def calculate_camera_affinity_matrix(
trackings: Sequence[Tracking],
camera_detections: Sequence[Detection],
w_2d: float,
alpha_2d: float,
w_3d: float,
alpha_3d: float,
lambda_a: float,
) -> Float[Array, "T D"]:
"""
Calculate an affinity matrix between trackings and detections from a single camera.
This follows the iterative camera-by-camera approach from the paper
"Cross-View Tracking for Multi-Human 3D Pose Estimation at over 100 FPS".
Instead of creating one large matrix for all cameras, this creates
a separate matrix for each camera, which can be processed independently.
Args:
trackings: Sequence of tracking objects
camera_detections: Sequence of detection objects, from the same camera
w_2d: Weight for 2D affinity
alpha_2d: Normalization factor for 2D distance
w_3d: Weight for 3D affinity
alpha_3d: Normalization factor for 3D distance
lambda_a: Decay rate for time difference
Returns:
Affinity matrix of shape (T, D) where:
- T = number of trackings (rows)
- D = number of detections from this specific camera (columns)
Matrix Layout:
The affinity matrix for a single camera has shape (T, D), where:
- T = number of trackings (rows)
- D = number of detections from this camera (columns)
The matrix is organized as follows:
```
| Detections from Camera c |
| d1 d2 d3 ... |
---------+------------------------+
Track 1 | a11 a12 a13 ... |
Track 2 | a21 a22 a23 ... |
... | ... ... ... ... |
Track t | at1 at2 at3 ... |
```
Each cell aij represents the affinity between tracking i and detection j,
computed using both 2D and 3D geometric correspondences.
"""
def verify_all_detection_from_same_camera(detections: Sequence[Detection]):
if not detections:
return True
camera_id = next(iter(detections)).camera.id
return all(map(lambda d: d.camera.id == camera_id, detections))
if not verify_all_detection_from_same_camera(camera_detections):
raise ValueError("All detections must be from the same camera")
affinity = jnp.zeros((len(trackings), len(camera_detections)))
for i, tracking in enumerate(trackings):
for j, det in enumerate(camera_detections):
global _DEBUG_CURRENT_TRACKING
_DEBUG_CURRENT_TRACKING = (i, j)
affinity_value = calculate_tracking_detection_affinity(
tracking,
det,
w_2d=w_2d,
alpha_2d=alpha_2d,
w_3d=w_3d,
alpha_3d=alpha_3d,
lambda_a=lambda_a,
)
affinity = affinity.at[i, j].set(affinity_value)
return affinity
@beartype
def calculate_camera_affinity_matrix_jax(
trackings: Sequence[Tracking],
@ -1010,9 +808,6 @@ def calculate_camera_affinity_matrix_jax(
(tracking, detection) pair. The mathematical definition of the affinity
is **unchanged**, so the result remains bit-identical to the reference
implementation used in the tests.
TODO: It gives a wrong result (maybe it's my problem?) somehow,
and I need to find a way to debug this.
"""
# ------------------------------------------------------------------
@ -1052,8 +847,8 @@ def calculate_camera_affinity_matrix_jax(
# ------------------------------------------------------------------
# Compute Δt matrix shape (T, D)
# ------------------------------------------------------------------
# Epoch timestamps are ~1.7×10⁹; storing them in float32 wipes out
# subsecond detail (resolution ≈ 200ms). Keep them in float64 until
# Epoch timestamps are ~1.7 × 10⁹; storing them in float32 wipes out
# subsecond detail (resolution ≈ 200 ms). Keep them in float64 until
# after subtraction so we preserve Δtontheorderofmilliseconds.
# --- timestamps ----------
t0 = min(
@ -1093,12 +888,6 @@ def calculate_camera_affinity_matrix_jax(
diff2d = norm_trk[:, None, :, :] - norm_det[None, :, :, :]
dist2d: Float[Array, "T D J"] = jnp.linalg.norm(diff2d, axis=-1)
jax.debug.print(
"[JAX] norm_trk[0,0,:6] = {}", norm_trk[0, :, :6] # shape (J,2) 取前6
)
jax.debug.print("[JAX] norm_det[0,:6] = {}", norm_det[0, :6]) # shape (J,2)
jax.debug.print("[JAX] dist2d(T0,D0) first6 = {}", dist2d[0, 0, :6])
# Compute per-keypoint 2D affinity
delta_t_broadcast = delta_t[:, :, None] # (T, D, 1)
affinity_2d = (
@ -1155,11 +944,6 @@ def calculate_camera_affinity_matrix_jax(
w_3d * (1 - dist3d / alpha_3d) * jnp.exp(-lambda_a * delta_t_broadcast)
)
jax.debug.print("[JAX] aff3d(T0,D0) first6 = {}", affinity_3d[0, 0, :6])
jax.debug.print("[JAX] aff2d(T0,D0) first6 = {}", affinity_2d[0, 0, :6])
jax.debug.print(
"[JAX] aff2d.shape={}; aff3d.shape={}", affinity_2d.shape, affinity_3d.shape
)
# ------------------------------------------------------------------
# Combine and reduce across keypoints → (T, D)
# ------------------------------------------------------------------
@ -1167,60 +951,57 @@ def calculate_camera_affinity_matrix_jax(
return total_affinity # type: ignore[return-value]
# ------------------------------------------------------------------
# Debug helper compare JAX vs reference implementation
# ------------------------------------------------------------------
@beartype
def debug_compare_affinity_matrices(
def calculate_affinity_matrix(
trackings: Sequence[Tracking],
camera_detections: Sequence[Detection],
*,
detections: Sequence[Detection] | Mapping[CameraID, list[Detection]],
w_2d: float,
alpha_2d: float,
w_3d: float,
alpha_3d: float,
lambda_a: float,
atol: float = 1e-5,
rtol: float = 1e-3,
) -> None:
) -> dict[CameraID, AffinityResult]:
"""
Compute both affinity matrices and print out the max absolute / relative
difference. If any entry differs more than atol+rtol*|ref|, dump the
offending indices so you can inspect individual terms.
Calculate the affinity matrix between a set of trackings and detections.
Args:
trackings: Sequence of tracking objects
detections: Sequence of detection objects or a group detections by ID
w_2d: Weight for 2D affinity
alpha_2d: Normalization factor for 2D distance
w_3d: Weight for 3D affinity
alpha_3d: Normalization factor for 3D distance
lambda_a: Decay rate for time difference
Returns:
A dictionary mapping camera IDs to affinity results.
"""
aff_jax = calculate_camera_affinity_matrix_jax(
trackings,
camera_detections,
w_2d=w_2d,
alpha_2d=alpha_2d,
w_3d=w_3d,
alpha_3d=alpha_3d,
lambda_a=lambda_a,
)
aff_ref = calculate_camera_affinity_matrix(
trackings,
camera_detections,
w_2d=w_2d,
alpha_2d=alpha_2d,
w_3d=w_3d,
alpha_3d=alpha_3d,
lambda_a=lambda_a,
)
diff = jnp.abs(aff_jax - aff_ref)
max_abs = float(diff.max())
max_rel = float((diff / (jnp.abs(aff_ref) + 1e-12)).max())
jax.debug.print(f"[DEBUG] max abs diff {max_abs:.6g}, max rel diff {max_rel:.6g}")
bad = jnp.where(diff > atol + rtol * jnp.abs(aff_ref))
if bad[0].size > 0:
for t, d in zip(*[arr.tolist() for arr in bad]):
jax.debug.print(
f" ↳ mismatch at (T={t}, D={d}): "
f"jax={aff_jax[t,d]:.6g}, ref={aff_ref[t,d]:.6g}"
)
if isinstance(detections, Mapping):
detection_by_camera = detections
else:
jax.debug.print("✅ matrices match within tolerance")
detection_by_camera = classify_by_camera(detections)
res: dict[CameraID, AffinityResult] = {}
for camera_id, camera_detections in detection_by_camera.items():
affinity_matrix = calculate_camera_affinity_matrix_jax(
trackings,
camera_detections,
w_2d,
alpha_2d,
w_3d,
alpha_3d,
lambda_a,
)
# row, col
indices_T, indices_D = linear_sum_assignment(affinity_matrix)
affinity_result = AffinityResult(
matrix=affinity_matrix,
trackings=trackings,
detections=camera_detections,
indices_T=indices_T,
indices_D=indices_D,
)
res[camera_id] = affinity_result
return res
# %%
@ -1235,15 +1016,15 @@ trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id)
unmatched_detections = shallow_copy(next_group)
camera_detections = classify_by_camera(unmatched_detections)
camera_detections_next_batch = camera_detections["AE_08"]
debug_compare_affinity_matrices(
affinities = calculate_affinity_matrix(
trackings,
camera_detections_next_batch,
unmatched_detections,
w_2d=W_2D,
alpha_2d=ALPHA_2D,
w_3d=W_3D,
alpha_3d=ALPHA_3D,
lambda_a=LAMBDA_A,
)
display(affinities)
# %%