forked from HQU-gxy/CVTH3PE
Single-person tracking, modified the tracking initialization and tracking exit conditions, and also modified the weighted calculation process of DLT
This commit is contained in:
@ -133,15 +133,9 @@ def get_camera_detect(
|
||||
def get_segment(
|
||||
camera_port: list[int], frame_index: list[int], keypoint_data: dict[int, ak.Array]
|
||||
) -> dict[int, ak.Array]:
|
||||
# for port in camera_port:
|
||||
# keypoint_data[port] = [
|
||||
# element_frame
|
||||
# for element_frame in KEYPOINT_DATASET[port]
|
||||
# if element_frame["frame_index"] in frame_index
|
||||
# ]
|
||||
for port in camera_port:
|
||||
segement_data = []
|
||||
camera_data = KEYPOINT_DATASET[port]
|
||||
camera_data = keypoint_data[port]
|
||||
for index, element_frame in enumerate(ak.to_numpy(camera_data["frame_index"])):
|
||||
if element_frame in frame_index:
|
||||
segement_data.append(camera_data[index])
|
||||
@ -431,43 +425,32 @@ def triangulate_one_point_from_multiple_views_linear_time_weighted(
|
||||
else:
|
||||
confi = jnp.sqrt(jnp.clip(confidences, 0, 1))
|
||||
|
||||
A = jnp.zeros((N * 2, 4), dtype=np.float32)
|
||||
time_weights = jnp.exp(-lambda_t * delta_t)
|
||||
weights = time_weights * confi
|
||||
sum_weights = jnp.sum(weights)
|
||||
weights = jnp.where(sum_weights > 0, weights / sum_weights, weights)
|
||||
|
||||
# First build the coefficient matrix without weights
|
||||
A = jnp.zeros((N * 2, 4), dtype=np.float32)
|
||||
for i in range(N):
|
||||
x, y = points[i]
|
||||
A = A.at[2 * i].set(proj_matrices[i, 2] * x - proj_matrices[i, 0])
|
||||
A = A.at[2 * i + 1].set(proj_matrices[i, 2] * y - proj_matrices[i, 1])
|
||||
row1 = proj_matrices[i, 2] * x - proj_matrices[i, 0]
|
||||
row2 = proj_matrices[i, 2] * y - proj_matrices[i, 1]
|
||||
A = A.at[2 * i].set(row1 * weights[i])
|
||||
A = A.at[2 * i + 1].set(row2 * weights[i])
|
||||
|
||||
# Then apply the time-based and confidence weights
|
||||
for i in range(N):
|
||||
# Calculate time-decay weight: e^(-λ_t * Δt)
|
||||
time_weight = jnp.exp(-lambda_t * delta_t[i])
|
||||
|
||||
# Calculate normalization factor: ||c^i^T||_2
|
||||
row_norm_1 = jnp.linalg.norm(A[2 * i])
|
||||
row_norm_2 = jnp.linalg.norm(A[2 * i + 1])
|
||||
|
||||
# Apply combined weight: time_weight / row_norm * confidence
|
||||
w1 = (time_weight / row_norm_1) * confi[i]
|
||||
w2 = (time_weight / row_norm_2) * confi[i]
|
||||
|
||||
A = A.at[2 * i].mul(w1)
|
||||
A = A.at[2 * i + 1].mul(w2)
|
||||
|
||||
# Solve using SVD
|
||||
_, _, vh = jnp.linalg.svd(A, full_matrices=False)
|
||||
point_3d_homo = vh[-1] # shape (4,)
|
||||
|
||||
# Ensure homogeneous coordinate is positive
|
||||
point_3d_homo = jnp.where(
|
||||
point_3d_homo[3] < 0,
|
||||
-point_3d_homo,
|
||||
point_3d_homo,
|
||||
point_3d_homo = vh[-1]
|
||||
point_3d_homo = jnp.where(point_3d_homo[3] < 0, -point_3d_homo, point_3d_homo)
|
||||
is_zero_weight = jnp.sum(weights) == 0
|
||||
point_3d = jnp.where(
|
||||
is_zero_weight,
|
||||
jnp.full((3,), jnp.nan, dtype=jnp.float32),
|
||||
jnp.where(
|
||||
jnp.abs(point_3d_homo[3]) > 1e-8,
|
||||
point_3d_homo[:3] / point_3d_homo[3],
|
||||
jnp.full((3,), jnp.nan, dtype=jnp.float32),
|
||||
),
|
||||
)
|
||||
|
||||
# Convert from homogeneous to Euclidean coordinates
|
||||
point_3d = point_3d_homo[:3] / point_3d_homo[3]
|
||||
return point_3d
|
||||
|
||||
|
||||
@ -489,36 +472,37 @@ def triangulate_one_point_from_multiple_views_linear(
|
||||
"""
|
||||
assert len(proj_matrices) == len(points)
|
||||
N = len(proj_matrices)
|
||||
# 置信度加权DLT
|
||||
# 置信度加权DLT
|
||||
if confidences is None:
|
||||
weights = jnp.ones(N, dtype=jnp.float32)
|
||||
else:
|
||||
# 置信度低于阈值的点权重为0,其余为sqrt(conf)
|
||||
valid_mask = confidences >= conf_threshold
|
||||
weights = jnp.where(valid_mask, jnp.sqrt(jnp.clip(confidences, 0, 1)), 0.0)
|
||||
# 归一化权重,避免某一帧权重过大
|
||||
weights = jnp.where(valid_mask, confidences, 0.0)
|
||||
sum_weights = jnp.sum(weights)
|
||||
weights = jnp.where(sum_weights > 0, weights / sum_weights, weights)
|
||||
|
||||
A = jnp.zeros((N * 2, 4), dtype=jnp.float32)
|
||||
for i in range(N):
|
||||
x, y = points[i]
|
||||
A = A.at[2 * i].set(proj_matrices[i, 2] * x - proj_matrices[i, 0])
|
||||
A = A.at[2 * i + 1].set(proj_matrices[i, 2] * y - proj_matrices[i, 1])
|
||||
A = A.at[2 * i].mul(weights[i])
|
||||
A = A.at[2 * i + 1].mul(weights[i])
|
||||
row1 = proj_matrices[i, 2] * x - proj_matrices[i, 0]
|
||||
row2 = proj_matrices[i, 2] * y - proj_matrices[i, 1]
|
||||
A = A.at[2 * i].set(row1 * weights[i])
|
||||
A = A.at[2 * i + 1].set(row2 * weights[i])
|
||||
|
||||
# https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.svd.html
|
||||
_, _, vh = jnp.linalg.svd(A, full_matrices=False)
|
||||
point_3d_homo = vh[-1] # shape (4,)
|
||||
|
||||
# replace the Python `if` with a jnp.where
|
||||
point_3d_homo = jnp.where(
|
||||
point_3d_homo[3] < 0, # predicate (scalar bool tracer)
|
||||
-point_3d_homo, # if True
|
||||
point_3d_homo, # if False
|
||||
point_3d_homo = vh[-1]
|
||||
point_3d_homo = jnp.where(point_3d_homo[3] < 0, -point_3d_homo, point_3d_homo)
|
||||
is_zero_weight = jnp.sum(weights) == 0
|
||||
point_3d = jnp.where(
|
||||
is_zero_weight,
|
||||
jnp.full((3,), jnp.nan, dtype=jnp.float32),
|
||||
jnp.where(
|
||||
jnp.abs(point_3d_homo[3]) > 1e-8,
|
||||
point_3d_homo[:3] / point_3d_homo[3],
|
||||
jnp.full((3,), jnp.nan, dtype=jnp.float32),
|
||||
),
|
||||
)
|
||||
|
||||
point_3d = point_3d_homo[:3] / point_3d_homo[3]
|
||||
return point_3d
|
||||
|
||||
|
||||
@ -542,16 +526,16 @@ def triangulate_points_from_multiple_views_linear(
|
||||
N, P, _ = points.shape
|
||||
assert proj_matrices.shape[0] == N
|
||||
|
||||
conf = jnp.array(confidences)
|
||||
if confidences is None:
|
||||
conf = jnp.ones((N, P), dtype=jnp.float32)
|
||||
else:
|
||||
conf = jnp.array(confidences)
|
||||
|
||||
# vectorize your one‐point routine over P
|
||||
vmap_triangulate = jax.vmap(
|
||||
triangulate_one_point_from_multiple_views_linear,
|
||||
in_axes=(None, 1, 1), # proj_matrices static, map over points[:,p,:], conf[:,p]
|
||||
in_axes=(None, 1, 1),
|
||||
out_axes=0,
|
||||
)
|
||||
|
||||
# returns (P, 3)
|
||||
return vmap_triangulate(proj_matrices, points, conf)
|
||||
|
||||
|
||||
@ -924,17 +908,20 @@ def smooth_3d_keypoints(
|
||||
|
||||
|
||||
# 通过平均置信度筛选2d检测数据
|
||||
def filter_keypoints_by_scores(detections: Iterable[Detection], threshold: float = 0.5):
|
||||
def filter_keypoints_by_scores(
|
||||
detections: Iterable[Detection], threshold: float = 0.5
|
||||
) -> list[Detection]:
|
||||
"""
|
||||
Filter detections based on the average confidence score of their keypoints.
|
||||
Only keep detections with an average score above the threshold.
|
||||
"""
|
||||
|
||||
def filter_detection(detection: Detection) -> bool:
|
||||
avg_score = np.mean(detection.confidences)
|
||||
return float(avg_score) >= threshold
|
||||
median_score = np.mean(detection.confidences[:17])
|
||||
# print(f"Mean score: {median_score}")
|
||||
return float(median_score) >= threshold
|
||||
|
||||
return filter(filter_detection, detections)
|
||||
return [d for d in detections if filter_detection(d)]
|
||||
|
||||
|
||||
def filter_camera_port(detections: list[Detection]):
|
||||
@ -956,7 +943,7 @@ camera_port = [5602, 5603, 5604, 5605]
|
||||
KEYPOINT_DATASET = get_camera_detect(DATASET_PATH, camera_port, AK_CAMERA_DATASET)
|
||||
|
||||
# 获取一段完整的跳跃片段
|
||||
FRAME_INDEX = [i for i in range(552, 1488)] # 552, 1488
|
||||
FRAME_INDEX = [i for i in range(700, 1600)] # 552, 1488
|
||||
KEYPOINT_DATASET = get_segment(camera_port, FRAME_INDEX, KEYPOINT_DATASET)
|
||||
|
||||
|
||||
@ -965,7 +952,7 @@ sync_gen: Generator[list[Detection], Any, None] = get_batch_detect(
|
||||
KEYPOINT_DATASET,
|
||||
AK_CAMERA_DATASET,
|
||||
camera_port,
|
||||
batch_fps=20,
|
||||
batch_fps=24,
|
||||
)
|
||||
|
||||
# 建立追踪目标
|
||||
@ -985,134 +972,104 @@ trackings: list[Tracking] = []
|
||||
# 3d数据,键为追踪目标id,值为该目标的所有3d数据
|
||||
all_3d_kps: dict[str, list] = {}
|
||||
|
||||
# 遍历2d数据,测试追踪状态
|
||||
tracking_initialized = False
|
||||
lost_frame_count = 0
|
||||
lost_frame_threshold = 12 # 0.5秒,假设20fps
|
||||
|
||||
while True:
|
||||
# 获得当前追踪目标
|
||||
# 重新梳理跟踪逻辑,保证唯一目标、唯一初始化、鲁棒丢失判定
|
||||
try:
|
||||
detections = next(sync_gen)
|
||||
detections = filter_keypoints_by_scores(detections, threshold=0.5)
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
# 1. 检查当前是否有已初始化的跟踪目标
|
||||
trackings: list[Tracking] = sorted(
|
||||
global_tracking_state.trackings.values(), key=lambda x: x.id
|
||||
)
|
||||
|
||||
try:
|
||||
detections = next(sync_gen)
|
||||
# 通过平均置信度筛选2d检测数据
|
||||
# detections = list(filter_keypoints_by_scores(detections, threshold=0.5))
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
if len(detections) == 0:
|
||||
print("no detections in this frame, continue")
|
||||
continue
|
||||
|
||||
# 获得最新一帧的数据2d数据
|
||||
# 判断追踪状态是否建立成功,若不成功则跳过这一帧数据,直到追踪建立
|
||||
if not trackings:
|
||||
|
||||
"""离机时使用,用于初始化第一帧"""
|
||||
"""
|
||||
# 使用盒子筛选后的2d检测数据
|
||||
filter_detections = get_filter_detections(detections)
|
||||
# 当3个机位均有目标时才建立追踪状态
|
||||
# if len(filter_detections) == 0:
|
||||
# continue
|
||||
if len(filter_detections) < len(camera_port):
|
||||
if not tracking_initialized:
|
||||
# 只初始化一次
|
||||
camera_port_this = filter_camera_port(detections)
|
||||
if len(camera_port_this) < len(camera_port) - 1:
|
||||
print(
|
||||
"init traincking error, filter filter_detections len:",
|
||||
len(filter_detections),
|
||||
"init tracking error, filter_detections len:",
|
||||
len(camera_port_this),
|
||||
)
|
||||
continue
|
||||
"""
|
||||
# 通过平均置信度筛选2d检测数据
|
||||
# detections = list(filter_keypoints_by_scores(detections, threshold=0.7))
|
||||
|
||||
# 当4个机位都识别到目标时才建立追踪状态
|
||||
camera_port = filter_camera_port(detections)
|
||||
if len(detections) < len(camera_port):
|
||||
print(
|
||||
"init traincking error, filter_detections len:",
|
||||
len(detections),
|
||||
)
|
||||
else:
|
||||
# 添加第一帧数据构建追踪目标
|
||||
global_tracking_state.add_tracking(detections) # 离机时:filter_detections
|
||||
# 获得当前追踪目标
|
||||
trackings: list[Tracking] = sorted(
|
||||
global_tracking_state.trackings.values(), key=lambda x: x.id
|
||||
)
|
||||
# 保留第一帧的3d姿态数据
|
||||
for element_tracking in trackings:
|
||||
if str(element_tracking.id) not in all_3d_kps.keys():
|
||||
all_3d_kps[str(element_tracking.id)] = [
|
||||
element_tracking.state.keypoints.tolist()
|
||||
]
|
||||
print("initer tracking:", trackings)
|
||||
else:
|
||||
# 计算相似度矩阵匹配结果
|
||||
affinities: dict[str, AffinityResult] = calculate_affinity_matrix(
|
||||
trackings,
|
||||
detections,
|
||||
w_2d=W_2D,
|
||||
alpha_2d=ALPHA_2D,
|
||||
w_3d=W_3D,
|
||||
alpha_3d=ALPHA_3D,
|
||||
lambda_a=LAMBDA_A,
|
||||
)
|
||||
|
||||
# 遍历追踪目标,获得该目标的匹配2d数据
|
||||
for element_tracking in trackings:
|
||||
tracking_detection = []
|
||||
# 匹配检测目标的索引值
|
||||
detection_index = None
|
||||
|
||||
temp_matrix = []
|
||||
|
||||
# 遍历相机的追踪相似度匹配结果
|
||||
for camera_name in affinities.keys():
|
||||
# 获得与每个跟踪目标匹配的相似度矩阵
|
||||
camera_matrix = jnp.array(affinities[camera_name].matrix).flatten()
|
||||
detection_index = jnp.argmax(camera_matrix).item()
|
||||
|
||||
if isnan(camera_matrix[detection_index].item()):
|
||||
breakpoint()
|
||||
temp_matrix.append(
|
||||
f"{camera_name} : {camera_matrix[detection_index].item()}"
|
||||
)
|
||||
|
||||
# 判断相似度矩阵极大值是否大于阈值
|
||||
# 目前只有一个跟踪目标,还未实现多跟踪目标的匹配-------------------------
|
||||
if camera_matrix[detection_index].item() > affinities_threshold:
|
||||
# 保留对应的2d检测数据
|
||||
tracking_detection.append(
|
||||
affinities[camera_name].detections[detection_index]
|
||||
)
|
||||
print("affinities matrix:", temp_matrix)
|
||||
# 当2个及以上数量的机位同时检测到追踪目标时,更新追踪状态
|
||||
if len(tracking_detection) >= 2:
|
||||
update_tracking(element_tracking, tracking_detection)
|
||||
# 保留对应的3d姿态数据
|
||||
all_3d_kps[str(element_tracking.id)].append(
|
||||
global_tracking_state.add_tracking(detections)
|
||||
tracking_initialized = True
|
||||
lost_frame_count = 0
|
||||
# 保留第一帧的3d姿态数据
|
||||
for element_tracking in global_tracking_state.trackings.values():
|
||||
if str(element_tracking.id) not in all_3d_kps.keys():
|
||||
all_3d_kps[str(element_tracking.id)] = [
|
||||
element_tracking.state.keypoints.tolist()
|
||||
)
|
||||
]
|
||||
print("init tracking:", global_tracking_state.trackings.values())
|
||||
continue
|
||||
|
||||
# tracking_initialized = True
|
||||
if len(detections) == 0:
|
||||
print("no detections in this frame, continue")
|
||||
lost_frame_count += 1
|
||||
# 进一步完善退出条件:
|
||||
# 1. 连续丢失阈值帧后才退出
|
||||
# 2. 若丢失时,最后一次检测到的时间与当前帧时间间隔超过1秒,才彻底退出
|
||||
last_tracking = None
|
||||
if global_tracking_state.trackings:
|
||||
last_tracking = list(global_tracking_state.trackings.values())[0]
|
||||
if lost_frame_count >= lost_frame_threshold:
|
||||
should_remove = True
|
||||
# 可选:可加时间间隔判定
|
||||
if should_remove:
|
||||
global_tracking_state._trackings.clear()
|
||||
tracking_initialized = False
|
||||
print(
|
||||
"update tracking:",
|
||||
global_tracking_state.trackings.values(),
|
||||
f"tracking lost after {lost_frame_count} frames, reset tracking state"
|
||||
)
|
||||
else:
|
||||
# if len(detections) == 0:
|
||||
# continue
|
||||
# 追踪目标丢失的时间间隔
|
||||
time_gap = (
|
||||
detections[0].timestamp
|
||||
- element_tracking.state.last_active_timestamp
|
||||
)
|
||||
# 当时间间隔超过1s,删除保留的追踪状态
|
||||
if time_gap.seconds > 0.5:
|
||||
global_tracking_state._trackings.pop(element_tracking.id)
|
||||
print(
|
||||
"remove trackings:",
|
||||
global_tracking_state.trackings.values(),
|
||||
"time:",
|
||||
detections[0].timestamp,
|
||||
)
|
||||
lost_frame_count = 0
|
||||
continue
|
||||
|
||||
# 有检测,正常跟踪
|
||||
lost_frame_count = 0
|
||||
affinities: dict[str, AffinityResult] = calculate_affinity_matrix(
|
||||
trackings,
|
||||
detections,
|
||||
w_2d=W_2D,
|
||||
alpha_2d=ALPHA_2D,
|
||||
w_3d=W_3D,
|
||||
alpha_3d=ALPHA_3D,
|
||||
lambda_a=LAMBDA_A,
|
||||
)
|
||||
for element_tracking in trackings:
|
||||
tracking_detection = []
|
||||
temp_matrix = []
|
||||
for camera_name in affinities.keys():
|
||||
camera_matrix = jnp.array(affinities[camera_name].matrix).flatten()
|
||||
detection_index = jnp.argmax(camera_matrix).item()
|
||||
if isnan(camera_matrix[detection_index].item()):
|
||||
breakpoint()
|
||||
temp_matrix.append(
|
||||
f"{camera_name} : {camera_matrix[detection_index].item()}"
|
||||
)
|
||||
# 选取相似度大于阈值的检测目标更新跟踪状态
|
||||
# if camera_matrix[detection_index].item() > affinities_threshold:
|
||||
tracking_detection.append(
|
||||
affinities[camera_name].detections[detection_index]
|
||||
)
|
||||
print("affinities matrix:", temp_matrix)
|
||||
if len(tracking_detection) > 2:
|
||||
update_tracking(element_tracking, tracking_detection)
|
||||
all_3d_kps[str(element_tracking.id)].append(
|
||||
element_tracking.state.keypoints.tolist()
|
||||
)
|
||||
print(
|
||||
"update tracking:",
|
||||
global_tracking_state.trackings.values(),
|
||||
)
|
||||
# 不再在else分支里删除tracking,只用lost_frame_count判定
|
||||
|
||||
# 对每一个3d目标进行滑动窗口平滑处理
|
||||
smoothed_points = smooth_3d_keypoints(all_3d_kps, window_size=5)
|
||||
|
||||
Reference in New Issue
Block a user