support CASIA-E dataset

This commit is contained in:
darkliang
2023-04-09 17:15:42 +08:00
parent 786aded8af
commit e69fb6f439
8 changed files with 1449 additions and 9 deletions
+19 -9
View File
@@ -70,12 +70,21 @@ def cross_view_gallery_evaluation(feature, label, seq_type, view, dataset, metri
def single_view_gallery_evaluation(feature, label, seq_type, view, dataset, metric):
probe_seq_dict = {'CASIA-B': {'NM': ['nm-05', 'nm-06'], 'BG': ['bg-01', 'bg-02'], 'CL': ['cl-01', 'cl-02']},
'OUMVLP': {'NM': ['00']}}
'OUMVLP': {'NM': ['00']},
'CASIA-E': {'NM': ['H-scene2-nm-1', 'H-scene2-nm-2', 'L-scene2-nm-1', 'L-scene2-nm-2', 'H-scene3-nm-1', 'H-scene3-nm-2', 'L-scene3-nm-1', 'L-scene3-nm-2', 'H-scene3_s-nm-1', 'H-scene3_s-nm-2', 'L-scene3_s-nm-1', 'L-scene3_s-nm-2',],
'BG': ['H-scene2-bg-1', 'H-scene2-bg-2', 'L-scene2-bg-1', 'L-scene2-bg-2', 'H-scene3-bg-1', 'H-scene3-bg-2', 'L-scene3-bg-1', 'L-scene3-bg-2', 'H-scene3_s-bg-1', 'H-scene3_s-bg-2', 'L-scene3_s-bg-1', 'L-scene3_s-bg-2'],
'CL': ['H-scene2-cl-1', 'H-scene2-cl-2', 'L-scene2-cl-1', 'L-scene2-cl-2', 'H-scene3-cl-1', 'H-scene3-cl-2', 'L-scene3-cl-1', 'L-scene3-cl-2', 'H-scene3_s-cl-1', 'H-scene3_s-cl-2', 'L-scene3_s-cl-1', 'L-scene3_s-cl-2']
}
}
gallery_seq_dict = {'CASIA-B': ['nm-01', 'nm-02', 'nm-03', 'nm-04'],
'OUMVLP': ['01']}
'OUMVLP': ['01'],
'CASIA-E': ['H-scene1-nm-1', 'H-scene1-nm-2', 'L-scene1-nm-1', 'L-scene1-nm-2']}
msg_mgr = get_msg_mgr()
acc = {}
view_list = sorted(np.unique(view))
if dataset == 'CASIA-E':
view_list.remove("270")
view_num = len(view_list)
num_rank = 1
for (type_, probe_seq) in probe_seq_dict[dataset].items():
@@ -92,8 +101,8 @@ def single_view_gallery_evaluation(feature, label, seq_type, view, dataset, metr
gallery_y = label[gseq_mask]
gallery_x = feature[gseq_mask, :]
dist = cuda_dist(probe_x, gallery_x, metric)
idx = dist.cpu().sort(1)[1].numpy()
acc[type_][v1, v2] = np.round(np.sum(np.cumsum(np.reshape(probe_y, [-1, 1]) == gallery_y[idx[:, 0:num_rank]], 1) > 0,
idx = dist.topk(num_rank, largest=False)[1].cpu().numpy()
acc[type_][v1, v2] = np.round(np.sum(np.cumsum(np.reshape(probe_y, [-1, 1]) == gallery_y[idx], 1) > 0,
0) * 100 / dist.shape[0], 2)
result_dict = {}
@@ -113,7 +122,7 @@ def evaluate_indoor_dataset(data, dataset, metric='euc', cross_view_gallery=Fals
label = np.array(label)
view = np.array(view)
if dataset not in ('CASIA-B', 'OUMVLP'):
if dataset not in ('CASIA-B', 'OUMVLP', 'CASIA-E'):
raise KeyError("DataSet %s hasn't been supported !" % dataset)
if cross_view_gallery:
return cross_view_gallery_evaluation(
@@ -145,7 +154,7 @@ def evaluate_real_scene(data, dataset, metric='euc'):
probe_y = label[pseq_mask]
dist = cuda_dist(probe_x, gallery_x, metric)
idx = dist.cpu().sort(1)[1].numpy()
idx = dist.topk(num_rank, largest=False)[1].cpu().numpy()
acc = np.round(np.sum(np.cumsum(np.reshape(probe_y, [-1, 1]) == gallery_y[idx[:, 0:num_rank]], 1) > 0,
0) * 100 / dist.shape[0], 2)
msg_mgr.log_info('==Rank-1==')
@@ -173,8 +182,9 @@ def GREW_submission(data, dataset, metric='euc'):
probe_x = feature[pseq_mask, :]
probe_y = view[pseq_mask]
num_rank = 20
dist = cuda_dist(probe_x, gallery_x, metric)
idx = dist.cpu().sort(1)[1].numpy()
idx = dist.topk(num_rank, largest=False)[1].cpu().numpy()
save_path = os.path.join(
"GREW_result/"+strftime('%Y-%m%d-%H%M%S', localtime())+".csv")
@@ -182,8 +192,8 @@ def GREW_submission(data, dataset, metric='euc'):
with open(save_path, "w") as f:
f.write("videoId,rank1,rank2,rank3,rank4,rank5,rank6,rank7,rank8,rank9,rank10,rank11,rank12,rank13,rank14,rank15,rank16,rank17,rank18,rank19,rank20\n")
for i in range(len(idx)):
r_format = [int(idx) for idx in gallery_y[idx[i, 0:20]]]
output_row = '{}'+',{}'*20+'\n'
r_format = [int(idx) for idx in gallery_y[idx[i, 0:num_rank]]]
output_row = '{}'+',{}'*num_rank+'\n'
f.write(output_row.format(probe_y[i], *r_format))
print("GREW result saved to {}/{}".format(os.getcwd(), save_path))
return