fix up bugs:

1. with_test not work
2. stay in iteration 0 bcz scale changed
This commit is contained in:
darkliang
2021-10-19 15:59:36 +08:00
parent 84c4d484a6
commit c8e5bc1cbe
+17 -18
View File
@@ -83,7 +83,7 @@ class MetaModel(metaclass=ABCMeta):
raise NotImplementedError
@abstractmethod
def train_step(self, loss_num):
def train_step(self, loss_num) -> bool:
raise NotImplementedError
@abstractmethod
@@ -292,37 +292,34 @@ class BaseModel(MetaModel, nn.Module):
del seqs
return ipts, labs, typs, vies, seqL
def train_step(self, loss_sum):
def train_step(self, loss_sum) -> bool:
'''
Conduct loss_sum.backward(), self.optimizer.step() and self.scheduler.step().
'''
skip_lr_sched = False
self.optimizer.zero_grad()
if loss_sum <= 1e-9:
self.msg_mgr.log_warning("Find the loss sum less than 1e-9 but the training process will continue!)")
self.msg_mgr.log_warning(
"Find the loss sum less than 1e-9 but the training process will continue!")
if self.engine_cfg['enable_float16']:
self.Scaler.scale(loss_sum).backward()
self.Scaler.step(self.optimizer)
scale = self.Scaler.get_scale()
self.Scaler.update()
skip_lr_sched = (scale != self.Scaler.get_scale())
# Warning caused by optimizer skip when NaN
# https://discuss.pytorch.org/t/optimizer-step-before-lr-scheduler-step-error-using-gradscaler/92930/5
#for debug
# for name, param in self.named_parameters():
# if param.grad is None:
# print(name)
if scale != self.Scaler.get_scale():
self.msg_mgr.log_debug("Training step skip. Expected the former scale equals to the present, got {} and {}".format(
scale, self.Scaler.get_scale()))
return False
else:
loss_sum.backward()
self.optimizer.step()
if not skip_lr_sched:
self.iteration += 1
self.scheduler.step()
self.iteration += 1
self.scheduler.step()
return True
def inference(self, rank):
total_size = len(self.test_loader)
@@ -368,7 +365,9 @@ class BaseModel(MetaModel, nn.Module):
training_feat, visual_summary = retval['training_feat'], retval['visual_summary']
del retval
loss_sum, loss_info = model.loss_aggregator(training_feat)
model.train_step(loss_sum)
ok = model.train_step(loss_sum)
if not ok:
continue
visual_summary.update(loss_info)
visual_summary['scalar/learning_rate'] = model.optimizer.param_groups[0]['lr']
@@ -403,13 +402,13 @@ class BaseModel(MetaModel, nn.Module):
info_dict.update({
'labels': label_list, 'types': types_list, 'views': views_list})
if 'eval_func' in model.engine_cfg.keys():
eval_func = model.engine_cfg['eval_func']
if 'eval_func' in model.cfgs["evaluator_cfg"].keys():
eval_func = model.cfgs['evaluator_cfg']["eval_func"]
else:
eval_func = 'identification'
eval_func = getattr(eval_functions, eval_func)
valid_args = get_valid_args(
eval_func, model.engine_cfg, ['metric'])
eval_func, model.cfgs["evaluator_cfg"], ['metric'])
try:
dataset_name = model.cfgs['data_cfg']['test_dataset_name']
except: