1
0
forked from HQU-gxy/CVTH3PE

Single-person tracking, modified the tracking initialization and tracking exit conditions, and also modified the weighted calculation process of DLT

This commit is contained in:
2025-07-08 17:00:10 +08:00
parent 835367cd6d
commit b4ac324d8f
2 changed files with 968 additions and 183 deletions

View File

@ -133,15 +133,9 @@ def get_camera_detect(
def get_segment(
camera_port: list[int], frame_index: list[int], keypoint_data: dict[int, ak.Array]
) -> dict[int, ak.Array]:
# for port in camera_port:
# keypoint_data[port] = [
# element_frame
# for element_frame in KEYPOINT_DATASET[port]
# if element_frame["frame_index"] in frame_index
# ]
for port in camera_port:
segement_data = []
camera_data = KEYPOINT_DATASET[port]
camera_data = keypoint_data[port]
for index, element_frame in enumerate(ak.to_numpy(camera_data["frame_index"])):
if element_frame in frame_index:
segement_data.append(camera_data[index])
@ -431,43 +425,32 @@ def triangulate_one_point_from_multiple_views_linear_time_weighted(
else:
confi = jnp.sqrt(jnp.clip(confidences, 0, 1))
A = jnp.zeros((N * 2, 4), dtype=np.float32)
time_weights = jnp.exp(-lambda_t * delta_t)
weights = time_weights * confi
sum_weights = jnp.sum(weights)
weights = jnp.where(sum_weights > 0, weights / sum_weights, weights)
# First build the coefficient matrix without weights
A = jnp.zeros((N * 2, 4), dtype=np.float32)
for i in range(N):
x, y = points[i]
A = A.at[2 * i].set(proj_matrices[i, 2] * x - proj_matrices[i, 0])
A = A.at[2 * i + 1].set(proj_matrices[i, 2] * y - proj_matrices[i, 1])
row1 = proj_matrices[i, 2] * x - proj_matrices[i, 0]
row2 = proj_matrices[i, 2] * y - proj_matrices[i, 1]
A = A.at[2 * i].set(row1 * weights[i])
A = A.at[2 * i + 1].set(row2 * weights[i])
# Then apply the time-based and confidence weights
for i in range(N):
# Calculate time-decay weight: e^(-λ_t * Δt)
time_weight = jnp.exp(-lambda_t * delta_t[i])
# Calculate normalization factor: ||c^i^T||_2
row_norm_1 = jnp.linalg.norm(A[2 * i])
row_norm_2 = jnp.linalg.norm(A[2 * i + 1])
# Apply combined weight: time_weight / row_norm * confidence
w1 = (time_weight / row_norm_1) * confi[i]
w2 = (time_weight / row_norm_2) * confi[i]
A = A.at[2 * i].mul(w1)
A = A.at[2 * i + 1].mul(w2)
# Solve using SVD
_, _, vh = jnp.linalg.svd(A, full_matrices=False)
point_3d_homo = vh[-1] # shape (4,)
# Ensure homogeneous coordinate is positive
point_3d_homo = jnp.where(
point_3d_homo[3] < 0,
-point_3d_homo,
point_3d_homo,
point_3d_homo = vh[-1]
point_3d_homo = jnp.where(point_3d_homo[3] < 0, -point_3d_homo, point_3d_homo)
is_zero_weight = jnp.sum(weights) == 0
point_3d = jnp.where(
is_zero_weight,
jnp.full((3,), jnp.nan, dtype=jnp.float32),
jnp.where(
jnp.abs(point_3d_homo[3]) > 1e-8,
point_3d_homo[:3] / point_3d_homo[3],
jnp.full((3,), jnp.nan, dtype=jnp.float32),
),
)
# Convert from homogeneous to Euclidean coordinates
point_3d = point_3d_homo[:3] / point_3d_homo[3]
return point_3d
@ -489,36 +472,37 @@ def triangulate_one_point_from_multiple_views_linear(
"""
assert len(proj_matrices) == len(points)
N = len(proj_matrices)
# 置信度加权DLT
# 置信度加权DLT
if confidences is None:
weights = jnp.ones(N, dtype=jnp.float32)
else:
# 置信度低于阈值的点权重为0其余为sqrt(conf)
valid_mask = confidences >= conf_threshold
weights = jnp.where(valid_mask, jnp.sqrt(jnp.clip(confidences, 0, 1)), 0.0)
# 归一化权重,避免某一帧权重过大
weights = jnp.where(valid_mask, confidences, 0.0)
sum_weights = jnp.sum(weights)
weights = jnp.where(sum_weights > 0, weights / sum_weights, weights)
A = jnp.zeros((N * 2, 4), dtype=jnp.float32)
for i in range(N):
x, y = points[i]
A = A.at[2 * i].set(proj_matrices[i, 2] * x - proj_matrices[i, 0])
A = A.at[2 * i + 1].set(proj_matrices[i, 2] * y - proj_matrices[i, 1])
A = A.at[2 * i].mul(weights[i])
A = A.at[2 * i + 1].mul(weights[i])
row1 = proj_matrices[i, 2] * x - proj_matrices[i, 0]
row2 = proj_matrices[i, 2] * y - proj_matrices[i, 1]
A = A.at[2 * i].set(row1 * weights[i])
A = A.at[2 * i + 1].set(row2 * weights[i])
# https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.svd.html
_, _, vh = jnp.linalg.svd(A, full_matrices=False)
point_3d_homo = vh[-1] # shape (4,)
# replace the Python `if` with a jnp.where
point_3d_homo = jnp.where(
point_3d_homo[3] < 0, # predicate (scalar bool tracer)
-point_3d_homo, # if True
point_3d_homo, # if False
point_3d_homo = vh[-1]
point_3d_homo = jnp.where(point_3d_homo[3] < 0, -point_3d_homo, point_3d_homo)
is_zero_weight = jnp.sum(weights) == 0
point_3d = jnp.where(
is_zero_weight,
jnp.full((3,), jnp.nan, dtype=jnp.float32),
jnp.where(
jnp.abs(point_3d_homo[3]) > 1e-8,
point_3d_homo[:3] / point_3d_homo[3],
jnp.full((3,), jnp.nan, dtype=jnp.float32),
),
)
point_3d = point_3d_homo[:3] / point_3d_homo[3]
return point_3d
@ -542,16 +526,16 @@ def triangulate_points_from_multiple_views_linear(
N, P, _ = points.shape
assert proj_matrices.shape[0] == N
conf = jnp.array(confidences)
if confidences is None:
conf = jnp.ones((N, P), dtype=jnp.float32)
else:
conf = jnp.array(confidences)
# vectorize your onepoint routine over P
vmap_triangulate = jax.vmap(
triangulate_one_point_from_multiple_views_linear,
in_axes=(None, 1, 1), # proj_matrices static, map over points[:,p,:], conf[:,p]
in_axes=(None, 1, 1),
out_axes=0,
)
# returns (P, 3)
return vmap_triangulate(proj_matrices, points, conf)
@ -924,17 +908,20 @@ def smooth_3d_keypoints(
# 通过平均置信度筛选2d检测数据
def filter_keypoints_by_scores(detections: Iterable[Detection], threshold: float = 0.5):
def filter_keypoints_by_scores(
detections: Iterable[Detection], threshold: float = 0.5
) -> list[Detection]:
"""
Filter detections based on the average confidence score of their keypoints.
Only keep detections with an average score above the threshold.
"""
def filter_detection(detection: Detection) -> bool:
avg_score = np.mean(detection.confidences)
return float(avg_score) >= threshold
median_score = np.mean(detection.confidences[:17])
# print(f"Mean score: {median_score}")
return float(median_score) >= threshold
return filter(filter_detection, detections)
return [d for d in detections if filter_detection(d)]
def filter_camera_port(detections: list[Detection]):
@ -956,7 +943,7 @@ camera_port = [5602, 5603, 5604, 5605]
KEYPOINT_DATASET = get_camera_detect(DATASET_PATH, camera_port, AK_CAMERA_DATASET)
# 获取一段完整的跳跃片段
FRAME_INDEX = [i for i in range(552, 1488)] # 552 1488
FRAME_INDEX = [i for i in range(700, 1600)] # 552 1488
KEYPOINT_DATASET = get_segment(camera_port, FRAME_INDEX, KEYPOINT_DATASET)
@ -965,7 +952,7 @@ sync_gen: Generator[list[Detection], Any, None] = get_batch_detect(
KEYPOINT_DATASET,
AK_CAMERA_DATASET,
camera_port,
batch_fps=20,
batch_fps=24,
)
# 建立追踪目标
@ -985,134 +972,104 @@ trackings: list[Tracking] = []
# 3d数据键为追踪目标id值为该目标的所有3d数据
all_3d_kps: dict[str, list] = {}
# 遍历2d数据测试追踪状态
tracking_initialized = False
lost_frame_count = 0
lost_frame_threshold = 12 # 0.5秒假设20fps
while True:
# 获得当前追踪目标
# 重新梳理跟踪逻辑,保证唯一目标、唯一初始化、鲁棒丢失判定
try:
detections = next(sync_gen)
detections = filter_keypoints_by_scores(detections, threshold=0.5)
except StopIteration:
break
# 1. 检查当前是否有已初始化的跟踪目标
trackings: list[Tracking] = sorted(
global_tracking_state.trackings.values(), key=lambda x: x.id
)
try:
detections = next(sync_gen)
# 通过平均置信度筛选2d检测数据
# detections = list(filter_keypoints_by_scores(detections, threshold=0.5))
except StopIteration:
break
if len(detections) == 0:
print("no detections in this frame, continue")
continue
# 获得最新一帧的数据2d数据
# 判断追踪状态是否建立成功,若不成功则跳过这一帧数据,直到追踪建立
if not trackings:
"""离机时使用,用于初始化第一帧"""
"""
# 使用盒子筛选后的2d检测数据
filter_detections = get_filter_detections(detections)
# 当3个机位均有目标时才建立追踪状态
# if len(filter_detections) == 0:
# continue
if len(filter_detections) < len(camera_port):
if not tracking_initialized:
# 只初始化一次
camera_port_this = filter_camera_port(detections)
if len(camera_port_this) < len(camera_port) - 1:
print(
"init traincking error, filter filter_detections len:",
len(filter_detections),
"init tracking error, filter_detections len:",
len(camera_port_this),
)
continue
"""
# 通过平均置信度筛选2d检测数据
# detections = list(filter_keypoints_by_scores(detections, threshold=0.7))
# 当4个机位都识别到目标时才建立追踪状态
camera_port = filter_camera_port(detections)
if len(detections) < len(camera_port):
print(
"init traincking error, filter_detections len:",
len(detections),
)
else:
# 添加第一帧数据构建追踪目标
global_tracking_state.add_tracking(detections) # 离机时filter_detections
# 获得当前追踪目标
trackings: list[Tracking] = sorted(
global_tracking_state.trackings.values(), key=lambda x: x.id
)
# 保留第一帧的3d姿态数据
for element_tracking in trackings:
if str(element_tracking.id) not in all_3d_kps.keys():
all_3d_kps[str(element_tracking.id)] = [
element_tracking.state.keypoints.tolist()
]
print("initer tracking:", trackings)
else:
# 计算相似度矩阵匹配结果
affinities: dict[str, AffinityResult] = calculate_affinity_matrix(
trackings,
detections,
w_2d=W_2D,
alpha_2d=ALPHA_2D,
w_3d=W_3D,
alpha_3d=ALPHA_3D,
lambda_a=LAMBDA_A,
)
# 遍历追踪目标获得该目标的匹配2d数据
for element_tracking in trackings:
tracking_detection = []
# 匹配检测目标的索引值
detection_index = None
temp_matrix = []
# 遍历相机的追踪相似度匹配结果
for camera_name in affinities.keys():
# 获得与每个跟踪目标匹配的相似度矩阵
camera_matrix = jnp.array(affinities[camera_name].matrix).flatten()
detection_index = jnp.argmax(camera_matrix).item()
if isnan(camera_matrix[detection_index].item()):
breakpoint()
temp_matrix.append(
f"{camera_name} : {camera_matrix[detection_index].item()}"
)
# 判断相似度矩阵极大值是否大于阈值
# 目前只有一个跟踪目标,还未实现多跟踪目标的匹配-------------------------
if camera_matrix[detection_index].item() > affinities_threshold:
# 保留对应的2d检测数据
tracking_detection.append(
affinities[camera_name].detections[detection_index]
)
print("affinities matrix:", temp_matrix)
# 当2个及以上数量的机位同时检测到追踪目标时更新追踪状态
if len(tracking_detection) >= 2:
update_tracking(element_tracking, tracking_detection)
# 保留对应的3d姿态数据
all_3d_kps[str(element_tracking.id)].append(
global_tracking_state.add_tracking(detections)
tracking_initialized = True
lost_frame_count = 0
# 保留第一帧的3d姿态数据
for element_tracking in global_tracking_state.trackings.values():
if str(element_tracking.id) not in all_3d_kps.keys():
all_3d_kps[str(element_tracking.id)] = [
element_tracking.state.keypoints.tolist()
)
]
print("init tracking:", global_tracking_state.trackings.values())
continue
# tracking_initialized = True
if len(detections) == 0:
print("no detections in this frame, continue")
lost_frame_count += 1
# 进一步完善退出条件:
# 1. 连续丢失阈值帧后才退出
# 2. 若丢失时最后一次检测到的时间与当前帧时间间隔超过1秒才彻底退出
last_tracking = None
if global_tracking_state.trackings:
last_tracking = list(global_tracking_state.trackings.values())[0]
if lost_frame_count >= lost_frame_threshold:
should_remove = True
# 可选:可加时间间隔判定
if should_remove:
global_tracking_state._trackings.clear()
tracking_initialized = False
print(
"update tracking:",
global_tracking_state.trackings.values(),
f"tracking lost after {lost_frame_count} frames, reset tracking state"
)
else:
# if len(detections) == 0:
# continue
# 追踪目标丢失的时间间隔
time_gap = (
detections[0].timestamp
- element_tracking.state.last_active_timestamp
)
# 当时间间隔超过1s删除保留的追踪状态
if time_gap.seconds > 0.5:
global_tracking_state._trackings.pop(element_tracking.id)
print(
"remove trackings:",
global_tracking_state.trackings.values(),
"time:",
detections[0].timestamp,
)
lost_frame_count = 0
continue
# 有检测,正常跟踪
lost_frame_count = 0
affinities: dict[str, AffinityResult] = calculate_affinity_matrix(
trackings,
detections,
w_2d=W_2D,
alpha_2d=ALPHA_2D,
w_3d=W_3D,
alpha_3d=ALPHA_3D,
lambda_a=LAMBDA_A,
)
for element_tracking in trackings:
tracking_detection = []
temp_matrix = []
for camera_name in affinities.keys():
camera_matrix = jnp.array(affinities[camera_name].matrix).flatten()
detection_index = jnp.argmax(camera_matrix).item()
if isnan(camera_matrix[detection_index].item()):
breakpoint()
temp_matrix.append(
f"{camera_name} : {camera_matrix[detection_index].item()}"
)
# 选取相似度大于阈值的检测目标更新跟踪状态
# if camera_matrix[detection_index].item() > affinities_threshold:
tracking_detection.append(
affinities[camera_name].detections[detection_index]
)
print("affinities matrix:", temp_matrix)
if len(tracking_detection) > 2:
update_tracking(element_tracking, tracking_detection)
all_3d_kps[str(element_tracking.id)].append(
element_tracking.state.keypoints.tolist()
)
print(
"update tracking:",
global_tracking_state.trackings.values(),
)
# 不再在else分支里删除tracking只用lost_frame_count判定
# 对每一个3d目标进行滑动窗口平滑处理
smoothed_points = smooth_3d_keypoints(all_3d_kps, window_size=5)