1.0.0 official release (#18)

* fix bug in fix_BN

* gaitgl OUMVLP support.

* update ./doc/3.advance_usage.md Cross-Dataset Evalution & Data Agumentation

* update config

* update docs.3

* update docs.3

* add loss doc and gather input decorator

* refine the create model doc

* support rearrange directory of unzipped OUMVLP

* fix some bugs in loss_aggregator.py

* refine docs and little fix

* add oumvlp pretreatment description

* pretreatment dataset fix oumvlp description

* add gaitgl oumvlp result

* assert gaitgl input size

* add pipeline

* update the readme.

* update pipeline and readme

* Corrigendum.

* add logo and remove path

* update new logo

* Update README.md

* modify logo size

Co-authored-by: 12131100 <12131100@mail.sustech.edu.cn>
Co-authored-by: noahshen98 <77523610+noahshen98@users.noreply.github.com>
Co-authored-by: Noah <595311942@qq.com>
This commit is contained in:
Junhao Liang
2021-12-08 20:05:28 +08:00
committed by GitHub
parent 6e71c7ac34
commit bb6cd5149a
39 changed files with 11401 additions and 230 deletions
+84 -30
View File
@@ -1,3 +1,14 @@
"""The base model definition.
This module defines the abstract meta model class and base model class. In the base model,
we define the basic model functions, like get_loader, build_network, and run_train, etc.
The api of the base model is run_train and run_test, they are used in `lib/main.py`.
Typical usage:
BaseModel.run_train(model)
BaseModel.run_test(model)
"""
import torch
import numpy as np
import os.path as osp
@@ -13,7 +24,6 @@ from abc import abstractmethod
from . import backbones
from .loss_aggregator import LossAggregator
from modeling.modules import fix_BN
from data.transform import get_transform
from data.collate_fn import CollateFn
from data.dataset import DataSet
@@ -28,80 +38,97 @@ __all__ = ['BaseModel']
class MetaModel(metaclass=ABCMeta):
"""The necessary functions for the base model.
This class defines the necessary functions for the base model, in the base model, we have implemented them.
"""
@abstractmethod
def get_loader(self, data_cfg):
'''
Build your data Loader here.
Inputs: data_cfg, dict
Return: Loader
'''
"""Based on the given data_cfg, we get the data loader."""
raise NotImplementedError
@abstractmethod
def build_network(self, model_cfg):
'''
Build your Model here.
Inputs: model_cfg, dict
Return: Network, nn.Module(s)
'''
"""Build your network here."""
raise NotImplementedError
@abstractmethod
def init_parameters(self):
"""Initialize the parameters of your network."""
raise NotImplementedError
@abstractmethod
def get_optimizer(self, optimizer_cfg):
'''
Build your Optimizer here.
Inputs: optimizer_cfg, dict
Return: Optimizer, a optimizer object
'''
"""Based on the given optimizer_cfg, we get the optimizer."""
raise NotImplementedError
@abstractmethod
def get_scheduler(self, scheduler_cfg):
'''
Build your Scheduler.
Inputs: scheduler_cfg, dict
Optimizer, your optimizer
Return: Scheduler, a scheduler object
'''
"""Based on the given scheduler_cfg, we get the scheduler."""
raise NotImplementedError
@abstractmethod
def save_ckpt(self, iteration):
"""Save the checkpoint, including model parameter, optimizer and scheduler."""
raise NotImplementedError
@abstractmethod
def resume_ckpt(self, restore_hint):
"""Resume the model from the checkpoint, including model parameter, optimizer and scheduler."""
raise NotImplementedError
@abstractmethod
def inputs_pretreament(self, inputs):
"""Transform the input data based on transform setting."""
raise NotImplementedError
@abstractmethod
def train_step(self, loss_num) -> bool:
"""Do one training step."""
raise NotImplementedError
@abstractmethod
def inference(self):
"""Do inference (calculate features.)."""
raise NotImplementedError
@abstractmethod
def run_train(model):
"""Run a whole train schedule."""
raise NotImplementedError
@abstractmethod
def run_test(model):
"""Run a whole test schedule."""
raise NotImplementedError
class BaseModel(MetaModel, nn.Module):
"""Base model.
This class inherites the MetaModel class, and implements the basic model functions, like get_loader, build_network, etc.
Attributes:
msg_mgr: the massage manager.
cfgs: the configs.
iteration: the current iteration of the model.
engine_cfg: the configs of the engine(train or test).
save_path: the path to save the checkpoints.
"""
def __init__(self, cfgs, training):
"""Initialize the base model.
Complete the model initialization, including the data loader, the network, the optimizer, the scheduler, the loss.
Args:
cfgs:
All of the configs.
training:
Whether the model is in training mode.
"""
super(BaseModel, self).__init__()
self.msg_mgr = get_msg_mgr()
self.cfgs = cfgs
@@ -132,8 +159,6 @@ class BaseModel(MetaModel, nn.Module):
"cuda", self.device))
if training:
if cfgs['trainer_cfg']['fix_BN']:
fix_BN(self)
self.loss_aggregator = LossAggregator(cfgs['loss_cfg'])
self.optimizer = self.get_optimizer(self.cfgs['optimizer_cfg'])
self.scheduler = self.get_scheduler(cfgs['scheduler_cfg'])
@@ -142,7 +167,12 @@ class BaseModel(MetaModel, nn.Module):
if restore_hint != 0:
self.resume_ckpt(restore_hint)
if training:
if cfgs['trainer_cfg']['fix_BN']:
self.fix_BN()
def get_backbone(self, model_cfg):
"""Get the backbone of the model."""
def _get_backbone(backbone_cfg):
if is_dict(backbone_cfg):
Backbone = get_attr_from([backbones], backbone_cfg['type'])
@@ -266,7 +296,20 @@ class BaseModel(MetaModel, nn.Module):
"Error type for -Restore_Hint-, supported: int or string.")
self._load_ckpt(save_name)
def fix_BN(self):
for module in self.modules():
classname = module.__class__.__name__
if classname.find('BatchNorm') != -1:
module.eval()
def inputs_pretreament(self, inputs):
"""Conduct transforms on input data.
Args:
inputs: the input data.
Returns:
tuple: training data including inputs, labels, and some meta data.
"""
seqs_batch, labs_batch, typs_batch, vies_batch, seqL_batch = inputs
trf_cfgs = self.engine_cfg['transform']
seq_trfs = get_transform(trf_cfgs)
@@ -293,9 +336,13 @@ class BaseModel(MetaModel, nn.Module):
return ipts, labs, typs, vies, seqL
def train_step(self, loss_sum) -> bool:
'''
Conduct loss_sum.backward(), self.optimizer.step() and self.scheduler.step().
'''
"""Conduct loss_sum.backward(), self.optimizer.step() and self.scheduler.step().
Args:
loss_sum:The loss of the current batch.
Returns:
bool: True if the training is finished, False otherwise.
"""
self.optimizer.zero_grad()
if loss_sum <= 1e-9:
@@ -322,6 +369,13 @@ class BaseModel(MetaModel, nn.Module):
return True
def inference(self, rank):
"""Inference all the test data.
Args:
rank: the rank of the current process.Transform
Returns:
Odict: contains the inference results.
"""
total_size = len(self.test_loader)
if rank == 0:
pbar = tqdm(total=total_size, desc='Transforming')
@@ -355,9 +409,7 @@ class BaseModel(MetaModel, nn.Module):
@ staticmethod
def run_train(model):
'''
Accept the instance object(model) here, and then run the train loop handler.
'''
"""Accept the instance object(model) here, and then run the train loop."""
for inputs in model.train_loader:
ipts = model.inputs_pretreament(inputs)
with autocast(enabled=model.engine_cfg['enable_float16']):
@@ -390,6 +442,8 @@ class BaseModel(MetaModel, nn.Module):
@ staticmethod
def run_test(model):
"""Accept the instance object(model) here, and then run the test loop."""
rank = torch.distributed.get_rank()
with torch.no_grad():
info_dict = model.inference(rank)