Files
CVTH3PE/single_people_detect_track.py
2025-07-09 10:31:22 +08:00

1096 lines
38 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from _collections_abc import dict_values
from math import isnan
from pathlib import Path
from re import L
import awkward as ak
from typing import (
Any,
Generator,
Iterable,
Optional,
Sequence,
TypeAlias,
TypedDict,
cast,
TypeVar,
)
from datetime import datetime, timedelta
from jaxtyping import Array, Float, Num, jaxtyped
import numpy as np
from cv2 import undistortPoints
from app.camera import Camera, CameraParams, Detection
import jax.numpy as jnp
from beartype import beartype
from scipy.spatial.transform import Rotation as R
from filter_object_by_box import (
filter_kps_in_contours,
calculater_box_3d_points,
calculater_box_2d_points,
calculater_box_common_scope,
calculate_triangle_union,
get_contours,
)
from app.tracking import (
TrackingID,
AffinityResult,
LastDifferenceVelocityFilter,
Tracking,
TrackingState,
)
from app.camera import (
Camera,
CameraID,
CameraParams,
Detection,
calculate_affinity_matrix_by_epipolar_constraint,
classify_by_camera,
)
from copy import copy as shallow_copy
from pyrsistent import pvector, v, m, pmap, PMap, freeze, thaw
import jax
from optax.assignment import hungarian_algorithm as linear_sum_assignment
from beartype.typing import Mapping, Sequence
from itertools import chain
import orjson
NDArray: TypeAlias = np.ndarray
DetectionGenerator: TypeAlias = Generator[Detection, None, None]
DELTA_T_MIN = timedelta(milliseconds=1)
"""所有类型"""
T = TypeVar("T")
def unwrap(val: Optional[T]) -> T:
if val is None:
raise ValueError("None")
return val
class KeypointDataset(TypedDict):
frame_index: int
boxes: Num[NDArray, "N 4"]
kps: Num[NDArray, "N J 2"]
kps_scores: Num[NDArray, "N J"]
class Resolution(TypedDict):
width: int
height: int
class Intrinsic(TypedDict):
camera_matrix: Num[Array, "3 3"]
"""
K
"""
distortion_coefficients: Num[Array, "N"]
"""
distortion coefficients; usually 5
"""
class Extrinsic(TypedDict):
rvec: Num[NDArray, "3"]
tvec: Num[NDArray, "3"]
class ExternalCameraParams(TypedDict):
name: str
port: int
intrinsic: Intrinsic
extrinsic: Extrinsic
resolution: Resolution
"""获得所有机位的相机内外参"""
def get_camera_params(camera_path: Path) -> ak.Array:
camera_dataset: ak.Array = ak.from_parquet(camera_path / "camera_params.parquet")
return camera_dataset
"""获取所有机位的2d检测数据"""
def get_camera_detect(
detect_path: Path, camera_port: list[int], camera_dataset: ak.Array
) -> dict[int, ak.Array]:
keypoint_data: dict[int, ak.Array] = {}
for element_port in ak.to_numpy(camera_dataset["port"]):
if element_port in camera_port:
keypoint_data[int(element_port)] = ak.from_parquet(
detect_path / f"{element_port}_detected.parquet"
)
return keypoint_data
"""获得指定帧的2d检测数据(一段完整的跳跃片段)"""
def get_segment(
camera_port: list[int], frame_index: list[int], keypoint_data: dict[int, ak.Array]
) -> dict[int, ak.Array]:
for port in camera_port:
segement_data = []
camera_data = keypoint_data[port]
for index, element_frame in enumerate(ak.to_numpy(camera_data["frame_index"])):
if element_frame in frame_index:
segement_data.append(camera_data[index])
keypoint_data[port] = ak.Array(segement_data)
return keypoint_data
"""将所有2d检测数据打包"""
@jaxtyped(typechecker=beartype)
def undistort_points(
points: Num[NDArray, "M 2"],
camera_matrix: Num[NDArray, "3 3"],
dist_coeffs: Num[NDArray, "N"],
) -> Num[NDArray, "M 2"]:
K = camera_matrix
dist = dist_coeffs
res = undistortPoints(points, K, dist, P=K) # type: ignore
return res.reshape(-1, 2)
@jaxtyped(typechecker=beartype)
def to_transformation_matrix(
rvec: Num[NDArray, "3"], tvec: Num[NDArray, "3"]
) -> Num[NDArray, "4 4"]:
res = np.eye(4)
res[:3, :3] = R.from_rotvec(rvec).as_matrix()
res[:3, 3] = tvec
return res
def from_camera_params(camera: ExternalCameraParams) -> Camera:
rt = jnp.array(
to_transformation_matrix(
ak.to_numpy(camera["extrinsic"]["rvec"]),
ak.to_numpy(camera["extrinsic"]["tvec"]),
)
)
K = jnp.array(camera["intrinsic"]["camera_matrix"]).reshape(3, 3)
dist_coeffs = jnp.array(camera["intrinsic"]["distortion_coefficients"])
image_size = jnp.array(
(camera["resolution"]["width"], camera["resolution"]["height"])
)
return Camera(
id=camera["name"],
params=CameraParams(
K=K,
Rt=rt,
dist_coeffs=dist_coeffs,
image_size=image_size,
),
)
def preprocess_keypoint_dataset(
dataset: Sequence[KeypointDataset],
camera: Camera,
fps: float,
start_timestamp: datetime,
) -> Generator[Detection, None, None]:
frame_interval_s = 1 / fps
for el in dataset:
frame_index = el["frame_index"]
timestamp = start_timestamp + timedelta(seconds=frame_index * frame_interval_s)
for kp, kp_score, boxes in zip(el["kps"], el["kps_scores"], el["boxes"]):
kp = undistort_points(
np.asarray(kp),
np.asarray(camera.params.K),
np.asarray(camera.params.dist_coeffs),
)
yield Detection(
keypoints=jnp.array(kp),
confidences=jnp.array(kp_score),
camera=camera,
timestamp=timestamp,
)
def sync_batch_gen(
gens: list[DetectionGenerator], diff: timedelta
) -> Generator[list[Detection], Any, None]:
from more_itertools import partition
"""
given a list of detection generators, return a generator that yields a batch of detections
Args:
gens: list of detection generators
diff: maximum timestamp difference between detections to consider them part of the same batch
"""
N = len(gens)
last_batch_timestamp: Optional[datetime] = None
current_batch: list[Detection] = []
paused: list[bool] = [False] * N
finished: list[bool] = [False] * N
unmached_detections: list[Detection] = []
def reset_paused():
"""
reset paused list based on finished list
"""
for i in range(N):
if not finished[i]:
paused[i] = False
else:
paused[i] = True
EPS = 1e-6
# a small epsilon to avoid floating point precision issues
diff_esp = diff - timedelta(seconds=EPS)
while True:
for i, gen in enumerate(gens):
try:
if finished[i] or paused[i]:
if all(finished):
if len(current_batch) > 0:
# All generators exhausted, flush remaining batch and exit
yield current_batch
return
else:
continue
val = next(gen)
if last_batch_timestamp is None:
last_batch_timestamp = val.timestamp
current_batch.append(val)
else:
if abs(val.timestamp - last_batch_timestamp) >= diff_esp:
unmached_detections.append(val)
paused[i] = True
if all(paused):
yield current_batch
reset_paused()
last_batch_timestamp = last_batch_timestamp + diff
bad, good = partition(
lambda x: x.timestamp < unwrap(last_batch_timestamp),
unmached_detections,
)
current_batch = list(good)
unmached_detections = list(bad)
else:
current_batch.append(val)
except StopIteration:
return
def get_batch_detect(
keypoint_dataset,
camera_dataset,
camera_port: list[int],
FPS: int = 24,
batch_fps: int = 10,
) -> Generator[list[Detection], Any, None]:
gen_data = [
preprocess_keypoint_dataset(
keypoint_dataset[port],
from_camera_params(camera_dataset[camera_dataset["port"] == port][0]),
FPS,
datetime(2024, 4, 2, 12, 0, 0),
)
for port in camera_port
]
sync_gen: Generator[list[Detection], Any, None] = sync_batch_gen(
gen_data,
timedelta(seconds=1 / batch_fps),
)
return sync_gen
"""通过盒子进行筛选构建第一帧匹配数据"""
def get_filter_detections(detections: list[Detection]) -> list[Detection]:
filter_detections: list[Detection] = []
for element_detection in detections:
filter_box_points_3d = calculater_box_3d_points()
box_points_2d = calculater_box_2d_points(
filter_box_points_3d, element_detection.camera
)
box_triangles_all_points = calculater_box_common_scope(box_points_2d)
union_area, union_polygon = calculate_triangle_union(box_triangles_all_points)
contours = get_contours(union_polygon)
if filter_kps_in_contours(element_detection.keypoints, contours):
filter_detections.append(element_detection)
return filter_detections
"""追踪"""
@jaxtyped(typechecker=beartype)
def triangulate_points_from_multiple_views_linear_time_weighted(
proj_matrices: Float[Array, "N 3 4"],
points: Num[Array, "N P 2"],
delta_t: Num[Array, "N"],
lambda_t: float = 10.0,
confidences: Optional[Float[Array, "N P"]] = None,
) -> Float[Array, "P 3"]:
"""
Vectorized version that triangulates P points from N camera views with time-weighting.
This function uses JAX's vmap to efficiently triangulate multiple points in parallel.
Args:
proj_matrices: Shape (N, 3, 4) projection matrices for N cameras
points: Shape (N, P, 2) 2D points for P keypoints across N cameras
delta_t: Shape (N,) time differences between current time and each camera's timestamp (seconds)
lambda_t: Time penalty rate (higher values decrease influence of older observations)
confidences: Shape (N, P) confidence values for each point in each camera
Returns:
points_3d: Shape (P, 3) triangulated 3D points
"""
N, P, _ = points.shape
assert (
proj_matrices.shape[0] == N
), "Number of projection matrices must match number of cameras"
assert delta_t.shape[0] == N, "Number of time deltas must match number of cameras"
if confidences is None:
# Create uniform confidences if none provided
conf = jnp.ones((N, P), dtype=jnp.float32)
else:
conf = confidences
# Define the vmapped version of the single-point function
# We map over the second dimension (P points) of the input arrays
vmap_triangulate = jax.vmap(
triangulate_one_point_from_multiple_views_linear_time_weighted,
in_axes=(
None,
1,
None,
None,
1,
), # proj_matrices and delta_t static, map over points
out_axes=0, # Output has first dimension corresponding to points
)
# For each point p, extract the 2D coordinates from all cameras and triangulate
return vmap_triangulate(
proj_matrices, # (N, 3, 4) - static across points
points, # (N, P, 2) - map over dim 1 (P)
delta_t, # (N,) - static across points
lambda_t, # scalar - static
conf, # (N, P) - map over dim 1 (P)
)
@jaxtyped(typechecker=beartype)
def triangulate_one_point_from_multiple_views_linear_time_weighted(
proj_matrices: Float[Array, "N 3 4"],
points: Num[Array, "N 2"],
delta_t: Num[Array, "N"],
lambda_t: float = 10.0,
confidences: Optional[Float[Array, "N"]] = None,
) -> Float[Array, "3"]:
"""
Triangulate one point from multiple views with time-weighted linear least squares.
Implements the incremental reconstruction method from "Cross-View Tracking for Multi-Human 3D Pose"
with weighting formula: w_i = exp(-λ_t(t-t_i)) / ||c^i^T||_2
Args:
proj_matrices: Shape (N, 3, 4) projection matrices sequence
points: Shape (N, 2) point coordinates sequence
delta_t: Time differences between current time and each observation (in seconds)
lambda_t: Time penalty rate (higher values decrease influence of older observations)
confidences: Shape (N,) confidence values in range [0.0, 1.0]
Returns:
point_3d: Shape (3,) triangulated 3D point
"""
assert len(proj_matrices) == len(points)
assert len(delta_t) == len(points)
N = len(proj_matrices)
# Prepare confidence weights
confi: Float[Array, "N"]
if confidences is None:
confi = jnp.ones(N, dtype=np.float32)
else:
confi = jnp.sqrt(jnp.clip(confidences, 0, 1))
time_weights = jnp.exp(-lambda_t * delta_t)
weights = time_weights * confi
sum_weights = jnp.sum(weights)
weights = jnp.where(sum_weights > 0, weights / sum_weights, weights)
A = jnp.zeros((N * 2, 4), dtype=np.float32)
for i in range(N):
x, y = points[i]
row1 = proj_matrices[i, 2] * x - proj_matrices[i, 0]
row2 = proj_matrices[i, 2] * y - proj_matrices[i, 1]
A = A.at[2 * i].set(row1 * weights[i])
A = A.at[2 * i + 1].set(row2 * weights[i])
_, _, vh = jnp.linalg.svd(A, full_matrices=False)
point_3d_homo = vh[-1]
point_3d_homo = jnp.where(point_3d_homo[3] < 0, -point_3d_homo, point_3d_homo)
is_zero_weight = jnp.sum(weights) == 0
point_3d = jnp.where(
is_zero_weight,
jnp.full((3,), jnp.nan, dtype=jnp.float32),
jnp.where(
jnp.abs(point_3d_homo[3]) > 1e-8,
point_3d_homo[:3] / point_3d_homo[3],
jnp.full((3,), jnp.nan, dtype=jnp.float32),
),
)
return point_3d
@jaxtyped(typechecker=beartype)
def triangulate_one_point_from_multiple_views_linear(
proj_matrices: Float[Array, "N 3 4"],
points: Num[Array, "N 2"],
confidences: Optional[Float[Array, "N"]] = None,
conf_threshold: float = 0.2,
) -> Float[Array, "3"]:
"""
Args:
proj_matrices: 形状为(N, 3, 4)的投影矩阵序列
points: 形状为(N, 2)的点坐标序列
confidences: 形状为(N,)的置信度序列,范围[0.0, 1.0]
conf_threshold: 置信度阈值低于该值的观测不参与DLT
Returns:
point_3d: 形状为(3,)的三角测量得到的3D点
"""
assert len(proj_matrices) == len(points)
N = len(proj_matrices)
# 置信度加权DLT
# 置信度加权DLT
if confidences is None:
weights = jnp.ones(N, dtype=jnp.float32)
else:
valid_mask = confidences >= conf_threshold
weights = jnp.where(valid_mask, confidences, 0.0)
sum_weights = jnp.sum(weights)
weights = jnp.where(sum_weights > 0, weights / sum_weights, weights)
A = jnp.zeros((N * 2, 4), dtype=jnp.float32)
for i in range(N):
x, y = points[i]
row1 = proj_matrices[i, 2] * x - proj_matrices[i, 0]
row2 = proj_matrices[i, 2] * y - proj_matrices[i, 1]
A = A.at[2 * i].set(row1 * weights[i])
A = A.at[2 * i + 1].set(row2 * weights[i])
_, _, vh = jnp.linalg.svd(A, full_matrices=False)
point_3d_homo = vh[-1]
point_3d_homo = jnp.where(point_3d_homo[3] < 0, -point_3d_homo, point_3d_homo)
is_zero_weight = jnp.sum(weights) == 0
point_3d = jnp.where(
is_zero_weight,
jnp.full((3,), jnp.nan, dtype=jnp.float32),
jnp.where(
jnp.abs(point_3d_homo[3]) > 1e-8,
point_3d_homo[:3] / point_3d_homo[3],
jnp.full((3,), jnp.nan, dtype=jnp.float32),
),
)
return point_3d
@jaxtyped(typechecker=beartype)
def triangulate_points_from_multiple_views_linear(
proj_matrices: Float[Array, "N 3 4"],
points: Num[Array, "N P 2"],
confidences: Optional[Float[Array, "N P"]] = None,
) -> Float[Array, "P 3"]:
"""
Batchtriangulate P points observed by N cameras, linearly via SVD.
Args:
proj_matrices: (N, 3, 4) projection matrices
points: (N, P, 2) image-coordinates per view
confidences: (N, P, 1) optional per-view confidences in [0,1]
Returns:
(P, 3) 3D point for each of the P tracks
"""
N, P, _ = points.shape
assert proj_matrices.shape[0] == N
if confidences is None:
conf = jnp.ones((N, P), dtype=jnp.float32)
else:
conf = jnp.array(confidences)
vmap_triangulate = jax.vmap(
triangulate_one_point_from_multiple_views_linear,
in_axes=(None, 1, 1),
out_axes=0,
)
return vmap_triangulate(proj_matrices, points, conf)
@jaxtyped(typechecker=beartype)
def triangle_from_cluster(
cluster: Sequence[Detection],
) -> tuple[Float[Array, "N 3"], datetime]:
proj_matrices = jnp.array([el.camera.params.projection_matrix for el in cluster])
points = jnp.array([el.keypoints_undistorted for el in cluster])
confidences = jnp.array([el.confidences for el in cluster])
latest_timestamp = max(el.timestamp for el in cluster)
return (
triangulate_points_from_multiple_views_linear(
proj_matrices, points, confidences=confidences
),
latest_timestamp,
)
def group_by_cluster_by_camera(
cluster: Sequence[Detection],
) -> PMap[CameraID, Detection]:
"""
group the detections by camera, and preserve the latest detection for each camera
"""
r: dict[CameraID, Detection] = {}
for el in cluster:
if el.camera.id in r:
eld = r[el.camera.id]
preserved = max([eld, el], key=lambda x: x.timestamp)
r[el.camera.id] = preserved
return pmap(r)
class GlobalTrackingState:
_last_id: int
_trackings: dict[int, Tracking]
def __init__(self):
self._last_id = 0
self._trackings = {}
def __repr__(self) -> str:
return (
f"GlobalTrackingState(last_id={self._last_id}, trackings={self._trackings})"
)
@property
def trackings(self) -> dict[int, Tracking]:
return shallow_copy(self._trackings)
def add_tracking(self, cluster: Sequence[Detection]) -> Tracking:
if len(cluster) < 2:
raise ValueError(
"cluster must contain at least 2 detections to form a tracking"
)
kps_3d, latest_timestamp = triangle_from_cluster(cluster)
next_id = self._last_id + 1
tracking_state = TrackingState(
keypoints=kps_3d,
last_active_timestamp=latest_timestamp,
historical_detections_by_camera=group_by_cluster_by_camera(cluster),
)
tracking = Tracking(
id=next_id,
state=tracking_state,
velocity_filter=LastDifferenceVelocityFilter(kps_3d, latest_timestamp),
)
self._trackings[next_id] = tracking
self._last_id = next_id
return tracking
@beartype
def calculate_camera_affinity_matrix_jax(
trackings: Sequence[Tracking],
camera_detections: Sequence[Detection],
w_2d: float,
alpha_2d: float,
w_3d: float,
alpha_3d: float,
lambda_a: float,
) -> Float[Array, "T D"]:
"""
Vectorized implementation to compute an affinity matrix between *trackings*
and *detections* coming from **one** camera.
Compared with the simple double-for-loop version, this leverages `jax`'s
broadcasting + `vmap` facilities and avoids Python loops over every
(tracking, detection) pair. The mathematical definition of the affinity
is **unchanged**, so the result remains bit-identical to the reference
implementation used in the tests.
"""
# ------------------------------------------------------------------
# Quick validations / early-exit guards
# ------------------------------------------------------------------
if len(trackings) == 0 or len(camera_detections) == 0:
# Return an empty affinity matrix with appropriate shape.
return jnp.zeros((len(trackings), len(camera_detections))) # type: ignore[return-value]
cam = next(iter(camera_detections)).camera
# Ensure every detection truly belongs to the same camera (guard clause)
cam_id = cam.id
if any(det.camera.id != cam_id for det in camera_detections):
raise ValueError(
"All detections passed to `calculate_camera_affinity_matrix` must come from one camera."
)
# We will rely on a single `Camera` instance (all detections share it)
w_img_, h_img_ = cam.params.image_size
w_img, h_img = float(w_img_), float(h_img_)
# ------------------------------------------------------------------
# Gather data into ndarray / DeviceArray batches so that we can compute
# everything in a single (or a few) fused kernels.
# ------------------------------------------------------------------
# === Tracking-side tensors ===
kps3d_trk: Float[Array, "T J 3"] = jnp.stack(
[trk.state.keypoints for trk in trackings]
) # (T, J, 3)
J = kps3d_trk.shape[1]
# === Detection-side tensors ===
kps2d_det: Float[Array, "D J 2"] = jnp.stack(
[det.keypoints for det in camera_detections]
) # (D, J, 2)
# ------------------------------------------------------------------
# Compute Δt matrix shape (T, D)
# ------------------------------------------------------------------
# Epoch timestamps are ~1.7 × 10⁹; storing them in float32 wipes out
# subsecond detail (resolution ≈ 200 ms). Keep them in float64 until
# after subtraction so we preserve Δtontheorderofmilliseconds.
# --- timestamps ----------
t0 = min(
chain(
(trk.state.last_active_timestamp for trk in trackings),
(det.timestamp for det in camera_detections),
)
).timestamp() # common origin (float)
ts_trk = jnp.array(
[trk.state.last_active_timestamp.timestamp() - t0 for trk in trackings],
dtype=jnp.float32, # now small, ms-scale fits in fp32
)
ts_det = jnp.array(
[det.timestamp.timestamp() - t0 for det in camera_detections],
dtype=jnp.float32,
)
# Δt in seconds, fp32 throughout
delta_t = ts_det[None, :] - ts_trk[:, None] # (T,D)
min_dt_s = float(DELTA_T_MIN.total_seconds())
delta_t = jnp.clip(delta_t, a_min=min_dt_s, a_max=None)
# ------------------------------------------------------------------
# ---------- 2D affinity -------------------------------------------
# ------------------------------------------------------------------
# Project each tracking's 3D keypoints onto the image once.
# `Camera.project` works per-sample, so we vmap over the first axis.
proj_fn = jax.vmap(cam.project, in_axes=0) # maps over the keypoint sets
kps2d_trk_proj: Float[Array, "T J 2"] = proj_fn(kps3d_trk) # (T, J, 2)
# Normalise keypoints by image size so absolute units do not bias distance
norm_trk = kps2d_trk_proj / jnp.array([w_img, h_img])
norm_det = kps2d_det / jnp.array([w_img, h_img])
# L2 distance for every (T, D, J)
# reshape for broadcasting: (T,1,J,2) vs (1,D,J,2)
diff2d = norm_trk[:, None, :, :] - norm_det[None, :, :, :]
dist2d: Float[Array, "T D J"] = jnp.linalg.norm(diff2d, axis=-1)
# Compute per-keypoint 2D affinity
delta_t_broadcast = delta_t[:, :, None] # (T, D, 1)
affinity_2d = (
w_2d
* (1 - dist2d / (alpha_2d * delta_t_broadcast))
* jnp.exp(-lambda_a * delta_t_broadcast)
)
# ------------------------------------------------------------------
# ---------- 3D affinity -------------------------------------------
# ------------------------------------------------------------------
# For each detection pre-compute back-projected 3D points lying on z=0 plane.
backproj_points_list = [
det.camera.unproject_points_to_z_plane(det.keypoints, z=0.0)
for det in camera_detections
] # each (J,3)
backproj: Float[Array, "D J 3"] = jnp.stack(backproj_points_list) # (D, J, 3)
zero_velocity = jnp.zeros((J, 3))
trk_velocities = jnp.stack(
[
trk.velocity if trk.velocity is not None else zero_velocity
for trk in trackings
]
)
predicted_pose: Float[Array, "T D J 3"] = (
kps3d_trk[:, None, :, :] # (T,1,J,3)
+ trk_velocities[:, None, :, :] * delta_t[:, :, None, None] # (T,D,1,1)
)
# Camera center shape (3,) -> will broadcast
cam_center = cam.params.location
# Compute perpendicular distance using vectorized formula
# p1 = cam_center (3,)
# p2 = backproj (D, J, 3)
# P = predicted_pose (T, D, J, 3)
# Broadcast plan: v1 = P - p1 → (T, D, J, 3)
# v2 = p2[None, ...]-p1 → (1, D, J, 3)
# Shapes now line up; no stray singleton axis.
p1 = cam_center
p2 = backproj
P = predicted_pose
v1 = P - p1
v2 = p2[None, :, :, :] - p1 # (1, D, J, 3)
cross = jnp.cross(v1, v2) # (T, D, J, 3)
num = jnp.linalg.norm(cross, axis=-1) # (T, D, J)
den = jnp.linalg.norm(v2, axis=-1) # (1, D, J)
dist3d: Float[Array, "T D J"] = num / den
affinity_3d = (
w_3d * (1 - dist3d / alpha_3d) * jnp.exp(-lambda_a * delta_t_broadcast)
)
# ------------------------------------------------------------------
# Combine and reduce across keypoints → (T, D)
# ------------------------------------------------------------------
total_affinity: Float[Array, "T D"] = jnp.sum(affinity_2d + affinity_3d, axis=-1)
return total_affinity # type: ignore[return-value]
@beartype
def calculate_affinity_matrix(
trackings: Sequence[Tracking],
detections: Sequence[Detection] | Mapping[CameraID, list[Detection]],
w_2d: float,
alpha_2d: float,
w_3d: float,
alpha_3d: float,
lambda_a: float,
) -> dict[CameraID, AffinityResult]:
"""
Calculate the affinity matrix between a set of trackings and detections.
Args:
trackings: Sequence of tracking objects
detections: Sequence of detection objects or a group detections by ID
w_2d: Weight for 2D affinity
alpha_2d: Normalization factor for 2D distance
w_3d: Weight for 3D affinity
alpha_3d: Normalization factor for 3D distance
lambda_a: Decay rate for time difference
Returns:
A dictionary mapping camera IDs to affinity results.
"""
if isinstance(detections, Mapping):
detection_by_camera = detections
else:
detection_by_camera = classify_by_camera(detections)
res: dict[CameraID, AffinityResult] = {}
for camera_id, camera_detections in detection_by_camera.items():
affinity_matrix = calculate_camera_affinity_matrix_jax(
trackings,
camera_detections,
w_2d,
alpha_2d,
w_3d,
alpha_3d,
lambda_a,
)
# row, col
indices_T, indices_D = linear_sum_assignment(affinity_matrix)
affinity_result = AffinityResult(
matrix=affinity_matrix,
trackings=trackings,
detections=camera_detections,
indices_T=indices_T,
indices_D=indices_D,
)
res[camera_id] = affinity_result
return res
DetectionMap: TypeAlias = PMap[CameraID, Detection]
def update_tracking(
tracking: Tracking,
detections: Sequence[Detection],
max_delta_t: timedelta = timedelta(milliseconds=100),
lambda_t: float = 10.0,
) -> None:
"""
update the tracking with a new set of detections
Args:
tracking: the tracking to update
detections: the detections to update the tracking with
max_delta_t: the maximum time difference between the last active timestamp and the latest detection
lambda_t: the lambda value for the time difference
Note:
the function would mutate the tracking object
"""
last_active_timestamp = tracking.state.last_active_timestamp
latest_timestamp = max(d.timestamp for d in detections)
d = tracking.state.historical_detections_by_camera
for detection in detections:
d = cast(DetectionMap, d.update({detection.camera.id: detection}))
for camera_id, detection in d.items():
if detection.timestamp - latest_timestamp > max_delta_t:
d = d.remove(camera_id)
new_detections = d
new_detections_list = list(new_detections.values())
project_matrices = jnp.stack(
[detection.camera.params.projection_matrix for detection in new_detections_list]
)
delta_t = jnp.array(
[
detection.timestamp.timestamp() - last_active_timestamp.timestamp()
for detection in new_detections_list
]
)
kps = jnp.stack([detection.keypoints for detection in new_detections_list])
conf = jnp.stack([detection.confidences for detection in new_detections_list])
kps_3d = triangulate_points_from_multiple_views_linear_time_weighted(
project_matrices, kps, delta_t, lambda_t, conf
)
new_state = TrackingState(
keypoints=kps_3d,
last_active_timestamp=latest_timestamp,
historical_detections_by_camera=new_detections,
)
tracking.update(kps_3d, latest_timestamp)
tracking.state = new_state
# 对每一个3d目标进行滑动窗口平滑处理
def smooth_3d_keypoints(
all_3d_kps: dict[str, list], window_size: int = 5
) -> dict[str, list]:
# window_size = 5
kernel = np.ones(window_size) / window_size
smoothed_points = dict()
for item_object_index in all_3d_kps.keys():
item_object = np.array(all_3d_kps[item_object_index])
if item_object.shape[0] < window_size:
# 如果数据点少于窗口大小,则直接返回原始数据
smoothed_points[item_object_index] = item_object.tolist()
continue
# 对每个关键点的每个坐标轴分别做滑动平均
item_smoothed = np.zeros_like(item_object)
# 遍历133个关节
for kp_idx in range(item_object.shape[1]):
# 遍历每个关节的空间三维坐标点
for axis in range(3):
# 对第i帧的滑动平滑方式 smoothed[i] = (point[i-2] + point[i-1] + point[i] + point[i+1] + point[i+2]) / 5
item_smoothed[:, kp_idx, axis] = np.convolve(
item_object[:, kp_idx, axis], kernel, mode="same"
)
smoothed_points[item_object_index] = item_smoothed.tolist()
return smoothed_points
# 通过平均置信度筛选2d检测数据
def filter_keypoints_by_scores(
detections: Iterable[Detection], threshold: float = 0.5
) -> list[Detection]:
"""
Filter detections based on the average confidence score of their keypoints.
Only keep detections with an average score above the threshold.
"""
def filter_detection(detection: Detection) -> bool:
median_score = np.mean(detection.confidences[:17])
# print(f"Mean score: {median_score}")
return float(median_score) >= threshold
return [d for d in detections if filter_detection(d)]
def filter_camera_port(detections: list[Detection]):
camera_port = set()
for detection in detections:
camera_port.add(detection.camera.id)
return list(camera_port)
# 相机内外参路径
CAMERA_PATH = Path("/home/admin/Documents/ActualTest_WeiHua/camera_params")
# 所有机位的相机内外参
AK_CAMERA_DATASET: ak.Array = get_camera_params(CAMERA_PATH)
# 2d检测数据路径
DATASET_PATH = Path("/home/admin/Documents/ActualTest_WeiHua/Test_Video")
# 指定机位的2d检测数据
camera_port = [5602, 5603, 5604, 5605]
KEYPOINT_DATASET = get_camera_detect(DATASET_PATH, camera_port, AK_CAMERA_DATASET)
# 获取一段完整的跳跃片段
FRAME_INDEX = [i for i in range(700, 1600)] # 552 1488
KEYPOINT_DATASET = get_segment(camera_port, FRAME_INDEX, KEYPOINT_DATASET)
# 将所有的2d检测数据打包
sync_gen: Generator[list[Detection], Any, None] = get_batch_detect(
KEYPOINT_DATASET,
AK_CAMERA_DATASET,
camera_port,
batch_fps=24,
)
# 建立追踪目标
global_tracking_state = GlobalTrackingState()
# 跟踪超参数
W_2D = 0.2
ALPHA_2D = 60.0
LAMBDA_A = 5.0
W_3D = 0.8
ALPHA_3D = 0.15
# 帧数计数器
count = 0
# 追踪相似度矩阵匹配阈值
affinities_threshold = -20
# 跟踪目标集合
trackings: list[Tracking] = []
# 3d数据键为追踪目标id值为该目标的所有3d数据
all_3d_kps: dict[str, list] = {}
tracking_initialized = False
lost_frame_count = 0
lost_frame_threshold = 12 # 0.5秒假设20fps
# ===================== 主循环:逐帧处理检测与跟踪 =====================
while True:
# 重新梳理跟踪逻辑,保证唯一目标、唯一初始化、鲁棒丢失判定
try:
# 获取下一个时间步的所有相机检测结果
detections = next(sync_gen)
# 过滤低置信度的检测,提升后续三角化和跟踪的准确性
detections = filter_keypoints_by_scores(detections, threshold=0.5)
except StopIteration:
# 检测数据读取完毕,退出主循环
break
# 1. 检查当前是否有已初始化的跟踪目标
# 跟踪目标按id排序便于后续一致性处理
trackings: list[Tracking] = sorted(
global_tracking_state.trackings.values(), key=lambda x: x.id
)
# ========== 跟踪目标初始化流程 ==========
if not tracking_initialized:
# 只初始化一次跟踪目标,防止多次重复初始化
camera_port_this = filter_camera_port(detections) # 获取当前帧检测到的相机端口
# 如果检测到的相机数量小于总机位数-1则认为初始化条件不满足跳过本帧
if len(camera_port_this) < len(camera_port) - 1:
print(
"init tracking error, filter_detections len:",
len(camera_port_this),
)
continue
# 满足条件后,初始化全局跟踪状态,添加跟踪目标
global_tracking_state.add_tracking(detections)
tracking_initialized = True # 标记已初始化
lost_frame_count = 0 # 丢失帧计数器归零
# 保留第一帧的3d姿态数据按id存储到all_3d_kps字典
for element_tracking in global_tracking_state.trackings.values():
if str(element_tracking.id) not in all_3d_kps.keys():
all_3d_kps[str(element_tracking.id)] = [
element_tracking.state.keypoints.tolist()
]
print("init tracking:", global_tracking_state.trackings.values())
continue # 跳过本帧后续处理,进入下一帧
# ========== 丢失目标处理 ==========
if len(detections) == 0:
# 当前帧没有检测到目标,进入丢失计数逻辑
print("no detections in this frame, continue")
lost_frame_count += 1 # 丢失帧数+1
# 进一步完善退出条件:
# 1. 连续丢失阈值帧后才退出
# 2. 若丢失时最后一次检测到的时间与当前帧时间间隔超过1秒才彻底退出
last_tracking = None
if global_tracking_state.trackings:
last_tracking = list(global_tracking_state.trackings.values())[0]
if lost_frame_count >= lost_frame_threshold:
should_remove = True
# 可选:可加时间间隔判定(可扩展)
if should_remove:
global_tracking_state._trackings.clear() # 清空所有跟踪目标
tracking_initialized = False # 允许后续重新初始化
print(
f"tracking lost after {lost_frame_count} frames, reset tracking state"
)
lost_frame_count = 0 # 丢失计数归零
continue # 跳过本帧后续处理
# ========== 正常跟踪流程 ==========
lost_frame_count = 0 # 检测到目标,丢失计数归零
# 计算当前帧所有跟踪目标与检测目标的相似度矩阵(多机位)
affinities: dict[str, AffinityResult] = calculate_affinity_matrix(
trackings,
detections,
w_2d=W_2D,
alpha_2d=ALPHA_2D,
w_3d=W_3D,
alpha_3d=ALPHA_3D,
lambda_a=LAMBDA_A,
)
for element_tracking in trackings:
tracking_detection = [] # 存储每个跟踪目标在各相机下最优匹配的检测
temp_matrix = [] # 打印用:每个相机的最大相似度
for camera_name in affinities.keys():
camera_matrix = jnp.array(affinities[camera_name].matrix).flatten()
detection_index = jnp.argmax(camera_matrix).item() # 取最大相似度的检测索引
if isnan(camera_matrix[detection_index].item()):
breakpoint() # 出现异常时调试
temp_matrix.append(
f"{camera_name} : {camera_matrix[detection_index].item()}"
)
# 选取相似度大于阈值的检测目标更新跟踪状态
# if camera_matrix[detection_index].item() > affinities_threshold:
tracking_detection.append(
affinities[camera_name].detections[detection_index]
)
print("affinities matrix:", temp_matrix)
# 只有匹配到足够多的检测目标时才更新跟踪如多于2个相机
if len(tracking_detection) > 2:
update_tracking(element_tracking, tracking_detection)
# 记录每一帧的3d关键点结果
all_3d_kps[str(element_tracking.id)].append(
element_tracking.state.keypoints.tolist()
)
print(
"update tracking:",
global_tracking_state.trackings.values(),
)
# 不再在else分支里删除tracking只用lost_frame_count判定
# 对每一个3d目标进行滑动窗口平滑处理
smoothed_points = smooth_3d_keypoints(all_3d_kps, window_size=5)
# 将结果保存到json文件中
with open("samples/Test_WeiHua.json", "wb") as f:
f.write(orjson.dumps(smoothed_points))
# 输出每个3d目标的维度
for element_3d_kps_id in smoothed_points.keys():
print(f"{element_3d_kps_id} : {np.array(all_3d_kps[element_3d_kps_id]).shape}")