1
0
forked from HQU-gxy/CVTH3PE
This commit is contained in:
2025-04-28 18:01:24 +08:00
parent 7ee4002567
commit b3ed20296a
2 changed files with 103 additions and 143 deletions

View File

@ -227,11 +227,13 @@ def project(
# Fall back to normalized coordinates if image_size not provided # Fall back to normalized coordinates if image_size not provided
valid = jnp.all(p2d >= 0, axis=1) & jnp.all(p2d < 1, axis=1) valid = jnp.all(p2d >= 0, axis=1) & jnp.all(p2d < 1, axis=1)
# only valid points need distortion # Distort *all* points, then blend results using `where` to keep
if jnp.any(valid): # numerical traces inside JAX this avoids Python ``if`` with a traced
valid_p2d = p2d[valid] # value (which triggers TracerBoolConversionError when the function is
distorted_valid = distortion(valid_p2d, K, dist_coeffs) # vmapped/jitted).
p2d = p2d.at[valid].set(distorted_valid) distorted_all = distortion(p2d, K, dist_coeffs)
# Broadcast the valid mask over the last (x,y) dimension
p2d = jnp.where(valid[:, None], distorted_all, p2d)
elif dist_coeffs is None and K is None: elif dist_coeffs is None and K is None:
pass pass
else: else:
@ -239,7 +241,7 @@ def project(
"dist_coeffs and K must be provided together to compute distortion" "dist_coeffs and K must be provided together to compute distortion"
) )
return jnp.squeeze(p2d) return p2d # type: ignore
@jaxtyped(typechecker=beartype) @jaxtyped(typechecker=beartype)

View File

