add gaitedge training code

This commit is contained in:
darkliang
2022-07-17 13:47:50 +08:00
parent 4205c5f283
commit b183455eb8
17 changed files with 814 additions and 11 deletions
+13 -1
View File
@@ -138,7 +138,7 @@ def clones(module, N):
def config_loader(path):
with open(path, 'r') as stream:
src_cfgs = yaml.safe_load(stream)
with open("./config/default.yaml", 'r') as stream:
with open("./configs/default.yaml", 'r') as stream:
dst_cfgs = yaml.safe_load(stream)
MergeCfgsDict(src_cfgs, dst_cfgs)
return dst_cfgs
@@ -203,3 +203,15 @@ def get_ddp_module(module, **kwargs):
def params_count(net):
n_parameters = sum(p.numel() for p in net.parameters())
return 'Parameters Count: {:.5f}M'.format(n_parameters / 1e6)
def MeanIOU(msk1, msk2, eps=1.0e-9):
if not is_tensor(msk1):
msk1 = torch.from_numpy(msk1).cuda()
if not is_tensor(msk2):
msk2 = torch.from_numpy(msk2).cuda()
n = msk1.size(0)
inter = msk1 * msk2
union = ((msk1 + msk2) > 0.).float()
MeIOU = inter.view(n, -1).sum(-1) / (union.view(n, -1).sum(-1) + eps)
return MeIOU