forked from HQU-gxy/CVTH3PE
new
This commit is contained in:
@ -459,7 +459,7 @@ def triangulate_one_point_from_multiple_views_linear(
|
||||
proj_matrices: Float[Array, "N 3 4"],
|
||||
points: Num[Array, "N 2"],
|
||||
confidences: Optional[Float[Array, "N"]] = None,
|
||||
conf_threshold: float = 0.2,
|
||||
conf_threshold: float = 0.4, # 0.2
|
||||
) -> Float[Array, "3"]:
|
||||
"""
|
||||
Args:
|
||||
@ -473,7 +473,6 @@ def triangulate_one_point_from_multiple_views_linear(
|
||||
assert len(proj_matrices) == len(points)
|
||||
N = len(proj_matrices)
|
||||
# 置信度加权DLT
|
||||
# 置信度加权DLT
|
||||
if confidences is None:
|
||||
weights = jnp.ones(N, dtype=jnp.float32)
|
||||
else:
|
||||
@ -932,18 +931,18 @@ def filter_camera_port(detections: list[Detection]):
|
||||
|
||||
|
||||
# 相机内外参路径
|
||||
CAMERA_PATH = Path("/home/admin/Documents/ActualTest_WeiHua/camera_params")
|
||||
CAMERA_PATH = Path("/home/admin/Documents/ActualTest_WeiHua/camera_params/")
|
||||
# 所有机位的相机内外参
|
||||
AK_CAMERA_DATASET: ak.Array = get_camera_params(CAMERA_PATH)
|
||||
|
||||
# 2d检测数据路径
|
||||
DATASET_PATH = Path("/home/admin/Documents/ActualTest_WeiHua/Test_Video")
|
||||
DATASET_PATH = Path("/home/admin/Documents/ActualTest_WeiHua/Segment_1/")
|
||||
# 指定机位的2d检测数据
|
||||
camera_port = [5602, 5603, 5604, 5605]
|
||||
KEYPOINT_DATASET = get_camera_detect(DATASET_PATH, camera_port, AK_CAMERA_DATASET)
|
||||
|
||||
# 获取一段完整的跳跃片段
|
||||
FRAME_INDEX = [i for i in range(700, 1600)] # 552, 1488
|
||||
FRAME_INDEX = [i for i in range(700, 1600)] # Segement_1:(700, 1600)
|
||||
KEYPOINT_DATASET = get_segment(camera_port, FRAME_INDEX, KEYPOINT_DATASET)
|
||||
|
||||
|
||||
@ -974,8 +973,10 @@ all_3d_kps: dict[str, list] = {}
|
||||
|
||||
tracking_initialized = False
|
||||
lost_frame_count = 0
|
||||
lost_frame_threshold = 12 # 0.5秒,假设20fps
|
||||
lost_frame_threshold = 12 # 0.5秒
|
||||
|
||||
# 丢失目标帧数计数器
|
||||
loss_track_count = 0
|
||||
|
||||
# ===================== 主循环:逐帧处理检测与跟踪 =====================
|
||||
while True:
|
||||
@ -984,7 +985,8 @@ while True:
|
||||
# 获取下一个时间步的所有相机检测结果
|
||||
detections = next(sync_gen)
|
||||
# 过滤低置信度的检测,提升后续三角化和跟踪的准确性
|
||||
detections = filter_keypoints_by_scores(detections, threshold=0.5)
|
||||
detections = filter_keypoints_by_scores(detections, threshold=0.2)
|
||||
# detections = get_filter_detections(detections) # 伞降跳台时使用
|
||||
except StopIteration:
|
||||
# 检测数据读取完毕,退出主循环
|
||||
break
|
||||
@ -1026,7 +1028,7 @@ while True:
|
||||
lost_frame_count += 1 # 丢失帧数+1
|
||||
# 进一步完善退出条件:
|
||||
# 1. 连续丢失阈值帧后才退出
|
||||
# 2. 若丢失时,最后一次检测到的时间与当前帧时间间隔超过1秒,才彻底退出
|
||||
# 2. 若丢失时,最后一次检测到的时间与当前帧时间间隔超过0.5秒,才彻底退出
|
||||
last_tracking = None
|
||||
if global_tracking_state.trackings:
|
||||
last_tracking = list(global_tracking_state.trackings.values())[0]
|
||||
@ -1058,6 +1060,10 @@ while True:
|
||||
tracking_detection = [] # 存储每个跟踪目标在各相机下最优匹配的检测
|
||||
temp_matrix = [] # 打印用:每个相机的最大相似度
|
||||
for camera_name in affinities.keys():
|
||||
# indices_T:表示匹配到检测的tracking的索引(在tracking列表中的下标)
|
||||
# indices_D:表示匹配到tracking的detection的索引(在detections列表中的下标)
|
||||
indices_T = affinities[camera_name].indices_T.item()
|
||||
indices_D = affinities[camera_name].indices_D.item()
|
||||
camera_matrix = jnp.array(affinities[camera_name].matrix).flatten()
|
||||
detection_index = jnp.argmax(camera_matrix).item() # 取最大相似度的检测索引
|
||||
if isnan(camera_matrix[detection_index].item()):
|
||||
@ -1065,11 +1071,11 @@ while True:
|
||||
temp_matrix.append(
|
||||
f"{camera_name} : {camera_matrix[detection_index].item()}"
|
||||
)
|
||||
match_tracking = affinities[camera_name].trackings[indices_T]
|
||||
# 选取相似度大于阈值的检测目标更新跟踪状态
|
||||
# if camera_matrix[detection_index].item() > affinities_threshold:
|
||||
tracking_detection.append(
|
||||
affinities[camera_name].detections[detection_index]
|
||||
)
|
||||
# if match_tracking == element_tracking:
|
||||
tracking_detection.append(affinities[camera_name].detections[indices_D])
|
||||
print("affinities matrix:", temp_matrix)
|
||||
# 只有匹配到足够多的检测目标时才更新跟踪(如多于2个相机)
|
||||
if len(tracking_detection) > 2:
|
||||
@ -1082,13 +1088,17 @@ while True:
|
||||
"update tracking:",
|
||||
global_tracking_state.trackings.values(),
|
||||
)
|
||||
# 不再在else分支里删除tracking,只用lost_frame_count判定
|
||||
else:
|
||||
loss_track_count += 1
|
||||
# ======如果单帧数据量不够,考虑如何更新跟踪=====
|
||||
|
||||
# 对每一个3d目标进行滑动窗口平滑处理
|
||||
smoothed_points = smooth_3d_keypoints(all_3d_kps, window_size=5)
|
||||
|
||||
print("Tracking completed, total loss frames processed:", count)
|
||||
|
||||
# 将结果保存到json文件中
|
||||
with open("samples/Test_WeiHua.json", "wb") as f:
|
||||
with open("samples/Test_WeiHua_Segment_1.json", "wb") as f:
|
||||
f.write(orjson.dumps(smoothed_points))
|
||||
# 输出每个3d目标的维度
|
||||
for element_3d_kps_id in smoothed_points.keys():
|
||||
|
||||
Reference in New Issue
Block a user