1
0
forked from HQU-gxy/CVTH3PE

refactor: Optimize affinity matrix calculation with parallel processing

- Introduced parallel processing using `ThreadPoolExecutor` to enhance the performance of the `calculate_affinity_matrix` function, allowing simultaneous processing of multiple cameras.
- Added a new helper function `_single_camera_job` to encapsulate the logic for processing individual camera detections, improving code organization and readability.
- Updated the function signature to include an optional `max_workers` parameter for controlling the number of threads used in parallel execution.
- Enhanced documentation to clarify the purpose and parameters of the new parallel processing implementation.
This commit is contained in:
2025-04-29 16:14:07 +08:00
parent 29ca66ad47
commit e79e899b87

View File

@ -14,7 +14,11 @@
# %% # %%
## imports
import time
from collections import OrderedDict from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor, as_completed
from copy import copy as shallow_copy from copy import copy as shallow_copy
from copy import deepcopy as deep_copy from copy import deepcopy as deep_copy
from dataclasses import dataclass from dataclasses import dataclass
@ -960,9 +964,11 @@ def calculate_affinity_matrix(
w_3d: float, w_3d: float,
alpha_3d: float, alpha_3d: float,
lambda_a: float, lambda_a: float,
max_workers: int | None = None,
) -> dict[CameraID, AffinityResult]: ) -> dict[CameraID, AffinityResult]:
""" """
Calculate the affinity matrix between a set of trackings and detections. Calculate the affinity matrix between a set of trackings and detections.
Uses parallel processing with ThreadPoolExecutor to handle multiple cameras concurrently.
Args: Args:
trackings: Sequence of tracking objects trackings: Sequence of tracking objects
@ -972,6 +978,7 @@ def calculate_affinity_matrix(
w_3d: Weight for 3D affinity w_3d: Weight for 3D affinity
alpha_3d: Normalization factor for 3D distance alpha_3d: Normalization factor for 3D distance
lambda_a: Decay rate for time difference lambda_a: Decay rate for time difference
max_workers: Maximum number of parallel workers (threads), None = auto
Returns: Returns:
A dictionary mapping camera IDs to affinity results. A dictionary mapping camera IDs to affinity results.
""" """
@ -980,8 +987,30 @@ def calculate_affinity_matrix(
else: else:
detection_by_camera = classify_by_camera(detections) detection_by_camera = classify_by_camera(detections)
res: dict[CameraID, AffinityResult] = {} def _single_camera_job(
for camera_id, camera_detections in detection_by_camera.items(): camera_id: CameraID,
camera_detections: list[Detection],
trackings: Sequence[Tracking],
w_2d: float,
alpha_2d: float,
w_3d: float,
alpha_3d: float,
lambda_a: float,
) -> tuple[CameraID, AffinityResult]:
"""
Process a single camera's detections to calculate affinity and assignment.
This function is designed to be run in parallel for each camera.
Args:
camera_id: The camera ID
camera_detections: List of detections from this camera
trackings: Sequence of tracking objects
w_2d, alpha_2d, w_3d, alpha_3d, lambda_a: Affinity parameters
Returns:
A tuple of (camera_id, AffinityResult)
"""
# 1) Calculate affinity matrix
affinity_matrix = calculate_camera_affinity_matrix_jax( affinity_matrix = calculate_camera_affinity_matrix_jax(
trackings, trackings,
camera_detections, camera_detections,
@ -991,16 +1020,37 @@ def calculate_affinity_matrix(
alpha_3d, alpha_3d,
lambda_a, lambda_a,
) )
# row, col # 2) Calculate assignment
indices_T, indices_D = linear_sum_assignment(affinity_matrix) indices_T, indices_D = linear_sum_assignment(affinity_matrix)
affinity_result = AffinityResult( return camera_id, AffinityResult(
matrix=affinity_matrix, matrix=affinity_matrix,
trackings=trackings, trackings=trackings,
detections=camera_detections, detections=camera_detections,
indices_T=indices_T, indices_T=indices_T,
indices_D=indices_D, indices_D=indices_D,
) )
res[camera_id] = affinity_result
# Run cameras in parallel
res: dict[CameraID, AffinityResult] = {}
job = partial(
_single_camera_job,
trackings=trackings,
w_2d=w_2d,
alpha_2d=alpha_2d,
w_3d=w_3d,
alpha_3d=alpha_3d,
lambda_a=lambda_a,
)
with ThreadPoolExecutor(max_workers=max_workers) as pool:
futures = [
pool.submit(job, cam_id, cam_dets)
for cam_id, cam_dets in detection_by_camera.items()
]
for fut in as_completed(futures):
cam_id, out = fut.result()
res[cam_id] = out
return res return res
@ -1016,6 +1066,7 @@ trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id)
unmatched_detections = shallow_copy(next_group) unmatched_detections = shallow_copy(next_group)
camera_detections = classify_by_camera(unmatched_detections) camera_detections = classify_by_camera(unmatched_detections)
t0 = time.perf_counter()
affinities = calculate_affinity_matrix( affinities = calculate_affinity_matrix(
trackings, trackings,
unmatched_detections, unmatched_detections,
@ -1026,5 +1077,7 @@ affinities = calculate_affinity_matrix(
lambda_a=LAMBDA_A, lambda_a=LAMBDA_A,
) )
display(affinities) display(affinities)
t1 = time.perf_counter()
print(f"Time taken: {t1 - t0} seconds")
# %% # %%