revert
This commit is contained in:
@ -227,11 +227,13 @@ def project(
|
|||||||
# Fall back to normalized coordinates if image_size not provided
|
# Fall back to normalized coordinates if image_size not provided
|
||||||
valid = jnp.all(p2d >= 0, axis=1) & jnp.all(p2d < 1, axis=1)
|
valid = jnp.all(p2d >= 0, axis=1) & jnp.all(p2d < 1, axis=1)
|
||||||
|
|
||||||
# only valid points need distortion
|
# Distort *all* points, then blend results using `where` to keep
|
||||||
if jnp.any(valid):
|
# numerical traces inside JAX – this avoids Python ``if`` with a traced
|
||||||
valid_p2d = p2d[valid]
|
# value (which triggers TracerBoolConversionError when the function is
|
||||||
distorted_valid = distortion(valid_p2d, K, dist_coeffs)
|
# vmapped/jitted).
|
||||||
p2d = p2d.at[valid].set(distorted_valid)
|
distorted_all = distortion(p2d, K, dist_coeffs)
|
||||||
|
# Broadcast the valid mask over the last (x,y) dimension
|
||||||
|
p2d = jnp.where(valid[:, None], distorted_all, p2d)
|
||||||
elif dist_coeffs is None and K is None:
|
elif dist_coeffs is None and K is None:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
@ -239,7 +241,7 @@ def project(
|
|||||||
"dist_coeffs and K must be provided together to compute distortion"
|
"dist_coeffs and K must be provided together to compute distortion"
|
||||||
)
|
)
|
||||||
|
|
||||||
return jnp.squeeze(p2d)
|
return p2d # type: ignore
|
||||||
|
|
||||||
|
|
||||||
@jaxtyped(typechecker=beartype)
|
@jaxtyped(typechecker=beartype)
|
||||||
|
|||||||
202
playground.py
202
playground.py
@ -46,6 +46,7 @@ from jaxtyping import Array, Float, Num, jaxtyped
|
|||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from numpy.typing import ArrayLike
|
from numpy.typing import ArrayLike
|
||||||
from scipy.optimize import linear_sum_assignment
|
from scipy.optimize import linear_sum_assignment
|
||||||
|
from functools import partial, reduce
|
||||||
from scipy.spatial.transform import Rotation as R
|
from scipy.spatial.transform import Rotation as R
|
||||||
from typing_extensions import deprecated
|
from typing_extensions import deprecated
|
||||||
|
|
||||||
@ -65,6 +66,7 @@ NDArray: TypeAlias = np.ndarray
|
|||||||
# %%
|
# %%
|
||||||
DATASET_PATH = Path("samples") / "04_02"
|
DATASET_PATH = Path("samples") / "04_02"
|
||||||
AK_CAMERA_DATASET: ak.Array = ak.from_parquet(DATASET_PATH / "camera_params.parquet")
|
AK_CAMERA_DATASET: ak.Array = ak.from_parquet(DATASET_PATH / "camera_params.parquet")
|
||||||
|
DELTA_T_MIN = timedelta(milliseconds=10)
|
||||||
display(AK_CAMERA_DATASET)
|
display(AK_CAMERA_DATASET)
|
||||||
|
|
||||||
|
|
||||||
@ -388,6 +390,16 @@ def flatten_values(
|
|||||||
return [v for vs in d.values() for v in vs]
|
return [v for vs in d.values() for v in vs]
|
||||||
|
|
||||||
|
|
||||||
|
def flatten_values_len(
|
||||||
|
d: Mapping[Any, Sequence[T]],
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Flatten a dictionary of sequences into a single list of values.
|
||||||
|
"""
|
||||||
|
val = reduce(lambda acc, xs: acc + len(xs), d.values(), 0)
|
||||||
|
return val
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
WIDTH = 2560
|
WIDTH = 2560
|
||||||
HEIGHT = 1440
|
HEIGHT = 1440
|
||||||
@ -676,6 +688,10 @@ def perpendicular_distance_camera_2d_points_to_tracking_raycasting(
|
|||||||
Array of perpendicular distances for each keypoint
|
Array of perpendicular distances for each keypoint
|
||||||
"""
|
"""
|
||||||
camera = detection.camera
|
camera = detection.camera
|
||||||
|
assert detection.timestamp >= tracking.last_active_timestamp
|
||||||
|
delta_t_raw = detection.timestamp - tracking.last_active_timestamp
|
||||||
|
# Clamp delta_t to avoid division-by-zero / exploding affinity.
|
||||||
|
delta_t = max(delta_t_raw, DELTA_T_MIN)
|
||||||
delta_t_s = delta_t.total_seconds()
|
delta_t_s = delta_t.total_seconds()
|
||||||
predicted_pose = predict_pose_3d(tracking, delta_t_s)
|
predicted_pose = predict_pose_3d(tracking, delta_t_s)
|
||||||
|
|
||||||
@ -769,7 +785,9 @@ def calculate_tracking_detection_affinity(
|
|||||||
Combined affinity score
|
Combined affinity score
|
||||||
"""
|
"""
|
||||||
camera = detection.camera
|
camera = detection.camera
|
||||||
delta_t = detection.timestamp - tracking.last_active_timestamp
|
delta_t_raw = detection.timestamp - tracking.last_active_timestamp
|
||||||
|
# Clamp delta_t to avoid division-by-zero / exploding affinity.
|
||||||
|
delta_t = max(delta_t_raw, DELTA_T_MIN)
|
||||||
|
|
||||||
# Calculate 2D affinity
|
# Calculate 2D affinity
|
||||||
tracking_2d_projection = camera.project(tracking.keypoints)
|
tracking_2d_projection = camera.project(tracking.keypoints)
|
||||||
@ -811,7 +829,7 @@ def calculate_tracking_detection_affinity(
|
|||||||
@beartype
|
@beartype
|
||||||
def calculate_affinity_matrix(
|
def calculate_affinity_matrix(
|
||||||
trackings: Sequence[Tracking],
|
trackings: Sequence[Tracking],
|
||||||
detections: Sequence[Detection],
|
detections: Sequence[Detection] | OrderedDict[CameraID, list[Detection]],
|
||||||
w_2d: float,
|
w_2d: float,
|
||||||
alpha_2d: float,
|
alpha_2d: float,
|
||||||
w_3d: float,
|
w_3d: float,
|
||||||
@ -863,6 +881,11 @@ def calculate_affinity_matrix(
|
|||||||
through the detection_by_camera dictionary, which is returned alongside
|
through the detection_by_camera dictionary, which is returned alongside
|
||||||
the matrix to maintain this relationship.
|
the matrix to maintain this relationship.
|
||||||
"""
|
"""
|
||||||
|
if isinstance(detections, OrderedDict):
|
||||||
|
D = flatten_values_len(detections)
|
||||||
|
affinity = jnp.zeros((len(trackings), D))
|
||||||
|
detection_by_camera = detections
|
||||||
|
else:
|
||||||
affinity = jnp.zeros((len(trackings), len(detections)))
|
affinity = jnp.zeros((len(trackings), len(detections)))
|
||||||
detection_by_camera = classify_by_camera(detections)
|
detection_by_camera = classify_by_camera(detections)
|
||||||
|
|
||||||
@ -896,31 +919,28 @@ def calculate_camera_affinity_matrix(
|
|||||||
lambda_a: float,
|
lambda_a: float,
|
||||||
) -> Float[Array, "T D"]:
|
) -> Float[Array, "T D"]:
|
||||||
"""
|
"""
|
||||||
Vectorized version (with JAX) that computes the affinity matrix between a set
|
Calculate an affinity matrix between trackings and detections from a single camera.
|
||||||
of *trackings* and *detections* coming from **one** camera.
|
|
||||||
|
|
||||||
The whole computation is done with JAX array operations and `vmap` – no
|
This follows the iterative camera-by-camera approach from the paper
|
||||||
explicit Python ``for``-loops over the (T, D) pairs. This makes the routine
|
"Cross-View Tracking for Multi-Human 3D Pose Estimation at over 100 FPS".
|
||||||
fully parallelisable on CPU/GPU/TPU without any extra `jit` compilation.
|
Instead of creating one large matrix for all cameras, this creates
|
||||||
|
a separate matrix for each camera, which can be processed independently.
|
||||||
|
|
||||||
Args
|
Args:
|
||||||
-----
|
trackings: Sequence of tracking objects
|
||||||
trackings : Sequence[Tracking]
|
camera_detections: Sequence of detection objects, from the same camera
|
||||||
Existing 3-D track states (length = T)
|
w_2d: Weight for 2D affinity
|
||||||
camera_detections : Sequence[Detection]
|
alpha_2d: Normalization factor for 2D distance
|
||||||
Detections from *a single* camera (length = D). All detections **must**
|
w_3d: Weight for 3D affinity
|
||||||
share the same ``detection.camera`` instance.
|
alpha_3d: Normalization factor for 3D distance
|
||||||
w_2d, alpha_2d, w_3d, alpha_3d, lambda_a : float
|
lambda_a: Decay rate for time difference
|
||||||
Hyper-parameters exactly as defined in the paper (and earlier helper
|
|
||||||
functions).
|
|
||||||
|
|
||||||
Returns
|
Returns:
|
||||||
-------
|
Affinity matrix of shape (T, D) where:
|
||||||
affinity : jnp.ndarray (T x D)
|
- T = number of trackings (rows)
|
||||||
Affinity matrix between each tracking (row) and detection (column).
|
- D = number of detections from this specific camera (columns)
|
||||||
|
|
||||||
Matrix Layout
|
Matrix Layout:
|
||||||
-------
|
|
||||||
The affinity matrix for a single camera has shape (T, D), where:
|
The affinity matrix for a single camera has shape (T, D), where:
|
||||||
- T = number of trackings (rows)
|
- T = number of trackings (rows)
|
||||||
- D = number of detections from this camera (columns)
|
- D = number of detections from this camera (columns)
|
||||||
@ -941,107 +961,31 @@ def calculate_camera_affinity_matrix(
|
|||||||
computed using both 2D and 3D geometric correspondences.
|
computed using both 2D and 3D geometric correspondences.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# ---------- Safety checks & early exits --------------------------------
|
def verify_all_detection_from_same_camera(detections: Sequence[Detection]):
|
||||||
if len(trackings) == 0 or len(camera_detections) == 0:
|
if not detections:
|
||||||
return jnp.zeros((len(trackings), len(camera_detections))) # pragma: no cover
|
return True
|
||||||
|
camera_id = next(iter(detections)).camera.id
|
||||||
|
return all(map(lambda d: d.camera.id == camera_id, detections))
|
||||||
|
|
||||||
# Ensure all detections come from the *same* camera
|
if not verify_all_detection_from_same_camera(camera_detections):
|
||||||
cam_id_ref = camera_detections[0].camera.id
|
raise ValueError("All detections must be from the same camera")
|
||||||
if any(det.camera.id != cam_id_ref for det in camera_detections):
|
|
||||||
raise ValueError(
|
affinity = jnp.zeros((len(trackings), len(camera_detections)))
|
||||||
"All detections given to calculate_camera_affinity_matrix must come from the same camera."
|
|
||||||
|
for i, tracking in enumerate(trackings):
|
||||||
|
for j, det in enumerate(camera_detections):
|
||||||
|
affinity_value = calculate_tracking_detection_affinity(
|
||||||
|
tracking,
|
||||||
|
det,
|
||||||
|
w_2d=w_2d,
|
||||||
|
alpha_2d=alpha_2d,
|
||||||
|
w_3d=w_3d,
|
||||||
|
alpha_3d=alpha_3d,
|
||||||
|
lambda_a=lambda_a,
|
||||||
)
|
)
|
||||||
|
affinity = affinity.at[i, j].set(affinity_value)
|
||||||
|
|
||||||
camera = camera_detections[0].camera # shared camera object
|
return affinity
|
||||||
cam_w, cam_h = map(int, camera.params.image_size)
|
|
||||||
cam_center = camera.params.location # (3,)
|
|
||||||
|
|
||||||
# ---------- Pack tracking data into JAX arrays -------------------------
|
|
||||||
# (T, J, 3)
|
|
||||||
track_kps_3d = jnp.stack([trk.keypoints for trk in trackings])
|
|
||||||
|
|
||||||
# (T, 3) velocity – zero if None
|
|
||||||
velocities = jnp.stack(
|
|
||||||
[
|
|
||||||
(
|
|
||||||
trk.velocity
|
|
||||||
if trk.velocity is not None
|
|
||||||
else jnp.zeros(3, dtype=jnp.float32)
|
|
||||||
)
|
|
||||||
for trk in trackings
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# (T,) last update timestamps (float seconds)
|
|
||||||
track_last_ts = jnp.array(
|
|
||||||
[trk.last_active_timestamp.timestamp() for trk in trackings]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Pre-project 3-D tracking points into 2-D for *this* camera – (T, J, 2)
|
|
||||||
track_proj_2d = jax.vmap(camera.project)(track_kps_3d)
|
|
||||||
|
|
||||||
# ---------- Pack detection data ----------------------------------------
|
|
||||||
# (D, J, 2)
|
|
||||||
det_kps_2d = jnp.stack([det.keypoints for det in camera_detections])
|
|
||||||
|
|
||||||
# (D,) detection timestamps (float seconds)
|
|
||||||
det_ts = jnp.array([det.timestamp.timestamp() for det in camera_detections])
|
|
||||||
|
|
||||||
# Back-project detection 2-D points to the z=0 plane in world coords – (D, J, 3)
|
|
||||||
det_backproj_3d = camera.unproject_points_to_z_plane(det_kps_2d, z=0.0)
|
|
||||||
|
|
||||||
# ---------- Broadcast / compute pair-wise quantities --------------------
|
|
||||||
# Time differences Δt (T, D) – always non-negative because detections are newer
|
|
||||||
delta_t = jnp.maximum(det_ts[None, :] - track_last_ts[:, None], 0.0)
|
|
||||||
|
|
||||||
# ---------- 2-D affinity --------------------------------------------------
|
|
||||||
# Normalise 2-D points by image size (already handled in helper but easier here)
|
|
||||||
track_proj_norm = track_proj_2d / jnp.array([cam_w, cam_h]) # (T, J, 2)
|
|
||||||
det_kps_norm = det_kps_2d / jnp.array([cam_w, cam_h]) # (D, J, 2)
|
|
||||||
|
|
||||||
# (T, D, J) Euclidean distances in normalised image space
|
|
||||||
dist_2d = jnp.linalg.norm(
|
|
||||||
track_proj_norm[:, None, :, :] - det_kps_norm[None, :, :, :],
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# (T, D, 1) for broadcasting with J dimension
|
|
||||||
delta_t_exp = delta_t[:, :, None]
|
|
||||||
|
|
||||||
affinity_2d_per_kp = (
|
|
||||||
w_2d
|
|
||||||
* (1.0 - dist_2d / (alpha_2d * jnp.clip(delta_t_exp, a_min=1e-6)))
|
|
||||||
* jnp.exp(-lambda_a * delta_t_exp)
|
|
||||||
)
|
|
||||||
affinity_2d = jnp.sum(affinity_2d_per_kp, axis=-1) # (T, D)
|
|
||||||
|
|
||||||
# ---------- 3-D affinity --------------------------------------------------
|
|
||||||
# Predict 3-D pose at detection time for each (T, D) pair – (T, D, J, 3)
|
|
||||||
predicted_pose = (
|
|
||||||
track_kps_3d[:, None, :, :]
|
|
||||||
+ velocities[:, None, None, :] * delta_t_exp[..., None]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Camera ray for each detection/keypoint – (1, D, J, 3)
|
|
||||||
line_vec = det_backproj_3d[None, :, :, :] - cam_center # broadcast (T, D, J, 3)
|
|
||||||
|
|
||||||
# Vector from camera centre to predicted point – (T, D, J, 3)
|
|
||||||
vec_cam_to_pred = cam_center - predicted_pose
|
|
||||||
|
|
||||||
# Cross-product norm and distance
|
|
||||||
cross_prod = jnp.cross(line_vec, vec_cam_to_pred)
|
|
||||||
numer = jnp.linalg.norm(cross_prod, axis=-1) # (T, D, J)
|
|
||||||
denom = jnp.linalg.norm(line_vec, axis=-1) # (1, D, J) broadcast automatically
|
|
||||||
dist_3d = numer / jnp.clip(denom, a_min=1e-6)
|
|
||||||
|
|
||||||
affinity_3d_per_kp = (
|
|
||||||
w_3d * (1.0 - dist_3d / alpha_3d) * jnp.exp(-lambda_a * delta_t_exp)
|
|
||||||
)
|
|
||||||
affinity_3d = jnp.sum(affinity_3d_per_kp, axis=-1) # (T, D)
|
|
||||||
|
|
||||||
# ---------- Final affinity ----------------------------------------------
|
|
||||||
affinity_total = affinity_2d + affinity_3d # (T, D)
|
|
||||||
return affinity_total
|
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
@ -1056,17 +1000,31 @@ trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id)
|
|||||||
unmatched_detections = shallow_copy(next_group)
|
unmatched_detections = shallow_copy(next_group)
|
||||||
camera_detections = classify_by_camera(unmatched_detections)
|
camera_detections = classify_by_camera(unmatched_detections)
|
||||||
|
|
||||||
|
camera_detections_next_batch = camera_detections["AE_08"]
|
||||||
affinity = calculate_camera_affinity_matrix(
|
affinity = calculate_camera_affinity_matrix(
|
||||||
trackings,
|
trackings,
|
||||||
next(iter(camera_detections.values())),
|
camera_detections_next_batch,
|
||||||
w_2d=W_2D,
|
w_2d=W_2D,
|
||||||
alpha_2d=ALPHA_2D,
|
alpha_2d=ALPHA_2D,
|
||||||
w_3d=W_3D,
|
w_3d=W_3D,
|
||||||
alpha_3d=ALPHA_3D,
|
alpha_3d=ALPHA_3D,
|
||||||
lambda_a=LAMBDA_A,
|
lambda_a=LAMBDA_A,
|
||||||
)
|
)
|
||||||
|
display(camera_detections_next_batch)
|
||||||
display(affinity)
|
display(affinity)
|
||||||
|
|
||||||
|
affinity_naive, _ = calculate_affinity_matrix(
|
||||||
|
trackings,
|
||||||
|
camera_detections,
|
||||||
|
w_2d=W_2D,
|
||||||
|
alpha_2d=ALPHA_2D,
|
||||||
|
w_3d=W_3D,
|
||||||
|
alpha_3d=ALPHA_3D,
|
||||||
|
lambda_a=LAMBDA_A,
|
||||||
|
)
|
||||||
|
display(camera_detections)
|
||||||
|
display(affinity_naive)
|
||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# Perform Hungarian algorithm for assignment for each camera
|
# Perform Hungarian algorithm for assignment for each camera
|
||||||
|
|||||||
Reference in New Issue
Block a user