update evaluation.py

This commit is contained in:
darkliang
2022-03-31 13:06:29 +08:00
parent b005bdd49b
commit 0d14d25596
+11 -15
View File
@@ -1,7 +1,9 @@
import os
from time import strftime, localtime
import torch import torch
import numpy as np import numpy as np
import torch.nn.functional as F import torch.nn.functional as F
from utils import get_msg_mgr from utils import get_msg_mgr, mkdir
def cuda_dist(x, y, metric='euc'): def cuda_dist(x, y, metric='euc'):
@@ -150,16 +152,14 @@ def identification_real_scene(data, dataset, metric='euc'):
msg_mgr.log_info('%.3f' % (np.mean(acc[19]))) msg_mgr.log_info('%.3f' % (np.mean(acc[19])))
return {"scalar/test_accuracy/Rank-1": np.mean(acc[0]), "scalar/test_accuracy/Rank-5": np.mean(acc[4])} return {"scalar/test_accuracy/Rank-1": np.mean(acc[0]), "scalar/test_accuracy/Rank-5": np.mean(acc[4])}
def identification_GREW_submission(data, dataset, metric='euc'): def identification_GREW_submission(data, dataset, metric='euc'):
msg_mgr = get_msg_mgr() get_msg_mgr().log_info("Evaluating GREW")
feature, label, seq_type, view = data['embeddings'], data['labels'], data['types'], data['views'] feature, label, seq_type, view = data['embeddings'], data['labels'], data['types'], data['views']
label = np.array(label) label = np.array(label)
view = np.array(view) view = np.array(view)
gallery_seq_type = {'GREW': ['01','02']} gallery_seq_type = {'GREW': ['01', '02']}
probe_seq_type = {'GREW': ['03']} probe_seq_type = {'GREW': ['03']}
num_rank = 20
acc = np.zeros([num_rank]) - 1.
gseq_mask = np.isin(seq_type, gallery_seq_type[dataset]) gseq_mask = np.isin(seq_type, gallery_seq_type[dataset])
gallery_x = feature[gseq_mask, :] gallery_x = feature[gseq_mask, :]
gallery_y = label[gseq_mask] gallery_y = label[gseq_mask]
@@ -170,11 +170,9 @@ def identification_GREW_submission(data, dataset, metric='euc'):
dist = cuda_dist(probe_x, gallery_x, metric) dist = cuda_dist(probe_x, gallery_x, metric)
idx = dist.cpu().sort(1)[1].numpy() idx = dist.cpu().sort(1)[1].numpy()
import os
from time import strftime, localtime
save_path = os.path.join( save_path = os.path.join(
"GREW_result/"+strftime('%Y-%m%d-%H%M%S', localtime())+".csv") "GREW_result/"+strftime('%Y-%m%d-%H%M%S', localtime())+".csv")
os.makedirs("GREW_result", exist_ok=True) mkdir("GREW_result")
with open(save_path, "w") as f: with open(save_path, "w") as f:
f.write("videoId,rank1,rank2,rank3,rank4,rank5,rank6,rank7,rank8,rank9,rank10,rank11,rank12,rank13,rank14,rank15,rank16,rank17,rank18,rank19,rank20\n") f.write("videoId,rank1,rank2,rank3,rank4,rank5,rank6,rank7,rank8,rank9,rank10,rank11,rank12,rank13,rank14,rank15,rank16,rank17,rank18,rank19,rank20\n")
for i in range(len(idx)): for i in range(len(idx)):
@@ -182,8 +180,8 @@ def identification_GREW_submission(data, dataset, metric='euc'):
output_row = '{}'+',{}'*20+'\n' output_row = '{}'+',{}'*20+'\n'
f.write(output_row.format(probe_y[i], *r_format)) f.write(output_row.format(probe_y[i], *r_format))
print("GREW result saved to {}/{}".format(os.getcwd(), save_path)) print("GREW result saved to {}/{}".format(os.getcwd(), save_path))
return
return
def evaluate_HID(data, dataset, metric='euc'): def evaluate_HID(data, dataset, metric='euc'):
msg_mgr = get_msg_mgr() msg_mgr = get_msg_mgr()
@@ -200,14 +198,13 @@ def evaluate_HID(data, dataset, metric='euc'):
feat = np.concatenate([probe_x, gallery_x]) feat = np.concatenate([probe_x, gallery_x])
dist = cuda_dist(feat, feat, metric).cpu().numpy() dist = cuda_dist(feat, feat, metric).cpu().numpy()
msg_mgr.log_info("Starting Re-ranking")
re_rank = re_ranking(dist, probe_x.shape[0], k1=6, k2=6, lambda_value=0.3) re_rank = re_ranking(dist, probe_x.shape[0], k1=6, k2=6, lambda_value=0.3)
idx = np.argsort(re_rank, axis=1) idx = np.argsort(re_rank, axis=1)
import os
from time import strftime, localtime
save_path = os.path.join( save_path = os.path.join(
"HID_result/"+strftime('%Y-%m%d-%H%M%S', localtime())+".csv") "HID_result/"+strftime('%Y-%m%d-%H%M%S', localtime())+".csv")
os.makedirs("HID_result", exist_ok=True) mkdir("HID_result")
with open(save_path, "w") as f: with open(save_path, "w") as f:
f.write("videoID,label\n") f.write("videoID,label\n")
for i in range(len(idx)): for i in range(len(idx)):
@@ -223,7 +220,6 @@ def re_ranking(original_dist, query_num, k1, k2, lambda_value):
V = np.zeros_like(original_dist).astype(np.float16) V = np.zeros_like(original_dist).astype(np.float16)
initial_rank = np.argsort(original_dist).astype(np.int32) initial_rank = np.argsort(original_dist).astype(np.int32)
print('starting re_ranking')
for i in range(all_num): for i in range(all_num):
# k-reciprocal neighbors # k-reciprocal neighbors
forward_k_neigh_index = initial_rank[i, :k1 + 1] forward_k_neigh_index = initial_rank[i, :k1 + 1]