diff --git a/app/tracking/__init__.py b/app/tracking/__init__.py new file mode 100644 index 0000000..e3b7916 --- /dev/null +++ b/app/tracking/__init__.py @@ -0,0 +1,69 @@ +from dataclasses import dataclass +from datetime import datetime +from typing import ( + Any, + Generator, + Optional, + TypeAlias, + TypedDict, + TypeVar, + cast, + overload, +) + +import jax +import jax.numpy as jnp +from beartype import beartype +from jaxtyping import Array, Float, jaxtyped + + +@jaxtyped(typechecker=beartype) +@dataclass(frozen=True) +class Tracking: + id: int + """ + The tracking id + """ + keypoints: Float[Array, "J 3"] + """ + The 3D keypoints of the tracking + """ + last_active_timestamp: datetime + + velocity: Optional[Float[Array, "3"]] = None + """ + Could be `None`. Like when the 3D pose is initialized. + + `velocity` should be updated when target association yields a new + 3D pose. + """ + + def __repr__(self) -> str: + return f"Tracking({self.id}, {self.last_active_timestamp})" + + def predict( + self, + delta_t_s: float, + ) -> Float[Array, "J 3"]: + """ + Predict the 3D pose of a tracking based on its velocity. + JAX-friendly implementation that avoids Python control flow. + + Args: + delta_t_s: Time delta in seconds + + Returns: + Predicted 3D pose keypoints + """ + # ------------------------------------------------------------------ + # Step 1 – decide velocity on the Python side + # ------------------------------------------------------------------ + if self.velocity is None: + velocity = jnp.zeros_like(self.keypoints) # (J, 3) + else: + velocity = self.velocity # (J, 3) + + # ------------------------------------------------------------------ + # Step 2 – pure JAX math + # ------------------------------------------------------------------ + return self.keypoints + velocity * delta_t_s diff --git a/affinity_result.py b/app/tracking/affinity_result.py similarity index 96% rename from affinity_result.py rename to app/tracking/affinity_result.py index 91631e8..78d325f 100644 --- a/affinity_result.py +++ b/app/tracking/affinity_result.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from typing import Sequence, Callable, Generator from app.camera import Detection -from playground import Tracking +from . import Tracking from beartype.typing import Sequence, Mapping from jaxtyping import jaxtyped, Float, Int from jax import Array diff --git a/playground.py b/playground.py index 6135f7a..68067c6 100644 --- a/playground.py +++ b/playground.py @@ -12,14 +12,14 @@ # name: python3 # --- -from collections import OrderedDict # %% +from collections import OrderedDict from copy import copy as shallow_copy -from copy import deepcopy from copy import deepcopy as deep_copy from dataclasses import dataclass from datetime import datetime, timedelta +from functools import partial, reduce from pathlib import Path from typing import ( Any, @@ -38,14 +38,13 @@ import jax.numpy as jnp import numpy as np import orjson from beartype import beartype -from beartype.typing import Sequence, Mapping +from beartype.typing import Mapping, Sequence from cv2 import undistortPoints from IPython.display import display from jaxtyping import Array, Float, Num, jaxtyped from matplotlib import pyplot as plt from numpy.typing import ArrayLike from scipy.optimize import linear_sum_assignment -from functools import partial, reduce from scipy.spatial.transform import Rotation as R from typing_extensions import deprecated @@ -58,6 +57,7 @@ from app.camera import ( classify_by_camera, ) from app.solver._old import GLPKSolver +from app.tracking import Tracking from app.visualize.whole_body import visualize_whole_body NDArray: TypeAlias = np.ndarray @@ -503,29 +503,6 @@ def triangulate_points_from_multiple_views_linear( # %% -@jaxtyped(typechecker=beartype) -@dataclass(frozen=True) -class Tracking: - id: int - """ - The tracking id - """ - keypoints: Float[Array, "J 3"] - """ - The 3D keypoints of the tracking - """ - last_active_timestamp: datetime - - velocity: Optional[Float[Array, "3"]] = None - """ - Could be `None`. Like when the 3D pose is initialized. - - `velocity` should be updated when target association yields a new - 3D pose. - """ - - def __repr__(self) -> str: - return f"Tracking({self.id}, {self.last_active_timestamp})" @jaxtyped(typechecker=beartype) @@ -692,7 +669,7 @@ def perpendicular_distance_camera_2d_points_to_tracking_raycasting( # Clamp delta_t to avoid division-by-zero / exploding affinity. delta_t = max(delta_t_raw, DELTA_T_MIN) delta_t_s = delta_t.total_seconds() - predicted_pose = predict_pose_3d(tracking, delta_t_s) + predicted_pose = tracking.predict(delta_t_s) # Back-project the 2D points to 3D space # intersection with z=0 plane @@ -746,36 +723,6 @@ def calculate_affinity_3d( return affinity_per_keypoint -@jaxtyped(typechecker=beartype) -def predict_pose_3d( - tracking: Tracking, - delta_t_s: float, -) -> Float[Array, "J 3"]: - """ - Predict the 3D pose of a tracking based on its velocity. - JAX-friendly implementation that avoids Python control flow. - - Args: - tracking: The tracking object containing keypoints and optional velocity - delta_t_s: Time delta in seconds - - Returns: - Predicted 3D pose keypoints - """ - # ------------------------------------------------------------------ - # Step 1 – decide velocity on the Python side - # ------------------------------------------------------------------ - if tracking.velocity is None: - velocity = jnp.zeros_like(tracking.keypoints) # (J, 3) - else: - velocity = tracking.velocity # (J, 3) - - # ------------------------------------------------------------------ - # Step 2 – pure JAX math - # ------------------------------------------------------------------ - return tracking.keypoints + velocity * delta_t_s - - @beartype def calculate_tracking_detection_affinity( tracking: Tracking, @@ -1137,7 +1084,6 @@ def calculate_camera_affinity_matrix_jax( P = predicted_pose v1 = P - p1 v2 = p2[None, :, :, :] - p1 # (1, D, J, 3) - # jax.debug.print cross = jnp.cross(v1, v2) # (T, D, J, 3) num = jnp.linalg.norm(cross, axis=-1) # (T, D, J) den = jnp.linalg.norm(v2, axis=-1) # (1, D, J) @@ -1147,6 +1093,7 @@ def calculate_camera_affinity_matrix_jax( w_3d * (1 - dist3d / alpha_3d) * jnp.exp(-lambda_a * delta_t_broadcast) ) + jax.debug.print("Shapes: dist2d {} dist3d {}", dist2d.shape, dist3d.shape) # ------------------------------------------------------------------ # Combine and reduce across keypoints → (T, D) # ------------------------------------------------------------------