@ -46,6 +46,7 @@ from jaxtyping import Array, Float, Num, jaxtyped
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from numpy.typing import ArrayLike from numpy.typing import ArrayLike
from scipy.optimize import linear_sum_assignment from scipy.optimize import linear_sum_assignment
from functools import partial, reduce
from scipy.spatial.transform import Rotation as R from scipy.spatial.transform import Rotation as R
from typing_extensions import deprecated from typing_extensions import deprecated
@ -65,6 +66,7 @@ NDArray: TypeAlias = np.ndarray
# %% # %%
DATASET_PATH = Path("samples") / "04_02" DATASET_PATH = Path("samples") / "04_02"
AK_CAMERA_DATASET: ak.Array = ak.from_parquet(DATASET_PATH / "camera_params.parquet") AK_CAMERA_DATASET: ak.Array = ak.from_parquet(DATASET_PATH / "camera_params.parquet")
DELTA_T_MIN = timedelta(milliseconds=10)
display(AK_CAMERA_DATASET) display(AK_CAMERA_DATASET)
@ -388,6 +390,16 @@ def flatten_values(
return [v for vs in d.values() for v in vs] return [v for vs in d.values() for v in vs]
def flatten_values_len(
d: Mapping[Any, Sequence[T]],
) -> int:
"""
Flatten a dictionary of sequences into a single list of values.
"""
val = reduce(lambda acc, xs: acc + len(xs), d.values(), 0)
return val
# %% # %%
WIDTH = 2560 WIDTH = 2560
HEIGHT = 1440 HEIGHT = 1440
@ -676,6 +688,10 @@ def perpendicular_distance_camera_2d_points_to_tracking_raycasting(
Array of perpendicular distances for each keypoint Array of perpendicular distances for each keypoint
""" """
camera = detection.camera camera = detection.camera
assert detection.timestamp >= tracking.last_active_timestamp
delta_t_raw = detection.timestamp - tracking.last_active_timestamp
# Clamp delta_t to avoid division-by-zero / exploding affinity.
delta_t = max(delta_t_raw, DELTA_T_MIN)
delta_t_s = delta_t.total_seconds() delta_t_s = delta_t.total_seconds()
predicted_pose = predict_pose_3d(tracking, delta_t_s) predicted_pose = predict_pose_3d(tracking, delta_t_s)
@ -769,7 +785,9 @@ def calculate_tracking_detection_affinity(
Combined affinity score Combined affinity score
""" """
camera = detection.camera camera = detection.camera
delta_t = detection.timestamp - tracking.last_active_timestamp delta_t_raw = detection.timestamp - tracking.last_active_timestamp
# Clamp delta_t to avoid division-by-zero / exploding affinity.
delta_t = max(delta_t_raw, DELTA_T_MIN)
# Calculate 2D affinity # Calculate 2D affinity
tracking_2d_projection = camera.project(tracking.keypoints) tracking_2d_projection = camera.project(tracking.keypoints)
@ -811,7 +829,7 @@ def calculate_tracking_detection_affinity(
@beartype @beartype
def calculate_affinity_matrix( def calculate_affinity_matrix(
trackings: Sequence[Tracking], trackings: Sequence[Tracking],
detections: Sequence[Detection], detections: Sequence[Detection] | OrderedDict[CameraID, list[Detection]],
w_2d: float, w_2d: float,
alpha_2d: float, alpha_2d: float,
w_3d: float, w_3d: float,
@ -863,6 +881,11 @@ def calculate_affinity_matrix(
through the detection_by_camera dictionary, which is returned alongside through the detection_by_camera dictionary, which is returned alongside
the matrix to maintain this relationship. the matrix to maintain this relationship.
""" """
if isinstance(detections, OrderedDict):
D = flatten_values_len(detections)
affinity = jnp.zeros((len(trackings), D))
detection_by_camera = detections
else:
affinity = jnp.zeros((len(trackings), len(detections))) affinity = jnp.zeros((len(trackings), len(detections)))
detection_by_camera = classify_by_camera(detections) detection_by_camera = classify_by_camera(detections)
@ -896,31 +919,28 @@ def calculate_camera_affinity_matrix(
lambda_a: float, lambda_a: float,
) -> Float[Array, "T D"]: ) -> Float[Array, "T D"]:
""" """
Vectorized version (with JAX) that computes the affinity matrix between a set Calculate an affinity matrix between trackings and detections from a single camera.
of *trackings* and *detections* coming from **one** camera.
The whole computation is done with JAX array operations and `vmap` no This follows the iterative camera-by-camera approach from the paper
explicit Python ``for``-loops over the (T, D) pairs. This makes the routine "Cross-View Tracking for Multi-Human 3D Pose Estimation at over 100 FPS".
fully parallelisable on CPU/GPU/TPU without any extra `jit` compilation. Instead of creating one large matrix for all cameras, this creates
a separate matrix for each camera, which can be processed independently.
Args Args:
----- trackings: Sequence of tracking objects
trackings : Sequence[Tracking] camera_detections: Sequence of detection objects, from the same camera
Existing 3-D track states (length = T) w_2d: Weight for 2D affinity
camera_detections : Sequence[Detection] alpha_2d: Normalization factor for 2D distance
Detections from *a single* camera (length = D). All detections **must** w_3d: Weight for 3D affinity
share the same ``detection.camera`` instance. alpha_3d: Normalization factor for 3D distance
w_2d, alpha_2d, w_3d, alpha_3d, lambda_a : float lambda_a: Decay rate for time difference
Hyper-parameters exactly as defined in the paper (and earlier helper
functions).
Returns Returns:
------- Affinity matrix of shape (T, D) where:
affinity : jnp.ndarray (T x D) - T = number of trackings (rows)
Affinity matrix between each tracking (row) and detection (column). - D = number of detections from this specific camera (columns)
Matrix Layout Matrix Layout:
-------
The affinity matrix for a single camera has shape (T, D), where: The affinity matrix for a single camera has shape (T, D), where:
- T = number of trackings (rows) - T = number of trackings (rows)
- D = number of detections from this camera (columns) - D = number of detections from this camera (columns)
@ -941,107 +961,31 @@ def calculate_camera_affinity_matrix(
computed using both 2D and 3D geometric correspondences. computed using both 2D and 3D geometric correspondences.
""" """
# ---------- Safety checks & early exits -------------------------------- def verify_all_detection_from_same_camera(detections: Sequence[Detection]):
if len(trackings) == 0 or len(camera_detections) == 0: if not detections:
return jnp.zeros((len(trackings), len(camera_detections))) # pragma: no cover return True
camera_id = next(iter(detections)).camera.id
return all(map(lambda d: d.camera.id == camera_id, detections))
# Ensure all detections come from the *same* camera if not verify_all_detection_from_same_camera(camera_detections):
cam_id_ref = camera_detections[0].camera.id raise ValueError("All detections must be from the same camera")
if any(det.camera.id != cam_id_ref for det in camera_detections):
raise ValueError( affinity = jnp.zeros((len(trackings), len(camera_detections)))
"All detections given to calculate_camera_affinity_matrix must come from the same camera."
for i, tracking in enumerate(trackings):
for j, det in enumerate(camera_detections):
affinity_value = calculate_tracking_detection_affinity(
tracking,
det,
w_2d=w_2d,
alpha_2d=alpha_2d,
w_3d=w_3d,
alpha_3d=alpha_3d,
lambda_a=lambda_a,
) )
affinity = affinity.at[i, j].set(affinity_value)
camera = camera_detections[0].camera # shared camera object return affinity
cam_w, cam_h = map(int, camera.params.image_size)
cam_center = camera.params.location # (3,)
# ---------- Pack tracking data into JAX arrays -------------------------
# (T, J, 3)
track_kps_3d = jnp.stack([trk.keypoints for trk in trackings])
# (T, 3) velocity zero if None
velocities = jnp.stack(
[
(
trk.velocity
if trk.velocity is not None
else jnp.zeros(3, dtype=jnp.float32)
)
for trk in trackings
]
)
# (T,) last update timestamps (float seconds)
track_last_ts = jnp.array(
[trk.last_active_timestamp.timestamp() for trk in trackings]
)
# Pre-project 3-D tracking points into 2-D for *this* camera (T, J, 2)
track_proj_2d = jax.vmap(camera.project)(track_kps_3d)
# ---------- Pack detection data ----------------------------------------
# (D, J, 2)
det_kps_2d = jnp.stack([det.keypoints for det in camera_detections])
# (D,) detection timestamps (float seconds)
det_ts = jnp.array([det.timestamp.timestamp() for det in camera_detections])
# Back-project detection 2-D points to the z=0 plane in world coords (D, J, 3)
det_backproj_3d = camera.unproject_points_to_z_plane(det_kps_2d, z=0.0)
# ---------- Broadcast / compute pair-wise quantities --------------------
# Time differences Δt (T, D) always non-negative because detections are newer
delta_t = jnp.maximum(det_ts[None, :] - track_last_ts[:, None], 0.0)
# ---------- 2-D affinity --------------------------------------------------
# Normalise 2-D points by image size (already handled in helper but easier here)
track_proj_norm = track_proj_2d / jnp.array([cam_w, cam_h]) # (T, J, 2)
det_kps_norm = det_kps_2d / jnp.array([cam_w, cam_h]) # (D, J, 2)
# (T, D, J) Euclidean distances in normalised image space
dist_2d = jnp.linalg.norm(
track_proj_norm[:, None, :, :] - det_kps_norm[None, :, :, :],
axis=-1,
)
# (T, D, 1) for broadcasting with J dimension
delta_t_exp = delta_t[:, :, None]
affinity_2d_per_kp = (
w_2d
* (1.0 - dist_2d / (alpha_2d * jnp.clip(delta_t_exp, a_min=1e-6)))
* jnp.exp(-lambda_a * delta_t_exp)
)
affinity_2d = jnp.sum(affinity_2d_per_kp, axis=-1) # (T, D)
# ---------- 3-D affinity --------------------------------------------------
# Predict 3-D pose at detection time for each (T, D) pair (T, D, J, 3)
predicted_pose = (
track_kps_3d[:, None, :, :]
+ velocities[:, None, None, :] * delta_t_exp[..., None]
)
# Camera ray for each detection/keypoint (1, D, J, 3)
line_vec = det_backproj_3d[None, :, :, :] - cam_center # broadcast (T, D, J, 3)
# Vector from camera centre to predicted point (T, D, J, 3)
vec_cam_to_pred = cam_center - predicted_pose
# Cross-product norm and distance
cross_prod = jnp.cross(line_vec, vec_cam_to_pred)
numer = jnp.linalg.norm(cross_prod, axis=-1) # (T, D, J)
denom = jnp.linalg.norm(line_vec, axis=-1) # (1, D, J) broadcast automatically
dist_3d = numer / jnp.clip(denom, a_min=1e-6)
affinity_3d_per_kp = (
w_3d * (1.0 - dist_3d / alpha_3d) * jnp.exp(-lambda_a * delta_t_exp)
)
affinity_3d = jnp.sum(affinity_3d_per_kp, axis=-1) # (T, D)
# ---------- Final affinity ----------------------------------------------
affinity_total = affinity_2d + affinity_3d # (T, D)
return affinity_total
# %% # %%
@ -1056,17 +1000,31 @@ trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id)
unmatched_detections = shallow_copy(next_group) unmatched_detections = shallow_copy(next_group)
camera_detections = classify_by_camera(unmatched_detections) camera_detections = classify_by_camera(unmatched_detections)
camera_detections_next_batch = camera_detections["AE_08"]
affinity = calculate_camera_affinity_matrix( affinity = calculate_camera_affinity_matrix(
trackings, trackings,
next(iter(camera_detections.values())), camera_detections_next_batch,
w_2d=W_2D, w_2d=W_2D,
alpha_2d=ALPHA_2D, alpha_2d=ALPHA_2D,
w_3d=W_3D, w_3d=W_3D,
alpha_3d=ALPHA_3D, alpha_3d=ALPHA_3D,
lambda_a=LAMBDA_A, lambda_a=LAMBDA_A,
) )
display(camera_detections_next_batch)
display(affinity) display(affinity)
affinity_naive, _ = calculate_affinity_matrix(
trackings,
camera_detections,
w_2d=W_2D,
alpha_2d=ALPHA_2D,
w_3d=W_3D,
alpha_3d=ALPHA_3D,
lambda_a=LAMBDA_A,
)
display(camera_detections)
display(affinity_naive)
# %% # %%
# Perform Hungarian algorithm for assignment for each camera # Perform Hungarian algorithm for assignment for each camera