add view angle result for OUMVLP

This commit is contained in:
darkliang
2022-03-11 22:24:21 +08:00
parent 7d7b81f48e
commit 22940ff3e0
2 changed files with 6 additions and 2 deletions
+5 -1
View File
@@ -40,7 +40,7 @@ def de_diag(acc, each_angle=False):
def identification(data, dataset, metric='euc'): def identification(data, dataset, metric='euc'):
msg_mgr = get_msg_mgr() msg_mgr = get_msg_mgr()
feature, label, seq_type, view = data['embeddings'], data['labels'], data['types'], data['views'] feature, label, seq_type, view = data['embeddings'], data['labels'], data['types'], data['views']
label = np.array(label) label = np.array(label)
view_list = list(set(view)) view_list = list(set(view))
@@ -78,6 +78,7 @@ def identification(data, dataset, metric='euc'):
np.sum(np.cumsum(np.reshape(probe_y, [-1, 1]) == gallery_y[idx[:, 0:num_rank]], 1) > 0, np.sum(np.cumsum(np.reshape(probe_y, [-1, 1]) == gallery_y[idx[:, 0:num_rank]], 1) > 0,
0) * 100 / dist.shape[0], 2) 0) * 100 / dist.shape[0], 2)
result_dict = {} result_dict = {}
np.set_printoptions(precision=3, suppress=True)
if 'OUMVLP' not in dataset: if 'OUMVLP' not in dataset:
for i in range(1): for i in range(1):
msg_mgr.log_info( msg_mgr.log_info(
@@ -108,6 +109,9 @@ def identification(data, dataset, metric='euc'):
msg_mgr.log_info('NM: %.3f ' % (np.mean(acc[0, :, :, 0]))) msg_mgr.log_info('NM: %.3f ' % (np.mean(acc[0, :, :, 0])))
msg_mgr.log_info('===Rank-1 (Exclude identical-view cases)===') msg_mgr.log_info('===Rank-1 (Exclude identical-view cases)===')
msg_mgr.log_info('NM: %.3f ' % (de_diag(acc[0, :, :, 0]))) msg_mgr.log_info('NM: %.3f ' % (de_diag(acc[0, :, :, 0])))
msg_mgr.log_info(
'===Rank-1 of each angle (Exclude identical-view cases)===')
msg_mgr.log_info('NM: {}'.format(de_diag(acc[0, :, :, 0], True)))
result_dict["scalar/test_accuracy/NM"] = de_diag(acc[0, :, :, 0]) result_dict["scalar/test_accuracy/NM"] = de_diag(acc[0, :, :, 0])
return result_dict return result_dict
+1 -1
View File
@@ -120,7 +120,7 @@ def download_file_and_uncompress(url,
if __name__ == "__main__": if __name__ == "__main__":
urls = [ urls = [
"https://github.com/ShiqiYu/OpenGait/releases/download/v1.0/pretrained_casiab_model.zip", "https://github.com/ShiqiYu/OpenGait/releases/download/v1.0/pretrained_casiab_model.zip",
"https://github.com/ShiqiYu/OpenGait/releases/download/v1.0/pretrained_oumvlp_model.zip"] "https://github.com/ShiqiYu/OpenGait/releases/download/v1.1/pretrained_oumvlp_model.zip"]
for url in urls: for url in urls:
download_file_and_uncompress( download_file_and_uncompress(
url=url, extrapath='output') url=url, extrapath='output')