refactor get_backbone and a little change
This commit is contained in:
+13
-21
@@ -171,29 +171,22 @@ class BaseModel(MetaModel, nn.Module):
|
||||
if cfgs['trainer_cfg']['fix_BN']:
|
||||
self.fix_BN()
|
||||
|
||||
def get_backbone(self, model_cfg):
|
||||
def get_backbone(self, backbone_cfg):
|
||||
"""Get the backbone of the model."""
|
||||
def _get_backbone(backbone_cfg):
|
||||
if is_dict(backbone_cfg):
|
||||
Backbone = get_attr_from([backbones], backbone_cfg['type'])
|
||||
valid_args = get_valid_args(Backbone, backbone_cfg, ['type'])
|
||||
return Backbone(**valid_args)
|
||||
if is_list(backbone_cfg):
|
||||
Backbone = nn.ModuleList([_get_backbone(cfg)
|
||||
for cfg in backbone_cfg])
|
||||
return Backbone
|
||||
raise ValueError(
|
||||
"Error type for -Backbone-Cfg-, supported: (A list of) dict.")
|
||||
|
||||
if 'backbone_cfg' in model_cfg.keys():
|
||||
backbone_cfg = model_cfg['backbone_cfg']
|
||||
Backbone = _get_backbone(backbone_cfg)
|
||||
else:
|
||||
Backbone = None
|
||||
return Backbone
|
||||
if is_dict(backbone_cfg):
|
||||
Backbone = get_attr_from([backbones], backbone_cfg['type'])
|
||||
valid_args = get_valid_args(Backbone, backbone_cfg, ['type'])
|
||||
return Backbone(**valid_args)
|
||||
if is_list(backbone_cfg):
|
||||
Backbone = nn.ModuleList([self.get_backbone(cfg)
|
||||
for cfg in backbone_cfg])
|
||||
return Backbone
|
||||
raise ValueError(
|
||||
"Error type for -Backbone-Cfg-, supported: (A list of) dict.")
|
||||
|
||||
def build_network(self, model_cfg):
|
||||
self.Backbone = self.get_backbone(model_cfg)
|
||||
if 'backbone_cfg' in model_cfg.keys():
|
||||
self.Backbone = self.get_backbone(model_cfg['backbone_cfg'])
|
||||
|
||||
def init_parameters(self):
|
||||
for m in self.modules():
|
||||
@@ -280,7 +273,6 @@ class BaseModel(MetaModel, nn.Module):
|
||||
self.msg_mgr.log_warning(
|
||||
"Restore NO Scheduler from %s !!!" % save_name)
|
||||
self.msg_mgr.log_info("Restore Parameters from %s !!!" % save_name)
|
||||
del checkpoint
|
||||
|
||||
def resume_ckpt(self, restore_hint):
|
||||
if isinstance(restore_hint, int):
|
||||
|
||||
Reference in New Issue
Block a user