forked from HQU-gxy/CVTH3PE
refactor: Update affinity matrix calculation and dependencies
- Replaced the `linear_sum_assignment` import from `scipy.optimize` with `hungarian_algorithm` from `optax` to enhance performance in affinity matrix calculations. - Introduced a new `AffinityResult` class to encapsulate results of affinity computations, including trackings and detections, improving the structure of the affinity calculation process. - Removed deprecated functions and debug print statements to streamline the codebase. - Updated `pyproject.toml` and `uv.lock` to include `optax` as a dependency, ensuring compatibility with the new implementation. - Refactored imports and type hints for better organization and consistency across the codebase.
This commit is contained in:
@ -2,8 +2,10 @@ from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Generator,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypeAlias,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
@ -14,7 +16,11 @@ from typing import (
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from beartype import beartype
|
||||
from jaxtyping import Array, Float, jaxtyped
|
||||
from beartype.typing import Mapping, Sequence
|
||||
from jax import Array
|
||||
from jaxtyping import Array, Float, Int, jaxtyped
|
||||
|
||||
from app.camera import Detection
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
@ -67,3 +73,30 @@ class Tracking:
|
||||
# Step 2 – pure JAX math
|
||||
# ------------------------------------------------------------------
|
||||
return self.keypoints + velocity * delta_t_s
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
@dataclass
|
||||
class AffinityResult:
|
||||
"""
|
||||
Result of affinity computation between trackings and detections.
|
||||
"""
|
||||
|
||||
matrix: Float[Array, "T D"]
|
||||
trackings: Sequence[Tracking]
|
||||
detections: Sequence[Detection]
|
||||
indices_T: Int[Array, "T"] # pylint: disable=invalid-name
|
||||
indices_D: Int[Array, "D"] # pylint: disable=invalid-name
|
||||
|
||||
def tracking_detections(
|
||||
self,
|
||||
) -> Generator[tuple[float, Tracking, Detection], None, None]:
|
||||
"""
|
||||
iterate over the best matching trackings and detections
|
||||
"""
|
||||
for t, d in zip(self.indices_T, self.indices_D):
|
||||
yield (
|
||||
self.matrix[t, d].item(),
|
||||
self.trackings[t],
|
||||
self.detections[d],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user