forked from HQU-gxy/CVTH3PE
refactor: Optimize affinity matrix calculation with parallel processing
- Introduced parallel processing using `ThreadPoolExecutor` to enhance the performance of the `calculate_affinity_matrix` function, allowing simultaneous processing of multiple cameras. - Added a new helper function `_single_camera_job` to encapsulate the logic for processing individual camera detections, improving code organization and readability. - Updated the function signature to include an optional `max_workers` parameter for controlling the number of threads used in parallel execution. - Enhanced documentation to clarify the purpose and parameters of the new parallel processing implementation.
This commit is contained in:
@ -14,7 +14,11 @@
|
|||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
## imports
|
||||||
|
|
||||||
|
import time
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
from copy import copy as shallow_copy
|
from copy import copy as shallow_copy
|
||||||
from copy import deepcopy as deep_copy
|
from copy import deepcopy as deep_copy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -960,9 +964,11 @@ def calculate_affinity_matrix(
|
|||||||
w_3d: float,
|
w_3d: float,
|
||||||
alpha_3d: float,
|
alpha_3d: float,
|
||||||
lambda_a: float,
|
lambda_a: float,
|
||||||
|
max_workers: int | None = None,
|
||||||
) -> dict[CameraID, AffinityResult]:
|
) -> dict[CameraID, AffinityResult]:
|
||||||
"""
|
"""
|
||||||
Calculate the affinity matrix between a set of trackings and detections.
|
Calculate the affinity matrix between a set of trackings and detections.
|
||||||
|
Uses parallel processing with ThreadPoolExecutor to handle multiple cameras concurrently.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
trackings: Sequence of tracking objects
|
trackings: Sequence of tracking objects
|
||||||
@ -972,6 +978,7 @@ def calculate_affinity_matrix(
|
|||||||
w_3d: Weight for 3D affinity
|
w_3d: Weight for 3D affinity
|
||||||
alpha_3d: Normalization factor for 3D distance
|
alpha_3d: Normalization factor for 3D distance
|
||||||
lambda_a: Decay rate for time difference
|
lambda_a: Decay rate for time difference
|
||||||
|
max_workers: Maximum number of parallel workers (threads), None = auto
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary mapping camera IDs to affinity results.
|
A dictionary mapping camera IDs to affinity results.
|
||||||
"""
|
"""
|
||||||
@ -980,8 +987,30 @@ def calculate_affinity_matrix(
|
|||||||
else:
|
else:
|
||||||
detection_by_camera = classify_by_camera(detections)
|
detection_by_camera = classify_by_camera(detections)
|
||||||
|
|
||||||
res: dict[CameraID, AffinityResult] = {}
|
def _single_camera_job(
|
||||||
for camera_id, camera_detections in detection_by_camera.items():
|
camera_id: CameraID,
|
||||||
|
camera_detections: list[Detection],
|
||||||
|
trackings: Sequence[Tracking],
|
||||||
|
w_2d: float,
|
||||||
|
alpha_2d: float,
|
||||||
|
w_3d: float,
|
||||||
|
alpha_3d: float,
|
||||||
|
lambda_a: float,
|
||||||
|
) -> tuple[CameraID, AffinityResult]:
|
||||||
|
"""
|
||||||
|
Process a single camera's detections to calculate affinity and assignment.
|
||||||
|
This function is designed to be run in parallel for each camera.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
camera_id: The camera ID
|
||||||
|
camera_detections: List of detections from this camera
|
||||||
|
trackings: Sequence of tracking objects
|
||||||
|
w_2d, alpha_2d, w_3d, alpha_3d, lambda_a: Affinity parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of (camera_id, AffinityResult)
|
||||||
|
"""
|
||||||
|
# 1) Calculate affinity matrix
|
||||||
affinity_matrix = calculate_camera_affinity_matrix_jax(
|
affinity_matrix = calculate_camera_affinity_matrix_jax(
|
||||||
trackings,
|
trackings,
|
||||||
camera_detections,
|
camera_detections,
|
||||||
@ -991,16 +1020,37 @@ def calculate_affinity_matrix(
|
|||||||
alpha_3d,
|
alpha_3d,
|
||||||
lambda_a,
|
lambda_a,
|
||||||
)
|
)
|
||||||
# row, col
|
# 2) Calculate assignment
|
||||||
indices_T, indices_D = linear_sum_assignment(affinity_matrix)
|
indices_T, indices_D = linear_sum_assignment(affinity_matrix)
|
||||||
affinity_result = AffinityResult(
|
return camera_id, AffinityResult(
|
||||||
matrix=affinity_matrix,
|
matrix=affinity_matrix,
|
||||||
trackings=trackings,
|
trackings=trackings,
|
||||||
detections=camera_detections,
|
detections=camera_detections,
|
||||||
indices_T=indices_T,
|
indices_T=indices_T,
|
||||||
indices_D=indices_D,
|
indices_D=indices_D,
|
||||||
)
|
)
|
||||||
res[camera_id] = affinity_result
|
|
||||||
|
# Run cameras in parallel
|
||||||
|
res: dict[CameraID, AffinityResult] = {}
|
||||||
|
job = partial(
|
||||||
|
_single_camera_job,
|
||||||
|
trackings=trackings,
|
||||||
|
w_2d=w_2d,
|
||||||
|
alpha_2d=alpha_2d,
|
||||||
|
w_3d=w_3d,
|
||||||
|
alpha_3d=alpha_3d,
|
||||||
|
lambda_a=lambda_a,
|
||||||
|
)
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||||
|
futures = [
|
||||||
|
pool.submit(job, cam_id, cam_dets)
|
||||||
|
for cam_id, cam_dets in detection_by_camera.items()
|
||||||
|
]
|
||||||
|
for fut in as_completed(futures):
|
||||||
|
cam_id, out = fut.result()
|
||||||
|
res[cam_id] = out
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
@ -1016,6 +1066,7 @@ trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id)
|
|||||||
unmatched_detections = shallow_copy(next_group)
|
unmatched_detections = shallow_copy(next_group)
|
||||||
camera_detections = classify_by_camera(unmatched_detections)
|
camera_detections = classify_by_camera(unmatched_detections)
|
||||||
|
|
||||||
|
t0 = time.perf_counter()
|
||||||
affinities = calculate_affinity_matrix(
|
affinities = calculate_affinity_matrix(
|
||||||
trackings,
|
trackings,
|
||||||
unmatched_detections,
|
unmatched_detections,
|
||||||
@ -1026,5 +1077,7 @@ affinities = calculate_affinity_matrix(
|
|||||||
lambda_a=LAMBDA_A,
|
lambda_a=LAMBDA_A,
|
||||||
)
|
)
|
||||||
display(affinities)
|
display(affinities)
|
||||||
|
t1 = time.perf_counter()
|
||||||
|
print(f"Time taken: {t1 - t0} seconds")
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
|||||||
Reference in New Issue
Block a user