forked from HQU-gxy/CVTH3PE
1106 lines
39 KiB
Python
1106 lines
39 KiB
Python
from _collections_abc import dict_values
|
||
from math import isnan
|
||
from pathlib import Path
|
||
from re import L
|
||
import awkward as ak
|
||
from typing import (
|
||
Any,
|
||
Generator,
|
||
Iterable,
|
||
Optional,
|
||
Sequence,
|
||
TypeAlias,
|
||
TypedDict,
|
||
cast,
|
||
TypeVar,
|
||
)
|
||
from datetime import datetime, timedelta
|
||
from jaxtyping import Array, Float, Num, jaxtyped
|
||
import numpy as np
|
||
from cv2 import undistortPoints
|
||
from app.camera import Camera, CameraParams, Detection
|
||
import jax.numpy as jnp
|
||
from beartype import beartype
|
||
from scipy.spatial.transform import Rotation as R
|
||
from filter_object_by_box import (
|
||
filter_kps_in_contours,
|
||
calculater_box_3d_points,
|
||
calculater_box_2d_points,
|
||
calculater_box_common_scope,
|
||
calculate_triangle_union,
|
||
get_contours,
|
||
)
|
||
from app.tracking import (
|
||
TrackingID,
|
||
AffinityResult,
|
||
LastDifferenceVelocityFilter,
|
||
Tracking,
|
||
TrackingState,
|
||
)
|
||
from app.camera import (
|
||
Camera,
|
||
CameraID,
|
||
CameraParams,
|
||
Detection,
|
||
calculate_affinity_matrix_by_epipolar_constraint,
|
||
classify_by_camera,
|
||
)
|
||
from copy import copy as shallow_copy
|
||
from pyrsistent import pvector, v, m, pmap, PMap, freeze, thaw
|
||
import jax
|
||
from optax.assignment import hungarian_algorithm as linear_sum_assignment
|
||
from beartype.typing import Mapping, Sequence
|
||
from itertools import chain
|
||
import orjson
|
||
|
||
NDArray: TypeAlias = np.ndarray
|
||
DetectionGenerator: TypeAlias = Generator[Detection, None, None]
|
||
|
||
DELTA_T_MIN = timedelta(milliseconds=1)
|
||
"""所有类型"""
|
||
|
||
T = TypeVar("T")
|
||
|
||
|
||
def unwrap(val: Optional[T]) -> T:
|
||
if val is None:
|
||
raise ValueError("None")
|
||
return val
|
||
|
||
|
||
class KeypointDataset(TypedDict):
|
||
frame_index: int
|
||
boxes: Num[NDArray, "N 4"]
|
||
kps: Num[NDArray, "N J 2"]
|
||
kps_scores: Num[NDArray, "N J"]
|
||
|
||
|
||
class Resolution(TypedDict):
|
||
width: int
|
||
height: int
|
||
|
||
|
||
class Intrinsic(TypedDict):
|
||
camera_matrix: Num[Array, "3 3"]
|
||
"""
|
||
K
|
||
"""
|
||
distortion_coefficients: Num[Array, "N"]
|
||
"""
|
||
distortion coefficients; usually 5
|
||
"""
|
||
|
||
|
||
class Extrinsic(TypedDict):
|
||
rvec: Num[NDArray, "3"]
|
||
tvec: Num[NDArray, "3"]
|
||
|
||
|
||
class ExternalCameraParams(TypedDict):
|
||
name: str
|
||
port: int
|
||
intrinsic: Intrinsic
|
||
extrinsic: Extrinsic
|
||
resolution: Resolution
|
||
|
||
|
||
"""获得所有机位的相机内外参"""
|
||
|
||
|
||
def get_camera_params(camera_path: Path) -> ak.Array:
|
||
camera_dataset: ak.Array = ak.from_parquet(camera_path / "camera_params.parquet")
|
||
return camera_dataset
|
||
|
||
|
||
"""获取所有机位的2d检测数据"""
|
||
|
||
|
||
def get_camera_detect(
|
||
detect_path: Path, camera_port: list[int], camera_dataset: ak.Array
|
||
) -> dict[int, ak.Array]:
|
||
keypoint_data: dict[int, ak.Array] = {}
|
||
for element_port in ak.to_numpy(camera_dataset["port"]):
|
||
if element_port in camera_port:
|
||
keypoint_data[int(element_port)] = ak.from_parquet(
|
||
detect_path / f"{element_port}_detected.parquet"
|
||
)
|
||
return keypoint_data
|
||
|
||
|
||
"""获得指定帧的2d检测数据(一段完整的跳跃片段)"""
|
||
|
||
|
||
def get_segment(
|
||
camera_port: list[int], frame_index: list[int], keypoint_data: dict[int, ak.Array]
|
||
) -> dict[int, ak.Array]:
|
||
for port in camera_port:
|
||
segement_data = []
|
||
camera_data = keypoint_data[port]
|
||
for index, element_frame in enumerate(ak.to_numpy(camera_data["frame_index"])):
|
||
if element_frame in frame_index:
|
||
segement_data.append(camera_data[index])
|
||
keypoint_data[port] = ak.Array(segement_data)
|
||
|
||
return keypoint_data
|
||
|
||
|
||
"""将所有2d检测数据打包"""
|
||
|
||
|
||
@jaxtyped(typechecker=beartype)
|
||
def undistort_points(
|
||
points: Num[NDArray, "M 2"],
|
||
camera_matrix: Num[NDArray, "3 3"],
|
||
dist_coeffs: Num[NDArray, "N"],
|
||
) -> Num[NDArray, "M 2"]:
|
||
K = camera_matrix
|
||
dist = dist_coeffs
|
||
res = undistortPoints(points, K, dist, P=K) # type: ignore
|
||
return res.reshape(-1, 2)
|
||
|
||
|
||
@jaxtyped(typechecker=beartype)
|
||
def to_transformation_matrix(
|
||
rvec: Num[NDArray, "3"], tvec: Num[NDArray, "3"]
|
||
) -> Num[NDArray, "4 4"]:
|
||
res = np.eye(4)
|
||
res[:3, :3] = R.from_rotvec(rvec).as_matrix()
|
||
res[:3, 3] = tvec
|
||
return res
|
||
|
||
|
||
def from_camera_params(camera: ExternalCameraParams) -> Camera:
|
||
rt = jnp.array(
|
||
to_transformation_matrix(
|
||
ak.to_numpy(camera["extrinsic"]["rvec"]),
|
||
ak.to_numpy(camera["extrinsic"]["tvec"]),
|
||
)
|
||
)
|
||
K = jnp.array(camera["intrinsic"]["camera_matrix"]).reshape(3, 3)
|
||
dist_coeffs = jnp.array(camera["intrinsic"]["distortion_coefficients"])
|
||
image_size = jnp.array(
|
||
(camera["resolution"]["width"], camera["resolution"]["height"])
|
||
)
|
||
return Camera(
|
||
id=camera["name"],
|
||
params=CameraParams(
|
||
K=K,
|
||
Rt=rt,
|
||
dist_coeffs=dist_coeffs,
|
||
image_size=image_size,
|
||
),
|
||
)
|
||
|
||
|
||
def preprocess_keypoint_dataset(
|
||
dataset: Sequence[KeypointDataset],
|
||
camera: Camera,
|
||
fps: float,
|
||
start_timestamp: datetime,
|
||
) -> Generator[Detection, None, None]:
|
||
frame_interval_s = 1 / fps
|
||
for el in dataset:
|
||
frame_index = el["frame_index"]
|
||
timestamp = start_timestamp + timedelta(seconds=frame_index * frame_interval_s)
|
||
for kp, kp_score, boxes in zip(el["kps"], el["kps_scores"], el["boxes"]):
|
||
kp = undistort_points(
|
||
np.asarray(kp),
|
||
np.asarray(camera.params.K),
|
||
np.asarray(camera.params.dist_coeffs),
|
||
)
|
||
|
||
yield Detection(
|
||
keypoints=jnp.array(kp),
|
||
confidences=jnp.array(kp_score),
|
||
camera=camera,
|
||
timestamp=timestamp,
|
||
)
|
||
|
||
|
||
def sync_batch_gen(
|
||
gens: list[DetectionGenerator], diff: timedelta
|
||
) -> Generator[list[Detection], Any, None]:
|
||
from more_itertools import partition
|
||
|
||
"""
|
||
given a list of detection generators, return a generator that yields a batch of detections
|
||
|
||
Args:
|
||
gens: list of detection generators
|
||
diff: maximum timestamp difference between detections to consider them part of the same batch
|
||
"""
|
||
N = len(gens)
|
||
last_batch_timestamp: Optional[datetime] = None
|
||
current_batch: list[Detection] = []
|
||
paused: list[bool] = [False] * N
|
||
finished: list[bool] = [False] * N
|
||
unmached_detections: list[Detection] = []
|
||
|
||
def reset_paused():
|
||
"""
|
||
reset paused list based on finished list
|
||
"""
|
||
for i in range(N):
|
||
if not finished[i]:
|
||
paused[i] = False
|
||
else:
|
||
paused[i] = True
|
||
|
||
EPS = 1e-6
|
||
# a small epsilon to avoid floating point precision issues
|
||
diff_esp = diff - timedelta(seconds=EPS)
|
||
while True:
|
||
for i, gen in enumerate(gens):
|
||
try:
|
||
if finished[i] or paused[i]:
|
||
if all(finished):
|
||
if len(current_batch) > 0:
|
||
# All generators exhausted, flush remaining batch and exit
|
||
yield current_batch
|
||
return
|
||
else:
|
||
continue
|
||
val = next(gen)
|
||
if last_batch_timestamp is None:
|
||
last_batch_timestamp = val.timestamp
|
||
current_batch.append(val)
|
||
else:
|
||
if abs(val.timestamp - last_batch_timestamp) >= diff_esp:
|
||
unmached_detections.append(val)
|
||
paused[i] = True
|
||
if all(paused):
|
||
yield current_batch
|
||
reset_paused()
|
||
last_batch_timestamp = last_batch_timestamp + diff
|
||
bad, good = partition(
|
||
lambda x: x.timestamp < unwrap(last_batch_timestamp),
|
||
unmached_detections,
|
||
)
|
||
current_batch = list(good)
|
||
unmached_detections = list(bad)
|
||
else:
|
||
current_batch.append(val)
|
||
except StopIteration:
|
||
return
|
||
|
||
|
||
def get_batch_detect(
|
||
keypoint_dataset,
|
||
camera_dataset,
|
||
camera_port: list[int],
|
||
FPS: int = 24,
|
||
batch_fps: int = 10,
|
||
) -> Generator[list[Detection], Any, None]:
|
||
gen_data = [
|
||
preprocess_keypoint_dataset(
|
||
keypoint_dataset[port],
|
||
from_camera_params(camera_dataset[camera_dataset["port"] == port][0]),
|
||
FPS,
|
||
datetime(2024, 4, 2, 12, 0, 0),
|
||
)
|
||
for port in camera_port
|
||
]
|
||
|
||
sync_gen: Generator[list[Detection], Any, None] = sync_batch_gen(
|
||
gen_data,
|
||
timedelta(seconds=1 / batch_fps),
|
||
)
|
||
return sync_gen
|
||
|
||
|
||
"""通过盒子进行筛选构建第一帧匹配数据"""
|
||
|
||
|
||
def get_filter_detections(detections: list[Detection]) -> list[Detection]:
|
||
filter_detections: list[Detection] = []
|
||
for element_detection in detections:
|
||
filter_box_points_3d = calculater_box_3d_points()
|
||
box_points_2d = calculater_box_2d_points(
|
||
filter_box_points_3d, element_detection.camera
|
||
)
|
||
box_triangles_all_points = calculater_box_common_scope(box_points_2d)
|
||
union_area, union_polygon = calculate_triangle_union(box_triangles_all_points)
|
||
contours = get_contours(union_polygon)
|
||
|
||
if filter_kps_in_contours(element_detection.keypoints, contours):
|
||
filter_detections.append(element_detection)
|
||
return filter_detections
|
||
|
||
|
||
"""追踪"""
|
||
|
||
|
||
@jaxtyped(typechecker=beartype)
|
||
def triangulate_points_from_multiple_views_linear_time_weighted(
|
||
proj_matrices: Float[Array, "N 3 4"],
|
||
points: Num[Array, "N P 2"],
|
||
delta_t: Num[Array, "N"],
|
||
lambda_t: float = 10.0,
|
||
confidences: Optional[Float[Array, "N P"]] = None,
|
||
) -> Float[Array, "P 3"]:
|
||
"""
|
||
Vectorized version that triangulates P points from N camera views with time-weighting.
|
||
|
||
This function uses JAX's vmap to efficiently triangulate multiple points in parallel.
|
||
|
||
Args:
|
||
proj_matrices: Shape (N, 3, 4) projection matrices for N cameras
|
||
points: Shape (N, P, 2) 2D points for P keypoints across N cameras
|
||
delta_t: Shape (N,) time differences between current time and each camera's timestamp (seconds)
|
||
lambda_t: Time penalty rate (higher values decrease influence of older observations)
|
||
confidences: Shape (N, P) confidence values for each point in each camera
|
||
|
||
Returns:
|
||
points_3d: Shape (P, 3) triangulated 3D points
|
||
"""
|
||
N, P, _ = points.shape
|
||
assert (
|
||
proj_matrices.shape[0] == N
|
||
), "Number of projection matrices must match number of cameras"
|
||
assert delta_t.shape[0] == N, "Number of time deltas must match number of cameras"
|
||
|
||
if confidences is None:
|
||
# Create uniform confidences if none provided
|
||
conf = jnp.ones((N, P), dtype=jnp.float32)
|
||
else:
|
||
conf = confidences
|
||
|
||
# Define the vmapped version of the single-point function
|
||
# We map over the second dimension (P points) of the input arrays
|
||
vmap_triangulate = jax.vmap(
|
||
triangulate_one_point_from_multiple_views_linear_time_weighted,
|
||
in_axes=(
|
||
None,
|
||
1,
|
||
None,
|
||
None,
|
||
1,
|
||
), # proj_matrices and delta_t static, map over points
|
||
out_axes=0, # Output has first dimension corresponding to points
|
||
)
|
||
|
||
# For each point p, extract the 2D coordinates from all cameras and triangulate
|
||
return vmap_triangulate(
|
||
proj_matrices, # (N, 3, 4) - static across points
|
||
points, # (N, P, 2) - map over dim 1 (P)
|
||
delta_t, # (N,) - static across points
|
||
lambda_t, # scalar - static
|
||
conf, # (N, P) - map over dim 1 (P)
|
||
)
|
||
|
||
|
||
@jaxtyped(typechecker=beartype)
|
||
def triangulate_one_point_from_multiple_views_linear_time_weighted(
|
||
proj_matrices: Float[Array, "N 3 4"],
|
||
points: Num[Array, "N 2"],
|
||
delta_t: Num[Array, "N"],
|
||
lambda_t: float = 10.0,
|
||
confidences: Optional[Float[Array, "N"]] = None,
|
||
) -> Float[Array, "3"]:
|
||
"""
|
||
Triangulate one point from multiple views with time-weighted linear least squares.
|
||
|
||
Implements the incremental reconstruction method from "Cross-View Tracking for Multi-Human 3D Pose"
|
||
with weighting formula: w_i = exp(-λ_t(t-t_i)) / ||c^i^T||_2
|
||
|
||
Args:
|
||
proj_matrices: Shape (N, 3, 4) projection matrices sequence
|
||
points: Shape (N, 2) point coordinates sequence
|
||
delta_t: Time differences between current time and each observation (in seconds)
|
||
lambda_t: Time penalty rate (higher values decrease influence of older observations)
|
||
confidences: Shape (N,) confidence values in range [0.0, 1.0]
|
||
|
||
Returns:
|
||
point_3d: Shape (3,) triangulated 3D point
|
||
"""
|
||
assert len(proj_matrices) == len(points)
|
||
assert len(delta_t) == len(points)
|
||
|
||
N = len(proj_matrices)
|
||
|
||
# Prepare confidence weights
|
||
confi: Float[Array, "N"]
|
||
if confidences is None:
|
||
confi = jnp.ones(N, dtype=np.float32)
|
||
else:
|
||
confi = jnp.sqrt(jnp.clip(confidences, 0, 1))
|
||
|
||
time_weights = jnp.exp(-lambda_t * delta_t)
|
||
weights = time_weights * confi
|
||
sum_weights = jnp.sum(weights)
|
||
weights = jnp.where(sum_weights > 0, weights / sum_weights, weights)
|
||
|
||
A = jnp.zeros((N * 2, 4), dtype=np.float32)
|
||
for i in range(N):
|
||
x, y = points[i]
|
||
row1 = proj_matrices[i, 2] * x - proj_matrices[i, 0]
|
||
row2 = proj_matrices[i, 2] * y - proj_matrices[i, 1]
|
||
A = A.at[2 * i].set(row1 * weights[i])
|
||
A = A.at[2 * i + 1].set(row2 * weights[i])
|
||
|
||
_, _, vh = jnp.linalg.svd(A, full_matrices=False)
|
||
point_3d_homo = vh[-1]
|
||
point_3d_homo = jnp.where(point_3d_homo[3] < 0, -point_3d_homo, point_3d_homo)
|
||
is_zero_weight = jnp.sum(weights) == 0
|
||
point_3d = jnp.where(
|
||
is_zero_weight,
|
||
jnp.full((3,), jnp.nan, dtype=jnp.float32),
|
||
jnp.where(
|
||
jnp.abs(point_3d_homo[3]) > 1e-8,
|
||
point_3d_homo[:3] / point_3d_homo[3],
|
||
jnp.full((3,), jnp.nan, dtype=jnp.float32),
|
||
),
|
||
)
|
||
return point_3d
|
||
|
||
|
||
@jaxtyped(typechecker=beartype)
|
||
def triangulate_one_point_from_multiple_views_linear(
|
||
proj_matrices: Float[Array, "N 3 4"],
|
||
points: Num[Array, "N 2"],
|
||
confidences: Optional[Float[Array, "N"]] = None,
|
||
conf_threshold: float = 0.4, # 0.2
|
||
) -> Float[Array, "3"]:
|
||
"""
|
||
Args:
|
||
proj_matrices: 形状为(N, 3, 4)的投影矩阵序列
|
||
points: 形状为(N, 2)的点坐标序列
|
||
confidences: 形状为(N,)的置信度序列,范围[0.0, 1.0]
|
||
conf_threshold: 置信度阈值,低于该值的观测不参与DLT
|
||
Returns:
|
||
point_3d: 形状为(3,)的三角测量得到的3D点
|
||
"""
|
||
assert len(proj_matrices) == len(points)
|
||
N = len(proj_matrices)
|
||
# 置信度加权DLT
|
||
if confidences is None:
|
||
weights = jnp.ones(N, dtype=jnp.float32)
|
||
else:
|
||
valid_mask = confidences >= conf_threshold
|
||
weights = jnp.where(valid_mask, confidences, 0.0)
|
||
sum_weights = jnp.sum(weights)
|
||
weights = jnp.where(sum_weights > 0, weights / sum_weights, weights)
|
||
|
||
A = jnp.zeros((N * 2, 4), dtype=jnp.float32)
|
||
for i in range(N):
|
||
x, y = points[i]
|
||
row1 = proj_matrices[i, 2] * x - proj_matrices[i, 0]
|
||
row2 = proj_matrices[i, 2] * y - proj_matrices[i, 1]
|
||
A = A.at[2 * i].set(row1 * weights[i])
|
||
A = A.at[2 * i + 1].set(row2 * weights[i])
|
||
|
||
_, _, vh = jnp.linalg.svd(A, full_matrices=False)
|
||
point_3d_homo = vh[-1]
|
||
point_3d_homo = jnp.where(point_3d_homo[3] < 0, -point_3d_homo, point_3d_homo)
|
||
is_zero_weight = jnp.sum(weights) == 0
|
||
point_3d = jnp.where(
|
||
is_zero_weight,
|
||
jnp.full((3,), jnp.nan, dtype=jnp.float32),
|
||
jnp.where(
|
||
jnp.abs(point_3d_homo[3]) > 1e-8,
|
||
point_3d_homo[:3] / point_3d_homo[3],
|
||
jnp.full((3,), jnp.nan, dtype=jnp.float32),
|
||
),
|
||
)
|
||
return point_3d
|
||
|
||
|
||
@jaxtyped(typechecker=beartype)
|
||
def triangulate_points_from_multiple_views_linear(
|
||
proj_matrices: Float[Array, "N 3 4"],
|
||
points: Num[Array, "N P 2"],
|
||
confidences: Optional[Float[Array, "N P"]] = None,
|
||
) -> Float[Array, "P 3"]:
|
||
"""
|
||
Batch‐triangulate P points observed by N cameras, linearly via SVD.
|
||
|
||
Args:
|
||
proj_matrices: (N, 3, 4) projection matrices
|
||
points: (N, P, 2) image-coordinates per view
|
||
confidences: (N, P, 1) optional per-view confidences in [0,1]
|
||
|
||
Returns:
|
||
(P, 3) 3D point for each of the P tracks
|
||
"""
|
||
N, P, _ = points.shape
|
||
assert proj_matrices.shape[0] == N
|
||
|
||
if confidences is None:
|
||
conf = jnp.ones((N, P), dtype=jnp.float32)
|
||
else:
|
||
conf = jnp.array(confidences)
|
||
|
||
vmap_triangulate = jax.vmap(
|
||
triangulate_one_point_from_multiple_views_linear,
|
||
in_axes=(None, 1, 1),
|
||
out_axes=0,
|
||
)
|
||
return vmap_triangulate(proj_matrices, points, conf)
|
||
|
||
|
||
@jaxtyped(typechecker=beartype)
|
||
def triangle_from_cluster(
|
||
cluster: Sequence[Detection],
|
||
) -> tuple[Float[Array, "N 3"], datetime]:
|
||
proj_matrices = jnp.array([el.camera.params.projection_matrix for el in cluster])
|
||
points = jnp.array([el.keypoints_undistorted for el in cluster])
|
||
confidences = jnp.array([el.confidences for el in cluster])
|
||
latest_timestamp = max(el.timestamp for el in cluster)
|
||
return (
|
||
triangulate_points_from_multiple_views_linear(
|
||
proj_matrices, points, confidences=confidences
|
||
),
|
||
latest_timestamp,
|
||
)
|
||
|
||
|
||
def group_by_cluster_by_camera(
|
||
cluster: Sequence[Detection],
|
||
) -> PMap[CameraID, Detection]:
|
||
"""
|
||
group the detections by camera, and preserve the latest detection for each camera
|
||
"""
|
||
r: dict[CameraID, Detection] = {}
|
||
for el in cluster:
|
||
if el.camera.id in r:
|
||
eld = r[el.camera.id]
|
||
preserved = max([eld, el], key=lambda x: x.timestamp)
|
||
r[el.camera.id] = preserved
|
||
return pmap(r)
|
||
|
||
|
||
class GlobalTrackingState:
|
||
_last_id: int
|
||
_trackings: dict[int, Tracking]
|
||
|
||
def __init__(self):
|
||
self._last_id = 0
|
||
self._trackings = {}
|
||
|
||
def __repr__(self) -> str:
|
||
return (
|
||
f"GlobalTrackingState(last_id={self._last_id}, trackings={self._trackings})"
|
||
)
|
||
|
||
@property
|
||
def trackings(self) -> dict[int, Tracking]:
|
||
return shallow_copy(self._trackings)
|
||
|
||
def add_tracking(self, cluster: Sequence[Detection]) -> Tracking:
|
||
if len(cluster) < 2:
|
||
raise ValueError(
|
||
"cluster must contain at least 2 detections to form a tracking"
|
||
)
|
||
kps_3d, latest_timestamp = triangle_from_cluster(cluster)
|
||
next_id = self._last_id + 1
|
||
tracking_state = TrackingState(
|
||
keypoints=kps_3d,
|
||
last_active_timestamp=latest_timestamp,
|
||
historical_detections_by_camera=group_by_cluster_by_camera(cluster),
|
||
)
|
||
tracking = Tracking(
|
||
id=next_id,
|
||
state=tracking_state,
|
||
velocity_filter=LastDifferenceVelocityFilter(kps_3d, latest_timestamp),
|
||
)
|
||
self._trackings[next_id] = tracking
|
||
self._last_id = next_id
|
||
return tracking
|
||
|
||
|
||
@beartype
|
||
def calculate_camera_affinity_matrix_jax(
|
||
trackings: Sequence[Tracking],
|
||
camera_detections: Sequence[Detection],
|
||
w_2d: float,
|
||
alpha_2d: float,
|
||
w_3d: float,
|
||
alpha_3d: float,
|
||
lambda_a: float,
|
||
) -> Float[Array, "T D"]:
|
||
"""
|
||
Vectorized implementation to compute an affinity matrix between *trackings*
|
||
and *detections* coming from **one** camera.
|
||
|
||
Compared with the simple double-for-loop version, this leverages `jax`'s
|
||
broadcasting + `vmap` facilities and avoids Python loops over every
|
||
(tracking, detection) pair. The mathematical definition of the affinity
|
||
is **unchanged**, so the result remains bit-identical to the reference
|
||
implementation used in the tests.
|
||
"""
|
||
|
||
# ------------------------------------------------------------------
|
||
# Quick validations / early-exit guards
|
||
# ------------------------------------------------------------------
|
||
if len(trackings) == 0 or len(camera_detections) == 0:
|
||
# Return an empty affinity matrix with appropriate shape.
|
||
return jnp.zeros((len(trackings), len(camera_detections))) # type: ignore[return-value]
|
||
|
||
cam = next(iter(camera_detections)).camera
|
||
# Ensure every detection truly belongs to the same camera (guard clause)
|
||
cam_id = cam.id
|
||
if any(det.camera.id != cam_id for det in camera_detections):
|
||
raise ValueError(
|
||
"All detections passed to `calculate_camera_affinity_matrix` must come from one camera."
|
||
)
|
||
|
||
# We will rely on a single `Camera` instance (all detections share it)
|
||
w_img_, h_img_ = cam.params.image_size
|
||
w_img, h_img = float(w_img_), float(h_img_)
|
||
|
||
# ------------------------------------------------------------------
|
||
# Gather data into ndarray / DeviceArray batches so that we can compute
|
||
# everything in a single (or a few) fused kernels.
|
||
# ------------------------------------------------------------------
|
||
|
||
# === Tracking-side tensors ===
|
||
kps3d_trk: Float[Array, "T J 3"] = jnp.stack(
|
||
[trk.state.keypoints for trk in trackings]
|
||
) # (T, J, 3)
|
||
J = kps3d_trk.shape[1]
|
||
# === Detection-side tensors ===
|
||
kps2d_det: Float[Array, "D J 2"] = jnp.stack(
|
||
[det.keypoints for det in camera_detections]
|
||
) # (D, J, 2)
|
||
|
||
# ------------------------------------------------------------------
|
||
# Compute Δt matrix – shape (T, D)
|
||
# ------------------------------------------------------------------
|
||
# Epoch timestamps are ~1.7 × 10⁹; storing them in float32 wipes out
|
||
# sub‑second detail (resolution ≈ 200 ms). Keep them in float64 until
|
||
# after subtraction so we preserve Δt‑on‑the‑order‑of‑milliseconds.
|
||
# --- timestamps ----------
|
||
t0 = min(
|
||
chain(
|
||
(trk.state.last_active_timestamp for trk in trackings),
|
||
(det.timestamp for det in camera_detections),
|
||
)
|
||
).timestamp() # common origin (float)
|
||
ts_trk = jnp.array(
|
||
[trk.state.last_active_timestamp.timestamp() - t0 for trk in trackings],
|
||
dtype=jnp.float32, # now small, ms-scale fits in fp32
|
||
)
|
||
ts_det = jnp.array(
|
||
[det.timestamp.timestamp() - t0 for det in camera_detections],
|
||
dtype=jnp.float32,
|
||
)
|
||
# Δt in seconds, fp32 throughout
|
||
delta_t = ts_det[None, :] - ts_trk[:, None] # (T,D)
|
||
min_dt_s = float(DELTA_T_MIN.total_seconds())
|
||
delta_t = jnp.clip(delta_t, a_min=min_dt_s, a_max=None)
|
||
|
||
# ------------------------------------------------------------------
|
||
# ---------- 2D affinity -------------------------------------------
|
||
# ------------------------------------------------------------------
|
||
# Project each tracking's 3D keypoints onto the image once.
|
||
# `Camera.project` works per-sample, so we vmap over the first axis.
|
||
|
||
proj_fn = jax.vmap(cam.project, in_axes=0) # maps over the keypoint sets
|
||
kps2d_trk_proj: Float[Array, "T J 2"] = proj_fn(kps3d_trk) # (T, J, 2)
|
||
|
||
# Normalise keypoints by image size so absolute units do not bias distance
|
||
norm_trk = kps2d_trk_proj / jnp.array([w_img, h_img])
|
||
norm_det = kps2d_det / jnp.array([w_img, h_img])
|
||
|
||
# L2 distance for every (T, D, J)
|
||
# reshape for broadcasting: (T,1,J,2) vs (1,D,J,2)
|
||
diff2d = norm_trk[:, None, :, :] - norm_det[None, :, :, :]
|
||
dist2d: Float[Array, "T D J"] = jnp.linalg.norm(diff2d, axis=-1)
|
||
|
||
# Compute per-keypoint 2D affinity
|
||
delta_t_broadcast = delta_t[:, :, None] # (T, D, 1)
|
||
affinity_2d = (
|
||
w_2d
|
||
* (1 - dist2d / (alpha_2d * delta_t_broadcast))
|
||
* jnp.exp(-lambda_a * delta_t_broadcast)
|
||
)
|
||
|
||
# ------------------------------------------------------------------
|
||
# ---------- 3D affinity -------------------------------------------
|
||
# ------------------------------------------------------------------
|
||
# For each detection pre-compute back-projected 3D points lying on z=0 plane.
|
||
|
||
backproj_points_list = [
|
||
det.camera.unproject_points_to_z_plane(det.keypoints, z=0.0)
|
||
for det in camera_detections
|
||
] # each (J,3)
|
||
backproj: Float[Array, "D J 3"] = jnp.stack(backproj_points_list) # (D, J, 3)
|
||
|
||
zero_velocity = jnp.zeros((J, 3))
|
||
trk_velocities = jnp.stack(
|
||
[
|
||
trk.velocity if trk.velocity is not None else zero_velocity
|
||
for trk in trackings
|
||
]
|
||
)
|
||
|
||
predicted_pose: Float[Array, "T D J 3"] = (
|
||
kps3d_trk[:, None, :, :] # (T,1,J,3)
|
||
+ trk_velocities[:, None, :, :] * delta_t[:, :, None, None] # (T,D,1,1)
|
||
)
|
||
|
||
# Camera center – shape (3,) -> will broadcast
|
||
cam_center = cam.params.location
|
||
|
||
# Compute perpendicular distance using vectorized formula
|
||
# p1 = cam_center (3,)
|
||
# p2 = backproj (D, J, 3)
|
||
# P = predicted_pose (T, D, J, 3)
|
||
# Broadcast plan: v1 = P - p1 → (T, D, J, 3)
|
||
# v2 = p2[None, ...]-p1 → (1, D, J, 3)
|
||
# Shapes now line up; no stray singleton axis.
|
||
p1 = cam_center
|
||
p2 = backproj
|
||
P = predicted_pose
|
||
v1 = P - p1
|
||
v2 = p2[None, :, :, :] - p1 # (1, D, J, 3)
|
||
cross = jnp.cross(v1, v2) # (T, D, J, 3)
|
||
num = jnp.linalg.norm(cross, axis=-1) # (T, D, J)
|
||
den = jnp.linalg.norm(v2, axis=-1) # (1, D, J)
|
||
dist3d: Float[Array, "T D J"] = num / den
|
||
|
||
affinity_3d = (
|
||
w_3d * (1 - dist3d / alpha_3d) * jnp.exp(-lambda_a * delta_t_broadcast)
|
||
)
|
||
|
||
# ------------------------------------------------------------------
|
||
# Combine and reduce across keypoints → (T, D)
|
||
# ------------------------------------------------------------------
|
||
total_affinity: Float[Array, "T D"] = jnp.sum(affinity_2d + affinity_3d, axis=-1)
|
||
return total_affinity # type: ignore[return-value]
|
||
|
||
|
||
@beartype
|
||
def calculate_affinity_matrix(
|
||
trackings: Sequence[Tracking],
|
||
detections: Sequence[Detection] | Mapping[CameraID, list[Detection]],
|
||
w_2d: float,
|
||
alpha_2d: float,
|
||
w_3d: float,
|
||
alpha_3d: float,
|
||
lambda_a: float,
|
||
) -> dict[CameraID, AffinityResult]:
|
||
"""
|
||
Calculate the affinity matrix between a set of trackings and detections.
|
||
|
||
Args:
|
||
trackings: Sequence of tracking objects
|
||
detections: Sequence of detection objects or a group detections by ID
|
||
w_2d: Weight for 2D affinity
|
||
alpha_2d: Normalization factor for 2D distance
|
||
w_3d: Weight for 3D affinity
|
||
alpha_3d: Normalization factor for 3D distance
|
||
lambda_a: Decay rate for time difference
|
||
Returns:
|
||
A dictionary mapping camera IDs to affinity results.
|
||
"""
|
||
if isinstance(detections, Mapping):
|
||
detection_by_camera = detections
|
||
else:
|
||
detection_by_camera = classify_by_camera(detections)
|
||
|
||
res: dict[CameraID, AffinityResult] = {}
|
||
for camera_id, camera_detections in detection_by_camera.items():
|
||
affinity_matrix = calculate_camera_affinity_matrix_jax(
|
||
trackings,
|
||
camera_detections,
|
||
w_2d,
|
||
alpha_2d,
|
||
w_3d,
|
||
alpha_3d,
|
||
lambda_a,
|
||
)
|
||
# row, col
|
||
indices_T, indices_D = linear_sum_assignment(affinity_matrix)
|
||
affinity_result = AffinityResult(
|
||
matrix=affinity_matrix,
|
||
trackings=trackings,
|
||
detections=camera_detections,
|
||
indices_T=indices_T,
|
||
indices_D=indices_D,
|
||
)
|
||
res[camera_id] = affinity_result
|
||
return res
|
||
|
||
|
||
DetectionMap: TypeAlias = PMap[CameraID, Detection]
|
||
|
||
|
||
def update_tracking(
|
||
tracking: Tracking,
|
||
detections: Sequence[Detection],
|
||
max_delta_t: timedelta = timedelta(milliseconds=100),
|
||
lambda_t: float = 10.0,
|
||
) -> None:
|
||
"""
|
||
update the tracking with a new set of detections
|
||
|
||
Args:
|
||
tracking: the tracking to update
|
||
detections: the detections to update the tracking with
|
||
max_delta_t: the maximum time difference between the last active timestamp and the latest detection
|
||
lambda_t: the lambda value for the time difference
|
||
|
||
Note:
|
||
the function would mutate the tracking object
|
||
"""
|
||
last_active_timestamp = tracking.state.last_active_timestamp
|
||
latest_timestamp = max(d.timestamp for d in detections)
|
||
|
||
d = tracking.state.historical_detections_by_camera
|
||
for detection in detections:
|
||
d = cast(DetectionMap, d.update({detection.camera.id: detection}))
|
||
for camera_id, detection in d.items():
|
||
if detection.timestamp - latest_timestamp > max_delta_t:
|
||
d = d.remove(camera_id)
|
||
new_detections = d
|
||
new_detections_list = list(new_detections.values())
|
||
project_matrices = jnp.stack(
|
||
[detection.camera.params.projection_matrix for detection in new_detections_list]
|
||
)
|
||
delta_t = jnp.array(
|
||
[
|
||
detection.timestamp.timestamp() - last_active_timestamp.timestamp()
|
||
for detection in new_detections_list
|
||
]
|
||
)
|
||
kps = jnp.stack([detection.keypoints for detection in new_detections_list])
|
||
conf = jnp.stack([detection.confidences for detection in new_detections_list])
|
||
kps_3d = triangulate_points_from_multiple_views_linear_time_weighted(
|
||
project_matrices, kps, delta_t, lambda_t, conf
|
||
)
|
||
new_state = TrackingState(
|
||
keypoints=kps_3d,
|
||
last_active_timestamp=latest_timestamp,
|
||
historical_detections_by_camera=new_detections,
|
||
)
|
||
tracking.update(kps_3d, latest_timestamp)
|
||
tracking.state = new_state
|
||
|
||
|
||
# 对每一个3d目标进行滑动窗口平滑处理
|
||
def smooth_3d_keypoints(
|
||
all_3d_kps: dict[str, list], window_size: int = 5
|
||
) -> dict[str, list]:
|
||
# window_size = 5
|
||
kernel = np.ones(window_size) / window_size
|
||
smoothed_points = dict()
|
||
for item_object_index in all_3d_kps.keys():
|
||
item_object = np.array(all_3d_kps[item_object_index])
|
||
if item_object.shape[0] < window_size:
|
||
# 如果数据点少于窗口大小,则直接返回原始数据
|
||
smoothed_points[item_object_index] = item_object.tolist()
|
||
continue
|
||
|
||
# 对每个关键点的每个坐标轴分别做滑动平均
|
||
item_smoothed = np.zeros_like(item_object)
|
||
# 遍历133个关节
|
||
for kp_idx in range(item_object.shape[1]):
|
||
# 遍历每个关节的空间三维坐标点
|
||
for axis in range(3):
|
||
# 对第i帧的滑动平滑方式 smoothed[i] = (point[i-2] + point[i-1] + point[i] + point[i+1] + point[i+2]) / 5
|
||
item_smoothed[:, kp_idx, axis] = np.convolve(
|
||
item_object[:, kp_idx, axis], kernel, mode="same"
|
||
)
|
||
smoothed_points[item_object_index] = item_smoothed.tolist()
|
||
return smoothed_points
|
||
|
||
|
||
# 通过平均置信度筛选2d检测数据
|
||
def filter_keypoints_by_scores(
|
||
detections: Iterable[Detection], threshold: float = 0.5
|
||
) -> list[Detection]:
|
||
"""
|
||
Filter detections based on the average confidence score of their keypoints.
|
||
Only keep detections with an average score above the threshold.
|
||
"""
|
||
|
||
def filter_detection(detection: Detection) -> bool:
|
||
median_score = np.mean(detection.confidences[:17])
|
||
# print(f"Mean score: {median_score}")
|
||
return float(median_score) >= threshold
|
||
|
||
return [d for d in detections if filter_detection(d)]
|
||
|
||
|
||
def filter_camera_port(detections: list[Detection]):
|
||
camera_port = set()
|
||
for detection in detections:
|
||
camera_port.add(detection.camera.id)
|
||
return list(camera_port)
|
||
|
||
|
||
# 相机内外参路径
|
||
CAMERA_PATH = Path("/home/admin/Documents/ActualTest_WeiHua/camera_params/")
|
||
# 所有机位的相机内外参
|
||
AK_CAMERA_DATASET: ak.Array = get_camera_params(CAMERA_PATH)
|
||
|
||
# 2d检测数据路径
|
||
DATASET_PATH = Path("/home/admin/Documents/ActualTest_WeiHua/Segment_1/")
|
||
# 指定机位的2d检测数据
|
||
camera_port = [5602, 5603, 5604, 5605]
|
||
KEYPOINT_DATASET = get_camera_detect(DATASET_PATH, camera_port, AK_CAMERA_DATASET)
|
||
|
||
# 获取一段完整的跳跃片段
|
||
FRAME_INDEX = [i for i in range(700, 1600)] # Segement_1:(700, 1600)
|
||
KEYPOINT_DATASET = get_segment(camera_port, FRAME_INDEX, KEYPOINT_DATASET)
|
||
|
||
|
||
# 将所有的2d检测数据打包
|
||
sync_gen: Generator[list[Detection], Any, None] = get_batch_detect(
|
||
KEYPOINT_DATASET,
|
||
AK_CAMERA_DATASET,
|
||
camera_port,
|
||
batch_fps=24,
|
||
)
|
||
|
||
# 建立追踪目标
|
||
global_tracking_state = GlobalTrackingState()
|
||
# 跟踪超参数
|
||
W_2D = 0.2
|
||
ALPHA_2D = 60.0
|
||
LAMBDA_A = 5.0
|
||
W_3D = 0.8
|
||
ALPHA_3D = 0.15
|
||
# 帧数计数器
|
||
count = 0
|
||
# 追踪相似度矩阵匹配阈值
|
||
affinities_threshold = -20
|
||
# 跟踪目标集合
|
||
trackings: list[Tracking] = []
|
||
# 3d数据,键为追踪目标id,值为该目标的所有3d数据
|
||
all_3d_kps: dict[str, list] = {}
|
||
|
||
tracking_initialized = False
|
||
lost_frame_count = 0
|
||
lost_frame_threshold = 12 # 0.5秒
|
||
|
||
# 丢失目标帧数计数器
|
||
loss_track_count = 0
|
||
|
||
# ===================== 主循环:逐帧处理检测与跟踪 =====================
|
||
while True:
|
||
# 重新梳理跟踪逻辑,保证唯一目标、唯一初始化、鲁棒丢失判定
|
||
try:
|
||
# 获取下一个时间步的所有相机检测结果
|
||
detections = next(sync_gen)
|
||
# 过滤低置信度的检测,提升后续三角化和跟踪的准确性
|
||
detections = filter_keypoints_by_scores(detections, threshold=0.2)
|
||
# detections = get_filter_detections(detections) # 伞降跳台时使用
|
||
except StopIteration:
|
||
# 检测数据读取完毕,退出主循环
|
||
break
|
||
|
||
# 1. 检查当前是否有已初始化的跟踪目标
|
||
# 跟踪目标按id排序,便于后续一致性处理
|
||
trackings: list[Tracking] = sorted(
|
||
global_tracking_state.trackings.values(), key=lambda x: x.id
|
||
)
|
||
|
||
# ========== 跟踪目标初始化流程 ==========
|
||
if not tracking_initialized:
|
||
# 只初始化一次跟踪目标,防止多次重复初始化
|
||
camera_port_this = filter_camera_port(detections) # 获取当前帧检测到的相机端口
|
||
# 如果检测到的相机数量小于总机位数-1,则认为初始化条件不满足,跳过本帧
|
||
if len(camera_port_this) < len(camera_port) - 1:
|
||
print(
|
||
"init tracking error, filter_detections len:",
|
||
len(camera_port_this),
|
||
)
|
||
continue
|
||
# 满足条件后,初始化全局跟踪状态,添加跟踪目标
|
||
global_tracking_state.add_tracking(detections)
|
||
tracking_initialized = True # 标记已初始化
|
||
lost_frame_count = 0 # 丢失帧计数器归零
|
||
# 保留第一帧的3d姿态数据,按id存储到all_3d_kps字典
|
||
for element_tracking in global_tracking_state.trackings.values():
|
||
if str(element_tracking.id) not in all_3d_kps.keys():
|
||
all_3d_kps[str(element_tracking.id)] = [
|
||
element_tracking.state.keypoints.tolist()
|
||
]
|
||
print("init tracking:", global_tracking_state.trackings.values())
|
||
continue # 跳过本帧后续处理,进入下一帧
|
||
|
||
# ========== 丢失目标处理 ==========
|
||
if len(detections) == 0:
|
||
# 当前帧没有检测到目标,进入丢失计数逻辑
|
||
print("no detections in this frame, continue")
|
||
lost_frame_count += 1 # 丢失帧数+1
|
||
# 进一步完善退出条件:
|
||
# 1. 连续丢失阈值帧后才退出
|
||
# 2. 若丢失时,最后一次检测到的时间与当前帧时间间隔超过0.5秒,才彻底退出
|
||
last_tracking = None
|
||
if global_tracking_state.trackings:
|
||
last_tracking = list(global_tracking_state.trackings.values())[0]
|
||
if lost_frame_count >= lost_frame_threshold:
|
||
should_remove = True
|
||
# 可选:可加时间间隔判定(可扩展)
|
||
if should_remove:
|
||
global_tracking_state._trackings.clear() # 清空所有跟踪目标
|
||
tracking_initialized = False # 允许后续重新初始化
|
||
print(
|
||
f"tracking lost after {lost_frame_count} frames, reset tracking state"
|
||
)
|
||
lost_frame_count = 0 # 丢失计数归零
|
||
continue # 跳过本帧后续处理
|
||
|
||
# ========== 正常跟踪流程 ==========
|
||
lost_frame_count = 0 # 检测到目标,丢失计数归零
|
||
# 计算当前帧所有跟踪目标与检测目标的相似度矩阵(多机位)
|
||
affinities: dict[str, AffinityResult] = calculate_affinity_matrix(
|
||
trackings,
|
||
detections,
|
||
w_2d=W_2D,
|
||
alpha_2d=ALPHA_2D,
|
||
w_3d=W_3D,
|
||
alpha_3d=ALPHA_3D,
|
||
lambda_a=LAMBDA_A,
|
||
)
|
||
for element_tracking in trackings:
|
||
tracking_detection = [] # 存储每个跟踪目标在各相机下最优匹配的检测
|
||
temp_matrix = [] # 打印用:每个相机的最大相似度
|
||
for camera_name in affinities.keys():
|
||
# indices_T:表示匹配到检测的tracking的索引(在tracking列表中的下标)
|
||
# indices_D:表示匹配到tracking的detection的索引(在detections列表中的下标)
|
||
indices_T = affinities[camera_name].indices_T.item()
|
||
indices_D = affinities[camera_name].indices_D.item()
|
||
camera_matrix = jnp.array(affinities[camera_name].matrix).flatten()
|
||
detection_index = jnp.argmax(camera_matrix).item() # 取最大相似度的检测索引
|
||
if isnan(camera_matrix[detection_index].item()):
|
||
breakpoint() # 出现异常时调试
|
||
temp_matrix.append(
|
||
f"{camera_name} : {camera_matrix[detection_index].item()}"
|
||
)
|
||
match_tracking = affinities[camera_name].trackings[indices_T]
|
||
# 选取相似度大于阈值的检测目标更新跟踪状态
|
||
# if camera_matrix[detection_index].item() > affinities_threshold:
|
||
# if match_tracking == element_tracking:
|
||
tracking_detection.append(affinities[camera_name].detections[indices_D])
|
||
print("affinities matrix:", temp_matrix)
|
||
# 只有匹配到足够多的检测目标时才更新跟踪(如多于2个相机)
|
||
if len(tracking_detection) > 2:
|
||
update_tracking(element_tracking, tracking_detection)
|
||
# 记录每一帧的3d关键点结果
|
||
all_3d_kps[str(element_tracking.id)].append(
|
||
element_tracking.state.keypoints.tolist()
|
||
)
|
||
print(
|
||
"update tracking:",
|
||
global_tracking_state.trackings.values(),
|
||
)
|
||
else:
|
||
loss_track_count += 1
|
||
# ======如果单帧数据量不够,考虑如何更新跟踪=====
|
||
|
||
# 对每一个3d目标进行滑动窗口平滑处理
|
||
smoothed_points = smooth_3d_keypoints(all_3d_kps, window_size=5)
|
||
|
||
print("Tracking completed, total loss frames processed:", count)
|
||
|
||
# 将结果保存到json文件中
|
||
with open("samples/Test_WeiHua_Segment_1.json", "wb") as f:
|
||
f.write(orjson.dumps(smoothed_points))
|
||
# 输出每个3d目标的维度
|
||
for element_3d_kps_id in smoothed_points.keys():
|
||
print(f"{element_3d_kps_id} : {np.array(all_3d_kps[element_3d_kps_id]).shape}")
|