add gaitedge training code

This commit is contained in:
darkliang
2022-07-17 13:47:50 +08:00
parent 4205c5f283
commit b183455eb8
17 changed files with 814 additions and 11 deletions
+1
View File
@@ -7,4 +7,5 @@ from .common import mkdir, clones
from .common import MergeCfgsDict
from .common import get_attr_from
from .common import NoOp
from .common import MeanIOU
from .msg_manager import get_msg_mgr
+13 -1
View File
@@ -138,7 +138,7 @@ def clones(module, N):
def config_loader(path):
with open(path, 'r') as stream:
src_cfgs = yaml.safe_load(stream)
with open("./config/default.yaml", 'r') as stream:
with open("./configs/default.yaml", 'r') as stream:
dst_cfgs = yaml.safe_load(stream)
MergeCfgsDict(src_cfgs, dst_cfgs)
return dst_cfgs
@@ -203,3 +203,15 @@ def get_ddp_module(module, **kwargs):
def params_count(net):
n_parameters = sum(p.numel() for p in net.parameters())
return 'Parameters Count: {:.5f}M'.format(n_parameters / 1e6)
def MeanIOU(msk1, msk2, eps=1.0e-9):
if not is_tensor(msk1):
msk1 = torch.from_numpy(msk1).cuda()
if not is_tensor(msk2):
msk2 = torch.from_numpy(msk2).cuda()
n = msk1.size(0)
inter = msk1 * msk2
union = ((msk1 + msk2) > 0.).float()
MeIOU = inter.view(n, -1).sum(-1) / (union.view(n, -1).sum(-1) + eps)
return MeIOU
+11 -3
View File
@@ -3,7 +3,7 @@ from time import strftime, localtime
import torch
import numpy as np
import torch.nn.functional as F
from utils import get_msg_mgr, mkdir
from utils import get_msg_mgr, mkdir, MeanIOU
def cuda_dist(x, y, metric='euc'):
@@ -124,10 +124,10 @@ def identification_real_scene(data, dataset, metric='euc'):
gallery_seq_type = {'0001-1000': ['1', '2'],
"HID2021": ['0'], '0001-1000-test': ['0'],
'GREW': ['01']}
'GREW': ['01'], 'TTG-200': ['1']}
probe_seq_type = {'0001-1000': ['3', '4', '5', '6'],
"HID2021": ['1'], '0001-1000-test': ['1'],
'GREW': ['02']}
'GREW': ['02'], 'TTG-200': ['2', '3', '4', '5', '6']}
num_rank = 20
acc = np.zeros([num_rank]) - 1.
@@ -274,3 +274,11 @@ def re_ranking(original_dist, query_num, k1, k2, lambda_value):
del jaccard_dist
final_dist = final_dist[:query_num, query_num:]
return final_dist
def mean_iou(data, dataset):
labels = data['mask']
pred = data['pred']
miou = MeanIOU(pred, labels)
get_msg_mgr().log_info('mIOU: %.3f' % (miou.mean()))
return {"scalar/test_accuracy/mIOU": miou}