support HID
This commit is contained in:
@@ -140,3 +140,30 @@ def identification_real_scene(data, dataset, metric='euc'):
|
||||
msg_mgr.log_info('==Rank-5==')
|
||||
msg_mgr.log_info('%.3f' % (np.mean(acc[4])))
|
||||
return {"scalar/test_accuracy/Rank-1": np.mean(acc[0]), "scalar/test_accuracy/Rank-5": np.mean(acc[4])}
|
||||
|
||||
|
||||
def evaluate_HID(data, dataset, metric='euc'):
|
||||
msg_mgr = get_msg_mgr()
|
||||
msg_mgr.log_info("Evaluating HID")
|
||||
feature, label, seq_type = data['embeddings'], data['labels'], data['types']
|
||||
label = np.array(label)
|
||||
seq_type = np.array(seq_type)
|
||||
probe_mask = (label == "probe")
|
||||
gallery_mask = (label != "probe")
|
||||
gallery_x = feature[gallery_mask, :]
|
||||
gallery_y = label[gallery_mask]
|
||||
probe_x = feature[probe_mask, :]
|
||||
probe_y = seq_type[probe_mask]
|
||||
dist = cuda_dist(probe_x, gallery_x, metric)
|
||||
idx = dist.cpu().sort(1)[1].numpy()
|
||||
import os
|
||||
from time import strftime, localtime
|
||||
save_path = os.path.join(
|
||||
"HID_result/"+strftime('%Y-%m%d-%H%M%S', localtime())+".csv")
|
||||
os.makedirs("HID_result", exist_ok=True)
|
||||
with open(save_path, "w") as f:
|
||||
f.write("videoID,label\n")
|
||||
for i in range(len(idx)):
|
||||
f.write("{},{}\n".format(probe_y[i], gallery_y[idx[i, 0]]))
|
||||
print("HID result saved to {}/{}".format(os.getcwd(), save_path))
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user