from dataclasses import dataclass from datetime import datetime from typing import ( Any, Callable, Generator, Optional, Sequence, TypeAlias, TypedDict, TypeVar, cast, overload, ) import jax import jax.numpy as jnp from beartype import beartype from beartype.typing import Mapping, Sequence from jax import Array from jaxtyping import Array, Float, Int, jaxtyped from app.camera import Detection @jaxtyped(typechecker=beartype) @dataclass(frozen=True) class Tracking: id: int """ The tracking id """ keypoints: Float[Array, "J 3"] """ The 3D keypoints of the tracking """ last_active_timestamp: datetime velocity: Optional[Float[Array, "3"]] = None """ Could be `None`. Like when the 3D pose is initialized. `velocity` should be updated when target association yields a new 3D pose. """ def __repr__(self) -> str: return f"Tracking({self.id}, {self.last_active_timestamp})" def predict( self, delta_t_s: float, ) -> Float[Array, "J 3"]: """ Predict the 3D pose of a tracking based on its velocity. JAX-friendly implementation that avoids Python control flow. Args: delta_t_s: Time delta in seconds Returns: Predicted 3D pose keypoints """ # ------------------------------------------------------------------ # Step 1 – decide velocity on the Python side # ------------------------------------------------------------------ if self.velocity is None: velocity = jnp.zeros_like(self.keypoints) # (J, 3) else: velocity = self.velocity # (J, 3) # ------------------------------------------------------------------ # Step 2 – pure JAX math # ------------------------------------------------------------------ return self.keypoints + velocity * delta_t_s @jaxtyped(typechecker=beartype) @dataclass class AffinityResult: """ Result of affinity computation between trackings and detections. """ matrix: Float[Array, "T D"] trackings: Sequence[Tracking] detections: Sequence[Detection] indices_T: Int[Array, "T"] # pylint: disable=invalid-name indices_D: Int[Array, "D"] # pylint: disable=invalid-name def tracking_detections( self, ) -> Generator[tuple[float, Tracking, Detection], None, None]: """ iterate over the best matching trackings and detections """ for t, d in zip(self.indices_T, self.indices_D): yield ( self.matrix[t, d].item(), self.trackings[t], self.detections[d], )