forked from HQU-gxy/CVTH3PE
- Added `pyrsistent` as a dependency to manage historical detections in the `Tracking` class, improving data integrity and immutability. - Updated the `GlobalTrackingState` to include historical detections using `PVector`, facilitating better tracking of past detections. - Introduced a new `update_tracking` function to handle timestamp updates for tracking objects, enhancing the tracking logic. - Refactored imports and type hints for improved organization and clarity across the codebase.
115 lines
2.9 KiB
Python
115 lines
2.9 KiB
Python
from dataclasses import dataclass
|
||
from datetime import datetime
|
||
from typing import (
|
||
Any,
|
||
Callable,
|
||
Generator,
|
||
Optional,
|
||
Sequence,
|
||
TypeAlias,
|
||
TypedDict,
|
||
TypeVar,
|
||
cast,
|
||
overload,
|
||
)
|
||
|
||
import jax.numpy as jnp
|
||
from beartype import beartype
|
||
from beartype.typing import Mapping, Sequence
|
||
from jax import Array
|
||
from jaxtyping import Array, Float, Int, jaxtyped
|
||
from pyrsistent import PVector
|
||
|
||
from app.camera import Detection
|
||
|
||
|
||
@jaxtyped(typechecker=beartype)
|
||
@dataclass(frozen=True)
|
||
class Tracking:
|
||
id: int
|
||
"""
|
||
The tracking id
|
||
"""
|
||
keypoints: Float[Array, "J 3"]
|
||
"""
|
||
The 3D keypoints of the tracking
|
||
|
||
Used for calculate affinity 3D
|
||
"""
|
||
last_active_timestamp: datetime
|
||
"""
|
||
The last active timestamp of the tracking
|
||
"""
|
||
|
||
historical_detections: PVector[Detection]
|
||
"""
|
||
Historical detections of the tracking.
|
||
|
||
Used for 3D re-triangulation
|
||
"""
|
||
|
||
velocity: Optional[Float[Array, "3"]] = None
|
||
"""
|
||
Could be `None`. Like when the 3D pose is initialized.
|
||
|
||
`velocity` should be updated when target association yields a new
|
||
3D pose.
|
||
"""
|
||
|
||
def __repr__(self) -> str:
|
||
return f"Tracking({self.id}, {self.last_active_timestamp})"
|
||
|
||
def predict(
|
||
self,
|
||
delta_t_s: float,
|
||
) -> Float[Array, "J 3"]:
|
||
"""
|
||
Predict the 3D pose of a tracking based on its velocity.
|
||
JAX-friendly implementation that avoids Python control flow.
|
||
|
||
Args:
|
||
delta_t_s: Time delta in seconds
|
||
|
||
Returns:
|
||
Predicted 3D pose keypoints
|
||
"""
|
||
# ------------------------------------------------------------------
|
||
# Step 1 – decide velocity on the Python side
|
||
# ------------------------------------------------------------------
|
||
if self.velocity is None:
|
||
velocity = jnp.zeros_like(self.keypoints) # (J, 3)
|
||
else:
|
||
velocity = self.velocity # (J, 3)
|
||
|
||
# ------------------------------------------------------------------
|
||
# Step 2 – pure JAX math
|
||
# ------------------------------------------------------------------
|
||
return self.keypoints + velocity * delta_t_s
|
||
|
||
|
||
@jaxtyped(typechecker=beartype)
|
||
@dataclass
|
||
class AffinityResult:
|
||
"""
|
||
Result of affinity computation between trackings and detections.
|
||
"""
|
||
|
||
matrix: Float[Array, "T D"]
|
||
trackings: Sequence[Tracking]
|
||
detections: Sequence[Detection]
|
||
indices_T: Int[Array, "T"] # pylint: disable=invalid-name
|
||
indices_D: Int[Array, "D"] # pylint: disable=invalid-name
|
||
|
||
def tracking_detections(
|
||
self,
|
||
) -> Generator[tuple[float, Tracking, Detection], None, None]:
|
||
"""
|
||
iterate over the best matching trackings and detections
|
||
"""
|
||
for t, d in zip(self.indices_T, self.indices_D):
|
||
yield (
|
||
self.matrix[t, d].item(),
|
||
self.trackings[t],
|
||
self.detections[d],
|
||
)
|