add gaitedge training code
This commit is contained in:
@@ -3,7 +3,7 @@ from time import strftime, localtime
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from utils import get_msg_mgr, mkdir
|
||||
from utils import get_msg_mgr, mkdir, MeanIOU
|
||||
|
||||
|
||||
def cuda_dist(x, y, metric='euc'):
|
||||
@@ -124,10 +124,10 @@ def identification_real_scene(data, dataset, metric='euc'):
|
||||
|
||||
gallery_seq_type = {'0001-1000': ['1', '2'],
|
||||
"HID2021": ['0'], '0001-1000-test': ['0'],
|
||||
'GREW': ['01']}
|
||||
'GREW': ['01'], 'TTG-200': ['1']}
|
||||
probe_seq_type = {'0001-1000': ['3', '4', '5', '6'],
|
||||
"HID2021": ['1'], '0001-1000-test': ['1'],
|
||||
'GREW': ['02']}
|
||||
'GREW': ['02'], 'TTG-200': ['2', '3', '4', '5', '6']}
|
||||
|
||||
num_rank = 20
|
||||
acc = np.zeros([num_rank]) - 1.
|
||||
@@ -274,3 +274,11 @@ def re_ranking(original_dist, query_num, k1, k2, lambda_value):
|
||||
del jaccard_dist
|
||||
final_dist = final_dist[:query_num, query_num:]
|
||||
return final_dist
|
||||
|
||||
|
||||
def mean_iou(data, dataset):
|
||||
labels = data['mask']
|
||||
pred = data['pred']
|
||||
miou = MeanIOU(pred, labels)
|
||||
get_msg_mgr().log_info('mIOU: %.3f' % (miou.mean()))
|
||||
return {"scalar/test_accuracy/mIOU": miou}
|
||||
|
||||
Reference in New Issue
Block a user