fix bug of evaluator_cfg
This commit is contained in:
@@ -438,10 +438,10 @@ class BaseModel(MetaModel, nn.Module):
|
||||
@ staticmethod
|
||||
def run_test(model):
|
||||
"""Accept the instance object(model) here, and then run the test loop."""
|
||||
|
||||
if torch.distributed.get_world_size() != model.engine_cfg['sampler']['batch_size']:
|
||||
evaluator_cfg = model.cfgs['evaluator_cfg']
|
||||
if torch.distributed.get_world_size() != evaluator_cfg['sampler']['batch_size']:
|
||||
raise ValueError("The batch size ({}) must be equal to the number of GPUs ({}) in testing mode!".format(
|
||||
model.engine_cfg['sampler']['batch_size'], torch.distributed.get_world_size()))
|
||||
evaluator_cfg['sampler']['batch_size'], torch.distributed.get_world_size()))
|
||||
rank = torch.distributed.get_rank()
|
||||
with torch.no_grad():
|
||||
info_dict = model.inference(rank)
|
||||
@@ -454,13 +454,13 @@ class BaseModel(MetaModel, nn.Module):
|
||||
info_dict.update({
|
||||
'labels': label_list, 'types': types_list, 'views': views_list})
|
||||
|
||||
if 'eval_func' in model.cfgs["evaluator_cfg"].keys():
|
||||
eval_func = model.cfgs['evaluator_cfg']["eval_func"]
|
||||
if 'eval_func' in evaluator_cfg.keys():
|
||||
eval_func = evaluator_cfg["eval_func"]
|
||||
else:
|
||||
eval_func = 'identification'
|
||||
eval_func = getattr(eval_functions, eval_func)
|
||||
valid_args = get_valid_args(
|
||||
eval_func, model.cfgs["evaluator_cfg"], ['metric'])
|
||||
eval_func, evaluator_cfg, ['metric'])
|
||||
try:
|
||||
dataset_name = model.cfgs['data_cfg']['test_dataset_name']
|
||||
except:
|
||||
|
||||
Reference in New Issue
Block a user