add gaitedge training code
This commit is contained in:
@@ -7,4 +7,5 @@ from .common import mkdir, clones
|
||||
from .common import MergeCfgsDict
|
||||
from .common import get_attr_from
|
||||
from .common import NoOp
|
||||
from .common import MeanIOU
|
||||
from .msg_manager import get_msg_mgr
|
||||
@@ -138,7 +138,7 @@ def clones(module, N):
|
||||
def config_loader(path):
|
||||
with open(path, 'r') as stream:
|
||||
src_cfgs = yaml.safe_load(stream)
|
||||
with open("./config/default.yaml", 'r') as stream:
|
||||
with open("./configs/default.yaml", 'r') as stream:
|
||||
dst_cfgs = yaml.safe_load(stream)
|
||||
MergeCfgsDict(src_cfgs, dst_cfgs)
|
||||
return dst_cfgs
|
||||
@@ -203,3 +203,15 @@ def get_ddp_module(module, **kwargs):
|
||||
def params_count(net):
|
||||
n_parameters = sum(p.numel() for p in net.parameters())
|
||||
return 'Parameters Count: {:.5f}M'.format(n_parameters / 1e6)
|
||||
|
||||
|
||||
def MeanIOU(msk1, msk2, eps=1.0e-9):
|
||||
if not is_tensor(msk1):
|
||||
msk1 = torch.from_numpy(msk1).cuda()
|
||||
if not is_tensor(msk2):
|
||||
msk2 = torch.from_numpy(msk2).cuda()
|
||||
n = msk1.size(0)
|
||||
inter = msk1 * msk2
|
||||
union = ((msk1 + msk2) > 0.).float()
|
||||
MeIOU = inter.view(n, -1).sum(-1) / (union.view(n, -1).sum(-1) + eps)
|
||||
return MeIOU
|
||||
|
||||
@@ -3,7 +3,7 @@ from time import strftime, localtime
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from utils import get_msg_mgr, mkdir
|
||||
from utils import get_msg_mgr, mkdir, MeanIOU
|
||||
|
||||
|
||||
def cuda_dist(x, y, metric='euc'):
|
||||
@@ -124,10 +124,10 @@ def identification_real_scene(data, dataset, metric='euc'):
|
||||
|
||||
gallery_seq_type = {'0001-1000': ['1', '2'],
|
||||
"HID2021": ['0'], '0001-1000-test': ['0'],
|
||||
'GREW': ['01']}
|
||||
'GREW': ['01'], 'TTG-200': ['1']}
|
||||
probe_seq_type = {'0001-1000': ['3', '4', '5', '6'],
|
||||
"HID2021": ['1'], '0001-1000-test': ['1'],
|
||||
'GREW': ['02']}
|
||||
'GREW': ['02'], 'TTG-200': ['2', '3', '4', '5', '6']}
|
||||
|
||||
num_rank = 20
|
||||
acc = np.zeros([num_rank]) - 1.
|
||||
@@ -274,3 +274,11 @@ def re_ranking(original_dist, query_num, k1, k2, lambda_value):
|
||||
del jaccard_dist
|
||||
final_dist = final_dist[:query_num, query_num:]
|
||||
return final_dist
|
||||
|
||||
|
||||
def mean_iou(data, dataset):
|
||||
labels = data['mask']
|
||||
pred = data['pred']
|
||||
miou = MeanIOU(pred, labels)
|
||||
get_msg_mgr().log_info('mIOU: %.3f' % (miou.mean()))
|
||||
return {"scalar/test_accuracy/mIOU": miou}
|
||||
|
||||
Reference in New Issue
Block a user