Files
CVTH3PE/app/tracking/__init__.py
crosstyan 072bf1c46f feat: Integrate pyrsistent for enhanced tracking state management
- Added `pyrsistent` as a dependency to manage historical detections in the `Tracking` class, improving data integrity and immutability.
- Updated the `GlobalTrackingState` to include historical detections using `PVector`, facilitating better tracking of past detections.
- Introduced a new `update_tracking` function to handle timestamp updates for tracking objects, enhancing the tracking logic.
- Refactored imports and type hints for improved organization and clarity across the codebase.
2025-04-30 10:04:39 +08:00

115 lines
2.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from dataclasses import dataclass
from datetime import datetime
from typing import (
Any,
Callable,
Generator,
Optional,
Sequence,
TypeAlias,
TypedDict,
TypeVar,
cast,
overload,
)
import jax.numpy as jnp
from beartype import beartype
from beartype.typing import Mapping, Sequence
from jax import Array
from jaxtyping import Array, Float, Int, jaxtyped
from pyrsistent import PVector
from app.camera import Detection
@jaxtyped(typechecker=beartype)
@dataclass(frozen=True)
class Tracking:
id: int
"""
The tracking id
"""
keypoints: Float[Array, "J 3"]
"""
The 3D keypoints of the tracking
Used for calculate affinity 3D
"""
last_active_timestamp: datetime
"""
The last active timestamp of the tracking
"""
historical_detections: PVector[Detection]
"""
Historical detections of the tracking.
Used for 3D re-triangulation
"""
velocity: Optional[Float[Array, "3"]] = None
"""
Could be `None`. Like when the 3D pose is initialized.
`velocity` should be updated when target association yields a new
3D pose.
"""
def __repr__(self) -> str:
return f"Tracking({self.id}, {self.last_active_timestamp})"
def predict(
self,
delta_t_s: float,
) -> Float[Array, "J 3"]:
"""
Predict the 3D pose of a tracking based on its velocity.
JAX-friendly implementation that avoids Python control flow.
Args:
delta_t_s: Time delta in seconds
Returns:
Predicted 3D pose keypoints
"""
# ------------------------------------------------------------------
# Step 1 decide velocity on the Python side
# ------------------------------------------------------------------
if self.velocity is None:
velocity = jnp.zeros_like(self.keypoints) # (J, 3)
else:
velocity = self.velocity # (J, 3)
# ------------------------------------------------------------------
# Step 2 pure JAX math
# ------------------------------------------------------------------
return self.keypoints + velocity * delta_t_s
@jaxtyped(typechecker=beartype)
@dataclass
class AffinityResult:
"""
Result of affinity computation between trackings and detections.
"""
matrix: Float[Array, "T D"]
trackings: Sequence[Tracking]
detections: Sequence[Detection]
indices_T: Int[Array, "T"] # pylint: disable=invalid-name
indices_D: Int[Array, "D"] # pylint: disable=invalid-name
def tracking_detections(
self,
) -> Generator[tuple[float, Tracking, Detection], None, None]:
"""
iterate over the best matching trackings and detections
"""
for t, d in zip(self.indices_T, self.indices_D):
yield (
self.matrix[t, d].item(),
self.trackings[t],
self.detections[d],
)