fix bug of evaluator_cfg
This commit is contained in:
@@ -438,10 +438,10 @@ class BaseModel(MetaModel, nn.Module):
|
|||||||
@ staticmethod
|
@ staticmethod
|
||||||
def run_test(model):
|
def run_test(model):
|
||||||
"""Accept the instance object(model) here, and then run the test loop."""
|
"""Accept the instance object(model) here, and then run the test loop."""
|
||||||
|
evaluator_cfg = model.cfgs['evaluator_cfg']
|
||||||
if torch.distributed.get_world_size() != model.engine_cfg['sampler']['batch_size']:
|
if torch.distributed.get_world_size() != evaluator_cfg['sampler']['batch_size']:
|
||||||
raise ValueError("The batch size ({}) must be equal to the number of GPUs ({}) in testing mode!".format(
|
raise ValueError("The batch size ({}) must be equal to the number of GPUs ({}) in testing mode!".format(
|
||||||
model.engine_cfg['sampler']['batch_size'], torch.distributed.get_world_size()))
|
evaluator_cfg['sampler']['batch_size'], torch.distributed.get_world_size()))
|
||||||
rank = torch.distributed.get_rank()
|
rank = torch.distributed.get_rank()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
info_dict = model.inference(rank)
|
info_dict = model.inference(rank)
|
||||||
@@ -454,13 +454,13 @@ class BaseModel(MetaModel, nn.Module):
|
|||||||
info_dict.update({
|
info_dict.update({
|
||||||
'labels': label_list, 'types': types_list, 'views': views_list})
|
'labels': label_list, 'types': types_list, 'views': views_list})
|
||||||
|
|
||||||
if 'eval_func' in model.cfgs["evaluator_cfg"].keys():
|
if 'eval_func' in evaluator_cfg.keys():
|
||||||
eval_func = model.cfgs['evaluator_cfg']["eval_func"]
|
eval_func = evaluator_cfg["eval_func"]
|
||||||
else:
|
else:
|
||||||
eval_func = 'identification'
|
eval_func = 'identification'
|
||||||
eval_func = getattr(eval_functions, eval_func)
|
eval_func = getattr(eval_functions, eval_func)
|
||||||
valid_args = get_valid_args(
|
valid_args = get_valid_args(
|
||||||
eval_func, model.cfgs["evaluator_cfg"], ['metric'])
|
eval_func, evaluator_cfg, ['metric'])
|
||||||
try:
|
try:
|
||||||
dataset_name = model.cfgs['data_cfg']['test_dataset_name']
|
dataset_name = model.cfgs['data_cfg']['test_dataset_name']
|
||||||
except:
|
except:
|
||||||
|
|||||||
Reference in New Issue
Block a user