diff --git a/lib/modeling/base_model.py b/lib/modeling/base_model.py index 27d171d..91510b9 100644 --- a/lib/modeling/base_model.py +++ b/lib/modeling/base_model.py @@ -83,7 +83,7 @@ class MetaModel(metaclass=ABCMeta): raise NotImplementedError @abstractmethod - def train_step(self, loss_num): + def train_step(self, loss_num) -> bool: raise NotImplementedError @abstractmethod @@ -292,37 +292,34 @@ class BaseModel(MetaModel, nn.Module): del seqs return ipts, labs, typs, vies, seqL - def train_step(self, loss_sum): + def train_step(self, loss_sum) -> bool: ''' Conduct loss_sum.backward(), self.optimizer.step() and self.scheduler.step(). ''' - - skip_lr_sched = False self.optimizer.zero_grad() if loss_sum <= 1e-9: - self.msg_mgr.log_warning("Find the loss sum less than 1e-9 but the training process will continue!)") - + self.msg_mgr.log_warning( + "Find the loss sum less than 1e-9 but the training process will continue!") + if self.engine_cfg['enable_float16']: self.Scaler.scale(loss_sum).backward() self.Scaler.step(self.optimizer) scale = self.Scaler.get_scale() self.Scaler.update() - skip_lr_sched = (scale != self.Scaler.get_scale()) # Warning caused by optimizer skip when NaN # https://discuss.pytorch.org/t/optimizer-step-before-lr-scheduler-step-error-using-gradscaler/92930/5 - - #for debug - # for name, param in self.named_parameters(): - # if param.grad is None: - # print(name) + if scale != self.Scaler.get_scale(): + self.msg_mgr.log_debug("Training step skip. Expected the former scale equals to the present, got {} and {}".format( + scale, self.Scaler.get_scale())) + return False else: loss_sum.backward() self.optimizer.step() - - if not skip_lr_sched: - self.iteration += 1 - self.scheduler.step() + + self.iteration += 1 + self.scheduler.step() + return True def inference(self, rank): total_size = len(self.test_loader) @@ -368,7 +365,9 @@ class BaseModel(MetaModel, nn.Module): training_feat, visual_summary = retval['training_feat'], retval['visual_summary'] del retval loss_sum, loss_info = model.loss_aggregator(training_feat) - model.train_step(loss_sum) + ok = model.train_step(loss_sum) + if not ok: + continue visual_summary.update(loss_info) visual_summary['scalar/learning_rate'] = model.optimizer.param_groups[0]['lr'] @@ -403,13 +402,13 @@ class BaseModel(MetaModel, nn.Module): info_dict.update({ 'labels': label_list, 'types': types_list, 'views': views_list}) - if 'eval_func' in model.engine_cfg.keys(): - eval_func = model.engine_cfg['eval_func'] + if 'eval_func' in model.cfgs["evaluator_cfg"].keys(): + eval_func = model.cfgs['evaluator_cfg']["eval_func"] else: eval_func = 'identification' eval_func = getattr(eval_functions, eval_func) valid_args = get_valid_args( - eval_func, model.engine_cfg, ['metric']) + eval_func, model.cfgs["evaluator_cfg"], ['metric']) try: dataset_name = model.cfgs['data_cfg']['test_dataset_name'] except: