feat: Enhance playground.py with new tracking and affinity calculation functionalities

- Introduced new functions for calculating 2D distances and affinities between detections, improving tracking capabilities.
- Added a `Tracking` dataclass with detailed docstrings for better clarity on its attributes.
- Refactored code to utilize `shallow_copy` for handling detections and improved organization of imports.
- Enhanced the cross-view association logic to accommodate the new functionalities, ensuring better integration with existing tracking systems.
This commit is contained in:
2025-04-27 16:03:05 +08:00
parent 2e63a3f9bf
commit 5b5ccbd92b

View File

@ -13,7 +13,10 @@
# --- # ---
# %% # %%
from copy import copy as shallow_copy
from copy import deepcopy from copy import deepcopy
from copy import deepcopy as deep_copy
from dataclasses import dataclass
from datetime import datetime, timedelta from datetime import datetime, timedelta
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
@ -34,14 +37,21 @@ import numpy as np
import orjson import orjson
from beartype import beartype from beartype import beartype
from cv2 import undistortPoints from cv2 import undistortPoints
from IPython.display import display
from jaxtyping import Array, Float, Num, jaxtyped from jaxtyping import Array, Float, Num, jaxtyped
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from numpy.typing import ArrayLike from numpy.typing import ArrayLike
from scipy.spatial.transform import Rotation as R from scipy.spatial.transform import Rotation as R
from app.camera import Camera, CameraParams, Detection from app.camera import (
Camera,
CameraParams,
Detection,
calculate_affinity_matrix_by_epipolar_constraint,
classify_by_camera,
)
from app.solver._old import GLPKSolver
from app.visualize.whole_body import visualize_whole_body from app.visualize.whole_body import visualize_whole_body
from IPython.display import display
NDArray: TypeAlias = np.ndarray NDArray: TypeAlias = np.ndarray
@ -316,14 +326,10 @@ sync_gen = sync_batch_gen(
) )
# %% # %%
detections = next(sync_gen)
# %%
from app.camera import calculate_affinity_matrix_by_epipolar_constraint
sorted_detections, affinity_matrix = calculate_affinity_matrix_by_epipolar_constraint( sorted_detections, affinity_matrix = calculate_affinity_matrix_by_epipolar_constraint(
detections, alpha_2d=2000 next(sync_gen), alpha_2d=2000
) )
display(sorted_detections)
# %% # %%
display( display(
@ -338,7 +344,6 @@ with jnp.printoptions(precision=3, suppress=True):
display(affinity_matrix) display(affinity_matrix)
# %% # %%
from app.solver._old import GLPKSolver
def clusters_to_detections( def clusters_to_detections(
@ -464,23 +469,33 @@ def triangulate_points_from_multiple_views_linear(
in_axes=(None, 1, 1), # proj_matrices static, map over points[:,p,:], conf[:,p] in_axes=(None, 1, 1), # proj_matrices static, map over points[:,p,:], conf[:,p]
out_axes=0, out_axes=0,
) )
# returns (P, 3)
return vmap_triangulate(proj_matrices, points, conf) return vmap_triangulate(proj_matrices, points, conf)
# %% # %%
from dataclasses import dataclass
from copy import copy as shallow_copy, deepcopy as deep_copy
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)
@dataclass(frozen=True) @dataclass(frozen=True)
class Tracking: class Tracking:
id: int id: int
"""
The tracking id
"""
keypoints: Float[Array, "J 3"] keypoints: Float[Array, "J 3"]
"""
The 3D keypoints of the tracking
"""
last_active_timestamp: datetime last_active_timestamp: datetime
velocity: Optional[Float[Array, "3"]] = None
"""
Could be `None`. Like when the 3D pose is initialized.
`velocity` should be updated when target association yields a new
3D pose.
"""
def __repr__(self) -> str: def __repr__(self) -> str:
return f"Tracking({self.id}, {self.last_active_timestamp})" return f"Tracking({self.id}, {self.last_active_timestamp})"
@ -546,17 +561,98 @@ display(global_tracking_state)
next_group = next(sync_gen) next_group = next(sync_gen)
display(next_group) display(next_group)
# %%
from app.camera import classify_by_camera
# %%
@jaxtyped(typechecker=beartype)
def calculate_distance_2d(
left: Num[Array, "J 2"],
right: Num[Array, "J 2"],
image_size: tuple[int, int] = (1, 1),
):
"""
Calculate the *normalized* distance between two sets of keypoints.
Args:
left: The left keypoints
right: The right keypoints
image_size: The size of the image
"""
w, h = image_size
if w == 1 and h == 1:
# already normalized
left_normalized = left
right_normalized = right
else:
left_normalized = left / jnp.array([w, h])
right_normalized = right / jnp.array([w, h])
return jnp.linalg.norm(left_normalized - right_normalized, axis=-1)
@jaxtyped(typechecker=beartype)
def calculate_affinity_2d(
distance_2d: float, w_2d: float, alpha_2d: float, lambda_a: float, delta_t: float
) -> float:
"""
Calculate the affinity between two detections based on the distance between their keypoints.
Args:
distance_2d: The normalized distance between the two keypoints (see `calculate_distance_2d`)
w_2d: The weight of the distance (parameter)
alpha_2d: The alpha value for the distance calculation (parameter)
lambda_a: The lambda value for the distance calculation (parameter)
delta_t: The time delta between the two detections, in seconds
"""
return w_2d * (1 - distance_2d / (alpha_2d * delta_t)) * np.exp(-lambda_a * delta_t)
@jaxtyped(typechecker=beartype)
def perpendicular_distance_point_to_line_two_points(
point: Num[Array, "2"], line: tuple[Num[Array, "2"], Num[Array, "2"]]
):
"""
Calculate the perpendicular distance between a point and a line.
where `line` is represented by two points: `(line_start, line_end)`
"""
line_start, line_end = line
distance = jnp.linalg.norm(
jnp.cross(line_end - line_start, line_start - point)
) / jnp.linalg.norm(line_end - line_start)
return distance
def predict_pose_3d(
tracking: Tracking,
delta_t: float,
) -> Float[Array, "J 3"]:
"""
Predict the 3D pose of a tracking based on its velocity.
"""
if tracking.velocity is None:
return tracking.keypoints
return tracking.keypoints + tracking.velocity * delta_t
# %%
# let's do cross-view association # let's do cross-view association
trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id) trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id)
detections = shallow_copy(next_group) unmatched_detections = shallow_copy(next_group)
# cross-view association matrix with shape (T, D), where T is the number of trackings, D is the number of detections # cross-view association matrix with shape (T, D), where T is the number of
affinity = np.zeros((len(trackings), len(detections))) # trackings, D is the number of detections
detection_by_camera = classify_by_camera(detections) # layout:
# a_t1_c1_d1, a_t1_c1_d2, a_t1_c1_d3,...,a_t1_c2_d1,..., a_t1_cc_dd
# a_t2_c1_d1,...
# ...
# a_tt_c1_d1,... , a_tt_cc_dd
#
# where T <- [t1..tt]; D <- join(c1..cc), where `cn` is a collection of
# detections from camera `n`
affinity = np.zeros((len(trackings), len(unmatched_detections)))
detection_by_camera = classify_by_camera(unmatched_detections)
for i, tracking in enumerate(trackings): for i, tracking in enumerate(trackings):
for c, detections in detection_by_camera.items(): for c, detections in detection_by_camera.items():
camera = next(iter(detections)).camera camera = next(iter(detections)).camera
# pixel space, unnormalized # pixel space, unnormalized
tracking_2d_projection = camera.project(tracking.keypoints) tracking_2d_projection = camera.project(tracking.keypoints)
for det in detections:
...