diff --git a/affinity_result.py b/affinity_result.py index a35cad7..91631e8 100644 --- a/affinity_result.py +++ b/affinity_result.py @@ -1,6 +1,7 @@ from dataclasses import dataclass -from typing import Sequence +from typing import Sequence, Callable, Generator +from app.camera import Detection from playground import Tracking from beartype.typing import Sequence, Mapping from jaxtyping import jaxtyped, Float, Int @@ -23,12 +24,14 @@ class AffinityResult: Trackings used to compute the affinity matrix. """ - indices_T: Sequence[int] + detections: Sequence[Detection] """ - Indices of the trackings that were used to compute the affinity matrix. + Detections used to compute the affinity matrix. """ + indices_T: Sequence[int] indices_D: Sequence[int] - """ - Indices of the detections that were used to compute the affinity matrix. - """ + + def tracking_detections(self) -> Generator[tuple[Tracking, Detection]]: + for t, d in zip(self.indices_T, self.indices_D): + yield (self.trackings[t], self.detections[d])