1
0
forked from HQU-gxy/CVTH3PE

feat: Add calculate_affinity_matrix function to streamline affinity calculations

- Introduced a new function `calculate_affinity_matrix` to compute the affinity matrix between trackings and detections, enhancing modularity and clarity.
- Updated the existing affinity calculation logic to utilize the new function, improving code organization and readability.
- Adjusted the image size parameter in the distance calculation to ensure proper type handling.
- Enhanced documentation for the new function, detailing its parameters, return values, and matrix layout for better understanding.
This commit is contained in:
2025-04-27 17:57:12 +08:00
parent 41e0141bde
commit d4ade248dc

View File

@ -42,9 +42,11 @@ from jaxtyping import Array, Float, Num, jaxtyped
from matplotlib import pyplot as plt
from numpy.typing import ArrayLike
from scipy.spatial.transform import Rotation as R
from collections import OrderedDict
from app.camera import (
Camera,
CameraID,
CameraParams,
Detection,
calculate_affinity_matrix_by_epipolar_constraint,
@ -754,7 +756,7 @@ def calculate_tracking_detection_affinity(
distance_2d = calculate_distance_2d(
tracking_2d_projection,
detection.keypoints,
image_size=(w, h),
image_size=(int(w), int(h)),
)
affinity_2d = calculate_affinity_2d(
distance_2d,
@ -781,6 +783,83 @@ def calculate_tracking_detection_affinity(
return jnp.sum(total_affinity).item()
@beartype
def calculate_affinity_matrix(
trackings: Sequence[Tracking],
detections: Sequence[Detection],
w_2d: float,
alpha_2d: float,
w_3d: float,
alpha_3d: float,
lambda_a: float,
) -> tuple[Float[Array, "T D"], OrderedDict[CameraID, list[Detection]]]:
"""
Calculate the affinity matrix between a set of trackings and detections.
Args:
trackings: Sequence of tracking objects
detections: Sequence of detection objects
w_2d: Weight for 2D affinity
alpha_2d: Normalization factor for 2D distance
w_3d: Weight for 3D affinity
alpha_3d: Normalization factor for 3D distance
lambda_a: Decay rate for time difference
Returns:
- affinity matrix of shape (T, D) where T is number of trackings and D
is number of detections
- dictionary mapping camera IDs to lists of detections from that camera,
since it's a `OrderDict` you could flat it out to get the indices of
detections in the affinity matrix
Matrix Layout:
The affinity matrix has shape (T, D), where:
- T = number of trackings (rows)
- D = total number of detections across all cameras (columns)
The matrix is organized as follows:
```
| Camera 1 | Camera 2 | Camera c |
| d1 d2 ... | d1 d2 ... | d1 d2 ... |
---------+-------------+-------------+-------------+
Track 1 | a11 a12 ... | a11 a12 ... | a11 a12 ... |
Track 2 | a21 a22 ... | a21 a22 ... | a21 a22 ... |
... | ... | ... | ... |
Track t | at1 at2 ... | at1 at2 ... | at1 at2 ... |
```
Where:
- Rows are ordered by tracking ID
- Columns are ordered by camera, then by detection within each camera
- Each cell aij represents the affinity between tracking i and detection j
The detection ordering in columns follows the exact same order as iterating
through the detection_by_camera dictionary, which is returned alongside
the matrix to maintain this relationship.
"""
affinity = jnp.zeros((len(trackings), len(detections)))
detection_by_camera = classify_by_camera(detections)
for i, tracking in enumerate(trackings):
j = 0
for c, camera_detections in detection_by_camera.items():
for det in camera_detections:
affinity_value = calculate_tracking_detection_affinity(
tracking,
det,
w_2d=w_2d,
alpha_2d=alpha_2d,
w_3d=w_3d,
alpha_3d=alpha_3d,
lambda_a=lambda_a,
)
affinity = affinity.at[i, j].set(affinity_value)
j += 1
return affinity, detection_by_camera
# %%
# let's do cross-view association
W_2D = 1.0
@ -791,31 +870,14 @@ ALPHA_3D = 1.0
trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id)
unmatched_detections = shallow_copy(next_group)
# cross-view association matrix with shape (T, D), where T is the number of
# trackings, D is the number of detections
# layout:
# a_t1_c1_d1, a_t1_c1_d2, a_t1_c1_d3,...,a_t1_c2_d1,..., a_t1_cc_dd
# a_t2_c1_d1,...
# ...
# a_tt_c1_d1,... , a_tt_cc_dd
#
# where T <- [t1..tt]; D <- join(c1..cc), where `cn` is a collection of
# detections from camera `n`
affinity = jnp.zeros((len(trackings), len(unmatched_detections)))
detection_by_camera = classify_by_camera(unmatched_detections)
for i, tracking in enumerate(trackings):
j = 0
for c, detections in detection_by_camera.items():
for det in detections:
affinity_value = calculate_tracking_detection_affinity(
tracking,
det,
w_2d=W_2D,
alpha_2d=ALPHA_2D,
w_3d=W_3D,
alpha_3d=ALPHA_3D,
lambda_a=LAMBDA_A,
)
affinity = affinity.at[i, j].set(affinity_value)
j += 1
affinity, detection_by_camera = calculate_affinity_matrix(
trackings,
unmatched_detections,
w_2d=W_2D,
alpha_2d=ALPHA_2D,
w_3d=W_3D,
alpha_3d=ALPHA_3D,
lambda_a=LAMBDA_A,
)
display(affinity)