forked from HQU-gxy/CVTH3PE
feat: Introduce Tracking class and refactor pose prediction
- Added a new `Tracking` class to encapsulate tracking data, including keypoints and velocity, with a method for predicting 3D poses based on velocity. - Refactored the pose prediction logic in `calculate_camera_affinity_matrix_jax` to utilize the new `predict` method from the `Tracking` class, enhancing clarity and modularity. - Introduced an `AffinityResult` class to manage results of affinity computations, including trackings and detections, improving the structure of the affinity calculation process. - Updated imports and type hints for better organization and consistency across the codebase.
This commit is contained in:
69
app/tracking/__init__.py
Normal file
69
app/tracking/__init__.py
Normal file
@ -0,0 +1,69 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
Any,
|
||||
Generator,
|
||||
Optional,
|
||||
TypeAlias,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from beartype import beartype
|
||||
from jaxtyping import Array, Float, jaxtyped
|
||||
|
||||
|
||||
@jaxtyped(typechecker=beartype)
|
||||
@dataclass(frozen=True)
|
||||
class Tracking:
|
||||
id: int
|
||||
"""
|
||||
The tracking id
|
||||
"""
|
||||
keypoints: Float[Array, "J 3"]
|
||||
"""
|
||||
The 3D keypoints of the tracking
|
||||
"""
|
||||
last_active_timestamp: datetime
|
||||
|
||||
velocity: Optional[Float[Array, "3"]] = None
|
||||
"""
|
||||
Could be `None`. Like when the 3D pose is initialized.
|
||||
|
||||
`velocity` should be updated when target association yields a new
|
||||
3D pose.
|
||||
"""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Tracking({self.id}, {self.last_active_timestamp})"
|
||||
|
||||
def predict(
|
||||
self,
|
||||
delta_t_s: float,
|
||||
) -> Float[Array, "J 3"]:
|
||||
"""
|
||||
Predict the 3D pose of a tracking based on its velocity.
|
||||
JAX-friendly implementation that avoids Python control flow.
|
||||
|
||||
Args:
|
||||
delta_t_s: Time delta in seconds
|
||||
|
||||
Returns:
|
||||
Predicted 3D pose keypoints
|
||||
"""
|
||||
# ------------------------------------------------------------------
|
||||
# Step 1 – decide velocity on the Python side
|
||||
# ------------------------------------------------------------------
|
||||
if self.velocity is None:
|
||||
velocity = jnp.zeros_like(self.keypoints) # (J, 3)
|
||||
else:
|
||||
velocity = self.velocity # (J, 3)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 2 – pure JAX math
|
||||
# ------------------------------------------------------------------
|
||||
return self.keypoints + velocity * delta_t_s
|
||||
37
app/tracking/affinity_result.py
Normal file
37
app/tracking/affinity_result.py
Normal file
@ -0,0 +1,37 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Sequence, Callable, Generator
|
||||
|
||||
from app.camera import Detection
|
||||
from . import Tracking
|
||||
from beartype.typing import Sequence, Mapping
|
||||
from jaxtyping import jaxtyped, Float, Int
|
||||
from jax import Array
|
||||
|
||||
|
||||
@dataclass
|
||||
class AffinityResult:
|
||||
"""
|
||||
Result of affinity computation between trackings and detections.
|
||||
"""
|
||||
|
||||
matrix: Float[Array, "T D"]
|
||||
"""
|
||||
Affinity matrix between trackings and detections.
|
||||
"""
|
||||
|
||||
trackings: Sequence[Tracking]
|
||||
"""
|
||||
Trackings used to compute the affinity matrix.
|
||||
"""
|
||||
|
||||
detections: Sequence[Detection]
|
||||
"""
|
||||
Detections used to compute the affinity matrix.
|
||||
"""
|
||||
|
||||
indices_T: Sequence[int]
|
||||
indices_D: Sequence[int]
|
||||
|
||||
def tracking_detections(self) -> Generator[tuple[Tracking, Detection]]:
|
||||
for t, d in zip(self.indices_T, self.indices_D):
|
||||
yield (self.trackings[t], self.detections[d])
|
||||
Reference in New Issue
Block a user