1031 lines
31 KiB
Python
1031 lines
31 KiB
Python
# ---
|
||
# jupyter:
|
||
# jupytext:
|
||
# text_representation:
|
||
# extension: .py
|
||
# format_name: percent
|
||
# format_version: '1.3'
|
||
# jupytext_version: 1.17.0
|
||
# kernelspec:
|
||
# display_name: .venv
|
||
# language: python
|
||
# name: python3
|
||
# ---
|
||
|
||
|
||
# %%
|
||
from collections import OrderedDict
|
||
from copy import copy as shallow_copy
|
||
from copy import deepcopy as deep_copy
|
||
from dataclasses import dataclass
|
||
from datetime import datetime, timedelta
|
||
from functools import partial, reduce
|
||
from itertools import chain
|
||
from pathlib import Path
|
||
from typing import (
|
||
Any,
|
||
Generator,
|
||
Optional,
|
||
TypeAlias,
|
||
TypedDict,
|
||
TypeVar,
|
||
cast,
|
||
overload,
|
||
)
|
||
|
||
import awkward as ak
|
||
import jax
|
||
import jax.numpy as jnp
|
||
import numpy as np
|
||
import orjson
|
||
from beartype import beartype
|
||
from beartype.typing import Mapping, Sequence
|
||
from cv2 import undistortPoints
|
||
from IPython.display import display
|
||
from jaxtyping import Array, Float, Num, jaxtyped
|
||
from matplotlib import pyplot as plt
|
||
from numpy.typing import ArrayLike
|
||
from optax.assignment import hungarian_algorithm as linear_sum_assignment
|
||
from scipy.spatial.transform import Rotation as R
|
||
from typing_extensions import deprecated
|
||
|
||
from app.camera import (
|
||
Camera,
|
||
CameraID,
|
||
CameraParams,
|
||
Detection,
|
||
calculate_affinity_matrix_by_epipolar_constraint,
|
||
classify_by_camera,
|
||
)
|
||
from app.solver._old import GLPKSolver
|
||
from app.tracking import AffinityResult, Tracking
|
||
from app.visualize.whole_body import visualize_whole_body
|
||
|
||
NDArray: TypeAlias = np.ndarray
|
||
|
||
# %%
|
||
DATASET_PATH = Path("samples") / "04_02"
|
||
AK_CAMERA_DATASET: ak.Array = ak.from_parquet(DATASET_PATH / "camera_params.parquet")
|
||
DELTA_T_MIN = timedelta(milliseconds=10)
|
||
display(AK_CAMERA_DATASET)
|
||
|
||
|
||
# %%
|
||
class Resolution(TypedDict):
|
||
width: int
|
||
height: int
|
||
|
||
|
||
class Intrinsic(TypedDict):
|
||
camera_matrix: Num[Array, "3 3"]
|
||
"""
|
||
K
|
||
"""
|
||
distortion_coefficients: Num[Array, "N"]
|
||
"""
|
||
distortion coefficients; usually 5
|
||
"""
|
||
|
||
|
||
class Extrinsic(TypedDict):
|
||
rvec: Num[NDArray, "3"]
|
||
tvec: Num[NDArray, "3"]
|
||
|
||
|
||
class ExternalCameraParams(TypedDict):
|
||
name: str
|
||
port: int
|
||
intrinsic: Intrinsic
|
||
extrinsic: Extrinsic
|
||
resolution: Resolution
|
||
|
||
|
||
# %%
|
||
def read_dataset_by_port(port: int) -> ak.Array:
|
||
P = DATASET_PATH / f"{port}.parquet"
|
||
return ak.from_parquet(P)
|
||
|
||
|
||
KEYPOINT_DATASET = {
|
||
int(p): read_dataset_by_port(p) for p in ak.to_numpy(AK_CAMERA_DATASET["port"])
|
||
}
|
||
|
||
|
||
# %%
|
||
class KeypointDataset(TypedDict):
|
||
frame_index: int
|
||
boxes: Num[NDArray, "N 4"]
|
||
kps: Num[NDArray, "N J 2"]
|
||
kps_scores: Num[NDArray, "N J"]
|
||
|
||
|
||
@jaxtyped(typechecker=beartype)
|
||
def to_transformation_matrix(
|
||
rvec: Num[NDArray, "3"], tvec: Num[NDArray, "3"]
|
||
) -> Num[NDArray, "4 4"]:
|
||
res = np.eye(4)
|
||
res[:3, :3] = R.from_rotvec(rvec).as_matrix()
|
||
res[:3, 3] = tvec
|
||
return res
|
||
|
||
|
||
@jaxtyped(typechecker=beartype)
|
||
def undistort_points(
|
||
points: Num[NDArray, "M 2"],
|
||
camera_matrix: Num[NDArray, "3 3"],
|
||
dist_coeffs: Num[NDArray, "N"],
|
||
) -> Num[NDArray, "M 2"]:
|
||
K = camera_matrix
|
||
dist = dist_coeffs
|
||
res = undistortPoints(points, K, dist, P=K) # type: ignore
|
||
return res.reshape(-1, 2)
|
||
|
||
|
||
def from_camera_params(camera: ExternalCameraParams) -> Camera:
|
||
rt = jnp.array(
|
||
to_transformation_matrix(
|
||
ak.to_numpy(camera["extrinsic"]["rvec"]),
|
||
ak.to_numpy(camera["extrinsic"]["tvec"]),
|
||
)
|
||
)
|
||
K = jnp.array(camera["intrinsic"]["camera_matrix"]).reshape(3, 3)
|
||
dist_coeffs = jnp.array(camera["intrinsic"]["distortion_coefficients"])
|
||
image_size = jnp.array(
|
||
(camera["resolution"]["width"], camera["resolution"]["height"])
|
||
)
|
||
return Camera(
|
||
id=camera["name"],
|
||
params=CameraParams(
|
||
K=K,
|
||
Rt=rt,
|
||
dist_coeffs=dist_coeffs,
|
||
image_size=image_size,
|
||
),
|
||
)
|
||
|
||
|
||
def preprocess_keypoint_dataset(
|
||
dataset: Sequence[KeypointDataset],
|
||
camera: Camera,
|
||
fps: float,
|
||
start_timestamp: datetime,
|
||
) -> Generator[Detection, None, None]:
|
||
frame_interval_s = 1 / fps
|
||
for el in dataset:
|
||
frame_index = el["frame_index"]
|
||
timestamp = start_timestamp + timedelta(seconds=frame_index * frame_interval_s)
|
||
for kp, kp_score in zip(el["kps"], el["kps_scores"]):
|
||
yield Detection(
|
||
keypoints=jnp.array(kp),
|
||
confidences=jnp.array(kp_score),
|
||
camera=camera,
|
||
timestamp=timestamp,
|
||
)
|
||
|
||
|
||
# %%
|
||
DetectionGenerator: TypeAlias = Generator[Detection, None, None]
|
||
|
||
|
||
def sync_batch_gen(gens: Sequence[DetectionGenerator], diff: timedelta):
|
||
"""
|
||
given a list of detection generators, return a generator that yields a batch of detections
|
||
|
||
Args:
|
||
gens: list of detection generators
|
||
diff: maximum timestamp difference between detections to consider them part of the same batch
|
||
"""
|
||
N = len(gens)
|
||
last_batch_timestamp: Optional[datetime] = None
|
||
next_batch_timestamp: Optional[datetime] = None
|
||
current_batch: list[Detection] = []
|
||
next_batch: list[Detection] = []
|
||
paused: list[bool] = [False] * N
|
||
finished: list[bool] = [False] * N
|
||
|
||
def reset_paused():
|
||
"""
|
||
reset paused list based on finished list
|
||
"""
|
||
for i in range(N):
|
||
if not finished[i]:
|
||
paused[i] = False
|
||
else:
|
||
paused[i] = True
|
||
|
||
EPS = 1e-6
|
||
# a small epsilon to avoid floating point precision issues
|
||
diff_esp = diff - timedelta(seconds=EPS)
|
||
while True:
|
||
for i, gen in enumerate(gens):
|
||
try:
|
||
if finished[i] or paused[i]:
|
||
continue
|
||
val = next(gen)
|
||
if last_batch_timestamp is None:
|
||
last_batch_timestamp = val.timestamp
|
||
current_batch.append(val)
|
||
else:
|
||
if abs(val.timestamp - last_batch_timestamp) >= diff_esp:
|
||
next_batch.append(val)
|
||
if next_batch_timestamp is None:
|
||
next_batch_timestamp = val.timestamp
|
||
paused[i] = True
|
||
if all(paused):
|
||
yield current_batch
|
||
current_batch = next_batch
|
||
next_batch = []
|
||
last_batch_timestamp = next_batch_timestamp
|
||
next_batch_timestamp = None
|
||
reset_paused()
|
||
else:
|
||
current_batch.append(val)
|
||
except StopIteration:
|
||
finished[i] = True
|
||
paused[i] = True
|
||
if all(finished):
|
||
if len(current_batch) > 0:
|
||
# All generators exhausted, flush remaining batch and exit
|
||
yield current_batch
|
||
break
|
||
|
||
|
||
# %%
|
||
@overload
|
||
def to_projection_matrix(
|
||
transformation_matrix: Num[NDArray, "4 4"], camera_matrix: Num[NDArray, "3 3"]
|
||
) -> Num[NDArray, "3 4"]: ...
|
||
|
||
|
||
@overload
|
||
def to_projection_matrix(
|
||
transformation_matrix: Num[Array, "4 4"], camera_matrix: Num[Array, "3 3"]
|
||
) -> Num[Array, "3 4"]: ...
|
||
|
||
|
||
@jaxtyped(typechecker=beartype)
|
||
def to_projection_matrix(
|
||
transformation_matrix: Num[Any, "4 4"],
|
||
camera_matrix: Num[Any, "3 3"],
|
||
) -> Num[Any, "3 4"]:
|
||
return camera_matrix @ transformation_matrix[:3, :]
|
||
|
||
|
||
to_projection_matrix_jit = jax.jit(to_projection_matrix)
|
||
|
||
|
||
@jaxtyped(typechecker=beartype)
|
||
def dlt(
|
||
H1: Num[NDArray, "3 4"],
|
||
H2: Num[NDArray, "3 4"],
|
||
p1: Num[NDArray, "2"],
|
||
p2: Num[NDArray, "2"],
|
||
) -> Num[NDArray, "3"]:
|
||
"""
|
||
Direct Linear Transformation
|
||
"""
|
||
A = [
|
||
p1[1] * H1[2, :] - H1[1, :],
|
||
H1[0, :] - p1[0] * H1[2, :],
|
||
p2[1] * H2[2, :] - H2[1, :],
|
||
H2[0, :] - p2[0] * H2[2, :],
|
||
]
|
||
A = np.array(A).reshape((4, 4))
|
||
|
||
B = A.transpose() @ A
|
||
from scipy import linalg
|
||
|
||
U, s, Vh = linalg.svd(B, full_matrices=False)
|
||
return Vh[3, 0:3] / Vh[3, 3]
|
||
|
||
|
||
@overload
|
||
def homogeneous_to_euclidean(points: Num[NDArray, "N 4"]) -> Num[NDArray, "N 3"]: ...
|
||
|
||
|
||
@overload
|
||
def homogeneous_to_euclidean(points: Num[Array, "N 4"]) -> Num[Array, "N 3"]: ...
|
||
|
||
|
||
@jaxtyped(typechecker=beartype)
|
||
def homogeneous_to_euclidean(
|
||
points: Num[Any, "N 4"],
|
||
) -> Num[Any, "N 3"]:
|
||
"""
|
||
将齐次坐标转换为欧几里得坐标
|
||
|
||
Args:
|
||
points: homogeneous coordinates (x, y, z, w) in numpy array or jax array
|
||
|
||
Returns:
|
||
euclidean coordinates (x, y, z) in numpy array or jax array
|
||
"""
|
||
return points[..., :-1] / points[..., -1:]
|
||
|
||
|
||
# %%
|
||
FPS = 24
|
||
image_gen_5600 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5600], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5600][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore
|
||
image_gen_5601 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5601], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5601][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore
|
||
image_gen_5602 = preprocess_keypoint_dataset(KEYPOINT_DATASET[5602], from_camera_params(AK_CAMERA_DATASET[AK_CAMERA_DATASET["port"] == 5602][0]), FPS, datetime(2024, 4, 2, 12, 0, 0)) # type: ignore
|
||
|
||
display(1 / FPS)
|
||
sync_gen = sync_batch_gen(
|
||
[image_gen_5600, image_gen_5601, image_gen_5602], timedelta(seconds=1 / FPS)
|
||
)
|
||
|
||
# %%
|
||
sorted_detections, affinity_matrix = calculate_affinity_matrix_by_epipolar_constraint(
|
||
next(sync_gen), alpha_2d=2000
|
||
)
|
||
display(sorted_detections)
|
||
|
||
# %%
|
||
display(
|
||
list(
|
||
map(
|
||
lambda x: {"timestamp": str(x.timestamp), "camera": x.camera.id},
|
||
sorted_detections,
|
||
)
|
||
)
|
||
)
|
||
with jnp.printoptions(precision=3, suppress=True):
|
||
display(affinity_matrix)
|
||
|
||
|
||
# %%
|
||
def clusters_to_detections(
|
||
clusters: Sequence[Sequence[int]], sorted_detections: Sequence[Detection]
|
||
) -> list[list[Detection]]:
|
||
"""
|
||
given a list of clusters (which is the indices of the detections in the sorted_detections list),
|
||
extract the detections from the sorted_detections list
|
||
|
||
Args:
|
||
clusters: list of clusters, each cluster is a list of indices of the detections in the `sorted_detections` list
|
||
sorted_detections: list of SORTED detections
|
||
|
||
Returns:
|
||
list of clusters, each cluster is a list of detections
|
||
"""
|
||
return [[sorted_detections[i] for i in cluster] for cluster in clusters]
|
||
|
||
|
||
solver = GLPKSolver()
|
||
aff_np = np.asarray(affinity_matrix).astype(np.float64)
|
||
clusters, sol_matrix = solver.solve(aff_np)
|
||
display(clusters)
|
||
display(sol_matrix)
|
||
|
||
# %%
|
||
T = TypeVar("T")
|
||
|
||
|
||
def flatten_values(
|
||
d: Mapping[Any, Sequence[T]],
|
||
) -> list[T]:
|
||
"""
|
||
Flatten a dictionary of sequences into a single list of values.
|
||
"""
|
||
return [v for vs in d.values() for v in vs]
|
||
|
||
|
||
def flatten_values_len(
|
||
d: Mapping[Any, Sequence[T]],
|
||
) -> int:
|
||
"""
|
||
Flatten a dictionary of sequences into a single list of values.
|
||
"""
|
||
val = reduce(lambda acc, xs: acc + len(xs), d.values(), 0)
|
||
return val
|
||
|
||
|
||
# %%
|
||
WIDTH = 2560
|
||
HEIGHT = 1440
|
||
|
||
clusters_detections = clusters_to_detections(clusters, sorted_detections)
|
||
im = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8)
|
||
for el in clusters_detections[0]:
|
||
im = visualize_whole_body(np.asarray(el.keypoints), im)
|
||
|
||
p = plt.imshow(im)
|
||
display(p)
|
||
|
||
# %%
|
||
im_prime = np.zeros((HEIGHT, WIDTH, 3), dtype=np.uint8)
|
||
for el in clusters_detections[1]:
|
||
im_prime = visualize_whole_body(np.asarray(el.keypoints), im_prime)
|
||
|
||
p_prime = plt.imshow(im_prime)
|
||
display(p_prime)
|
||
|
||
|
||
# %%
|
||
@jaxtyped(typechecker=beartype)
|
||
def triangulate_one_point_from_multiple_views_linear(
|
||
proj_matrices: Float[Array, "N 3 4"],
|
||
points: Num[Array, "N 2"],
|
||
confidences: Optional[Float[Array, "N"]] = None,
|
||
) -> Float[Array, "3"]:
|
||
"""
|
||
Args:
|
||
proj_matrices: 形状为(N, 3, 4)的投影矩阵序列
|
||
points: 形状为(N, 2)的点坐标序列
|
||
confidences: 形状为(N,)的置信度序列,范围[0.0, 1.0]
|
||
|
||
Returns:
|
||
point_3d: 形状为(3,)的三角测量得到的3D点
|
||
"""
|
||
assert len(proj_matrices) == len(points)
|
||
|
||
N = len(proj_matrices)
|
||
confi: Float[Array, "N"]
|
||
if confidences is None:
|
||
confi = jnp.ones(N, dtype=np.float32)
|
||
else:
|
||
# Use square root of confidences for weighting - more balanced approach
|
||
confi = jnp.sqrt(jnp.clip(confidences, 0, 1))
|
||
|
||
A = jnp.zeros((N * 2, 4), dtype=np.float32)
|
||
for i in range(N):
|
||
x, y = points[i]
|
||
A = A.at[2 * i].set(proj_matrices[i, 2] * x - proj_matrices[i, 0])
|
||
A = A.at[2 * i + 1].set(proj_matrices[i, 2] * y - proj_matrices[i, 1])
|
||
A = A.at[2 * i].mul(confi[i])
|
||
A = A.at[2 * i + 1].mul(confi[i])
|
||
|
||
# https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.svd.html
|
||
_, _, vh = jnp.linalg.svd(A, full_matrices=False)
|
||
point_3d_homo = vh[-1] # shape (4,)
|
||
|
||
# replace the Python `if` with a jnp.where
|
||
point_3d_homo = jnp.where(
|
||
point_3d_homo[3] < 0, # predicate (scalar bool tracer)
|
||
-point_3d_homo, # if True
|
||
point_3d_homo, # if False
|
||
)
|
||
|
||
point_3d = point_3d_homo[:3] / point_3d_homo[3]
|
||
return point_3d
|
||
|
||
|
||
@jaxtyped(typechecker=beartype)
|
||
def triangulate_points_from_multiple_views_linear(
|
||
proj_matrices: Float[Array, "N 3 4"],
|
||
points: Num[Array, "N P 2"],
|
||
confidences: Optional[Float[Array, "N P"]] = None,
|
||
) -> Float[Array, "P 3"]:
|
||
"""
|
||
Batch-triangulate P points observed by N cameras, linearly via SVD.
|
||
|
||
Args:
|
||
proj_matrices: (N, 3, 4) projection matrices
|
||
points: (N, P, 2) image-coordinates per view
|
||
confidences: (N, P, 1) optional per-view confidences in [0,1]
|
||
|
||
Returns:
|
||
(P, 3) 3D point for each of the P tracks
|
||
"""
|
||
N, P, _ = points.shape
|
||
assert proj_matrices.shape[0] == N
|
||
if confidences is None:
|
||
conf = jnp.ones((N, P), dtype=jnp.float32)
|
||
else:
|
||
conf = jnp.sqrt(jnp.clip(confidences, 0.0, 1.0))
|
||
|
||
# vectorize your one-point routine over P
|
||
vmap_triangulate = jax.vmap(
|
||
triangulate_one_point_from_multiple_views_linear,
|
||
in_axes=(None, 1, 1), # proj_matrices static, map over points[:,p,:], conf[:,p]
|
||
out_axes=0,
|
||
)
|
||
return vmap_triangulate(proj_matrices, points, conf)
|
||
|
||
|
||
# %%
|
||
|
||
|
||
@jaxtyped(typechecker=beartype)
|
||
def triangle_from_cluster(
|
||
cluster: Sequence[Detection],
|
||
) -> tuple[Float[Array, "N 3"], datetime]:
|
||
proj_matrices = jnp.array([el.camera.params.projection_matrix for el in cluster])
|
||
points = jnp.array([el.keypoints_undistorted for el in cluster])
|
||
confidences = jnp.array([el.confidences for el in cluster])
|
||
latest_timestamp = max(el.timestamp for el in cluster)
|
||
return (
|
||
triangulate_points_from_multiple_views_linear(
|
||
proj_matrices, points, confidences=confidences
|
||
),
|
||
latest_timestamp,
|
||
)
|
||
|
||
|
||
# %%
|
||
class GlobalTrackingState:
|
||
_last_id: int
|
||
_trackings: dict[int, Tracking]
|
||
|
||
def __init__(self):
|
||
self._last_id = 0
|
||
self._trackings = {}
|
||
|
||
def __repr__(self) -> str:
|
||
return (
|
||
f"GlobalTrackingState(last_id={self._last_id}, trackings={self._trackings})"
|
||
)
|
||
|
||
@property
|
||
def trackings(self) -> dict[int, Tracking]:
|
||
return shallow_copy(self._trackings)
|
||
|
||
def add_tracking(self, cluster: Sequence[Detection]) -> Tracking:
|
||
kps_3d, latest_timestamp = triangle_from_cluster(cluster)
|
||
next_id = self._last_id + 1
|
||
tracking = Tracking(
|
||
id=next_id, keypoints=kps_3d, last_active_timestamp=latest_timestamp
|
||
)
|
||
self._trackings[next_id] = tracking
|
||
self._last_id = next_id
|
||
return tracking
|
||
|
||
|
||
global_tracking_state = GlobalTrackingState()
|
||
for cluster in clusters_detections:
|
||
global_tracking_state.add_tracking(cluster)
|
||
display(global_tracking_state)
|
||
|
||
# %%
|
||
next_group = next(sync_gen)
|
||
display(next_group)
|
||
|
||
|
||
# %%
|
||
@jaxtyped(typechecker=beartype)
|
||
def calculate_distance_2d(
|
||
left: Num[Array, "J 2"],
|
||
right: Num[Array, "J 2"],
|
||
image_size: tuple[int, int] = (1, 1),
|
||
) -> Float[Array, "J"]:
|
||
"""
|
||
Calculate the *normalized* distance between two sets of keypoints.
|
||
|
||
Args:
|
||
left: The left keypoints
|
||
right: The right keypoints
|
||
image_size: The size of the image
|
||
|
||
Returns:
|
||
Array of normalized Euclidean distances between corresponding keypoints
|
||
"""
|
||
w, h = image_size
|
||
if w == 1 and h == 1:
|
||
# already normalized
|
||
left_normalized = left
|
||
right_normalized = right
|
||
else:
|
||
left_normalized = left / jnp.array([w, h])
|
||
right_normalized = right / jnp.array([w, h])
|
||
dist = jnp.linalg.norm(left_normalized - right_normalized, axis=-1)
|
||
return dist
|
||
|
||
|
||
@jaxtyped(typechecker=beartype)
|
||
def calculate_affinity_2d(
|
||
distance_2d: Float[Array, "J"],
|
||
delta_t: timedelta,
|
||
w_2d: float,
|
||
alpha_2d: float,
|
||
lambda_a: float,
|
||
) -> Float[Array, "J"]:
|
||
"""
|
||
Calculate the affinity between two detections based on the distances between their keypoints.
|
||
|
||
The affinity score is calculated by summing individual keypoint affinities:
|
||
A_2D = sum(w_2D * (1 - distance_2D / (alpha_2D*delta_t)) * np.exp(-lambda_a * delta_t)) for each keypoint
|
||
|
||
Args:
|
||
distance_2d: The normalized distances between keypoints (array with one value per keypoint)
|
||
w_2d: The weight for 2D affinity
|
||
alpha_2d: The normalization factor for distance
|
||
lambda_a: The decay rate for time difference
|
||
delta_t: The time delta between the two detections, in seconds
|
||
|
||
Returns:
|
||
Sum of affinity scores across all keypoints
|
||
"""
|
||
delta_t_s = delta_t.total_seconds()
|
||
affinity_per_keypoint = (
|
||
w_2d
|
||
* (1 - distance_2d / (alpha_2d * delta_t_s))
|
||
* jnp.exp(-lambda_a * delta_t_s)
|
||
)
|
||
return affinity_per_keypoint
|
||
|
||
|
||
@jaxtyped(typechecker=beartype)
|
||
def perpendicular_distance_point_to_line_two_points(
|
||
point: Num[Array, "3"], line: tuple[Num[Array, "3"], Num[Array, "3"]]
|
||
) -> Float[Array, ""]:
|
||
"""
|
||
Calculate the perpendicular distance between a point and a line.
|
||
where `line` is represented by two points: `(line_start, line_end)`
|
||
|
||
Args:
|
||
point: The point to calculate the distance to
|
||
line: The line to calculate the distance to, represented by two points
|
||
|
||
Returns:
|
||
The perpendicular distance between the point and the line
|
||
(should be a scalar in `float`)
|
||
"""
|
||
line_start, line_end = line
|
||
distance = jnp.linalg.norm(
|
||
jnp.cross(line_end - line_start, line_start - point)
|
||
) / jnp.linalg.norm(line_end - line_start)
|
||
return distance
|
||
|
||
|
||
@jaxtyped(typechecker=beartype)
|
||
def perpendicular_distance_camera_2d_points_to_tracking_raycasting(
|
||
detection: Detection,
|
||
tracking: Tracking,
|
||
delta_t: timedelta,
|
||
) -> Float[Array, "J"]:
|
||
"""
|
||
NOTE: `delta_t` is now taken from the caller and NOT recomputed internally.
|
||
Calculate the perpendicular distances between predicted 3D tracking points
|
||
and the rays cast from camera center through the 2D image points.
|
||
|
||
Args:
|
||
detection: The detection object containing 2D keypoints and camera parameters
|
||
tracking: The tracking object containing 3D keypoints
|
||
delta_t: Time delta between the tracking's last update and current observation
|
||
|
||
Returns:
|
||
Array of perpendicular distances for each keypoint
|
||
"""
|
||
camera = detection.camera
|
||
# Use the delta_t supplied by the caller, but clamp to DELTA_T_MIN to
|
||
# avoid division-by-zero / exploding affinities.
|
||
delta_t = max(delta_t, DELTA_T_MIN)
|
||
delta_t_s = delta_t.total_seconds()
|
||
predicted_pose = tracking.predict(delta_t_s)
|
||
|
||
# Back-project the 2D points to 3D space
|
||
# intersection with z=0 plane
|
||
back_projected_points = detection.camera.unproject_points_to_z_plane(
|
||
detection.keypoints, z=0.0
|
||
)
|
||
camera_center = camera.params.location
|
||
|
||
def calc_distance(predicted_point, back_projected_point):
|
||
return perpendicular_distance_point_to_line_two_points(
|
||
predicted_point, (camera_center, back_projected_point)
|
||
)
|
||
|
||
# Vectorize over all keypoints
|
||
vmap_calc_distance = jax.vmap(calc_distance)
|
||
distances: Float[Array, "J"] = vmap_calc_distance(
|
||
predicted_pose, back_projected_points
|
||
)
|
||
|
||
return distances
|
||
|
||
|
||
@jaxtyped(typechecker=beartype)
|
||
def calculate_affinity_3d(
|
||
distances: Float[Array, "J"],
|
||
delta_t: timedelta,
|
||
w_3d: float,
|
||
alpha_3d: float,
|
||
lambda_a: float,
|
||
) -> Float[Array, "J"]:
|
||
"""
|
||
Calculate 3D affinity score between a tracking and detection.
|
||
|
||
The affinity score is calculated by summing individual keypoint affinities:
|
||
A_3D = sum(w_3D * (1 - dl / alpha_3D) * np.exp(-lambda_a * delta_t)) for each keypoint
|
||
|
||
Args:
|
||
distances: Array of perpendicular distances for each keypoint
|
||
delta_t: Time difference between tracking and detection
|
||
w_3d: Weight for 3D affinity
|
||
alpha_3d: Normalization factor for distance
|
||
lambda_a: Decay rate for time difference
|
||
|
||
Returns:
|
||
Sum of affinity scores across all keypoints
|
||
"""
|
||
delta_t_s = delta_t.total_seconds()
|
||
affinity_per_keypoint = (
|
||
w_3d * (1 - distances / alpha_3d) * jnp.exp(-lambda_a * delta_t_s)
|
||
)
|
||
return affinity_per_keypoint
|
||
|
||
|
||
@beartype
|
||
def calculate_tracking_detection_affinity(
|
||
tracking: Tracking,
|
||
detection: Detection,
|
||
w_2d: float,
|
||
alpha_2d: float,
|
||
w_3d: float,
|
||
alpha_3d: float,
|
||
lambda_a: float,
|
||
) -> float:
|
||
"""
|
||
Calculate the affinity between a tracking and a detection.
|
||
|
||
Args:
|
||
tracking: The tracking object
|
||
detection: The detection object
|
||
w_2d: Weight for 2D affinity
|
||
alpha_2d: Normalization factor for 2D distance
|
||
w_3d: Weight for 3D affinity
|
||
alpha_3d: Normalization factor for 3D distance
|
||
lambda_a: Decay rate for time difference
|
||
|
||
Returns:
|
||
Combined affinity score
|
||
"""
|
||
camera = detection.camera
|
||
delta_t_raw = detection.timestamp - tracking.last_active_timestamp
|
||
# Clamp delta_t to avoid division-by-zero / exploding affinity.
|
||
delta_t = max(delta_t_raw, DELTA_T_MIN)
|
||
|
||
# Calculate 2D affinity
|
||
tracking_2d_projection = camera.project(tracking.keypoints)
|
||
w, h = camera.params.image_size
|
||
distance_2d = calculate_distance_2d(
|
||
tracking_2d_projection,
|
||
detection.keypoints,
|
||
image_size=(int(w), int(h)),
|
||
)
|
||
affinity_2d = calculate_affinity_2d(
|
||
distance_2d,
|
||
delta_t,
|
||
w_2d=w_2d,
|
||
alpha_2d=alpha_2d,
|
||
lambda_a=lambda_a,
|
||
)
|
||
|
||
# Calculate 3D affinity
|
||
distances = perpendicular_distance_camera_2d_points_to_tracking_raycasting(
|
||
detection, tracking, delta_t
|
||
)
|
||
affinity_3d = calculate_affinity_3d(
|
||
distances,
|
||
delta_t,
|
||
w_3d=w_3d,
|
||
alpha_3d=alpha_3d,
|
||
lambda_a=lambda_a,
|
||
)
|
||
|
||
# Combine affinities
|
||
total_affinity = affinity_2d + affinity_3d
|
||
return jnp.sum(total_affinity).item()
|
||
|
||
|
||
# %%
|
||
@beartype
|
||
def calculate_camera_affinity_matrix_jax(
|
||
trackings: Sequence[Tracking],
|
||
camera_detections: Sequence[Detection],
|
||
w_2d: float,
|
||
alpha_2d: float,
|
||
w_3d: float,
|
||
alpha_3d: float,
|
||
lambda_a: float,
|
||
) -> Float[Array, "T D"]:
|
||
"""
|
||
Vectorized implementation to compute an affinity matrix between *trackings*
|
||
and *detections* coming from **one** camera.
|
||
|
||
Compared with the simple double-for-loop version, this leverages `jax`'s
|
||
broadcasting + `vmap` facilities and avoids Python loops over every
|
||
(tracking, detection) pair. The mathematical definition of the affinity
|
||
is **unchanged**, so the result remains bit-identical to the reference
|
||
implementation used in the tests.
|
||
"""
|
||
|
||
# ------------------------------------------------------------------
|
||
# Quick validations / early-exit guards
|
||
# ------------------------------------------------------------------
|
||
if len(trackings) == 0 or len(camera_detections) == 0:
|
||
# Return an empty affinity matrix with appropriate shape.
|
||
return jnp.zeros((len(trackings), len(camera_detections))) # type: ignore[return-value]
|
||
|
||
cam = next(iter(camera_detections)).camera
|
||
# Ensure every detection truly belongs to the same camera (guard clause)
|
||
cam_id = cam.id
|
||
if any(det.camera.id != cam_id for det in camera_detections):
|
||
raise ValueError(
|
||
"All detections passed to `calculate_camera_affinity_matrix` must come from one camera."
|
||
)
|
||
|
||
# We will rely on a single `Camera` instance (all detections share it)
|
||
w_img_, h_img_ = cam.params.image_size
|
||
w_img, h_img = float(w_img_), float(h_img_)
|
||
|
||
# ------------------------------------------------------------------
|
||
# Gather data into ndarray / DeviceArray batches so that we can compute
|
||
# everything in a single (or a few) fused kernels.
|
||
# ------------------------------------------------------------------
|
||
|
||
# === Tracking-side tensors ===
|
||
kps3d_trk: Float[Array, "T J 3"] = jnp.stack(
|
||
[trk.keypoints for trk in trackings]
|
||
) # (T, J, 3)
|
||
J = kps3d_trk.shape[1]
|
||
# === Detection-side tensors ===
|
||
kps2d_det: Float[Array, "D J 2"] = jnp.stack(
|
||
[det.keypoints for det in camera_detections]
|
||
) # (D, J, 2)
|
||
|
||
# ------------------------------------------------------------------
|
||
# Compute Δt matrix – shape (T, D)
|
||
# ------------------------------------------------------------------
|
||
# Epoch timestamps are ~1.7 × 10⁹; storing them in float32 wipes out
|
||
# sub‑second detail (resolution ≈ 200 ms). Keep them in float64 until
|
||
# after subtraction so we preserve Δt‑on‑the‑order‑of‑milliseconds.
|
||
# --- timestamps ----------
|
||
t0 = min(
|
||
chain(
|
||
(trk.last_active_timestamp for trk in trackings),
|
||
(det.timestamp for det in camera_detections),
|
||
)
|
||
).timestamp() # common origin (float)
|
||
ts_trk = jnp.array(
|
||
[trk.last_active_timestamp.timestamp() - t0 for trk in trackings],
|
||
dtype=jnp.float32, # now small, ms-scale fits in fp32
|
||
)
|
||
ts_det = jnp.array(
|
||
[det.timestamp.timestamp() - t0 for det in camera_detections],
|
||
dtype=jnp.float32,
|
||
)
|
||
# Δt in seconds, fp32 throughout
|
||
delta_t = ts_det[None, :] - ts_trk[:, None] # (T,D)
|
||
min_dt_s = float(DELTA_T_MIN.total_seconds())
|
||
delta_t = jnp.clip(delta_t, a_min=min_dt_s, a_max=None)
|
||
|
||
# ------------------------------------------------------------------
|
||
# ---------- 2D affinity -------------------------------------------
|
||
# ------------------------------------------------------------------
|
||
# Project each tracking's 3D keypoints onto the image once.
|
||
# `Camera.project` works per-sample, so we vmap over the first axis.
|
||
|
||
proj_fn = jax.vmap(cam.project, in_axes=0) # maps over the keypoint sets
|
||
kps2d_trk_proj: Float[Array, "T J 2"] = proj_fn(kps3d_trk) # (T, J, 2)
|
||
|
||
# Normalise keypoints by image size so absolute units do not bias distance
|
||
norm_trk = kps2d_trk_proj / jnp.array([w_img, h_img])
|
||
norm_det = kps2d_det / jnp.array([w_img, h_img])
|
||
|
||
# L2 distance for every (T, D, J)
|
||
# reshape for broadcasting: (T,1,J,2) vs (1,D,J,2)
|
||
diff2d = norm_trk[:, None, :, :] - norm_det[None, :, :, :]
|
||
dist2d: Float[Array, "T D J"] = jnp.linalg.norm(diff2d, axis=-1)
|
||
|
||
# Compute per-keypoint 2D affinity
|
||
delta_t_broadcast = delta_t[:, :, None] # (T, D, 1)
|
||
affinity_2d = (
|
||
w_2d
|
||
* (1 - dist2d / (alpha_2d * delta_t_broadcast))
|
||
* jnp.exp(-lambda_a * delta_t_broadcast)
|
||
)
|
||
|
||
# ------------------------------------------------------------------
|
||
# ---------- 3D affinity -------------------------------------------
|
||
# ------------------------------------------------------------------
|
||
# For each detection pre-compute back-projected 3D points lying on z=0 plane.
|
||
|
||
backproj_points_list = [
|
||
det.camera.unproject_points_to_z_plane(det.keypoints, z=0.0)
|
||
for det in camera_detections
|
||
] # each (J,3)
|
||
backproj: Float[Array, "D J 3"] = jnp.stack(backproj_points_list) # (D, J, 3)
|
||
|
||
zero_velocity = jnp.zeros((J, 3))
|
||
trk_velocities = jnp.stack(
|
||
[
|
||
trk.velocity if trk.velocity is not None else zero_velocity
|
||
for trk in trackings
|
||
]
|
||
)
|
||
|
||
predicted_pose: Float[Array, "T D J 3"] = (
|
||
kps3d_trk[:, None, :, :] # (T,1,J,3)
|
||
+ trk_velocities[:, None, :, :] * delta_t[:, :, None, None] # (T,D,1,1)
|
||
)
|
||
|
||
# Camera center – shape (3,) -> will broadcast
|
||
cam_center = cam.params.location
|
||
|
||
# Compute perpendicular distance using vectorized formula
|
||
# p1 = cam_center (3,)
|
||
# p2 = backproj (D, J, 3)
|
||
# P = predicted_pose (T, D, J, 3)
|
||
# Broadcast plan: v1 = P - p1 → (T, D, J, 3)
|
||
# v2 = p2[None, ...]-p1 → (1, D, J, 3)
|
||
# Shapes now line up; no stray singleton axis.
|
||
p1 = cam_center
|
||
p2 = backproj
|
||
P = predicted_pose
|
||
v1 = P - p1
|
||
v2 = p2[None, :, :, :] - p1 # (1, D, J, 3)
|
||
cross = jnp.cross(v1, v2) # (T, D, J, 3)
|
||
num = jnp.linalg.norm(cross, axis=-1) # (T, D, J)
|
||
den = jnp.linalg.norm(v2, axis=-1) # (1, D, J)
|
||
dist3d: Float[Array, "T D J"] = num / den
|
||
|
||
affinity_3d = (
|
||
w_3d * (1 - dist3d / alpha_3d) * jnp.exp(-lambda_a * delta_t_broadcast)
|
||
)
|
||
|
||
# ------------------------------------------------------------------
|
||
# Combine and reduce across keypoints → (T, D)
|
||
# ------------------------------------------------------------------
|
||
total_affinity: Float[Array, "T D"] = jnp.sum(affinity_2d + affinity_3d, axis=-1)
|
||
return total_affinity # type: ignore[return-value]
|
||
|
||
|
||
@beartype
|
||
def calculate_affinity_matrix(
|
||
trackings: Sequence[Tracking],
|
||
detections: Sequence[Detection] | Mapping[CameraID, list[Detection]],
|
||
w_2d: float,
|
||
alpha_2d: float,
|
||
w_3d: float,
|
||
alpha_3d: float,
|
||
lambda_a: float,
|
||
) -> dict[CameraID, AffinityResult]:
|
||
"""
|
||
Calculate the affinity matrix between a set of trackings and detections.
|
||
|
||
Args:
|
||
trackings: Sequence of tracking objects
|
||
detections: Sequence of detection objects or a group detections by ID
|
||
w_2d: Weight for 2D affinity
|
||
alpha_2d: Normalization factor for 2D distance
|
||
w_3d: Weight for 3D affinity
|
||
alpha_3d: Normalization factor for 3D distance
|
||
lambda_a: Decay rate for time difference
|
||
Returns:
|
||
A dictionary mapping camera IDs to affinity results.
|
||
"""
|
||
if isinstance(detections, Mapping):
|
||
detection_by_camera = detections
|
||
else:
|
||
detection_by_camera = classify_by_camera(detections)
|
||
|
||
res: dict[CameraID, AffinityResult] = {}
|
||
for camera_id, camera_detections in detection_by_camera.items():
|
||
affinity_matrix = calculate_camera_affinity_matrix_jax(
|
||
trackings,
|
||
camera_detections,
|
||
w_2d,
|
||
alpha_2d,
|
||
w_3d,
|
||
alpha_3d,
|
||
lambda_a,
|
||
)
|
||
# row, col
|
||
indices_T, indices_D = linear_sum_assignment(affinity_matrix)
|
||
affinity_result = AffinityResult(
|
||
matrix=affinity_matrix,
|
||
trackings=trackings,
|
||
detections=camera_detections,
|
||
indices_T=indices_T,
|
||
indices_D=indices_D,
|
||
)
|
||
res[camera_id] = affinity_result
|
||
return res
|
||
|
||
|
||
# %%
|
||
# let's do cross-view association
|
||
W_2D = 1.0
|
||
ALPHA_2D = 1.0
|
||
LAMBDA_A = 0.1
|
||
W_3D = 1.0
|
||
ALPHA_3D = 1.0
|
||
|
||
trackings = sorted(global_tracking_state.trackings.values(), key=lambda x: x.id)
|
||
unmatched_detections = shallow_copy(next_group)
|
||
camera_detections = classify_by_camera(unmatched_detections)
|
||
|
||
affinities = calculate_affinity_matrix(
|
||
trackings,
|
||
unmatched_detections,
|
||
w_2d=W_2D,
|
||
alpha_2d=ALPHA_2D,
|
||
w_3d=W_3D,
|
||
alpha_3d=ALPHA_3D,
|
||
lambda_a=LAMBDA_A,
|
||
)
|
||
display(affinities)
|
||
|
||
# %%
|