fix bug of evaluator_cfg

This commit is contained in:
darkliang
2024-03-20 18:10:12 +08:00
parent ec30085ea3
commit eb0b102f8b
+6 -6
View File
@@ -438,10 +438,10 @@ class BaseModel(MetaModel, nn.Module):
@ staticmethod @ staticmethod
def run_test(model): def run_test(model):
"""Accept the instance object(model) here, and then run the test loop.""" """Accept the instance object(model) here, and then run the test loop."""
evaluator_cfg = model.cfgs['evaluator_cfg']
if torch.distributed.get_world_size() != model.engine_cfg['sampler']['batch_size']: if torch.distributed.get_world_size() != evaluator_cfg['sampler']['batch_size']:
raise ValueError("The batch size ({}) must be equal to the number of GPUs ({}) in testing mode!".format( raise ValueError("The batch size ({}) must be equal to the number of GPUs ({}) in testing mode!".format(
model.engine_cfg['sampler']['batch_size'], torch.distributed.get_world_size())) evaluator_cfg['sampler']['batch_size'], torch.distributed.get_world_size()))
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
with torch.no_grad(): with torch.no_grad():
info_dict = model.inference(rank) info_dict = model.inference(rank)
@@ -454,13 +454,13 @@ class BaseModel(MetaModel, nn.Module):
info_dict.update({ info_dict.update({
'labels': label_list, 'types': types_list, 'views': views_list}) 'labels': label_list, 'types': types_list, 'views': views_list})
if 'eval_func' in model.cfgs["evaluator_cfg"].keys(): if 'eval_func' in evaluator_cfg.keys():
eval_func = model.cfgs['evaluator_cfg']["eval_func"] eval_func = evaluator_cfg["eval_func"]
else: else:
eval_func = 'identification' eval_func = 'identification'
eval_func = getattr(eval_functions, eval_func) eval_func = getattr(eval_functions, eval_func)
valid_args = get_valid_args( valid_args = get_valid_args(
eval_func, model.cfgs["evaluator_cfg"], ['metric']) eval_func, evaluator_cfg, ['metric'])
try: try:
dataset_name = model.cfgs['data_cfg']['test_dataset_name'] dataset_name = model.cfgs['data_cfg']['test_dataset_name']
except: except: