forked from HQU-gxy/CVTH3PE
feat: Add calculate_affinity_matrix function to streamline affinity calculations
- Introduced a new function `calculate_affinity_matrix` to compute the affinity matrix between trackings and detections, enhancing modularity and clarity. - Updated the existing affinity calculation logic to utilize the new function, improving code organization and readability. - Adjusted the image size parameter in the distance calculation to ensure proper type handling. - Enhanced documentation for the new function, detailing its parameters, return values, and matrix layout for better understanding.
This commit is contained in:
118
playground.py
118
playground.py
@ -42,9 +42,11 @@ from jaxtyping import Array, Float, Num, jaxtyped
|
||||
from matplotlib import pyplot as plt
|
||||
from numpy.typing import ArrayLike
|
||||
from scipy.spatial.transform import Rotation as R
|
||||
from collections import OrderedDict
|
||||
|
||||
from app.camera import (
|
||||
Camera,
|
||||
CameraID,
|
||||
CameraParams,
|
||||
Detection,
|
||||
calculate_affinity_matrix_by_epipolar_constraint,
|
||||
@ -754,7 +756,7 @@ def calculate_tracking_detection_affinity(
|
||||
distance_2d = calculate_distance_2d(
|
||||
tracking_2d_projection,
|
||||
detection.keypoints,
|
||||
image_size=(w, h),
|
||||
image_size=(int(w), int(h)),
|
||||
)
|
||||
affinity_2d = calculate_affinity_2d(
|
||||
distance_2d,
|
||||
@ -781,6 +783,83 @@ def calculate_tracking_detection_affinity(
|
||||
return jnp.sum(total_affinity).item()
|
||||
|
||||
|
||||
@beartype
|
||||
def calculate_affinity_matrix(
|
||||
trackings: Sequence[Tracking],
|
||||
detections: Sequence[Detection],
|
||||
w_2d: float,
|
||||
alpha_2d: float,
|
||||
w_3d: float,
|
||||
alpha_3d: float,
|
||||
lambda_a: float,
|
||||
) -> tuple[Float[Array, "T D"], OrderedDict[CameraID, list[Detection]]]:
|
||||
"""
|
||||
Calculate the affinity matrix between a set of trackings and detections.
|
||||
|
||||
Args:
|
||||
trackings: Sequence of tracking objects
|
||||
detections: Sequence of detection objects
|
||||
w_2d: Weight for 2D affinity
|
||||
alpha_2d: Normalization factor for 2D distance
|
||||
w_3d: Weight for 3D affinity
|
||||
alpha_3d: Normalization factor for 3D distance
|
||||
lambda_a: Decay rate for time difference
|
||||
|
||||
Returns:
|
||||
- affinity matrix of shape (T, D) where T is number of trackings and D
|
||||
is number of detections
|
||||
- dictionary mapping camera IDs to lists of detections from that camera,
|
||||
since it's a `OrderDict` you could flat it out to get the indices of
|
||||
detections in the affinity matrix
|
||||
|
||||
Matrix Layout:
|
||||
The affinity matrix has shape (T, D), where:
|
||||
- T = number of trackings (rows)
|
||||
- D = total number of detections across all cameras (columns)
|
||||
|
||||
The matrix is organized as follows:
|
||||
|
||||
```
|
||||
| Camera 1 | Camera 2 | Camera c |
|
||||
| d1 d2 ... | d1 d2 ... | d1 d2 ... |
|
||||
---------+-------------+-------------+-------------+
|
||||
Track 1 | a11 a12 ... | a11 a12 ... | a11 a12 ... |
|
||||
Track 2 | a21 a22 ... | a21 a22 ... | a21 a22 ... |
|
||||
... | ... | ... | ... |
|
||||
Track t | at1 at2 ... | at1 at2 ... | at1 at2 ... |
|
||||
```
|
||||
|
||||
Where:
|
||||
- Rows are ordered by tracking ID
|
||||
- Columns are ordered by camera, then by detection within each camera
|
||||
- Each cell aij represents the affinity between tracking i and detection j
|
||||
|
||||
The detection ordering in columns follows the exact same order as iterating
|
||||
through the detection_by_camera dictionary, which is returned alongside
|
||||
the matrix to maintain this relationship.
|
||||
"""
|
||||
affinity = jnp.zeros((len(trackings), len(detections)))
|
||||
detection_by_camera = classify_by_camera(detections)
|
||||
|
||||
for i, tracking in enumerate(trackings):
|
||||
j = 0
|
||||
for c, camera_detections in detection_by_camera.items():
|
||||
for det in camera_detections:
|
||||
affinity_value = calculate_tracking_detection_affinity(
|
||||
tracking,
|
||||
det,
|
||||
w_2d=w_2d,
|
||||
alpha_2d=alpha_2d,
|
||||
w_3d=w_3d,
|
||||
alpha_3d=alpha_3d,
|
||||
lambda_a=lambda_a,
|
||||
)
|
||||
affinity = affinity.at[i, j].set(affinity_value)
|
||||
j += 1
|
||||
|
||||
return affinity, detection_by_camera
|
||||
|
||||
|
||||
# %%
|
||||
# let's do cross-view association
|
||||
W_2D = 1.0
|
||||
@ -791,31 +870,14 @@ ALPHA_3D = 1.0
|
||||
|
||||
trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id)
|
||||
unmatched_detections = shallow_copy(next_group)
|
||||
# cross-view association matrix with shape (T, D), where T is the number of
|
||||
# trackings, D is the number of detections
|
||||
# layout:
|
||||
# a_t1_c1_d1, a_t1_c1_d2, a_t1_c1_d3,...,a_t1_c2_d1,..., a_t1_cc_dd
|
||||
# a_t2_c1_d1,...
|
||||
# ...
|
||||
# a_tt_c1_d1,... , a_tt_cc_dd
|
||||
#
|
||||
# where T <- [t1..tt]; D <- join(c1..cc), where `cn` is a collection of
|
||||
# detections from camera `n`
|
||||
affinity = jnp.zeros((len(trackings), len(unmatched_detections)))
|
||||
detection_by_camera = classify_by_camera(unmatched_detections)
|
||||
for i, tracking in enumerate(trackings):
|
||||
j = 0
|
||||
for c, detections in detection_by_camera.items():
|
||||
for det in detections:
|
||||
affinity_value = calculate_tracking_detection_affinity(
|
||||
tracking,
|
||||
det,
|
||||
w_2d=W_2D,
|
||||
alpha_2d=ALPHA_2D,
|
||||
w_3d=W_3D,
|
||||
alpha_3d=ALPHA_3D,
|
||||
lambda_a=LAMBDA_A,
|
||||
)
|
||||
affinity = affinity.at[i, j].set(affinity_value)
|
||||
j += 1
|
||||
|
||||
affinity, detection_by_camera = calculate_affinity_matrix(
|
||||
trackings,
|
||||
unmatched_detections,
|
||||
w_2d=W_2D,
|
||||
alpha_2d=ALPHA_2D,
|
||||
w_3d=W_3D,
|
||||
alpha_3d=ALPHA_3D,
|
||||
lambda_a=LAMBDA_A,
|
||||
)
|
||||
display(affinity)
|
||||
|
||||
Reference in New Issue
Block a user