import torch import numpy as np import os.path as osp import torch.nn as nn import torch.optim as optim import torch.utils.data as tordata from tqdm import tqdm from torch.cuda.amp import autocast from torch.cuda.amp import GradScaler from abc import ABCMeta from abc import abstractmethod from . import backbones from .loss_aggregator import LossAggregator from modeling.modules import fix_BN from data.transform import get_transform from data.collate_fn import CollateFn from data.dataset import DataSet import data.sampler as Samplers from utils import Odict, mkdir, ddp_all_gather from utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from from utils import evaluation as eval_functions from utils import NoOp from utils import get_msg_mgr __all__ = ['BasicModel'] class MetaModel(metaclass=ABCMeta): @abstractmethod def get_loader(self, data_cfg): ''' Build your data Loader here. Inputs: data_cfg, dict Return: Loader ''' raise NotImplementedError @abstractmethod def build_network(self, model_cfg): ''' Build your Model here. Inputs: model_cfg, dict Return: Network, nn.Module(s) ''' raise NotImplementedError @abstractmethod def init_parameters(self): raise NotImplementedError @abstractmethod def get_optimizer(self, optimizer_cfg): ''' Build your Optimizer here. Inputs: optimizer_cfg, dict Return: Optimizer, a optimizer object ''' raise NotImplementedError @abstractmethod def get_scheduler(self, scheduler_cfg): ''' Build your Scheduler. Inputs: scheduler_cfg, dict Optimizer, your optimizer Return: Scheduler, a scheduler object ''' raise NotImplementedError @abstractmethod def save_ckpt(self, iteration): raise NotImplementedError @abstractmethod def resume_ckpt(self, restore_hint): raise NotImplementedError @abstractmethod def inputs_pretreament(self, inputs): raise NotImplementedError @abstractmethod def train_step(self, loss_num) -> bool: raise NotImplementedError @abstractmethod def inference(self): raise NotImplementedError @abstractmethod def run_train(model): raise NotImplementedError @abstractmethod def run_test(model): raise NotImplementedError class BaseModel(MetaModel, nn.Module): def __init__(self, cfgs, training): super(BaseModel, self).__init__() self.msg_mgr = get_msg_mgr() self.cfgs = cfgs self.iteration = 0 self.engine_cfg = cfgs['trainer_cfg'] if training else cfgs['evaluator_cfg'] if self.engine_cfg is None: raise Exception("Initialize a model without -Engine-Cfgs-") if training and self.engine_cfg['enable_float16']: self.Scaler = GradScaler() self.save_path = osp.join('output/', cfgs['data_cfg']['dataset_name'], cfgs['model_cfg']['model'], self.engine_cfg['save_name']) self.build_network(cfgs['model_cfg']) self.init_parameters() self.msg_mgr.log_info(cfgs['data_cfg']) if training: self.train_loader = self.get_loader( cfgs['data_cfg'], train=True) if not training or self.engine_cfg['with_test']: self.test_loader = self.get_loader( cfgs['data_cfg'], train=False) self.device = torch.distributed.get_rank() torch.cuda.set_device(self.device) self.to(device=torch.device( "cuda", self.device)) if training: if cfgs['trainer_cfg']['fix_BN']: fix_BN(self) self.loss_aggregator = LossAggregator(cfgs['loss_cfg']) self.optimizer = self.get_optimizer(self.cfgs['optimizer_cfg']) self.scheduler = self.get_scheduler(cfgs['scheduler_cfg']) self.train(training) restore_hint = self.engine_cfg['restore_hint'] if restore_hint != 0: self.resume_ckpt(restore_hint) def get_backbone(self, model_cfg): def _get_backbone(backbone_cfg): if is_dict(backbone_cfg): Backbone = get_attr_from([backbones], backbone_cfg['type']) valid_args = get_valid_args(Backbone, backbone_cfg, ['type']) return Backbone(**valid_args) if is_list(backbone_cfg): Backbone = nn.ModuleList([_get_backbone(cfg) for cfg in backbone_cfg]) return Backbone raise ValueError( "Error type for -Backbone-Cfg-, supported: (A list of) dict.") if 'backbone_cfg' in model_cfg.keys(): backbone_cfg = model_cfg['backbone_cfg'] Backbone = _get_backbone(backbone_cfg) else: Backbone = None return Backbone def build_network(self, model_cfg): self.Backbone = self.get_backbone(model_cfg) def init_parameters(self): for m in self.modules(): if isinstance(m, (nn.Conv3d, nn.Conv2d, nn.Conv1d)): nn.init.xavier_uniform_(m.weight.data) if m.bias is not None: nn.init.constant_(m.bias.data, 0.0) elif isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight.data) if m.bias is not None: nn.init.constant_(m.bias.data, 0.0) elif isinstance(m, (nn.BatchNorm3d, nn.BatchNorm2d, nn.BatchNorm1d)): if m.affine: nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0.0) def get_loader(self, data_cfg, train=True): sampler_cfg = self.cfgs['trainer_cfg']['sampler'] if train else self.cfgs['evaluator_cfg']['sampler'] dataset = DataSet(data_cfg, train) Sampler = get_attr_from([Samplers], sampler_cfg['type']) vaild_args = get_valid_args(Sampler, sampler_cfg, free_keys=[ 'sample_type', 'type']) sampler = Sampler(dataset, **vaild_args) loader = tordata.DataLoader( dataset=dataset, batch_sampler=sampler, collate_fn=CollateFn(dataset.label_set, sampler_cfg), num_workers=data_cfg['num_workers']) return loader def get_optimizer(self, optimizer_cfg): self.msg_mgr.log_info(optimizer_cfg) optimizer = get_attr_from([optim], optimizer_cfg['solver']) valid_arg = get_valid_args(optimizer, optimizer_cfg, ['solver']) optimizer = optimizer( filter(lambda p: p.requires_grad, self.parameters()), **valid_arg) return optimizer def get_scheduler(self, scheduler_cfg): self.msg_mgr.log_info(scheduler_cfg) Scheduler = get_attr_from( [optim.lr_scheduler], scheduler_cfg['scheduler']) valid_arg = get_valid_args(Scheduler, scheduler_cfg, ['scheduler']) scheduler = Scheduler(self.optimizer, **valid_arg) return scheduler def save_ckpt(self, iteration): if torch.distributed.get_rank() == 0: mkdir(osp.join(self.save_path, "checkpoints/")) save_name = self.engine_cfg['save_name'] checkpoint = { 'model': self.state_dict(), 'optimizer': self.optimizer.state_dict(), 'scheduler': self.scheduler.state_dict(), 'iteration': iteration} torch.save(checkpoint, osp.join(self.save_path, 'checkpoints/{}-{:0>5}.pt'.format(save_name, iteration))) def _load_ckpt(self, save_name): load_ckpt_strict = self.engine_cfg['restore_ckpt_strict'] checkpoint = torch.load(save_name, map_location=torch.device( "cuda", self.device)) model_state_dict = checkpoint['model'] if not load_ckpt_strict: self.msg_mgr.log_info("-------- Restored Params List --------") self.msg_mgr.log_info(sorted(set(model_state_dict.keys()).intersection( set(self.state_dict().keys())))) self.load_state_dict(model_state_dict, strict=load_ckpt_strict) if self.training: if not self.engine_cfg["optimizer_reset"] and 'optimizer' in checkpoint: self.optimizer.load_state_dict(checkpoint['optimizer']) else: self.msg_mgr.log_warning( "Restore NO Optimizer from %s !!!" % save_name) if not self.engine_cfg["scheduler_reset"] and 'scheduler' in checkpoint: self.scheduler.load_state_dict( checkpoint['scheduler']) else: self.msg_mgr.log_warning( "Restore NO Scheduler from %s !!!" % save_name) self.msg_mgr.log_info("Restore Parameters from %s !!!" % save_name) del checkpoint def resume_ckpt(self, restore_hint): if isinstance(restore_hint, int): save_name = self.engine_cfg['save_name'] save_name = osp.join( self.save_path, 'checkpoints/{}-{:0>5}.pt'.format(save_name, restore_hint)) self.iteration = restore_hint elif isinstance(restore_hint, str): save_name = restore_hint self.iteration = 0 else: raise ValueError( "Error type for -Restore_Hint-, supported: int or string.") self._load_ckpt(save_name) def inputs_pretreament(self, inputs): seqs_batch, labs_batch, typs_batch, vies_batch, seqL_batch = inputs trf_cfgs = self.engine_cfg['transform'] seq_trfs = get_transform(trf_cfgs) requires_grad = bool(self.training) seqs = [np2var(np.asarray([trf(fra) for fra in seq]), requires_grad=requires_grad).float() for trf, seq in zip(seq_trfs, seqs_batch)] typs = typs_batch vies = vies_batch labs = list2var(labs_batch).long() if seqL_batch is not None: seqL_batch = np2var(seqL_batch).int() seqL = seqL_batch if seqL is not None: seqL_sum = int(seqL.sum().data.cpu().numpy()) ipts = [_[:, :seqL_sum] for _ in seqs] else: ipts = seqs del seqs return ipts, labs, typs, vies, seqL def train_step(self, loss_sum) -> bool: ''' Conduct loss_sum.backward(), self.optimizer.step() and self.scheduler.step(). ''' self.optimizer.zero_grad() if loss_sum <= 1e-9: self.msg_mgr.log_warning( "Find the loss sum less than 1e-9 but the training process will continue!") if self.engine_cfg['enable_float16']: self.Scaler.scale(loss_sum).backward() self.Scaler.step(self.optimizer) scale = self.Scaler.get_scale() self.Scaler.update() # Warning caused by optimizer skip when NaN # https://discuss.pytorch.org/t/optimizer-step-before-lr-scheduler-step-error-using-gradscaler/92930/5 if scale != self.Scaler.get_scale(): self.msg_mgr.log_debug("Training step skip. Expected the former scale equals to the present, got {} and {}".format( scale, self.Scaler.get_scale())) return False else: loss_sum.backward() self.optimizer.step() self.iteration += 1 self.scheduler.step() return True def inference(self, rank): total_size = len(self.test_loader) if rank == 0: pbar = tqdm(total=total_size, desc='Transforming') else: pbar = NoOp() batch_size = self.test_loader.batch_sampler.batch_size rest_size = total_size info_dict = Odict() for inputs in self.test_loader: ipts = self.inputs_pretreament(inputs) with autocast(enabled=self.engine_cfg['enable_float16']): retval = self.forward(ipts) inference_feat = retval['inference_feat'] for k, v in inference_feat.items(): inference_feat[k] = ddp_all_gather(v, requires_grad=False) del retval for k, v in inference_feat.items(): inference_feat[k] = ts2np(v) info_dict.append(inference_feat) rest_size -= batch_size if rest_size >= 0: update_size = batch_size else: update_size = total_size % batch_size pbar.update(update_size) pbar.close() for k, v in info_dict.items(): v = np.concatenate(v)[:total_size] info_dict[k] = v return info_dict @ staticmethod def run_train(model): ''' Accept the instace object(model) here, and then run the train loop handler. ''' for inputs in model.train_loader: ipts = model.inputs_pretreament(inputs) with autocast(enabled=model.engine_cfg['enable_float16']): retval = model(ipts) training_feat, visual_summary = retval['training_feat'], retval['visual_summary'] del retval loss_sum, loss_info = model.loss_aggregator(training_feat) ok = model.train_step(loss_sum) if not ok: continue visual_summary.update(loss_info) visual_summary['scalar/learning_rate'] = model.optimizer.param_groups[0]['lr'] model.msg_mgr.train_step(loss_info, visual_summary) if model.iteration % model.engine_cfg['save_iter'] == 0: # save the checkpoint model.save_ckpt(model.iteration) # run test if with_test = true if model.engine_cfg['with_test']: model.msg_mgr.log_info("Running test...") model.eval() result_dict = BaseModel.run_test(model) model.train() model.msg_mgr.write_to_tensorboard(result_dict) model.msg_mgr.reset_time() if model.iteration >= model.engine_cfg['total_iter']: break @ staticmethod def run_test(model): rank = torch.distributed.get_rank() with torch.no_grad(): info_dict = model.inference(rank) if rank == 0: loader = model.test_loader label_list = loader.dataset.label_list types_list = loader.dataset.types_list views_list = loader.dataset.views_list info_dict.update({ 'labels': label_list, 'types': types_list, 'views': views_list}) if 'eval_func' in model.cfgs["evaluator_cfg"].keys(): eval_func = model.cfgs['evaluator_cfg']["eval_func"] else: eval_func = 'identification' eval_func = getattr(eval_functions, eval_func) valid_args = get_valid_args( eval_func, model.cfgs["evaluator_cfg"], ['metric']) try: dataset_name = model.cfgs['data_cfg']['test_dataset_name'] except: dataset_name = model.cfgs['data_cfg']['dataset_name'] return eval_func(info_dict, dataset_name, **valid_args)