1
0
forked from HQU-gxy/CVTH3PE
This commit is contained in:
2025-07-11 15:33:56 +08:00
parent 3ec4a89103
commit b3da8ef7b2
5 changed files with 26 additions and 845 deletions

View File

@ -459,7 +459,7 @@ def triangulate_one_point_from_multiple_views_linear(
proj_matrices: Float[Array, "N 3 4"],
points: Num[Array, "N 2"],
confidences: Optional[Float[Array, "N"]] = None,
conf_threshold: float = 0.2,
conf_threshold: float = 0.4, # 0.2
) -> Float[Array, "3"]:
"""
Args:
@ -473,7 +473,6 @@ def triangulate_one_point_from_multiple_views_linear(
assert len(proj_matrices) == len(points)
N = len(proj_matrices)
# 置信度加权DLT
# 置信度加权DLT
if confidences is None:
weights = jnp.ones(N, dtype=jnp.float32)
else:
@ -932,18 +931,18 @@ def filter_camera_port(detections: list[Detection]):
# 相机内外参路径
CAMERA_PATH = Path("/home/admin/Documents/ActualTest_WeiHua/camera_params")
CAMERA_PATH = Path("/home/admin/Documents/ActualTest_WeiHua/camera_params/")
# 所有机位的相机内外参
AK_CAMERA_DATASET: ak.Array = get_camera_params(CAMERA_PATH)
# 2d检测数据路径
DATASET_PATH = Path("/home/admin/Documents/ActualTest_WeiHua/Test_Video")
DATASET_PATH = Path("/home/admin/Documents/ActualTest_WeiHua/Segment_1/")
# 指定机位的2d检测数据
camera_port = [5602, 5603, 5604, 5605]
KEYPOINT_DATASET = get_camera_detect(DATASET_PATH, camera_port, AK_CAMERA_DATASET)
# 获取一段完整的跳跃片段
FRAME_INDEX = [i for i in range(700, 1600)] # 552 1488
FRAME_INDEX = [i for i in range(700, 1600)] # Segement_1:(700, 1600)
KEYPOINT_DATASET = get_segment(camera_port, FRAME_INDEX, KEYPOINT_DATASET)
@ -974,8 +973,10 @@ all_3d_kps: dict[str, list] = {}
tracking_initialized = False
lost_frame_count = 0
lost_frame_threshold = 12 # 0.5秒假设20fps
lost_frame_threshold = 12 # 0.5秒
# 丢失目标帧数计数器
loss_track_count = 0
# ===================== 主循环:逐帧处理检测与跟踪 =====================
while True:
@ -984,7 +985,8 @@ while True:
# 获取下一个时间步的所有相机检测结果
detections = next(sync_gen)
# 过滤低置信度的检测,提升后续三角化和跟踪的准确性
detections = filter_keypoints_by_scores(detections, threshold=0.5)
detections = filter_keypoints_by_scores(detections, threshold=0.2)
# detections = get_filter_detections(detections) # 伞降跳台时使用
except StopIteration:
# 检测数据读取完毕,退出主循环
break
@ -1026,7 +1028,7 @@ while True:
lost_frame_count += 1 # 丢失帧数+1
# 进一步完善退出条件:
# 1. 连续丢失阈值帧后才退出
# 2. 若丢失时,最后一次检测到的时间与当前帧时间间隔超过1秒,才彻底退出
# 2. 若丢失时,最后一次检测到的时间与当前帧时间间隔超过0.5秒,才彻底退出
last_tracking = None
if global_tracking_state.trackings:
last_tracking = list(global_tracking_state.trackings.values())[0]
@ -1058,6 +1060,10 @@ while True:
tracking_detection = [] # 存储每个跟踪目标在各相机下最优匹配的检测
temp_matrix = [] # 打印用:每个相机的最大相似度
for camera_name in affinities.keys():
# indices_T:表示匹配到检测的tracking的索引(在tracking列表中的下标)
# indices_D:表示匹配到tracking的detection的索引(在detections列表中的下标)
indices_T = affinities[camera_name].indices_T.item()
indices_D = affinities[camera_name].indices_D.item()
camera_matrix = jnp.array(affinities[camera_name].matrix).flatten()
detection_index = jnp.argmax(camera_matrix).item() # 取最大相似度的检测索引
if isnan(camera_matrix[detection_index].item()):
@ -1065,11 +1071,11 @@ while True:
temp_matrix.append(
f"{camera_name} : {camera_matrix[detection_index].item()}"
)
match_tracking = affinities[camera_name].trackings[indices_T]
# 选取相似度大于阈值的检测目标更新跟踪状态
# if camera_matrix[detection_index].item() > affinities_threshold:
tracking_detection.append(
affinities[camera_name].detections[detection_index]
)
# if match_tracking == element_tracking:
tracking_detection.append(affinities[camera_name].detections[indices_D])
print("affinities matrix:", temp_matrix)
# 只有匹配到足够多的检测目标时才更新跟踪如多于2个相机
if len(tracking_detection) > 2:
@ -1082,13 +1088,17 @@ while True:
"update tracking:",
global_tracking_state.trackings.values(),
)
# 不再在else分支里删除tracking只用lost_frame_count判定
else:
loss_track_count += 1
# ======如果单帧数据量不够,考虑如何更新跟踪=====
# 对每一个3d目标进行滑动窗口平滑处理
smoothed_points = smooth_3d_keypoints(all_3d_kps, window_size=5)
print("Tracking completed, total loss frames processed:", count)
# 将结果保存到json文件中
with open("samples/Test_WeiHua.json", "wb") as f:
with open("samples/Test_WeiHua_Segment_1.json", "wb") as f:
f.write(orjson.dumps(smoothed_points))
# 输出每个3d目标的维度
for element_3d_kps_id in smoothed_points.keys():