diff --git a/playground.py b/playground.py index ca6fb4b..e6ef055 100644 --- a/playground.py +++ b/playground.py @@ -1197,13 +1197,22 @@ display(affinities) # %% def affinity_result_by_tracking( results: Iterable[AffinityResult], + min_affinity: float = 0.0, ) -> dict[TrackingID, list[Detection]]: """ Group affinity results by target ID. + + Args: + results: the affinity results to group + min_affinity: the minimum affinity to consider + Returns: + a dictionary mapping tracking IDs to a list of detections """ res: dict[TrackingID, list[Detection]] = defaultdict(list) for affinity_result in results: - for _affinity, t, d in affinity_result.tracking_association(): + for affinity, t, d in affinity_result.tracking_association(): + if affinity < min_affinity: + continue res[t.id].append(d) return res