refactor get_backbone and a little change

This commit is contained in:
darkliang
2021-12-22 22:16:56 +08:00
parent b74d6f7959
commit 51bb8b72b4
5 changed files with 21 additions and 26 deletions
+13 -21
View File
@@ -171,29 +171,22 @@ class BaseModel(MetaModel, nn.Module):
if cfgs['trainer_cfg']['fix_BN']:
self.fix_BN()
def get_backbone(self, model_cfg):
def get_backbone(self, backbone_cfg):
"""Get the backbone of the model."""
def _get_backbone(backbone_cfg):
if is_dict(backbone_cfg):
Backbone = get_attr_from([backbones], backbone_cfg['type'])
valid_args = get_valid_args(Backbone, backbone_cfg, ['type'])
return Backbone(**valid_args)
if is_list(backbone_cfg):
Backbone = nn.ModuleList([_get_backbone(cfg)
for cfg in backbone_cfg])
return Backbone
raise ValueError(
"Error type for -Backbone-Cfg-, supported: (A list of) dict.")
if 'backbone_cfg' in model_cfg.keys():
backbone_cfg = model_cfg['backbone_cfg']
Backbone = _get_backbone(backbone_cfg)
else:
Backbone = None
return Backbone
if is_dict(backbone_cfg):
Backbone = get_attr_from([backbones], backbone_cfg['type'])
valid_args = get_valid_args(Backbone, backbone_cfg, ['type'])
return Backbone(**valid_args)
if is_list(backbone_cfg):
Backbone = nn.ModuleList([self.get_backbone(cfg)
for cfg in backbone_cfg])
return Backbone
raise ValueError(
"Error type for -Backbone-Cfg-, supported: (A list of) dict.")
def build_network(self, model_cfg):
self.Backbone = self.get_backbone(model_cfg)
if 'backbone_cfg' in model_cfg.keys():
self.Backbone = self.get_backbone(model_cfg['backbone_cfg'])
def init_parameters(self):
for m in self.modules():
@@ -280,7 +273,6 @@ class BaseModel(MetaModel, nn.Module):
self.msg_mgr.log_warning(
"Restore NO Scheduler from %s !!!" % save_name)
self.msg_mgr.log_info("Restore Parameters from %s !!!" % save_name)
del checkpoint
def resume_ckpt(self, restore_hint):
if isinstance(restore_hint, int):