add gaitedge training code
This commit is contained in:
@@ -138,7 +138,7 @@ def clones(module, N):
|
||||
def config_loader(path):
|
||||
with open(path, 'r') as stream:
|
||||
src_cfgs = yaml.safe_load(stream)
|
||||
with open("./config/default.yaml", 'r') as stream:
|
||||
with open("./configs/default.yaml", 'r') as stream:
|
||||
dst_cfgs = yaml.safe_load(stream)
|
||||
MergeCfgsDict(src_cfgs, dst_cfgs)
|
||||
return dst_cfgs
|
||||
@@ -203,3 +203,15 @@ def get_ddp_module(module, **kwargs):
|
||||
def params_count(net):
|
||||
n_parameters = sum(p.numel() for p in net.parameters())
|
||||
return 'Parameters Count: {:.5f}M'.format(n_parameters / 1e6)
|
||||
|
||||
|
||||
def MeanIOU(msk1, msk2, eps=1.0e-9):
|
||||
if not is_tensor(msk1):
|
||||
msk1 = torch.from_numpy(msk1).cuda()
|
||||
if not is_tensor(msk2):
|
||||
msk2 = torch.from_numpy(msk2).cuda()
|
||||
n = msk1.size(0)
|
||||
inter = msk1 * msk2
|
||||
union = ((msk1 + msk2) > 0.).float()
|
||||
MeIOU = inter.view(n, -1).sum(-1) / (union.view(n, -1).sum(-1) + eps)
|
||||
return MeIOU
|
||||
|
||||
Reference in New Issue
Block a user