feat: Introduce Tracking class and refactor pose prediction

- Added a new `Tracking` class to encapsulate tracking data, including keypoints and velocity, with a method for predicting 3D poses based on velocity.
- Refactored the pose prediction logic in `calculate_camera_affinity_matrix_jax` to utilize the new `predict` method from the `Tracking` class, enhancing clarity and modularity.
- Introduced an `AffinityResult` class to manage results of affinity computations, including trackings and detections, improving the structure of the affinity calculation process.
- Updated imports and type hints for better organization and consistency across the codebase.
This commit is contained in:
2025-04-29 12:14:08 +08:00
parent 86fcc5f283
commit 65cc646927
3 changed files with 76 additions and 60 deletions

69
app/tracking/__init__.py Normal file
View File

@ -0,0 +1,69 @@
from dataclasses import dataclass
from datetime import datetime
from typing import (
Any,
Generator,
Optional,
TypeAlias,
TypedDict,
TypeVar,
cast,
overload,
)
import jax
import jax.numpy as jnp
from beartype import beartype
from jaxtyping import Array, Float, jaxtyped
@jaxtyped(typechecker=beartype)
@dataclass(frozen=True)
class Tracking:
id: int
"""
The tracking id
"""
keypoints: Float[Array, "J 3"]
"""
The 3D keypoints of the tracking
"""
last_active_timestamp: datetime
velocity: Optional[Float[Array, "3"]] = None
"""
Could be `None`. Like when the 3D pose is initialized.
`velocity` should be updated when target association yields a new
3D pose.
"""
def __repr__(self) -> str:
return f"Tracking({self.id}, {self.last_active_timestamp})"
def predict(
self,
delta_t_s: float,
) -> Float[Array, "J 3"]:
"""
Predict the 3D pose of a tracking based on its velocity.
JAX-friendly implementation that avoids Python control flow.
Args:
delta_t_s: Time delta in seconds
Returns:
Predicted 3D pose keypoints
"""
# ------------------------------------------------------------------
# Step 1 decide velocity on the Python side
# ------------------------------------------------------------------
if self.velocity is None:
velocity = jnp.zeros_like(self.keypoints) # (J, 3)
else:
velocity = self.velocity # (J, 3)
# ------------------------------------------------------------------
# Step 2 pure JAX math
# ------------------------------------------------------------------
return self.keypoints + velocity * delta_t_s