1
0
forked from HQU-gxy/CVTH3PE

feat: Introduce Tracking class and refactor pose prediction

- Added a new `Tracking` class to encapsulate tracking data, including keypoints and velocity, with a method for predicting 3D poses based on velocity.
- Refactored the pose prediction logic in `calculate_camera_affinity_matrix_jax` to utilize the new `predict` method from the `Tracking` class, enhancing clarity and modularity.
- Introduced an `AffinityResult` class to manage results of affinity computations, including trackings and detections, improving the structure of the affinity calculation process.
- Updated imports and type hints for better organization and consistency across the codebase.
This commit is contained in:
2025-04-29 12:14:08 +08:00
parent 86fcc5f283
commit 65cc646927
3 changed files with 76 additions and 60 deletions

69
app/tracking/__init__.py Normal file
View File

@ -0,0 +1,69 @@
from dataclasses import dataclass
from datetime import datetime
from typing import (
Any,
Generator,
Optional,
TypeAlias,
TypedDict,
TypeVar,
cast,
overload,
)
import jax
import jax.numpy as jnp
from beartype import beartype
from jaxtyping import Array, Float, jaxtyped
@jaxtyped(typechecker=beartype)
@dataclass(frozen=True)
class Tracking:
id: int
"""
The tracking id
"""
keypoints: Float[Array, "J 3"]
"""
The 3D keypoints of the tracking
"""
last_active_timestamp: datetime
velocity: Optional[Float[Array, "3"]] = None
"""
Could be `None`. Like when the 3D pose is initialized.
`velocity` should be updated when target association yields a new
3D pose.
"""
def __repr__(self) -> str:
return f"Tracking({self.id}, {self.last_active_timestamp})"
def predict(
self,
delta_t_s: float,
) -> Float[Array, "J 3"]:
"""
Predict the 3D pose of a tracking based on its velocity.
JAX-friendly implementation that avoids Python control flow.
Args:
delta_t_s: Time delta in seconds
Returns:
Predicted 3D pose keypoints
"""
# ------------------------------------------------------------------
# Step 1 decide velocity on the Python side
# ------------------------------------------------------------------
if self.velocity is None:
velocity = jnp.zeros_like(self.keypoints) # (J, 3)
else:
velocity = self.velocity # (J, 3)
# ------------------------------------------------------------------
# Step 2 pure JAX math
# ------------------------------------------------------------------
return self.keypoints + velocity * delta_t_s

View File

@ -0,0 +1,37 @@
from dataclasses import dataclass
from typing import Sequence, Callable, Generator
from app.camera import Detection
from . import Tracking
from beartype.typing import Sequence, Mapping
from jaxtyping import jaxtyped, Float, Int
from jax import Array
@dataclass
class AffinityResult:
"""
Result of affinity computation between trackings and detections.
"""
matrix: Float[Array, "T D"]
"""
Affinity matrix between trackings and detections.
"""
trackings: Sequence[Tracking]
"""
Trackings used to compute the affinity matrix.
"""
detections: Sequence[Detection]
"""
Detections used to compute the affinity matrix.
"""
indices_T: Sequence[int]
indices_D: Sequence[int]
def tracking_detections(self) -> Generator[tuple[Tracking, Detection]]:
for t, d in zip(self.indices_T, self.indices_D):
yield (self.trackings[t], self.detections[d])