feat: Enhance playground.py with new tracking and affinity calculation functionalities
- Introduced new functions for calculating 2D distances and affinities between detections, improving tracking capabilities. - Added a `Tracking` dataclass with detailed docstrings for better clarity on its attributes. - Refactored code to utilize `shallow_copy` for handling detections and improved organization of imports. - Enhanced the cross-view association logic to accommodate the new functionalities, ensuring better integration with existing tracking systems.
This commit is contained in:
134
playground.py
134
playground.py
@ -13,7 +13,10 @@
|
|||||||
# ---
|
# ---
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
|
from copy import copy as shallow_copy
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from copy import deepcopy as deep_copy
|
||||||
|
from dataclasses import dataclass
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import (
|
||||||
@ -34,14 +37,21 @@ import numpy as np
|
|||||||
import orjson
|
import orjson
|
||||||
from beartype import beartype
|
from beartype import beartype
|
||||||
from cv2 import undistortPoints
|
from cv2 import undistortPoints
|
||||||
|
from IPython.display import display
|
||||||
from jaxtyping import Array, Float, Num, jaxtyped
|
from jaxtyping import Array, Float, Num, jaxtyped
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from numpy.typing import ArrayLike
|
from numpy.typing import ArrayLike
|
||||||
from scipy.spatial.transform import Rotation as R
|
from scipy.spatial.transform import Rotation as R
|
||||||
|
|
||||||
from app.camera import Camera, CameraParams, Detection
|
from app.camera import (
|
||||||
|
Camera,
|
||||||
|
CameraParams,
|
||||||
|
Detection,
|
||||||
|
calculate_affinity_matrix_by_epipolar_constraint,
|
||||||
|
classify_by_camera,
|
||||||
|
)
|
||||||
|
from app.solver._old import GLPKSolver
|
||||||
from app.visualize.whole_body import visualize_whole_body
|
from app.visualize.whole_body import visualize_whole_body
|
||||||
from IPython.display import display
|
|
||||||
|
|
||||||
NDArray: TypeAlias = np.ndarray
|
NDArray: TypeAlias = np.ndarray
|
||||||
|
|
||||||
@ -316,14 +326,10 @@ sync_gen = sync_batch_gen(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
detections = next(sync_gen)
|
|
||||||
|
|
||||||
# %%
|
|
||||||
from app.camera import calculate_affinity_matrix_by_epipolar_constraint
|
|
||||||
|
|
||||||
sorted_detections, affinity_matrix = calculate_affinity_matrix_by_epipolar_constraint(
|
sorted_detections, affinity_matrix = calculate_affinity_matrix_by_epipolar_constraint(
|
||||||
detections, alpha_2d=2000
|
next(sync_gen), alpha_2d=2000
|
||||||
)
|
)
|
||||||
|
display(sorted_detections)
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
display(
|
display(
|
||||||
@ -338,7 +344,6 @@ with jnp.printoptions(precision=3, suppress=True):
|
|||||||
display(affinity_matrix)
|
display(affinity_matrix)
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
from app.solver._old import GLPKSolver
|
|
||||||
|
|
||||||
|
|
||||||
def clusters_to_detections(
|
def clusters_to_detections(
|
||||||
@ -464,23 +469,33 @@ def triangulate_points_from_multiple_views_linear(
|
|||||||
in_axes=(None, 1, 1), # proj_matrices static, map over points[:,p,:], conf[:,p]
|
in_axes=(None, 1, 1), # proj_matrices static, map over points[:,p,:], conf[:,p]
|
||||||
out_axes=0,
|
out_axes=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
# returns (P, 3)
|
|
||||||
return vmap_triangulate(proj_matrices, points, conf)
|
return vmap_triangulate(proj_matrices, points, conf)
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
from dataclasses import dataclass
|
|
||||||
from copy import copy as shallow_copy, deepcopy as deep_copy
|
|
||||||
|
|
||||||
|
|
||||||
@jaxtyped(typechecker=beartype)
|
@jaxtyped(typechecker=beartype)
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class Tracking:
|
class Tracking:
|
||||||
id: int
|
id: int
|
||||||
|
"""
|
||||||
|
The tracking id
|
||||||
|
"""
|
||||||
keypoints: Float[Array, "J 3"]
|
keypoints: Float[Array, "J 3"]
|
||||||
|
"""
|
||||||
|
The 3D keypoints of the tracking
|
||||||
|
"""
|
||||||
last_active_timestamp: datetime
|
last_active_timestamp: datetime
|
||||||
|
|
||||||
|
velocity: Optional[Float[Array, "3"]] = None
|
||||||
|
"""
|
||||||
|
Could be `None`. Like when the 3D pose is initialized.
|
||||||
|
|
||||||
|
`velocity` should be updated when target association yields a new
|
||||||
|
3D pose.
|
||||||
|
"""
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"Tracking({self.id}, {self.last_active_timestamp})"
|
return f"Tracking({self.id}, {self.last_active_timestamp})"
|
||||||
|
|
||||||
@ -546,17 +561,98 @@ display(global_tracking_state)
|
|||||||
next_group = next(sync_gen)
|
next_group = next(sync_gen)
|
||||||
display(next_group)
|
display(next_group)
|
||||||
|
|
||||||
# %%
|
|
||||||
from app.camera import classify_by_camera
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
@jaxtyped(typechecker=beartype)
|
||||||
|
def calculate_distance_2d(
|
||||||
|
left: Num[Array, "J 2"],
|
||||||
|
right: Num[Array, "J 2"],
|
||||||
|
image_size: tuple[int, int] = (1, 1),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Calculate the *normalized* distance between two sets of keypoints.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
left: The left keypoints
|
||||||
|
right: The right keypoints
|
||||||
|
image_size: The size of the image
|
||||||
|
"""
|
||||||
|
w, h = image_size
|
||||||
|
if w == 1 and h == 1:
|
||||||
|
# already normalized
|
||||||
|
left_normalized = left
|
||||||
|
right_normalized = right
|
||||||
|
else:
|
||||||
|
left_normalized = left / jnp.array([w, h])
|
||||||
|
right_normalized = right / jnp.array([w, h])
|
||||||
|
return jnp.linalg.norm(left_normalized - right_normalized, axis=-1)
|
||||||
|
|
||||||
|
|
||||||
|
@jaxtyped(typechecker=beartype)
|
||||||
|
def calculate_affinity_2d(
|
||||||
|
distance_2d: float, w_2d: float, alpha_2d: float, lambda_a: float, delta_t: float
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Calculate the affinity between two detections based on the distance between their keypoints.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
distance_2d: The normalized distance between the two keypoints (see `calculate_distance_2d`)
|
||||||
|
w_2d: The weight of the distance (parameter)
|
||||||
|
alpha_2d: The alpha value for the distance calculation (parameter)
|
||||||
|
lambda_a: The lambda value for the distance calculation (parameter)
|
||||||
|
delta_t: The time delta between the two detections, in seconds
|
||||||
|
"""
|
||||||
|
return w_2d * (1 - distance_2d / (alpha_2d * delta_t)) * np.exp(-lambda_a * delta_t)
|
||||||
|
|
||||||
|
|
||||||
|
@jaxtyped(typechecker=beartype)
|
||||||
|
def perpendicular_distance_point_to_line_two_points(
|
||||||
|
point: Num[Array, "2"], line: tuple[Num[Array, "2"], Num[Array, "2"]]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Calculate the perpendicular distance between a point and a line.
|
||||||
|
|
||||||
|
where `line` is represented by two points: `(line_start, line_end)`
|
||||||
|
"""
|
||||||
|
line_start, line_end = line
|
||||||
|
distance = jnp.linalg.norm(
|
||||||
|
jnp.cross(line_end - line_start, line_start - point)
|
||||||
|
) / jnp.linalg.norm(line_end - line_start)
|
||||||
|
return distance
|
||||||
|
|
||||||
|
|
||||||
|
def predict_pose_3d(
|
||||||
|
tracking: Tracking,
|
||||||
|
delta_t: float,
|
||||||
|
) -> Float[Array, "J 3"]:
|
||||||
|
"""
|
||||||
|
Predict the 3D pose of a tracking based on its velocity.
|
||||||
|
"""
|
||||||
|
if tracking.velocity is None:
|
||||||
|
return tracking.keypoints
|
||||||
|
return tracking.keypoints + tracking.velocity * delta_t
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
# let's do cross-view association
|
# let's do cross-view association
|
||||||
trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id)
|
trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id)
|
||||||
detections = shallow_copy(next_group)
|
unmatched_detections = shallow_copy(next_group)
|
||||||
# cross-view association matrix with shape (T, D), where T is the number of trackings, D is the number of detections
|
# cross-view association matrix with shape (T, D), where T is the number of
|
||||||
affinity = np.zeros((len(trackings), len(detections)))
|
# trackings, D is the number of detections
|
||||||
detection_by_camera = classify_by_camera(detections)
|
# layout:
|
||||||
|
# a_t1_c1_d1, a_t1_c1_d2, a_t1_c1_d3,...,a_t1_c2_d1,..., a_t1_cc_dd
|
||||||
|
# a_t2_c1_d1,...
|
||||||
|
# ...
|
||||||
|
# a_tt_c1_d1,... , a_tt_cc_dd
|
||||||
|
#
|
||||||
|
# where T <- [t1..tt]; D <- join(c1..cc), where `cn` is a collection of
|
||||||
|
# detections from camera `n`
|
||||||
|
affinity = np.zeros((len(trackings), len(unmatched_detections)))
|
||||||
|
detection_by_camera = classify_by_camera(unmatched_detections)
|
||||||
for i, tracking in enumerate(trackings):
|
for i, tracking in enumerate(trackings):
|
||||||
for c, detections in detection_by_camera.items():
|
for c, detections in detection_by_camera.items():
|
||||||
camera = next(iter(detections)).camera
|
camera = next(iter(detections)).camera
|
||||||
# pixel space, unnormalized
|
# pixel space, unnormalized
|
||||||
tracking_2d_projection = camera.project(tracking.keypoints)
|
tracking_2d_projection = camera.project(tracking.keypoints)
|
||||||
|
for det in detections:
|
||||||
|
...
|
||||||
|
|||||||
Reference in New Issue
Block a user