diff --git a/playground.py b/playground.py index 33fa20d..d3a1eb8 100644 --- a/playground.py +++ b/playground.py @@ -42,9 +42,11 @@ from jaxtyping import Array, Float, Num, jaxtyped from matplotlib import pyplot as plt from numpy.typing import ArrayLike from scipy.spatial.transform import Rotation as R +from collections import OrderedDict from app.camera import ( Camera, + CameraID, CameraParams, Detection, calculate_affinity_matrix_by_epipolar_constraint, @@ -754,7 +756,7 @@ def calculate_tracking_detection_affinity( distance_2d = calculate_distance_2d( tracking_2d_projection, detection.keypoints, - image_size=(w, h), + image_size=(int(w), int(h)), ) affinity_2d = calculate_affinity_2d( distance_2d, @@ -781,6 +783,83 @@ def calculate_tracking_detection_affinity( return jnp.sum(total_affinity).item() +@beartype +def calculate_affinity_matrix( + trackings: Sequence[Tracking], + detections: Sequence[Detection], + w_2d: float, + alpha_2d: float, + w_3d: float, + alpha_3d: float, + lambda_a: float, +) -> tuple[Float[Array, "T D"], OrderedDict[CameraID, list[Detection]]]: + """ + Calculate the affinity matrix between a set of trackings and detections. + + Args: + trackings: Sequence of tracking objects + detections: Sequence of detection objects + w_2d: Weight for 2D affinity + alpha_2d: Normalization factor for 2D distance + w_3d: Weight for 3D affinity + alpha_3d: Normalization factor for 3D distance + lambda_a: Decay rate for time difference + + Returns: + - affinity matrix of shape (T, D) where T is number of trackings and D + is number of detections + - dictionary mapping camera IDs to lists of detections from that camera, + since it's a `OrderDict` you could flat it out to get the indices of + detections in the affinity matrix + + Matrix Layout: + The affinity matrix has shape (T, D), where: + - T = number of trackings (rows) + - D = total number of detections across all cameras (columns) + + The matrix is organized as follows: + + ``` + | Camera 1 | Camera 2 | Camera c | + | d1 d2 ... | d1 d2 ... | d1 d2 ... | + ---------+-------------+-------------+-------------+ + Track 1 | a11 a12 ... | a11 a12 ... | a11 a12 ... | + Track 2 | a21 a22 ... | a21 a22 ... | a21 a22 ... | + ... | ... | ... | ... | + Track t | at1 at2 ... | at1 at2 ... | at1 at2 ... | + ``` + + Where: + - Rows are ordered by tracking ID + - Columns are ordered by camera, then by detection within each camera + - Each cell aij represents the affinity between tracking i and detection j + + The detection ordering in columns follows the exact same order as iterating + through the detection_by_camera dictionary, which is returned alongside + the matrix to maintain this relationship. + """ + affinity = jnp.zeros((len(trackings), len(detections))) + detection_by_camera = classify_by_camera(detections) + + for i, tracking in enumerate(trackings): + j = 0 + for c, camera_detections in detection_by_camera.items(): + for det in camera_detections: + affinity_value = calculate_tracking_detection_affinity( + tracking, + det, + w_2d=w_2d, + alpha_2d=alpha_2d, + w_3d=w_3d, + alpha_3d=alpha_3d, + lambda_a=lambda_a, + ) + affinity = affinity.at[i, j].set(affinity_value) + j += 1 + + return affinity, detection_by_camera + + # %% # let's do cross-view association W_2D = 1.0 @@ -791,31 +870,14 @@ ALPHA_3D = 1.0 trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id) unmatched_detections = shallow_copy(next_group) -# cross-view association matrix with shape (T, D), where T is the number of -# trackings, D is the number of detections -# layout: -# a_t1_c1_d1, a_t1_c1_d2, a_t1_c1_d3,...,a_t1_c2_d1,..., a_t1_cc_dd -# a_t2_c1_d1,... -# ... -# a_tt_c1_d1,... , a_tt_cc_dd -# -# where T <- [t1..tt]; D <- join(c1..cc), where `cn` is a collection of -# detections from camera `n` -affinity = jnp.zeros((len(trackings), len(unmatched_detections))) -detection_by_camera = classify_by_camera(unmatched_detections) -for i, tracking in enumerate(trackings): - j = 0 - for c, detections in detection_by_camera.items(): - for det in detections: - affinity_value = calculate_tracking_detection_affinity( - tracking, - det, - w_2d=W_2D, - alpha_2d=ALPHA_2D, - w_3d=W_3D, - alpha_3d=ALPHA_3D, - lambda_a=LAMBDA_A, - ) - affinity = affinity.at[i, j].set(affinity_value) - j += 1 + +affinity, detection_by_camera = calculate_affinity_matrix( + trackings, + unmatched_detections, + w_2d=W_2D, + alpha_2d=ALPHA_2D, + w_3d=W_3D, + alpha_3d=ALPHA_3D, + lambda_a=LAMBDA_A, +) display(affinity)