from dataclasses import dataclass from datetime import datetime from typing import ( Any, Callable, Generator, Optional, Sequence, TypeAlias, TypedDict, TypeVar, cast, overload, ) import jax.numpy as jnp from beartype import beartype from beartype.typing import Mapping, Sequence from jax import Array from jaxtyping import Array, Float, Int, jaxtyped from pyrsistent import PVector from app.camera import Detection @jaxtyped(typechecker=beartype) @dataclass(frozen=True) class Tracking: id: int """ The tracking id """ keypoints: Float[Array, "J 3"] """ The 3D keypoints of the tracking Used for calculate affinity 3D """ last_active_timestamp: datetime """ The last active timestamp of the tracking """ historical_detections: PVector[Detection] """ Historical detections of the tracking. Used for 3D re-triangulation """ velocity: Optional[Float[Array, "3"]] = None """ Could be `None`. Like when the 3D pose is initialized. `velocity` should be updated when target association yields a new 3D pose. """ def __repr__(self) -> str: return f"Tracking({self.id}, {self.last_active_timestamp})" def predict( self, delta_t_s: float, ) -> Float[Array, "J 3"]: """ Predict the 3D pose of a tracking based on its velocity. JAX-friendly implementation that avoids Python control flow. Args: delta_t_s: Time delta in seconds Returns: Predicted 3D pose keypoints """ # ------------------------------------------------------------------ # Step 1 – decide velocity on the Python side # ------------------------------------------------------------------ if self.velocity is None: velocity = jnp.zeros_like(self.keypoints) # (J, 3) else: velocity = self.velocity # (J, 3) # ------------------------------------------------------------------ # Step 2 – pure JAX math # ------------------------------------------------------------------ return self.keypoints + velocity * delta_t_s @jaxtyped(typechecker=beartype) @dataclass class AffinityResult: """ Result of affinity computation between trackings and detections. """ matrix: Float[Array, "T D"] trackings: Sequence[Tracking] detections: Sequence[Detection] indices_T: Int[Array, "T"] # pylint: disable=invalid-name indices_D: Int[Array, "D"] # pylint: disable=invalid-name def tracking_detections( self, ) -> Generator[tuple[float, Tracking, Detection], None, None]: """ iterate over the best matching trackings and detections """ for t, d in zip(self.indices_T, self.indices_D): yield ( self.matrix[t, d].item(), self.trackings[t], self.detections[d], )