feat: Introduce Tracking class and refactor pose prediction

- Added a new `Tracking` class to encapsulate tracking data, including keypoints and velocity, with a method for predicting 3D poses based on velocity.
- Refactored the pose prediction logic in `calculate_camera_affinity_matrix_jax` to utilize the new `predict` method from the `Tracking` class, enhancing clarity and modularity.
- Introduced an `AffinityResult` class to manage results of affinity computations, including trackings and detections, improving the structure of the affinity calculation process.
- Updated imports and type hints for better organization and consistency across the codebase.
This commit is contained in:
2025-04-29 12:14:08 +08:00
parent 86fcc5f283
commit 65cc646927
3 changed files with 76 additions and 60 deletions

69
app/tracking/__init__.py Normal file
View File

@ -0,0 +1,69 @@
from dataclasses import dataclass
from datetime import datetime
from typing import (
Any,
Generator,
Optional,
TypeAlias,
TypedDict,
TypeVar,
cast,
overload,
)
import jax
import jax.numpy as jnp
from beartype import beartype
from jaxtyping import Array, Float, jaxtyped
@jaxtyped(typechecker=beartype)
@dataclass(frozen=True)
class Tracking:
id: int
"""
The tracking id
"""
keypoints: Float[Array, "J 3"]
"""
The 3D keypoints of the tracking
"""
last_active_timestamp: datetime
velocity: Optional[Float[Array, "3"]] = None
"""
Could be `None`. Like when the 3D pose is initialized.
`velocity` should be updated when target association yields a new
3D pose.
"""
def __repr__(self) -> str:
return f"Tracking({self.id}, {self.last_active_timestamp})"
def predict(
self,
delta_t_s: float,
) -> Float[Array, "J 3"]:
"""
Predict the 3D pose of a tracking based on its velocity.
JAX-friendly implementation that avoids Python control flow.
Args:
delta_t_s: Time delta in seconds
Returns:
Predicted 3D pose keypoints
"""
# ------------------------------------------------------------------
# Step 1 decide velocity on the Python side
# ------------------------------------------------------------------
if self.velocity is None:
velocity = jnp.zeros_like(self.keypoints) # (J, 3)
else:
velocity = self.velocity # (J, 3)
# ------------------------------------------------------------------
# Step 2 pure JAX math
# ------------------------------------------------------------------
return self.keypoints + velocity * delta_t_s

View File

@ -2,7 +2,7 @@ from dataclasses import dataclass
from typing import Sequence, Callable, Generator from typing import Sequence, Callable, Generator
from app.camera import Detection from app.camera import Detection
from playground import Tracking from . import Tracking
from beartype.typing import Sequence, Mapping from beartype.typing import Sequence, Mapping
from jaxtyping import jaxtyped, Float, Int from jaxtyping import jaxtyped, Float, Int
from jax import Array from jax import Array

View File

