diff --git a/lib/utils/evaluation.py b/lib/utils/evaluation.py index fba16af..1fdc7e3 100644 --- a/lib/utils/evaluation.py +++ b/lib/utils/evaluation.py @@ -57,7 +57,7 @@ def identification(data, dataset, metric='euc'): raise KeyError("DataSet %s hasn't been supported !" % dataset) num_rank = 5 acc = np.zeros([len(probe_seq_dict[dataset]), - view_num, view_num, num_rank]) - 1. + view_num, view_num, num_rank]) - 1. for (p, probe_seq) in enumerate(probe_seq_dict[dataset]): for gallery_seq in gallery_seq_dict[dataset]: for (v1, probe_view) in enumerate(view_list): @@ -93,9 +93,9 @@ def identification(data, dataset, metric='euc'): de_diag(acc[0, :, :, i]), de_diag(acc[1, :, :, i]), de_diag(acc[2, :, :, i]))) - result_dict["scalar/test_accuracy/NM"] = acc[0, :, :, i] - result_dict["scalar/test_accuracy/BG"] = acc[1, :, :, i] - result_dict["scalar/test_accuracy/CL"] = acc[2, :, :, i] + result_dict["scalar/test_accuracy/NM"] = de_diag(acc[0, :, :, i]) + result_dict["scalar/test_accuracy/BG"] = de_diag(acc[1, :, :, i]) + result_dict["scalar/test_accuracy/CL"] = de_diag(acc[2, :, :, i]) np.set_printoptions(precision=2, floatmode='fixed') for i in range(1): msg_mgr.log_info( @@ -107,9 +107,8 @@ def identification(data, dataset, metric='euc'): msg_mgr.log_info('===Rank-1 (Include identical-view cases)===') msg_mgr.log_info('NM: %.3f ' % (np.mean(acc[0, :, :, 0]))) msg_mgr.log_info('===Rank-1 (Exclude identical-view cases)===') - msg_mgr.log_info('NM: %.3f ' % (np.mean(de_diag(acc[0, :, :, 0])))) - result_dict["scalar/test_accuracy/NM"] = np.mean( - de_diag(acc[0, :, :, 0])) + msg_mgr.log_info('NM: %.3f ' % (de_diag(acc[0, :, :, 0]))) + result_dict["scalar/test_accuracy/NM"] = de_diag(acc[0, :, :, 0]) return result_dict