from dataclasses import dataclass from datetime import datetime from typing import ( Any, Generator, Optional, TypeAlias, TypedDict, TypeVar, cast, overload, ) import jax import jax.numpy as jnp from beartype import beartype from jaxtyping import Array, Float, jaxtyped @jaxtyped(typechecker=beartype) @dataclass(frozen=True) class Tracking: id: int """ The tracking id """ keypoints: Float[Array, "J 3"] """ The 3D keypoints of the tracking """ last_active_timestamp: datetime velocity: Optional[Float[Array, "3"]] = None """ Could be `None`. Like when the 3D pose is initialized. `velocity` should be updated when target association yields a new 3D pose. """ def __repr__(self) -> str: return f"Tracking({self.id}, {self.last_active_timestamp})" def predict( self, delta_t_s: float, ) -> Float[Array, "J 3"]: """ Predict the 3D pose of a tracking based on its velocity. JAX-friendly implementation that avoids Python control flow. Args: delta_t_s: Time delta in seconds Returns: Predicted 3D pose keypoints """ # ------------------------------------------------------------------ # Step 1 – decide velocity on the Python side # ------------------------------------------------------------------ if self.velocity is None: velocity = jnp.zeros_like(self.keypoints) # (J, 3) else: velocity = self.velocity # (J, 3) # ------------------------------------------------------------------ # Step 2 – pure JAX math # ------------------------------------------------------------------ return self.keypoints + velocity * delta_t_s