@ -12,14 +12,14 @@
# name: python3 # name: python3
# --- # ---
from collections import OrderedDict
# %% # %%
from collections import OrderedDict
from copy import copy as shallow_copy from copy import copy as shallow_copy
from copy import deepcopy
from copy import deepcopy as deep_copy from copy import deepcopy as deep_copy
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timedelta from datetime import datetime, timedelta
from functools import partial, reduce
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
Any, Any,
@ -38,14 +38,13 @@ import jax.numpy as jnp
import numpy as np import numpy as np
import orjson import orjson
from beartype import beartype from beartype import beartype
from beartype.typing import Sequence, Mapping from beartype.typing import Mapping, Sequence
from cv2 import undistortPoints from cv2 import undistortPoints
from IPython.display import display from IPython.display import display
from jaxtyping import Array, Float, Num, jaxtyped from jaxtyping import Array, Float, Num, jaxtyped
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from numpy.typing import ArrayLike from numpy.typing import ArrayLike
from scipy.optimize import linear_sum_assignment from scipy.optimize import linear_sum_assignment
from functools import partial, reduce
from scipy.spatial.transform import Rotation as R from scipy.spatial.transform import Rotation as R
from typing_extensions import deprecated from typing_extensions import deprecated
@ -58,6 +57,7 @@ from app.camera import (
classify_by_camera, classify_by_camera,
) )
from app.solver._old import GLPKSolver from app.solver._old import GLPKSolver
from app.tracking import Tracking
from app.visualize.whole_body import visualize_whole_body from app.visualize.whole_body import visualize_whole_body
NDArray: TypeAlias = np.ndarray NDArray: TypeAlias = np.ndarray
@ -503,29 +503,6 @@ def triangulate_points_from_multiple_views_linear(
# %% # %%
@jaxtyped(typechecker=beartype)
@dataclass(frozen=True)
class Tracking:
id: int
"""
The tracking id
"""
keypoints: Float[Array, "J 3"]
"""
The 3D keypoints of the tracking
"""
last_active_timestamp: datetime
velocity: Optional[Float[Array, "3"]] = None
"""
Could be `None`. Like when the 3D pose is initialized.
`velocity` should be updated when target association yields a new
3D pose.
"""
def __repr__(self) -> str:
return f"Tracking({self.id}, {self.last_active_timestamp})"
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
@ -692,7 +669,7 @@ def perpendicular_distance_camera_2d_points_to_tracking_raycasting(
# Clamp delta_t to avoid division-by-zero / exploding affinity. # Clamp delta_t to avoid division-by-zero / exploding affinity.
delta_t = max(delta_t_raw, DELTA_T_MIN) delta_t = max(delta_t_raw, DELTA_T_MIN)
delta_t_s = delta_t.total_seconds() delta_t_s = delta_t.total_seconds()
predicted_pose = predict_pose_3d(tracking, delta_t_s) predicted_pose = tracking.predict(delta_t_s)
# Back-project the 2D points to 3D space # Back-project the 2D points to 3D space
# intersection with z=0 plane # intersection with z=0 plane
@ -746,36 +723,6 @@ def calculate_affinity_3d(
return affinity_per_keypoint return affinity_per_keypoint
@jaxtyped(typechecker=beartype)
def predict_pose_3d(
tracking: Tracking,
delta_t_s: float,
) -> Float[Array, "J 3"]:
"""
Predict the 3D pose of a tracking based on its velocity.
JAX-friendly implementation that avoids Python control flow.
Args:
tracking: The tracking object containing keypoints and optional velocity
delta_t_s: Time delta in seconds
Returns:
Predicted 3D pose keypoints
"""
# ------------------------------------------------------------------
# Step 1 decide velocity on the Python side
# ------------------------------------------------------------------
if tracking.velocity is None:
velocity = jnp.zeros_like(tracking.keypoints) # (J, 3)
else:
velocity = tracking.velocity # (J, 3)
# ------------------------------------------------------------------
# Step 2 pure JAX math
# ------------------------------------------------------------------
return tracking.keypoints + velocity * delta_t_s
@beartype @beartype
def calculate_tracking_detection_affinity( def calculate_tracking_detection_affinity(
tracking: Tracking, tracking: Tracking,
@ -1137,7 +1084,6 @@ def calculate_camera_affinity_matrix_jax(
P = predicted_pose P = predicted_pose
v1 = P - p1 v1 = P - p1
v2 = p2[None, :, :, :] - p1 # (1, D, J, 3) v2 = p2[None, :, :, :] - p1 # (1, D, J, 3)
# jax.debug.print
cross = jnp.cross(v1, v2) # (T, D, J, 3) cross = jnp.cross(v1, v2) # (T, D, J, 3)
num = jnp.linalg.norm(cross, axis=-1) # (T, D, J) num = jnp.linalg.norm(cross, axis=-1) # (T, D, J)
den = jnp.linalg.norm(v2, axis=-1) # (1, D, J) den = jnp.linalg.norm(v2, axis=-1) # (1, D, J)
@ -1147,6 +1093,7 @@ def calculate_camera_affinity_matrix_jax(
w_3d * (1 - dist3d / alpha_3d) * jnp.exp(-lambda_a * delta_t_broadcast) w_3d * (1 - dist3d / alpha_3d) * jnp.exp(-lambda_a * delta_t_broadcast)
) )
jax.debug.print("Shapes: dist2d {} dist3d {}", dist2d.shape, dist3d.shape)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Combine and reduce across keypoints → (T, D) # Combine and reduce across keypoints → (T, D)
# ------------------------------------------------------------------ # ------------------------------------------------------------------