support CASIA-E dataset
This commit is contained in:
@@ -70,12 +70,21 @@ def cross_view_gallery_evaluation(feature, label, seq_type, view, dataset, metri
|
||||
|
||||
def single_view_gallery_evaluation(feature, label, seq_type, view, dataset, metric):
|
||||
probe_seq_dict = {'CASIA-B': {'NM': ['nm-05', 'nm-06'], 'BG': ['bg-01', 'bg-02'], 'CL': ['cl-01', 'cl-02']},
|
||||
'OUMVLP': {'NM': ['00']}}
|
||||
'OUMVLP': {'NM': ['00']},
|
||||
'CASIA-E': {'NM': ['H-scene2-nm-1', 'H-scene2-nm-2', 'L-scene2-nm-1', 'L-scene2-nm-2', 'H-scene3-nm-1', 'H-scene3-nm-2', 'L-scene3-nm-1', 'L-scene3-nm-2', 'H-scene3_s-nm-1', 'H-scene3_s-nm-2', 'L-scene3_s-nm-1', 'L-scene3_s-nm-2',],
|
||||
'BG': ['H-scene2-bg-1', 'H-scene2-bg-2', 'L-scene2-bg-1', 'L-scene2-bg-2', 'H-scene3-bg-1', 'H-scene3-bg-2', 'L-scene3-bg-1', 'L-scene3-bg-2', 'H-scene3_s-bg-1', 'H-scene3_s-bg-2', 'L-scene3_s-bg-1', 'L-scene3_s-bg-2'],
|
||||
'CL': ['H-scene2-cl-1', 'H-scene2-cl-2', 'L-scene2-cl-1', 'L-scene2-cl-2', 'H-scene3-cl-1', 'H-scene3-cl-2', 'L-scene3-cl-1', 'L-scene3-cl-2', 'H-scene3_s-cl-1', 'H-scene3_s-cl-2', 'L-scene3_s-cl-1', 'L-scene3_s-cl-2']
|
||||
}
|
||||
|
||||
}
|
||||
gallery_seq_dict = {'CASIA-B': ['nm-01', 'nm-02', 'nm-03', 'nm-04'],
|
||||
'OUMVLP': ['01']}
|
||||
'OUMVLP': ['01'],
|
||||
'CASIA-E': ['H-scene1-nm-1', 'H-scene1-nm-2', 'L-scene1-nm-1', 'L-scene1-nm-2']}
|
||||
msg_mgr = get_msg_mgr()
|
||||
acc = {}
|
||||
view_list = sorted(np.unique(view))
|
||||
if dataset == 'CASIA-E':
|
||||
view_list.remove("270")
|
||||
view_num = len(view_list)
|
||||
num_rank = 1
|
||||
for (type_, probe_seq) in probe_seq_dict[dataset].items():
|
||||
@@ -92,8 +101,8 @@ def single_view_gallery_evaluation(feature, label, seq_type, view, dataset, metr
|
||||
gallery_y = label[gseq_mask]
|
||||
gallery_x = feature[gseq_mask, :]
|
||||
dist = cuda_dist(probe_x, gallery_x, metric)
|
||||
idx = dist.cpu().sort(1)[1].numpy()
|
||||
acc[type_][v1, v2] = np.round(np.sum(np.cumsum(np.reshape(probe_y, [-1, 1]) == gallery_y[idx[:, 0:num_rank]], 1) > 0,
|
||||
idx = dist.topk(num_rank, largest=False)[1].cpu().numpy()
|
||||
acc[type_][v1, v2] = np.round(np.sum(np.cumsum(np.reshape(probe_y, [-1, 1]) == gallery_y[idx], 1) > 0,
|
||||
0) * 100 / dist.shape[0], 2)
|
||||
|
||||
result_dict = {}
|
||||
@@ -113,7 +122,7 @@ def evaluate_indoor_dataset(data, dataset, metric='euc', cross_view_gallery=Fals
|
||||
label = np.array(label)
|
||||
view = np.array(view)
|
||||
|
||||
if dataset not in ('CASIA-B', 'OUMVLP'):
|
||||
if dataset not in ('CASIA-B', 'OUMVLP', 'CASIA-E'):
|
||||
raise KeyError("DataSet %s hasn't been supported !" % dataset)
|
||||
if cross_view_gallery:
|
||||
return cross_view_gallery_evaluation(
|
||||
@@ -145,7 +154,7 @@ def evaluate_real_scene(data, dataset, metric='euc'):
|
||||
probe_y = label[pseq_mask]
|
||||
|
||||
dist = cuda_dist(probe_x, gallery_x, metric)
|
||||
idx = dist.cpu().sort(1)[1].numpy()
|
||||
idx = dist.topk(num_rank, largest=False)[1].cpu().numpy()
|
||||
acc = np.round(np.sum(np.cumsum(np.reshape(probe_y, [-1, 1]) == gallery_y[idx[:, 0:num_rank]], 1) > 0,
|
||||
0) * 100 / dist.shape[0], 2)
|
||||
msg_mgr.log_info('==Rank-1==')
|
||||
@@ -173,8 +182,9 @@ def GREW_submission(data, dataset, metric='euc'):
|
||||
probe_x = feature[pseq_mask, :]
|
||||
probe_y = view[pseq_mask]
|
||||
|
||||
num_rank = 20
|
||||
dist = cuda_dist(probe_x, gallery_x, metric)
|
||||
idx = dist.cpu().sort(1)[1].numpy()
|
||||
idx = dist.topk(num_rank, largest=False)[1].cpu().numpy()
|
||||
|
||||
save_path = os.path.join(
|
||||
"GREW_result/"+strftime('%Y-%m%d-%H%M%S', localtime())+".csv")
|
||||
@@ -182,8 +192,8 @@ def GREW_submission(data, dataset, metric='euc'):
|
||||
with open(save_path, "w") as f:
|
||||
f.write("videoId,rank1,rank2,rank3,rank4,rank5,rank6,rank7,rank8,rank9,rank10,rank11,rank12,rank13,rank14,rank15,rank16,rank17,rank18,rank19,rank20\n")
|
||||
for i in range(len(idx)):
|
||||
r_format = [int(idx) for idx in gallery_y[idx[i, 0:20]]]
|
||||
output_row = '{}'+',{}'*20+'\n'
|
||||
r_format = [int(idx) for idx in gallery_y[idx[i, 0:num_rank]]]
|
||||
output_row = '{}'+',{}'*num_rank+'\n'
|
||||
f.write(output_row.format(probe_y[i], *r_format))
|
||||
print("GREW result saved to {}/{}".format(os.getcwd(), save_path))
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user