From 51bb8b72b4ccd4d4ed408338780b0e16b5deacfc Mon Sep 17 00:00:00 2001 From: darkliang <11710911@mail.sustech.edu.cn> Date: Wed, 22 Dec 2021 22:16:56 +0800 Subject: [PATCH] refactor get_backbone and a little change --- config/baseline.yaml | 2 +- lib/data/transform.py | 5 +++++ lib/modeling/base_model.py | 34 +++++++++++++-------------------- lib/modeling/models/baseline.py | 4 +--- lib/modeling/models/gaitpart.py | 2 +- 5 files changed, 21 insertions(+), 26 deletions(-) diff --git a/config/baseline.yaml b/config/baseline.yaml index 7254d16..ec2c9c8 100644 --- a/config/baseline.yaml +++ b/config/baseline.yaml @@ -79,7 +79,7 @@ scheduler_cfg: scheduler: MultiStepLR trainer_cfg: enable_float16: true # half_percesion float for memory reduction and speedup - fix_BN: true + fix_BN: false log_iter: 100 restore_ckpt_strict: true restore_hint: 0 diff --git a/lib/data/transform.py b/lib/data/transform.py index b65df68..3655123 100644 --- a/lib/data/transform.py +++ b/lib/data/transform.py @@ -4,6 +4,11 @@ import numpy as np from utils import is_list, is_dict, get_valid_args +class NoOperation(): + def __call__(self, x): + return x + + class BaseSilTransform(): def __init__(self, disvor=255.0, img_shape=None): self.disvor = disvor diff --git a/lib/modeling/base_model.py b/lib/modeling/base_model.py index 7b7a0df..d5082a9 100644 --- a/lib/modeling/base_model.py +++ b/lib/modeling/base_model.py @@ -171,29 +171,22 @@ class BaseModel(MetaModel, nn.Module): if cfgs['trainer_cfg']['fix_BN']: self.fix_BN() - def get_backbone(self, model_cfg): + def get_backbone(self, backbone_cfg): """Get the backbone of the model.""" - def _get_backbone(backbone_cfg): - if is_dict(backbone_cfg): - Backbone = get_attr_from([backbones], backbone_cfg['type']) - valid_args = get_valid_args(Backbone, backbone_cfg, ['type']) - return Backbone(**valid_args) - if is_list(backbone_cfg): - Backbone = nn.ModuleList([_get_backbone(cfg) - for cfg in backbone_cfg]) - return Backbone - raise ValueError( - "Error type for -Backbone-Cfg-, supported: (A list of) dict.") - - if 'backbone_cfg' in model_cfg.keys(): - backbone_cfg = model_cfg['backbone_cfg'] - Backbone = _get_backbone(backbone_cfg) - else: - Backbone = None - return Backbone + if is_dict(backbone_cfg): + Backbone = get_attr_from([backbones], backbone_cfg['type']) + valid_args = get_valid_args(Backbone, backbone_cfg, ['type']) + return Backbone(**valid_args) + if is_list(backbone_cfg): + Backbone = nn.ModuleList([self.get_backbone(cfg) + for cfg in backbone_cfg]) + return Backbone + raise ValueError( + "Error type for -Backbone-Cfg-, supported: (A list of) dict.") def build_network(self, model_cfg): - self.Backbone = self.get_backbone(model_cfg) + if 'backbone_cfg' in model_cfg.keys(): + self.Backbone = self.get_backbone(model_cfg['backbone_cfg']) def init_parameters(self): for m in self.modules(): @@ -280,7 +273,6 @@ class BaseModel(MetaModel, nn.Module): self.msg_mgr.log_warning( "Restore NO Scheduler from %s !!!" % save_name) self.msg_mgr.log_info("Restore Parameters from %s !!!" % save_name) - del checkpoint def resume_ckpt(self, restore_hint): if isinstance(restore_hint, int): diff --git a/lib/modeling/models/baseline.py b/lib/modeling/models/baseline.py index 2e1f507..febcfeb 100644 --- a/lib/modeling/models/baseline.py +++ b/lib/modeling/models/baseline.py @@ -5,11 +5,9 @@ from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWra class Baseline(BaseModel): - def __init__(self, cfgs, is_training): - super().__init__(cfgs, is_training) def build_network(self, model_cfg): - self.Backbone = self.get_backbone(model_cfg) + self.Backbone = self.get_backbone(model_cfg['backbone_cfg']) self.Backbone = SetBlockWrapper(self.Backbone) self.FCs = SeparateFCs(**model_cfg['SeparateFCs']) self.BNNecks = SeparateBNNecks(**model_cfg['SeparateBNNecks']) diff --git a/lib/modeling/models/gaitpart.py b/lib/modeling/models/gaitpart.py index 984afbe..242dcaa 100644 --- a/lib/modeling/models/gaitpart.py +++ b/lib/modeling/models/gaitpart.py @@ -88,7 +88,7 @@ class GaitPart(BaseModel): def build_network(self, model_cfg): - self.Backbone = self.get_backbone(model_cfg) + self.Backbone = self.get_backbone(model_cfg['backbone_cfg']) head_cfg = model_cfg['SeparateFCs'] self.Head = SeparateFCs(**model_cfg['SeparateFCs']) self.Backbone = SetBlockWrapper(self.Backbone)