forked from HQU-gxy/CVTH3PE
feat: Introduce Tracking class and refactor pose prediction
- Added a new `Tracking` class to encapsulate tracking data, including keypoints and velocity, with a method for predicting 3D poses based on velocity. - Refactored the pose prediction logic in `calculate_camera_affinity_matrix_jax` to utilize the new `predict` method from the `Tracking` class, enhancing clarity and modularity. - Introduced an `AffinityResult` class to manage results of affinity computations, including trackings and detections, improving the structure of the affinity calculation process. - Updated imports and type hints for better organization and consistency across the codebase.
This commit is contained in:
69
app/tracking/__init__.py
Normal file
69
app/tracking/__init__.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Generator,
|
||||||
|
Optional,
|
||||||
|
TypeAlias,
|
||||||
|
TypedDict,
|
||||||
|
TypeVar,
|
||||||
|
cast,
|
||||||
|
overload,
|
||||||
|
)
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from beartype import beartype
|
||||||
|
from jaxtyping import Array, Float, jaxtyped
|
||||||
|
|
||||||
|
|
||||||
|
@jaxtyped(typechecker=beartype)
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Tracking:
|
||||||
|
id: int
|
||||||
|
"""
|
||||||
|
The tracking id
|
||||||
|
"""
|
||||||
|
keypoints: Float[Array, "J 3"]
|
||||||
|
"""
|
||||||
|
The 3D keypoints of the tracking
|
||||||
|
"""
|
||||||
|
last_active_timestamp: datetime
|
||||||
|
|
||||||
|
velocity: Optional[Float[Array, "3"]] = None
|
||||||
|
"""
|
||||||
|
Could be `None`. Like when the 3D pose is initialized.
|
||||||
|
|
||||||
|
`velocity` should be updated when target association yields a new
|
||||||
|
3D pose.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"Tracking({self.id}, {self.last_active_timestamp})"
|
||||||
|
|
||||||
|
def predict(
|
||||||
|
self,
|
||||||
|
delta_t_s: float,
|
||||||
|
) -> Float[Array, "J 3"]:
|
||||||
|
"""
|
||||||
|
Predict the 3D pose of a tracking based on its velocity.
|
||||||
|
JAX-friendly implementation that avoids Python control flow.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
delta_t_s: Time delta in seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Predicted 3D pose keypoints
|
||||||
|
"""
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Step 1 – decide velocity on the Python side
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
if self.velocity is None:
|
||||||
|
velocity = jnp.zeros_like(self.keypoints) # (J, 3)
|
||||||
|
else:
|
||||||
|
velocity = self.velocity # (J, 3)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Step 2 – pure JAX math
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
return self.keypoints + velocity * delta_t_s
|
||||||
@ -2,7 +2,7 @@ from dataclasses import dataclass
|
|||||||
from typing import Sequence, Callable, Generator
|
from typing import Sequence, Callable, Generator
|
||||||
|
|
||||||
from app.camera import Detection
|
from app.camera import Detection
|
||||||
from playground import Tracking
|
from . import Tracking
|
||||||
from beartype.typing import Sequence, Mapping
|
from beartype.typing import Sequence, Mapping
|
||||||
from jaxtyping import jaxtyped, Float, Int
|
from jaxtyping import jaxtyped, Float, Int
|
||||||
from jax import Array
|
from jax import Array
|
||||||
@ -12,14 +12,14 @@
|
|||||||
# name: python3
|
# name: python3
|
||||||
# ---
|
# ---
|
||||||
|
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
from collections import OrderedDict
|
||||||
from copy import copy as shallow_copy
|
from copy import copy as shallow_copy
|
||||||
from copy import deepcopy
|
|
||||||
from copy import deepcopy as deep_copy
|
from copy import deepcopy as deep_copy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
from functools import partial, reduce
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
@ -38,14 +38,13 @@ import jax.numpy as jnp
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import orjson
|
import orjson
|
||||||
from beartype import beartype
|
from beartype import beartype
|
||||||
from beartype.typing import Sequence, Mapping
|
from beartype.typing import Mapping, Sequence
|
||||||
from cv2 import undistortPoints
|
from cv2 import undistortPoints
|
||||||
from IPython.display import display
|
from IPython.display import display
|
||||||
from jaxtyping import Array, Float, Num, jaxtyped
|
from jaxtyping import Array, Float, Num, jaxtyped
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from numpy.typing import ArrayLike
|
from numpy.typing import ArrayLike
|
||||||
from scipy.optimize import linear_sum_assignment
|
from scipy.optimize import linear_sum_assignment
|
||||||
from functools import partial, reduce
|
|
||||||
from scipy.spatial.transform import Rotation as R
|
from scipy.spatial.transform import Rotation as R
|
||||||
from typing_extensions import deprecated
|
from typing_extensions import deprecated
|
||||||
|
|
||||||
@ -58,6 +57,7 @@ from app.camera import (
|
|||||||
classify_by_camera,
|
classify_by_camera,
|
||||||
)
|
)
|
||||||
from app.solver._old import GLPKSolver
|
from app.solver._old import GLPKSolver
|
||||||
|
from app.tracking import Tracking
|
||||||
from app.visualize.whole_body import visualize_whole_body
|
from app.visualize.whole_body import visualize_whole_body
|
||||||
|
|
||||||
NDArray: TypeAlias = np.ndarray
|
NDArray: TypeAlias = np.ndarray
|
||||||
@ -503,29 +503,6 @@ def triangulate_points_from_multiple_views_linear(
|
|||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
@jaxtyped(typechecker=beartype)
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class Tracking:
|
|
||||||
id: int
|
|
||||||
"""
|
|
||||||
The tracking id
|
|
||||||
"""
|
|
||||||
keypoints: Float[Array, "J 3"]
|
|
||||||
"""
|
|
||||||
The 3D keypoints of the tracking
|
|
||||||
"""
|
|
||||||
last_active_timestamp: datetime
|
|
||||||
|
|
||||||
velocity: Optional[Float[Array, "3"]] = None
|
|
||||||
"""
|
|
||||||
Could be `None`. Like when the 3D pose is initialized.
|
|
||||||
|
|
||||||
`velocity` should be updated when target association yields a new
|
|
||||||
3D pose.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return f"Tracking({self.id}, {self.last_active_timestamp})"
|
|
||||||
|
|
||||||
|
|
||||||
@jaxtyped(typechecker=beartype)
|
@jaxtyped(typechecker=beartype)
|
||||||
@ -692,7 +669,7 @@ def perpendicular_distance_camera_2d_points_to_tracking_raycasting(
|
|||||||
# Clamp delta_t to avoid division-by-zero / exploding affinity.
|
# Clamp delta_t to avoid division-by-zero / exploding affinity.
|
||||||
delta_t = max(delta_t_raw, DELTA_T_MIN)
|
delta_t = max(delta_t_raw, DELTA_T_MIN)
|
||||||
delta_t_s = delta_t.total_seconds()
|
delta_t_s = delta_t.total_seconds()
|
||||||
predicted_pose = predict_pose_3d(tracking, delta_t_s)
|
predicted_pose = tracking.predict(delta_t_s)
|
||||||
|
|
||||||
# Back-project the 2D points to 3D space
|
# Back-project the 2D points to 3D space
|
||||||
# intersection with z=0 plane
|
# intersection with z=0 plane
|
||||||
@ -746,36 +723,6 @@ def calculate_affinity_3d(
|
|||||||
return affinity_per_keypoint
|
return affinity_per_keypoint
|
||||||
|
|
||||||
|
|
||||||
@jaxtyped(typechecker=beartype)
|
|
||||||
def predict_pose_3d(
|
|
||||||
tracking: Tracking,
|
|
||||||
delta_t_s: float,
|
|
||||||
) -> Float[Array, "J 3"]:
|
|
||||||
"""
|
|
||||||
Predict the 3D pose of a tracking based on its velocity.
|
|
||||||
JAX-friendly implementation that avoids Python control flow.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tracking: The tracking object containing keypoints and optional velocity
|
|
||||||
delta_t_s: Time delta in seconds
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Predicted 3D pose keypoints
|
|
||||||
"""
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Step 1 – decide velocity on the Python side
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
if tracking.velocity is None:
|
|
||||||
velocity = jnp.zeros_like(tracking.keypoints) # (J, 3)
|
|
||||||
else:
|
|
||||||
velocity = tracking.velocity # (J, 3)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Step 2 – pure JAX math
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
return tracking.keypoints + velocity * delta_t_s
|
|
||||||
|
|
||||||
|
|
||||||
@beartype
|
@beartype
|
||||||
def calculate_tracking_detection_affinity(
|
def calculate_tracking_detection_affinity(
|
||||||
tracking: Tracking,
|
tracking: Tracking,
|
||||||
@ -1137,7 +1084,6 @@ def calculate_camera_affinity_matrix_jax(
|
|||||||
P = predicted_pose
|
P = predicted_pose
|
||||||
v1 = P - p1
|
v1 = P - p1
|
||||||
v2 = p2[None, :, :, :] - p1 # (1, D, J, 3)
|
v2 = p2[None, :, :, :] - p1 # (1, D, J, 3)
|
||||||
# jax.debug.print
|
|
||||||
cross = jnp.cross(v1, v2) # (T, D, J, 3)
|
cross = jnp.cross(v1, v2) # (T, D, J, 3)
|
||||||
num = jnp.linalg.norm(cross, axis=-1) # (T, D, J)
|
num = jnp.linalg.norm(cross, axis=-1) # (T, D, J)
|
||||||
den = jnp.linalg.norm(v2, axis=-1) # (1, D, J)
|
den = jnp.linalg.norm(v2, axis=-1) # (1, D, J)
|
||||||
@ -1147,6 +1093,7 @@ def calculate_camera_affinity_matrix_jax(
|
|||||||
w_3d * (1 - dist3d / alpha_3d) * jnp.exp(-lambda_a * delta_t_broadcast)
|
w_3d * (1 - dist3d / alpha_3d) * jnp.exp(-lambda_a * delta_t_broadcast)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
jax.debug.print("Shapes: dist2d {} dist3d {}", dist2d.shape, dist3d.shape)
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Combine and reduce across keypoints → (T, D)
|
# Combine and reduce across keypoints → (T, D)
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|||||||
Reference in New Issue
Block a user