1
0
forked from HQU-gxy/CVTH3PE

feat: Enhance playground.py with new tracking and affinity calculation functionalities

- Introduced new functions for calculating 2D distances and affinities between detections, improving tracking capabilities.
- Added a `Tracking` dataclass with detailed docstrings for better clarity on its attributes.
- Refactored code to utilize `shallow_copy` for handling detections and improved organization of imports.
- Enhanced the cross-view association logic to accommodate the new functionalities, ensuring better integration with existing tracking systems.
This commit is contained in:
2025-04-27 16:03:05 +08:00
parent 2e63a3f9bf
commit 5b5ccbd92b

View File

@ -13,7 +13,10 @@
# ---
# %%
from copy import copy as shallow_copy
from copy import deepcopy
from copy import deepcopy as deep_copy
from dataclasses import dataclass
from datetime import datetime, timedelta
from pathlib import Path
from typing import (
@ -34,14 +37,21 @@ import numpy as np
import orjson
from beartype import beartype
from cv2 import undistortPoints
from IPython.display import display
from jaxtyping import Array, Float, Num, jaxtyped
from matplotlib import pyplot as plt
from numpy.typing import ArrayLike
from scipy.spatial.transform import Rotation as R
from app.camera import Camera, CameraParams, Detection
from app.camera import (
Camera,
CameraParams,
Detection,
calculate_affinity_matrix_by_epipolar_constraint,
classify_by_camera,
)
from app.solver._old import GLPKSolver
from app.visualize.whole_body import visualize_whole_body
from IPython.display import display
NDArray: TypeAlias = np.ndarray
@ -316,14 +326,10 @@ sync_gen = sync_batch_gen(
)
# %%
detections = next(sync_gen)
# %%
from app.camera import calculate_affinity_matrix_by_epipolar_constraint
sorted_detections, affinity_matrix = calculate_affinity_matrix_by_epipolar_constraint(
detections, alpha_2d=2000
next(sync_gen), alpha_2d=2000
)
display(sorted_detections)
# %%
display(
@ -338,7 +344,6 @@ with jnp.printoptions(precision=3, suppress=True):
display(affinity_matrix)
# %%
from app.solver._old import GLPKSolver
def clusters_to_detections(
@ -464,23 +469,33 @@ def triangulate_points_from_multiple_views_linear(
in_axes=(None, 1, 1), # proj_matrices static, map over points[:,p,:], conf[:,p]
out_axes=0,
)
# returns (P, 3)
return vmap_triangulate(proj_matrices, points, conf)
# %%
from dataclasses import dataclass
from copy import copy as shallow_copy, deepcopy as deep_copy
@jaxtyped(typechecker=beartype)
@dataclass(frozen=True)
class Tracking:
id: int
"""
The tracking id
"""
keypoints: Float[Array, "J 3"]
"""
The 3D keypoints of the tracking
"""
last_active_timestamp: datetime
velocity: Optional[Float[Array, "3"]] = None
"""
Could be `None`. Like when the 3D pose is initialized.
`velocity` should be updated when target association yields a new
3D pose.
"""
def __repr__(self) -> str:
return f"Tracking({self.id}, {self.last_active_timestamp})"
@ -546,17 +561,98 @@ display(global_tracking_state)
next_group = next(sync_gen)
display(next_group)
# %%
from app.camera import classify_by_camera
# %%
@jaxtyped(typechecker=beartype)
def calculate_distance_2d(
left: Num[Array, "J 2"],
right: Num[Array, "J 2"],
image_size: tuple[int, int] = (1, 1),
):
"""
Calculate the *normalized* distance between two sets of keypoints.
Args:
left: The left keypoints
right: The right keypoints
image_size: The size of the image
"""
w, h = image_size
if w == 1 and h == 1:
# already normalized
left_normalized = left
right_normalized = right
else:
left_normalized = left / jnp.array([w, h])
right_normalized = right / jnp.array([w, h])
return jnp.linalg.norm(left_normalized - right_normalized, axis=-1)
@jaxtyped(typechecker=beartype)
def calculate_affinity_2d(
distance_2d: float, w_2d: float, alpha_2d: float, lambda_a: float, delta_t: float
) -> float:
"""
Calculate the affinity between two detections based on the distance between their keypoints.
Args:
distance_2d: The normalized distance between the two keypoints (see `calculate_distance_2d`)
w_2d: The weight of the distance (parameter)
alpha_2d: The alpha value for the distance calculation (parameter)
lambda_a: The lambda value for the distance calculation (parameter)
delta_t: The time delta between the two detections, in seconds
"""
return w_2d * (1 - distance_2d / (alpha_2d * delta_t)) * np.exp(-lambda_a * delta_t)
@jaxtyped(typechecker=beartype)
def perpendicular_distance_point_to_line_two_points(
point: Num[Array, "2"], line: tuple[Num[Array, "2"], Num[Array, "2"]]
):
"""
Calculate the perpendicular distance between a point and a line.
where `line` is represented by two points: `(line_start, line_end)`
"""
line_start, line_end = line
distance = jnp.linalg.norm(
jnp.cross(line_end - line_start, line_start - point)
) / jnp.linalg.norm(line_end - line_start)
return distance
def predict_pose_3d(
tracking: Tracking,
delta_t: float,
) -> Float[Array, "J 3"]:
"""
Predict the 3D pose of a tracking based on its velocity.
"""
if tracking.velocity is None:
return tracking.keypoints
return tracking.keypoints + tracking.velocity * delta_t
# %%
# let's do cross-view association
trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id)
detections = shallow_copy(next_group)
# cross-view association matrix with shape (T, D), where T is the number of trackings, D is the number of detections
affinity = np.zeros((len(trackings), len(detections)))
detection_by_camera = classify_by_camera(detections)
unmatched_detections = shallow_copy(next_group)
# cross-view association matrix with shape (T, D), where T is the number of
# trackings, D is the number of detections
# layout:
# a_t1_c1_d1, a_t1_c1_d2, a_t1_c1_d3,...,a_t1_c2_d1,..., a_t1_cc_dd
# a_t2_c1_d1,...
# ...
# a_tt_c1_d1,... , a_tt_cc_dd
#
# where T <- [t1..tt]; D <- join(c1..cc), where `cn` is a collection of
# detections from camera `n`
affinity = np.zeros((len(trackings), len(unmatched_detections)))
detection_by_camera = classify_by_camera(unmatched_detections)
for i, tracking in enumerate(trackings):
for c, detections in detection_by_camera.items():
camera = next(iter(detections)).camera
# pixel space, unnormalized
tracking_2d_projection = camera.project(tracking.keypoints)
for det in detections:
...