1
0
forked from HQU-gxy/CVTH3PE

feat: Add calculate_affinity_matrix function to streamline affinity calculations

- Introduced a new function `calculate_affinity_matrix` to compute the affinity matrix between trackings and detections, enhancing modularity and clarity.
- Updated the existing affinity calculation logic to utilize the new function, improving code organization and readability.
- Adjusted the image size parameter in the distance calculation to ensure proper type handling.
- Enhanced documentation for the new function, detailing its parameters, return values, and matrix layout for better understanding.
This commit is contained in:
2025-04-27 17:57:12 +08:00
parent 41e0141bde
commit d4ade248dc

View File

@ -42,9 +42,11 @@ from jaxtyping import Array, Float, Num, jaxtyped
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from numpy.typing import ArrayLike from numpy.typing import ArrayLike
from scipy.spatial.transform import Rotation as R from scipy.spatial.transform import Rotation as R
from collections import OrderedDict
from app.camera import ( from app.camera import (
Camera, Camera,
CameraID,
CameraParams, CameraParams,
Detection, Detection,
calculate_affinity_matrix_by_epipolar_constraint, calculate_affinity_matrix_by_epipolar_constraint,
@ -754,7 +756,7 @@ def calculate_tracking_detection_affinity(
distance_2d = calculate_distance_2d( distance_2d = calculate_distance_2d(
tracking_2d_projection, tracking_2d_projection,
detection.keypoints, detection.keypoints,
image_size=(w, h), image_size=(int(w), int(h)),
) )
affinity_2d = calculate_affinity_2d( affinity_2d = calculate_affinity_2d(
distance_2d, distance_2d,
@ -781,6 +783,83 @@ def calculate_tracking_detection_affinity(
return jnp.sum(total_affinity).item() return jnp.sum(total_affinity).item()
@beartype
def calculate_affinity_matrix(
trackings: Sequence[Tracking],
detections: Sequence[Detection],
w_2d: float,
alpha_2d: float,
w_3d: float,
alpha_3d: float,
lambda_a: float,
) -> tuple[Float[Array, "T D"], OrderedDict[CameraID, list[Detection]]]:
"""
Calculate the affinity matrix between a set of trackings and detections.
Args:
trackings: Sequence of tracking objects
detections: Sequence of detection objects
w_2d: Weight for 2D affinity
alpha_2d: Normalization factor for 2D distance
w_3d: Weight for 3D affinity
alpha_3d: Normalization factor for 3D distance
lambda_a: Decay rate for time difference
Returns:
- affinity matrix of shape (T, D) where T is number of trackings and D
is number of detections
- dictionary mapping camera IDs to lists of detections from that camera,
since it's a `OrderDict` you could flat it out to get the indices of
detections in the affinity matrix
Matrix Layout:
The affinity matrix has shape (T, D), where:
- T = number of trackings (rows)
- D = total number of detections across all cameras (columns)
The matrix is organized as follows:
```
| Camera 1 | Camera 2 | Camera c |
| d1 d2 ... | d1 d2 ... | d1 d2 ... |
---------+-------------+-------------+-------------+
Track 1 | a11 a12 ... | a11 a12 ... | a11 a12 ... |
Track 2 | a21 a22 ... | a21 a22 ... | a21 a22 ... |
... | ... | ... | ... |
Track t | at1 at2 ... | at1 at2 ... | at1 at2 ... |
```
Where:
- Rows are ordered by tracking ID
- Columns are ordered by camera, then by detection within each camera
- Each cell aij represents the affinity between tracking i and detection j
The detection ordering in columns follows the exact same order as iterating
through the detection_by_camera dictionary, which is returned alongside
the matrix to maintain this relationship.
"""
affinity = jnp.zeros((len(trackings), len(detections)))
detection_by_camera = classify_by_camera(detections)
for i, tracking in enumerate(trackings):
j = 0
for c, camera_detections in detection_by_camera.items():
for det in camera_detections:
affinity_value = calculate_tracking_detection_affinity(
tracking,
det,
w_2d=w_2d,
alpha_2d=alpha_2d,
w_3d=w_3d,
alpha_3d=alpha_3d,
lambda_a=lambda_a,
)
affinity = affinity.at[i, j].set(affinity_value)
j += 1
return affinity, detection_by_camera
# %% # %%
# let's do cross-view association # let's do cross-view association
W_2D = 1.0 W_2D = 1.0
@ -791,31 +870,14 @@ ALPHA_3D = 1.0
trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id) trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id)
unmatched_detections = shallow_copy(next_group) unmatched_detections = shallow_copy(next_group)
# cross-view association matrix with shape (T, D), where T is the number of
# trackings, D is the number of detections affinity, detection_by_camera = calculate_affinity_matrix(
# layout: trackings,
# a_t1_c1_d1, a_t1_c1_d2, a_t1_c1_d3,...,a_t1_c2_d1,..., a_t1_cc_dd unmatched_detections,
# a_t2_c1_d1,... w_2d=W_2D,
# ... alpha_2d=ALPHA_2D,
# a_tt_c1_d1,... , a_tt_cc_dd w_3d=W_3D,
# alpha_3d=ALPHA_3D,
# where T <- [t1..tt]; D <- join(c1..cc), where `cn` is a collection of lambda_a=LAMBDA_A,
# detections from camera `n` )
affinity = jnp.zeros((len(trackings), len(unmatched_detections)))
detection_by_camera = classify_by_camera(unmatched_detections)
for i, tracking in enumerate(trackings):
j = 0
for c, detections in detection_by_camera.items():
for det in detections:
affinity_value = calculate_tracking_detection_affinity(
tracking,
det,
w_2d=W_2D,
alpha_2d=ALPHA_2D,
w_3d=W_3D,
alpha_3d=ALPHA_3D,
lambda_a=LAMBDA_A,
)
affinity = affinity.at[i, j].set(affinity_value)
j += 1
display(affinity) display(affinity)