add gaitedge training code

This commit is contained in:
darkliang
2022-07-17 13:47:50 +08:00
parent 4205c5f283
commit b183455eb8
17 changed files with 814 additions and 11 deletions
+11 -3
View File
@@ -3,7 +3,7 @@ from time import strftime, localtime
import torch
import numpy as np
import torch.nn.functional as F
from utils import get_msg_mgr, mkdir
from utils import get_msg_mgr, mkdir, MeanIOU
def cuda_dist(x, y, metric='euc'):
@@ -124,10 +124,10 @@ def identification_real_scene(data, dataset, metric='euc'):
gallery_seq_type = {'0001-1000': ['1', '2'],
"HID2021": ['0'], '0001-1000-test': ['0'],
'GREW': ['01']}
'GREW': ['01'], 'TTG-200': ['1']}
probe_seq_type = {'0001-1000': ['3', '4', '5', '6'],
"HID2021": ['1'], '0001-1000-test': ['1'],
'GREW': ['02']}
'GREW': ['02'], 'TTG-200': ['2', '3', '4', '5', '6']}
num_rank = 20
acc = np.zeros([num_rank]) - 1.
@@ -274,3 +274,11 @@ def re_ranking(original_dist, query_num, k1, k2, lambda_value):
del jaccard_dist
final_dist = final_dist[:query_num, query_num:]
return final_dist
def mean_iou(data, dataset):
labels = data['mask']
pred = data['pred']
miou = MeanIOU(pred, labels)
get_msg_mgr().log_info('mIOU: %.3f' % (miou.mean()))
return {"scalar/test_accuracy/mIOU": miou}