From e79e899b874806b134ea7cc50f2f02bbaefd9507 Mon Sep 17 00:00:00 2001 From: crosstyan Date: Tue, 29 Apr 2025 16:14:07 +0800 Subject: [PATCH] refactor: Optimize affinity matrix calculation with parallel processing - Introduced parallel processing using `ThreadPoolExecutor` to enhance the performance of the `calculate_affinity_matrix` function, allowing simultaneous processing of multiple cameras. - Added a new helper function `_single_camera_job` to encapsulate the logic for processing individual camera detections, improving code organization and readability. - Updated the function signature to include an optional `max_workers` parameter for controlling the number of threads used in parallel execution. - Enhanced documentation to clarify the purpose and parameters of the new parallel processing implementation. --- playground.py | 63 +++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 58 insertions(+), 5 deletions(-) diff --git a/playground.py b/playground.py index 39cdb6a..43f1e92 100644 --- a/playground.py +++ b/playground.py @@ -14,7 +14,11 @@ # %% +## imports + +import time from collections import OrderedDict +from concurrent.futures import ThreadPoolExecutor, as_completed from copy import copy as shallow_copy from copy import deepcopy as deep_copy from dataclasses import dataclass @@ -960,9 +964,11 @@ def calculate_affinity_matrix( w_3d: float, alpha_3d: float, lambda_a: float, + max_workers: int | None = None, ) -> dict[CameraID, AffinityResult]: """ Calculate the affinity matrix between a set of trackings and detections. + Uses parallel processing with ThreadPoolExecutor to handle multiple cameras concurrently. Args: trackings: Sequence of tracking objects @@ -972,6 +978,7 @@ def calculate_affinity_matrix( w_3d: Weight for 3D affinity alpha_3d: Normalization factor for 3D distance lambda_a: Decay rate for time difference + max_workers: Maximum number of parallel workers (threads), None = auto Returns: A dictionary mapping camera IDs to affinity results. """ @@ -980,8 +987,30 @@ def calculate_affinity_matrix( else: detection_by_camera = classify_by_camera(detections) - res: dict[CameraID, AffinityResult] = {} - for camera_id, camera_detections in detection_by_camera.items(): + def _single_camera_job( + camera_id: CameraID, + camera_detections: list[Detection], + trackings: Sequence[Tracking], + w_2d: float, + alpha_2d: float, + w_3d: float, + alpha_3d: float, + lambda_a: float, + ) -> tuple[CameraID, AffinityResult]: + """ + Process a single camera's detections to calculate affinity and assignment. + This function is designed to be run in parallel for each camera. + + Args: + camera_id: The camera ID + camera_detections: List of detections from this camera + trackings: Sequence of tracking objects + w_2d, alpha_2d, w_3d, alpha_3d, lambda_a: Affinity parameters + + Returns: + A tuple of (camera_id, AffinityResult) + """ + # 1) Calculate affinity matrix affinity_matrix = calculate_camera_affinity_matrix_jax( trackings, camera_detections, @@ -991,16 +1020,37 @@ def calculate_affinity_matrix( alpha_3d, lambda_a, ) - # row, col + # 2) Calculate assignment indices_T, indices_D = linear_sum_assignment(affinity_matrix) - affinity_result = AffinityResult( + return camera_id, AffinityResult( matrix=affinity_matrix, trackings=trackings, detections=camera_detections, indices_T=indices_T, indices_D=indices_D, ) - res[camera_id] = affinity_result + + # Run cameras in parallel + res: dict[CameraID, AffinityResult] = {} + job = partial( + _single_camera_job, + trackings=trackings, + w_2d=w_2d, + alpha_2d=alpha_2d, + w_3d=w_3d, + alpha_3d=alpha_3d, + lambda_a=lambda_a, + ) + + with ThreadPoolExecutor(max_workers=max_workers) as pool: + futures = [ + pool.submit(job, cam_id, cam_dets) + for cam_id, cam_dets in detection_by_camera.items() + ] + for fut in as_completed(futures): + cam_id, out = fut.result() + res[cam_id] = out + return res @@ -1016,6 +1066,7 @@ trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id) unmatched_detections = shallow_copy(next_group) camera_detections = classify_by_camera(unmatched_detections) +t0 = time.perf_counter() affinities = calculate_affinity_matrix( trackings, unmatched_detections, @@ -1026,5 +1077,7 @@ affinities = calculate_affinity_matrix( lambda_a=LAMBDA_A, ) display(affinities) +t1 = time.perf_counter() +print(f"Time taken: {t1 - t0} seconds") # %%