463 lines
17 KiB
Python
463 lines
17 KiB
Python
"""The base model definition.
|
|
|
|
This module defines the abstract meta model class and base model class. In the base model,
|
|
we define the basic model functions, like get_loader, build_network, and run_train, etc.
|
|
The api of the base model is run_train and run_test, they are used in `opengait/main.py`.
|
|
|
|
Typical usage:
|
|
|
|
BaseModel.run_train(model)
|
|
BaseModel.run_test(model)
|
|
"""
|
|
import torch
|
|
import numpy as np
|
|
import os.path as osp
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import torch.utils.data as tordata
|
|
|
|
from tqdm import tqdm
|
|
from torch.cuda.amp import autocast
|
|
from torch.cuda.amp import GradScaler
|
|
from abc import ABCMeta
|
|
from abc import abstractmethod
|
|
|
|
from . import backbones
|
|
from .loss_aggregator import LossAggregator
|
|
from data.transform import get_transform
|
|
from data.collate_fn import CollateFn
|
|
from data.dataset import DataSet
|
|
import data.sampler as Samplers
|
|
from utils import Odict, mkdir, ddp_all_gather
|
|
from utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from
|
|
from evaluation import evaluator as eval_functions
|
|
from utils import NoOp
|
|
from utils import get_msg_mgr
|
|
|
|
__all__ = ['BaseModel']
|
|
|
|
|
|
class MetaModel(metaclass=ABCMeta):
|
|
"""The necessary functions for the base model.
|
|
|
|
This class defines the necessary functions for the base model, in the base model, we have implemented them.
|
|
"""
|
|
@abstractmethod
|
|
def get_loader(self, data_cfg):
|
|
"""Based on the given data_cfg, we get the data loader."""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def build_network(self, model_cfg):
|
|
"""Build your network here."""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def init_parameters(self):
|
|
"""Initialize the parameters of your network."""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def get_optimizer(self, optimizer_cfg):
|
|
"""Based on the given optimizer_cfg, we get the optimizer."""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def get_scheduler(self, scheduler_cfg):
|
|
"""Based on the given scheduler_cfg, we get the scheduler."""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def save_ckpt(self, iteration):
|
|
"""Save the checkpoint, including model parameter, optimizer and scheduler."""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def resume_ckpt(self, restore_hint):
|
|
"""Resume the model from the checkpoint, including model parameter, optimizer and scheduler."""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def inputs_pretreament(self, inputs):
|
|
"""Transform the input data based on transform setting."""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def train_step(self, loss_num) -> bool:
|
|
"""Do one training step."""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def inference(self):
|
|
"""Do inference (calculate features.)."""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def run_train(model):
|
|
"""Run a whole train schedule."""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def run_test(model):
|
|
"""Run a whole test schedule."""
|
|
raise NotImplementedError
|
|
|
|
|
|
class BaseModel(MetaModel, nn.Module):
|
|
"""Base model.
|
|
|
|
This class inherites the MetaModel class, and implements the basic model functions, like get_loader, build_network, etc.
|
|
|
|
Attributes:
|
|
msg_mgr: the massage manager.
|
|
cfgs: the configs.
|
|
iteration: the current iteration of the model.
|
|
engine_cfg: the configs of the engine(train or test).
|
|
save_path: the path to save the checkpoints.
|
|
|
|
"""
|
|
|
|
def __init__(self, cfgs, training):
|
|
"""Initialize the base model.
|
|
|
|
Complete the model initialization, including the data loader, the network, the optimizer, the scheduler, the loss.
|
|
|
|
Args:
|
|
cfgs:
|
|
All of the configs.
|
|
training:
|
|
Whether the model is in training mode.
|
|
"""
|
|
|
|
super(BaseModel, self).__init__()
|
|
self.msg_mgr = get_msg_mgr()
|
|
self.cfgs = cfgs
|
|
self.iteration = 0
|
|
self.engine_cfg = cfgs['trainer_cfg'] if training else cfgs['evaluator_cfg']
|
|
if self.engine_cfg is None:
|
|
raise Exception("Initialize a model without -Engine-Cfgs-")
|
|
|
|
if training and self.engine_cfg['enable_float16']:
|
|
self.Scaler = GradScaler()
|
|
self.save_path = osp.join('output/', cfgs['data_cfg']['dataset_name'],
|
|
cfgs['model_cfg']['model'], self.engine_cfg['save_name'])
|
|
|
|
self.build_network(cfgs['model_cfg'])
|
|
self.init_parameters()
|
|
|
|
self.msg_mgr.log_info(cfgs['data_cfg'])
|
|
if training:
|
|
self.train_loader = self.get_loader(
|
|
cfgs['data_cfg'], train=True)
|
|
if not training or self.engine_cfg['with_test']:
|
|
self.test_loader = self.get_loader(
|
|
cfgs['data_cfg'], train=False)
|
|
|
|
self.device = torch.distributed.get_rank()
|
|
torch.cuda.set_device(self.device)
|
|
self.to(device=torch.device(
|
|
"cuda", self.device))
|
|
|
|
if training:
|
|
self.loss_aggregator = LossAggregator(cfgs['loss_cfg'])
|
|
self.optimizer = self.get_optimizer(self.cfgs['optimizer_cfg'])
|
|
self.scheduler = self.get_scheduler(cfgs['scheduler_cfg'])
|
|
self.train(training)
|
|
restore_hint = self.engine_cfg['restore_hint']
|
|
if restore_hint != 0:
|
|
self.resume_ckpt(restore_hint)
|
|
|
|
def get_backbone(self, backbone_cfg):
|
|
"""Get the backbone of the model."""
|
|
if is_dict(backbone_cfg):
|
|
Backbone = get_attr_from([backbones], backbone_cfg['type'])
|
|
valid_args = get_valid_args(Backbone, backbone_cfg, ['type'])
|
|
return Backbone(**valid_args)
|
|
if is_list(backbone_cfg):
|
|
Backbone = nn.ModuleList([self.get_backbone(cfg)
|
|
for cfg in backbone_cfg])
|
|
return Backbone
|
|
raise ValueError(
|
|
"Error type for -Backbone-Cfg-, supported: (A list of) dict.")
|
|
|
|
def build_network(self, model_cfg):
|
|
if 'backbone_cfg' in model_cfg.keys():
|
|
self.Backbone = self.get_backbone(model_cfg['backbone_cfg'])
|
|
|
|
def init_parameters(self):
|
|
for m in self.modules():
|
|
if isinstance(m, (nn.Conv3d, nn.Conv2d, nn.Conv1d)):
|
|
nn.init.xavier_uniform_(m.weight.data)
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias.data, 0.0)
|
|
elif isinstance(m, nn.Linear):
|
|
nn.init.xavier_uniform_(m.weight.data)
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias.data, 0.0)
|
|
elif isinstance(m, (nn.BatchNorm3d, nn.BatchNorm2d, nn.BatchNorm1d)):
|
|
if m.affine:
|
|
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
|
nn.init.constant_(m.bias.data, 0.0)
|
|
|
|
def get_loader(self, data_cfg, train=True):
|
|
sampler_cfg = self.cfgs['trainer_cfg']['sampler'] if train else self.cfgs['evaluator_cfg']['sampler']
|
|
dataset = DataSet(data_cfg, train)
|
|
|
|
Sampler = get_attr_from([Samplers], sampler_cfg['type'])
|
|
vaild_args = get_valid_args(Sampler, sampler_cfg, free_keys=[
|
|
'sample_type', 'type'])
|
|
sampler = Sampler(dataset, **vaild_args)
|
|
|
|
loader = tordata.DataLoader(
|
|
dataset=dataset,
|
|
batch_sampler=sampler,
|
|
collate_fn=CollateFn(dataset.label_set, sampler_cfg),
|
|
num_workers=data_cfg['num_workers'])
|
|
return loader
|
|
|
|
def get_optimizer(self, optimizer_cfg):
|
|
self.msg_mgr.log_info(optimizer_cfg)
|
|
optimizer = get_attr_from([optim], optimizer_cfg['solver'])
|
|
valid_arg = get_valid_args(optimizer, optimizer_cfg, ['solver'])
|
|
optimizer = optimizer(
|
|
filter(lambda p: p.requires_grad, self.parameters()), **valid_arg)
|
|
return optimizer
|
|
|
|
def get_scheduler(self, scheduler_cfg):
|
|
self.msg_mgr.log_info(scheduler_cfg)
|
|
Scheduler = get_attr_from(
|
|
[optim.lr_scheduler], scheduler_cfg['scheduler'])
|
|
valid_arg = get_valid_args(Scheduler, scheduler_cfg, ['scheduler'])
|
|
scheduler = Scheduler(self.optimizer, **valid_arg)
|
|
return scheduler
|
|
|
|
def save_ckpt(self, iteration):
|
|
if torch.distributed.get_rank() == 0:
|
|
mkdir(osp.join(self.save_path, "checkpoints/"))
|
|
save_name = self.engine_cfg['save_name']
|
|
checkpoint = {
|
|
'model': self.state_dict(),
|
|
'optimizer': self.optimizer.state_dict(),
|
|
'scheduler': self.scheduler.state_dict(),
|
|
'iteration': iteration}
|
|
torch.save(checkpoint,
|
|
osp.join(self.save_path, 'checkpoints/{}-{:0>5}.pt'.format(save_name, iteration)))
|
|
|
|
def _load_ckpt(self, save_name):
|
|
load_ckpt_strict = self.engine_cfg['restore_ckpt_strict']
|
|
|
|
checkpoint = torch.load(save_name, map_location=torch.device(
|
|
"cuda", self.device))
|
|
model_state_dict = checkpoint['model']
|
|
|
|
if not load_ckpt_strict:
|
|
self.msg_mgr.log_info("-------- Restored Params List --------")
|
|
self.msg_mgr.log_info(sorted(set(model_state_dict.keys()).intersection(
|
|
set(self.state_dict().keys()))))
|
|
|
|
self.load_state_dict(model_state_dict, strict=load_ckpt_strict)
|
|
if self.training:
|
|
if not self.engine_cfg["optimizer_reset"] and 'optimizer' in checkpoint:
|
|
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
|
else:
|
|
self.msg_mgr.log_warning(
|
|
"Restore NO Optimizer from %s !!!" % save_name)
|
|
if not self.engine_cfg["scheduler_reset"] and 'scheduler' in checkpoint:
|
|
self.scheduler.load_state_dict(
|
|
checkpoint['scheduler'])
|
|
else:
|
|
self.msg_mgr.log_warning(
|
|
"Restore NO Scheduler from %s !!!" % save_name)
|
|
self.msg_mgr.log_info("Restore Parameters from %s !!!" % save_name)
|
|
|
|
def resume_ckpt(self, restore_hint):
|
|
if isinstance(restore_hint, int):
|
|
save_name = self.engine_cfg['save_name']
|
|
save_name = osp.join(
|
|
self.save_path, 'checkpoints/{}-{:0>5}.pt'.format(save_name, restore_hint))
|
|
self.iteration = restore_hint
|
|
elif isinstance(restore_hint, str):
|
|
save_name = restore_hint
|
|
self.iteration = 0
|
|
else:
|
|
raise ValueError(
|
|
"Error type for -Restore_Hint-, supported: int or string.")
|
|
self._load_ckpt(save_name)
|
|
|
|
def fix_BN(self):
|
|
for module in self.modules():
|
|
classname = module.__class__.__name__
|
|
if classname.find('BatchNorm') != -1:
|
|
module.eval()
|
|
|
|
def inputs_pretreament(self, inputs):
|
|
"""Conduct transforms on input data.
|
|
|
|
Args:
|
|
inputs: the input data.
|
|
Returns:
|
|
tuple: training data including inputs, labels, and some meta data.
|
|
"""
|
|
seqs_batch, labs_batch, typs_batch, vies_batch, seqL_batch = inputs
|
|
trf_cfgs = self.engine_cfg['transform']
|
|
seq_trfs = get_transform(trf_cfgs)
|
|
if len(seqs_batch) != len(seq_trfs):
|
|
raise ValueError(
|
|
"The number of types of input data and transform should be same. But got {} and {}".format(len(seqs_batch), len(seq_trfs)))
|
|
requires_grad = bool(self.training)
|
|
seqs = [np2var(np.asarray([trf(fra) for fra in seq]), requires_grad=requires_grad).float()
|
|
for trf, seq in zip(seq_trfs, seqs_batch)]
|
|
|
|
typs = typs_batch
|
|
vies = vies_batch
|
|
|
|
labs = list2var(labs_batch).long()
|
|
|
|
if seqL_batch is not None:
|
|
seqL_batch = np2var(seqL_batch).int()
|
|
seqL = seqL_batch
|
|
|
|
if seqL is not None:
|
|
seqL_sum = int(seqL.sum().data.cpu().numpy())
|
|
ipts = [_[:, :seqL_sum] for _ in seqs]
|
|
else:
|
|
ipts = seqs
|
|
del seqs
|
|
return ipts, labs, typs, vies, seqL
|
|
|
|
def train_step(self, loss_sum) -> bool:
|
|
"""Conduct loss_sum.backward(), self.optimizer.step() and self.scheduler.step().
|
|
|
|
Args:
|
|
loss_sum:The loss of the current batch.
|
|
Returns:
|
|
bool: True if the training is finished, False otherwise.
|
|
"""
|
|
|
|
self.optimizer.zero_grad()
|
|
if loss_sum <= 1e-9:
|
|
self.msg_mgr.log_warning(
|
|
"Find the loss sum less than 1e-9 but the training process will continue!")
|
|
|
|
if self.engine_cfg['enable_float16']:
|
|
self.Scaler.scale(loss_sum).backward()
|
|
self.Scaler.step(self.optimizer)
|
|
scale = self.Scaler.get_scale()
|
|
self.Scaler.update()
|
|
# Warning caused by optimizer skip when NaN
|
|
# https://discuss.pytorch.org/t/optimizer-step-before-lr-scheduler-step-error-using-gradscaler/92930/5
|
|
if scale != self.Scaler.get_scale():
|
|
self.msg_mgr.log_debug("Training step skip. Expected the former scale equals to the present, got {} and {}".format(
|
|
scale, self.Scaler.get_scale()))
|
|
return False
|
|
else:
|
|
loss_sum.backward()
|
|
self.optimizer.step()
|
|
|
|
self.iteration += 1
|
|
self.scheduler.step()
|
|
return True
|
|
|
|
def inference(self, rank):
|
|
"""Inference all the test data.
|
|
|
|
Args:
|
|
rank: the rank of the current process.Transform
|
|
Returns:
|
|
Odict: contains the inference results.
|
|
"""
|
|
total_size = len(self.test_loader)
|
|
if rank == 0:
|
|
pbar = tqdm(total=total_size, desc='Transforming')
|
|
else:
|
|
pbar = NoOp()
|
|
batch_size = self.test_loader.batch_sampler.batch_size
|
|
rest_size = total_size
|
|
info_dict = Odict()
|
|
for inputs in self.test_loader:
|
|
ipts = self.inputs_pretreament(inputs)
|
|
with autocast(enabled=self.engine_cfg['enable_float16']):
|
|
retval = self.forward(ipts)
|
|
inference_feat = retval['inference_feat']
|
|
for k, v in inference_feat.items():
|
|
inference_feat[k] = ddp_all_gather(v, requires_grad=False)
|
|
del retval
|
|
for k, v in inference_feat.items():
|
|
inference_feat[k] = ts2np(v)
|
|
info_dict.append(inference_feat)
|
|
rest_size -= batch_size
|
|
if rest_size >= 0:
|
|
update_size = batch_size
|
|
else:
|
|
update_size = total_size % batch_size
|
|
pbar.update(update_size)
|
|
pbar.close()
|
|
for k, v in info_dict.items():
|
|
v = np.concatenate(v)[:total_size]
|
|
info_dict[k] = v
|
|
return info_dict
|
|
|
|
@ staticmethod
|
|
def run_train(model):
|
|
"""Accept the instance object(model) here, and then run the train loop."""
|
|
for inputs in model.train_loader:
|
|
ipts = model.inputs_pretreament(inputs)
|
|
with autocast(enabled=model.engine_cfg['enable_float16']):
|
|
retval = model(ipts)
|
|
training_feat, visual_summary = retval['training_feat'], retval['visual_summary']
|
|
del retval
|
|
loss_sum, loss_info = model.loss_aggregator(training_feat)
|
|
ok = model.train_step(loss_sum)
|
|
if not ok:
|
|
continue
|
|
|
|
visual_summary.update(loss_info)
|
|
visual_summary['scalar/learning_rate'] = model.optimizer.param_groups[0]['lr']
|
|
|
|
model.msg_mgr.train_step(loss_info, visual_summary)
|
|
if model.iteration % model.engine_cfg['save_iter'] == 0:
|
|
# save the checkpoint
|
|
model.save_ckpt(model.iteration)
|
|
|
|
# run test if with_test = true
|
|
if model.engine_cfg['with_test']:
|
|
model.msg_mgr.log_info("Running test...")
|
|
model.eval()
|
|
result_dict = BaseModel.run_test(model)
|
|
model.train()
|
|
if model.cfgs['trainer_cfg']['fix_BN']:
|
|
model.fix_BN()
|
|
model.msg_mgr.write_to_tensorboard(result_dict)
|
|
model.msg_mgr.reset_time()
|
|
if model.iteration >= model.engine_cfg['total_iter']:
|
|
break
|
|
|
|
@ staticmethod
|
|
def run_test(model):
|
|
"""Accept the instance object(model) here, and then run the test loop."""
|
|
|
|
rank = torch.distributed.get_rank()
|
|
with torch.no_grad():
|
|
info_dict = model.inference(rank)
|
|
if rank == 0:
|
|
loader = model.test_loader
|
|
label_list = loader.dataset.label_list
|
|
types_list = loader.dataset.types_list
|
|
views_list = loader.dataset.views_list
|
|
|
|
info_dict.update({
|
|
'labels': label_list, 'types': types_list, 'views': views_list})
|
|
|
|
if 'eval_func' in model.cfgs["evaluator_cfg"].keys():
|
|
eval_func = model.cfgs['evaluator_cfg']["eval_func"]
|
|
else:
|
|
eval_func = 'identification'
|
|
eval_func = getattr(eval_functions, eval_func)
|
|
valid_args = get_valid_args(
|
|
eval_func, model.cfgs["evaluator_cfg"], ['metric'])
|
|
try:
|
|
dataset_name = model.cfgs['data_cfg']['test_dataset_name']
|
|
except:
|
|
dataset_name = model.cfgs['data_cfg']['dataset_name']
|
|
return eval_func(info_dict, dataset_name, **valid_args)
|