1
0
forked from HQU-gxy/CVTH3PE
This commit is contained in:
2025-04-28 16:39:23 +08:00
parent ebcd38eb52
commit 7ee4002567

View File

@ -47,6 +47,7 @@ from matplotlib import pyplot as plt
from numpy.typing import ArrayLike from numpy.typing import ArrayLike
from scipy.optimize import linear_sum_assignment from scipy.optimize import linear_sum_assignment
from scipy.spatial.transform import Rotation as R from scipy.spatial.transform import Rotation as R
from typing_extensions import deprecated
from app.camera import ( from app.camera import (
Camera, Camera,
@ -349,9 +350,8 @@ display(
with jnp.printoptions(precision=3, suppress=True): with jnp.printoptions(precision=3, suppress=True):
display(affinity_matrix) display(affinity_matrix)
# %% # %%
def clusters_to_detections( def clusters_to_detections(
clusters: Sequence[Sequence[int]], sorted_detections: Sequence[Detection] clusters: Sequence[Sequence[int]], sorted_detections: Sequence[Detection]
) -> list[list[Detection]]: ) -> list[list[Detection]]:
@ -375,6 +375,19 @@ clusters, sol_matrix = solver.solve(aff_np)
display(clusters) display(clusters)
display(sol_matrix) display(sol_matrix)
# %%
T = TypeVar("T")
def flatten_values(
d: Mapping[Any, Sequence[T]],
) -> list[T]:
"""
Flatten a dictionary of sequences into a single list of values.
"""
return [v for vs in d.values() for v in vs]
# %% # %%
WIDTH = 2560 WIDTH = 2560
HEIGHT = 1440 HEIGHT = 1440
@ -792,6 +805,9 @@ def calculate_tracking_detection_affinity(
# %% # %%
@deprecated(
"Use `calculate_camera_affinity_matrix` instead. This implementation has the problem of under-utilizing views from different cameras."
)
@beartype @beartype
def calculate_affinity_matrix( def calculate_affinity_matrix(
trackings: Sequence[Tracking], trackings: Sequence[Tracking],
@ -880,28 +896,31 @@ def calculate_camera_affinity_matrix(
lambda_a: float, lambda_a: float,
) -> Float[Array, "T D"]: ) -> Float[Array, "T D"]:
""" """
Calculate an affinity matrix between trackings and detections from a single camera. Vectorized version (with JAX) that computes the affinity matrix between a set
of *trackings* and *detections* coming from **one** camera.
This follows the iterative camera-by-camera approach from the paper The whole computation is done with JAX array operations and `vmap` no
"Cross-View Tracking for Multi-Human 3D Pose Estimation at over 100 FPS". explicit Python ``for``-loops over the (T, D) pairs. This makes the routine
Instead of creating one large matrix for all cameras, this creates fully parallelisable on CPU/GPU/TPU without any extra `jit` compilation.
a separate matrix for each camera, which can be processed independently.
Args: Args
trackings: Sequence of tracking objects -----
camera_detections: Sequence of detection objects, from the same camera trackings : Sequence[Tracking]
w_2d: Weight for 2D affinity Existing 3-D track states (length = T)
alpha_2d: Normalization factor for 2D distance camera_detections : Sequence[Detection]
w_3d: Weight for 3D affinity Detections from *a single* camera (length = D). All detections **must**
alpha_3d: Normalization factor for 3D distance share the same ``detection.camera`` instance.
lambda_a: Decay rate for time difference w_2d, alpha_2d, w_3d, alpha_3d, lambda_a : float
Hyper-parameters exactly as defined in the paper (and earlier helper
functions).
Returns: Returns
Affinity matrix of shape (T, D) where: -------
- T = number of trackings (rows) affinity : jnp.ndarray (T x D)
- D = number of detections from this specific camera (columns) Affinity matrix between each tracking (row) and detection (column).
Matrix Layout: Matrix Layout
-------
The affinity matrix for a single camera has shape (T, D), where: The affinity matrix for a single camera has shape (T, D), where:
- T = number of trackings (rows) - T = number of trackings (rows)
- D = number of detections from this camera (columns) - D = number of detections from this camera (columns)
@ -922,100 +941,107 @@ def calculate_camera_affinity_matrix(
computed using both 2D and 3D geometric correspondences. computed using both 2D and 3D geometric correspondences.
""" """
def verify_all_detection_from_same_camera(detections: Sequence[Detection]): # ---------- Safety checks & early exits --------------------------------
if not detections: if len(trackings) == 0 or len(camera_detections) == 0:
return True return jnp.zeros((len(trackings), len(camera_detections))) # pragma: no cover
camera_id = next(iter(detections)).camera.id
return all(map(lambda d: d.camera.id == camera_id, detections))
if not verify_all_detection_from_same_camera(camera_detections): # Ensure all detections come from the *same* camera
raise ValueError("All detections must be from the same camera") cam_id_ref = camera_detections[0].camera.id
if any(det.camera.id != cam_id_ref for det in camera_detections):
affinity = jnp.zeros((len(trackings), len(camera_detections))) raise ValueError(
"All detections given to calculate_camera_affinity_matrix must come from the same camera."
for i, tracking in enumerate(trackings):
for j, det in enumerate(camera_detections):
affinity_value = calculate_tracking_detection_affinity(
tracking,
det,
w_2d=w_2d,
alpha_2d=alpha_2d,
w_3d=w_3d,
alpha_3d=alpha_3d,
lambda_a=lambda_a,
)
affinity = affinity.at[i, j].set(affinity_value)
return affinity
@beartype
def process_detections_iteratively(
trackings: Sequence[Tracking],
detections: Sequence[Detection],
w_2d: float = 1.0,
alpha_2d: float = 1.0,
w_3d: float = 1.0,
alpha_3d: float = 1.0,
lambda_a: float = 0.1,
) -> list[tuple[int, Detection]]:
"""
Process detections iteratively camera by camera, matching them to trackings.
This implements the paper's approach where each camera is processed
independently, and the affinity matrix is calculated for one camera at a time.
This approach has several advantages:
1. Computational cost scales linearly with number of cameras
2. Can handle non-synchronized camera frames
3. More efficient for large-scale camera systems
Args:
trackings: Sequence of tracking objects
detections: Sequence of detection objects
w_2d: Weight for 2D affinity
alpha_2d: Normalization factor for 2D distance
w_3d: Weight for 3D affinity
alpha_3d: Normalization factor for 3D distance
lambda_a: Decay rate for time difference
Returns:
List of (tracking_index, detection) pairs representing matches
"""
# Group detections by camera
detection_by_camera = classify_by_camera(detections)
# Store matches between trackings and detections
matches = []
# Process each camera one by one
for camera_id, camera_detections in detection_by_camera.items():
# Calculate affinity matrix for this camera only
camera_affinity = calculate_camera_affinity_matrix(
trackings,
camera_detections,
w_2d=w_2d,
alpha_2d=alpha_2d,
w_3d=w_3d,
alpha_3d=alpha_3d,
lambda_a=lambda_a,
) )
# Apply Hungarian algorithm for this camera only camera = camera_detections[0].camera # shared camera object
tracking_indices, detection_indices = linear_sum_assignment( cam_w, cam_h = map(int, camera.params.image_size)
camera_affinity, maximize=True cam_center = camera.params.location # (3,)
# ---------- Pack tracking data into JAX arrays -------------------------
# (T, J, 3)
track_kps_3d = jnp.stack([trk.keypoints for trk in trackings])
# (T, 3) velocity zero if None
velocities = jnp.stack(
[
(
trk.velocity
if trk.velocity is not None
else jnp.zeros(3, dtype=jnp.float32)
)
for trk in trackings
]
) )
tracking_indices = cast(Sequence[int], tracking_indices)
detection_indices = cast(Sequence[int], detection_indices)
# Add matches to result # (T,) last update timestamps (float seconds)
for t_idx, d_idx in zip(tracking_indices, detection_indices): track_last_ts = jnp.array(
# Skip matches with zero or negative affinity [trk.last_active_timestamp.timestamp() for trk in trackings]
if camera_affinity[t_idx, d_idx] <= 0: )
continue
matches.append((t_idx, camera_detections[d_idx])) # Pre-project 3-D tracking points into 2-D for *this* camera (T, J, 2)
track_proj_2d = jax.vmap(camera.project)(track_kps_3d)
return matches # ---------- Pack detection data ----------------------------------------
# (D, J, 2)
det_kps_2d = jnp.stack([det.keypoints for det in camera_detections])
# (D,) detection timestamps (float seconds)
det_ts = jnp.array([det.timestamp.timestamp() for det in camera_detections])
# Back-project detection 2-D points to the z=0 plane in world coords (D, J, 3)
det_backproj_3d = camera.unproject_points_to_z_plane(det_kps_2d, z=0.0)
# ---------- Broadcast / compute pair-wise quantities --------------------
# Time differences Δt (T, D) always non-negative because detections are newer
delta_t = jnp.maximum(det_ts[None, :] - track_last_ts[:, None], 0.0)
# ---------- 2-D affinity --------------------------------------------------
# Normalise 2-D points by image size (already handled in helper but easier here)
track_proj_norm = track_proj_2d / jnp.array([cam_w, cam_h]) # (T, J, 2)
det_kps_norm = det_kps_2d / jnp.array([cam_w, cam_h]) # (D, J, 2)
# (T, D, J) Euclidean distances in normalised image space
dist_2d = jnp.linalg.norm(
track_proj_norm[:, None, :, :] - det_kps_norm[None, :, :, :],
axis=-1,
)
# (T, D, 1) for broadcasting with J dimension
delta_t_exp = delta_t[:, :, None]
affinity_2d_per_kp = (
w_2d
* (1.0 - dist_2d / (alpha_2d * jnp.clip(delta_t_exp, a_min=1e-6)))
* jnp.exp(-lambda_a * delta_t_exp)
)
affinity_2d = jnp.sum(affinity_2d_per_kp, axis=-1) # (T, D)
# ---------- 3-D affinity --------------------------------------------------
# Predict 3-D pose at detection time for each (T, D) pair (T, D, J, 3)
predicted_pose = (
track_kps_3d[:, None, :, :]
+ velocities[:, None, None, :] * delta_t_exp[..., None]
)
# Camera ray for each detection/keypoint (1, D, J, 3)
line_vec = det_backproj_3d[None, :, :, :] - cam_center # broadcast (T, D, J, 3)
# Vector from camera centre to predicted point (T, D, J, 3)
vec_cam_to_pred = cam_center - predicted_pose
# Cross-product norm and distance
cross_prod = jnp.cross(line_vec, vec_cam_to_pred)
numer = jnp.linalg.norm(cross_prod, axis=-1) # (T, D, J)
denom = jnp.linalg.norm(line_vec, axis=-1) # (1, D, J) broadcast automatically
dist_3d = numer / jnp.clip(denom, a_min=1e-6)
affinity_3d_per_kp = (
w_3d * (1.0 - dist_3d / alpha_3d) * jnp.exp(-lambda_a * delta_t_exp)
)
affinity_3d = jnp.sum(affinity_3d_per_kp, axis=-1) # (T, D)
# ---------- Final affinity ----------------------------------------------
affinity_total = affinity_2d + affinity_3d # (T, D)
return affinity_total
# %% # %%
@ -1028,10 +1054,11 @@ ALPHA_3D = 1.0
trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id) trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id)
unmatched_detections = shallow_copy(next_group) unmatched_detections = shallow_copy(next_group)
camera_detections = classify_by_camera(unmatched_detections)
affinity, detection_by_camera = calculate_affinity_matrix( affinity = calculate_camera_affinity_matrix(
trackings, trackings,
unmatched_detections, next(iter(camera_detections.values())),
w_2d=W_2D, w_2d=W_2D,
alpha_2d=ALPHA_2D, alpha_2d=ALPHA_2D,
w_3d=W_3D, w_3d=W_3D,
@ -1041,23 +1068,6 @@ affinity, detection_by_camera = calculate_affinity_matrix(
display(affinity) display(affinity)
# %%
T = TypeVar("T")
def flatten_values(
d: Mapping[Any, Sequence[T]],
) -> list[T]:
"""
Flatten a dictionary of sequences into a single list of values.
"""
return [v for vs in d.values() for v in vs]
detections_sorted = flatten_values(detection_by_camera)
display(detections_sorted)
display(detection_by_camera)
# %% # %%
# Perform Hungarian algorithm for assignment for each camera # Perform Hungarian algorithm for assignment for each camera
indices_T, indices_D = linear_sum_assignment(affinity, maximize=True) indices_T, indices_D = linear_sum_assignment(affinity, maximize=True)