1.0.0 official release (#18)
* fix bug in fix_BN * gaitgl OUMVLP support. * update ./doc/3.advance_usage.md Cross-Dataset Evalution & Data Agumentation * update config * update docs.3 * update docs.3 * add loss doc and gather input decorator * refine the create model doc * support rearrange directory of unzipped OUMVLP * fix some bugs in loss_aggregator.py * refine docs and little fix * add oumvlp pretreatment description * pretreatment dataset fix oumvlp description * add gaitgl oumvlp result * assert gaitgl input size * add pipeline * update the readme. * update pipeline and readme * Corrigendum. * add logo and remove path * update new logo * Update README.md * modify logo size Co-authored-by: 12131100 <12131100@mail.sustech.edu.cn> Co-authored-by: noahshen98 <77523610+noahshen98@users.noreply.github.com> Co-authored-by: Noah <595311942@qq.com>
This commit is contained in:
+84
-30
@@ -1,3 +1,14 @@
|
||||
"""The base model definition.
|
||||
|
||||
This module defines the abstract meta model class and base model class. In the base model,
|
||||
we define the basic model functions, like get_loader, build_network, and run_train, etc.
|
||||
The api of the base model is run_train and run_test, they are used in `lib/main.py`.
|
||||
|
||||
Typical usage:
|
||||
|
||||
BaseModel.run_train(model)
|
||||
BaseModel.run_test(model)
|
||||
"""
|
||||
import torch
|
||||
import numpy as np
|
||||
import os.path as osp
|
||||
@@ -13,7 +24,6 @@ from abc import abstractmethod
|
||||
|
||||
from . import backbones
|
||||
from .loss_aggregator import LossAggregator
|
||||
from modeling.modules import fix_BN
|
||||
from data.transform import get_transform
|
||||
from data.collate_fn import CollateFn
|
||||
from data.dataset import DataSet
|
||||
@@ -28,80 +38,97 @@ __all__ = ['BaseModel']
|
||||
|
||||
|
||||
class MetaModel(metaclass=ABCMeta):
|
||||
"""The necessary functions for the base model.
|
||||
|
||||
This class defines the necessary functions for the base model, in the base model, we have implemented them.
|
||||
"""
|
||||
@abstractmethod
|
||||
def get_loader(self, data_cfg):
|
||||
'''
|
||||
Build your data Loader here.
|
||||
Inputs: data_cfg, dict
|
||||
Return: Loader
|
||||
'''
|
||||
"""Based on the given data_cfg, we get the data loader."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def build_network(self, model_cfg):
|
||||
'''
|
||||
Build your Model here.
|
||||
Inputs: model_cfg, dict
|
||||
Return: Network, nn.Module(s)
|
||||
'''
|
||||
"""Build your network here."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def init_parameters(self):
|
||||
"""Initialize the parameters of your network."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_optimizer(self, optimizer_cfg):
|
||||
'''
|
||||
Build your Optimizer here.
|
||||
Inputs: optimizer_cfg, dict
|
||||
Return: Optimizer, a optimizer object
|
||||
'''
|
||||
"""Based on the given optimizer_cfg, we get the optimizer."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_scheduler(self, scheduler_cfg):
|
||||
'''
|
||||
Build your Scheduler.
|
||||
Inputs: scheduler_cfg, dict
|
||||
Optimizer, your optimizer
|
||||
Return: Scheduler, a scheduler object
|
||||
'''
|
||||
"""Based on the given scheduler_cfg, we get the scheduler."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def save_ckpt(self, iteration):
|
||||
"""Save the checkpoint, including model parameter, optimizer and scheduler."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def resume_ckpt(self, restore_hint):
|
||||
"""Resume the model from the checkpoint, including model parameter, optimizer and scheduler."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def inputs_pretreament(self, inputs):
|
||||
"""Transform the input data based on transform setting."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def train_step(self, loss_num) -> bool:
|
||||
"""Do one training step."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def inference(self):
|
||||
"""Do inference (calculate features.)."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def run_train(model):
|
||||
"""Run a whole train schedule."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def run_test(model):
|
||||
"""Run a whole test schedule."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BaseModel(MetaModel, nn.Module):
|
||||
"""Base model.
|
||||
|
||||
This class inherites the MetaModel class, and implements the basic model functions, like get_loader, build_network, etc.
|
||||
|
||||
Attributes:
|
||||
msg_mgr: the massage manager.
|
||||
cfgs: the configs.
|
||||
iteration: the current iteration of the model.
|
||||
engine_cfg: the configs of the engine(train or test).
|
||||
save_path: the path to save the checkpoints.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, cfgs, training):
|
||||
"""Initialize the base model.
|
||||
|
||||
Complete the model initialization, including the data loader, the network, the optimizer, the scheduler, the loss.
|
||||
|
||||
Args:
|
||||
cfgs:
|
||||
All of the configs.
|
||||
training:
|
||||
Whether the model is in training mode.
|
||||
"""
|
||||
|
||||
super(BaseModel, self).__init__()
|
||||
self.msg_mgr = get_msg_mgr()
|
||||
self.cfgs = cfgs
|
||||
@@ -132,8 +159,6 @@ class BaseModel(MetaModel, nn.Module):
|
||||
"cuda", self.device))
|
||||
|
||||
if training:
|
||||
if cfgs['trainer_cfg']['fix_BN']:
|
||||
fix_BN(self)
|
||||
self.loss_aggregator = LossAggregator(cfgs['loss_cfg'])
|
||||
self.optimizer = self.get_optimizer(self.cfgs['optimizer_cfg'])
|
||||
self.scheduler = self.get_scheduler(cfgs['scheduler_cfg'])
|
||||
@@ -142,7 +167,12 @@ class BaseModel(MetaModel, nn.Module):
|
||||
if restore_hint != 0:
|
||||
self.resume_ckpt(restore_hint)
|
||||
|
||||
if training:
|
||||
if cfgs['trainer_cfg']['fix_BN']:
|
||||
self.fix_BN()
|
||||
|
||||
def get_backbone(self, model_cfg):
|
||||
"""Get the backbone of the model."""
|
||||
def _get_backbone(backbone_cfg):
|
||||
if is_dict(backbone_cfg):
|
||||
Backbone = get_attr_from([backbones], backbone_cfg['type'])
|
||||
@@ -266,7 +296,20 @@ class BaseModel(MetaModel, nn.Module):
|
||||
"Error type for -Restore_Hint-, supported: int or string.")
|
||||
self._load_ckpt(save_name)
|
||||
|
||||
def fix_BN(self):
|
||||
for module in self.modules():
|
||||
classname = module.__class__.__name__
|
||||
if classname.find('BatchNorm') != -1:
|
||||
module.eval()
|
||||
|
||||
def inputs_pretreament(self, inputs):
|
||||
"""Conduct transforms on input data.
|
||||
|
||||
Args:
|
||||
inputs: the input data.
|
||||
Returns:
|
||||
tuple: training data including inputs, labels, and some meta data.
|
||||
"""
|
||||
seqs_batch, labs_batch, typs_batch, vies_batch, seqL_batch = inputs
|
||||
trf_cfgs = self.engine_cfg['transform']
|
||||
seq_trfs = get_transform(trf_cfgs)
|
||||
@@ -293,9 +336,13 @@ class BaseModel(MetaModel, nn.Module):
|
||||
return ipts, labs, typs, vies, seqL
|
||||
|
||||
def train_step(self, loss_sum) -> bool:
|
||||
'''
|
||||
Conduct loss_sum.backward(), self.optimizer.step() and self.scheduler.step().
|
||||
'''
|
||||
"""Conduct loss_sum.backward(), self.optimizer.step() and self.scheduler.step().
|
||||
|
||||
Args:
|
||||
loss_sum:The loss of the current batch.
|
||||
Returns:
|
||||
bool: True if the training is finished, False otherwise.
|
||||
"""
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
if loss_sum <= 1e-9:
|
||||
@@ -322,6 +369,13 @@ class BaseModel(MetaModel, nn.Module):
|
||||
return True
|
||||
|
||||
def inference(self, rank):
|
||||
"""Inference all the test data.
|
||||
|
||||
Args:
|
||||
rank: the rank of the current process.Transform
|
||||
Returns:
|
||||
Odict: contains the inference results.
|
||||
"""
|
||||
total_size = len(self.test_loader)
|
||||
if rank == 0:
|
||||
pbar = tqdm(total=total_size, desc='Transforming')
|
||||
@@ -355,9 +409,7 @@ class BaseModel(MetaModel, nn.Module):
|
||||
|
||||
@ staticmethod
|
||||
def run_train(model):
|
||||
'''
|
||||
Accept the instance object(model) here, and then run the train loop handler.
|
||||
'''
|
||||
"""Accept the instance object(model) here, and then run the train loop."""
|
||||
for inputs in model.train_loader:
|
||||
ipts = model.inputs_pretreament(inputs)
|
||||
with autocast(enabled=model.engine_cfg['enable_float16']):
|
||||
@@ -390,6 +442,8 @@ class BaseModel(MetaModel, nn.Module):
|
||||
|
||||
@ staticmethod
|
||||
def run_test(model):
|
||||
"""Accept the instance object(model) here, and then run the test loop."""
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
with torch.no_grad():
|
||||
info_dict = model.inference(rank)
|
||||
|
||||
Reference in New Issue
Block a user