refactor: Update affinity matrix calculation and dependencies

- Replaced the `linear_sum_assignment` import from `scipy.optimize` with `hungarian_algorithm` from `optax` to enhance performance in affinity matrix calculations.
- Introduced a new `AffinityResult` class to encapsulate results of affinity computations, including trackings and detections, improving the structure of the affinity calculation process.
- Removed deprecated functions and debug print statements to streamline the codebase.
- Updated `pyproject.toml` and `uv.lock` to include `optax` as a dependency, ensuring compatibility with the new implementation.
- Refactored imports and type hints for better organization and consistency across the codebase.
This commit is contained in:
2025-04-29 15:45:24 +08:00
parent ce1d5f3cf7
commit 29ca66ad47
6 changed files with 152 additions and 529 deletions

View File

@ -2,8 +2,10 @@ from dataclasses import dataclass
from datetime import datetime
from typing import (
Any,
Callable,
Generator,
Optional,
Sequence,
TypeAlias,
TypedDict,
TypeVar,
@ -14,7 +16,11 @@ from typing import (
import jax
import jax.numpy as jnp
from beartype import beartype
from jaxtyping import Array, Float, jaxtyped
from beartype.typing import Mapping, Sequence
from jax import Array
from jaxtyping import Array, Float, Int, jaxtyped
from app.camera import Detection
@jaxtyped(typechecker=beartype)
@ -67,3 +73,30 @@ class Tracking:
# Step 2 pure JAX math
# ------------------------------------------------------------------
return self.keypoints + velocity * delta_t_s
@jaxtyped(typechecker=beartype)
@dataclass
class AffinityResult:
"""
Result of affinity computation between trackings and detections.
"""
matrix: Float[Array, "T D"]
trackings: Sequence[Tracking]
detections: Sequence[Detection]
indices_T: Int[Array, "T"] # pylint: disable=invalid-name
indices_D: Int[Array, "D"] # pylint: disable=invalid-name
def tracking_detections(
self,
) -> Generator[tuple[float, Tracking, Detection], None, None]:
"""
iterate over the best matching trackings and detections
"""
for t, d in zip(self.indices_T, self.indices_D):
yield (
self.matrix[t, d].item(),
self.trackings[t],
self.detections[d],
)

View File

@ -1,37 +0,0 @@
from dataclasses import dataclass
from typing import Sequence, Callable, Generator
from app.camera import Detection
from . import Tracking
from beartype.typing import Sequence, Mapping
from jaxtyping import jaxtyped, Float, Int
from jax import Array
@dataclass
class AffinityResult:
"""
Result of affinity computation between trackings and detections.
"""
matrix: Float[Array, "T D"]
"""
Affinity matrix between trackings and detections.
"""
trackings: Sequence[Tracking]
"""
Trackings used to compute the affinity matrix.
"""
detections: Sequence[Detection]
"""
Detections used to compute the affinity matrix.
"""
indices_T: Sequence[int]
indices_D: Sequence[int]
def tracking_detections(self) -> Generator[tuple[Tracking, Detection]]:
for t, d in zip(self.indices_T, self.indices_D):
yield (self.trackings[t], self.detections[d])