refactor get_backbone and a little change
This commit is contained in:
@@ -79,7 +79,7 @@ scheduler_cfg:
|
||||
scheduler: MultiStepLR
|
||||
trainer_cfg:
|
||||
enable_float16: true # half_percesion float for memory reduction and speedup
|
||||
fix_BN: true
|
||||
fix_BN: false
|
||||
log_iter: 100
|
||||
restore_ckpt_strict: true
|
||||
restore_hint: 0
|
||||
|
||||
@@ -4,6 +4,11 @@ import numpy as np
|
||||
from utils import is_list, is_dict, get_valid_args
|
||||
|
||||
|
||||
class NoOperation():
|
||||
def __call__(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class BaseSilTransform():
|
||||
def __init__(self, disvor=255.0, img_shape=None):
|
||||
self.disvor = disvor
|
||||
|
||||
@@ -171,29 +171,22 @@ class BaseModel(MetaModel, nn.Module):
|
||||
if cfgs['trainer_cfg']['fix_BN']:
|
||||
self.fix_BN()
|
||||
|
||||
def get_backbone(self, model_cfg):
|
||||
def get_backbone(self, backbone_cfg):
|
||||
"""Get the backbone of the model."""
|
||||
def _get_backbone(backbone_cfg):
|
||||
if is_dict(backbone_cfg):
|
||||
Backbone = get_attr_from([backbones], backbone_cfg['type'])
|
||||
valid_args = get_valid_args(Backbone, backbone_cfg, ['type'])
|
||||
return Backbone(**valid_args)
|
||||
if is_list(backbone_cfg):
|
||||
Backbone = nn.ModuleList([_get_backbone(cfg)
|
||||
Backbone = nn.ModuleList([self.get_backbone(cfg)
|
||||
for cfg in backbone_cfg])
|
||||
return Backbone
|
||||
raise ValueError(
|
||||
"Error type for -Backbone-Cfg-, supported: (A list of) dict.")
|
||||
|
||||
if 'backbone_cfg' in model_cfg.keys():
|
||||
backbone_cfg = model_cfg['backbone_cfg']
|
||||
Backbone = _get_backbone(backbone_cfg)
|
||||
else:
|
||||
Backbone = None
|
||||
return Backbone
|
||||
|
||||
def build_network(self, model_cfg):
|
||||
self.Backbone = self.get_backbone(model_cfg)
|
||||
if 'backbone_cfg' in model_cfg.keys():
|
||||
self.Backbone = self.get_backbone(model_cfg['backbone_cfg'])
|
||||
|
||||
def init_parameters(self):
|
||||
for m in self.modules():
|
||||
@@ -280,7 +273,6 @@ class BaseModel(MetaModel, nn.Module):
|
||||
self.msg_mgr.log_warning(
|
||||
"Restore NO Scheduler from %s !!!" % save_name)
|
||||
self.msg_mgr.log_info("Restore Parameters from %s !!!" % save_name)
|
||||
del checkpoint
|
||||
|
||||
def resume_ckpt(self, restore_hint):
|
||||
if isinstance(restore_hint, int):
|
||||
|
||||
@@ -5,11 +5,9 @@ from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWra
|
||||
|
||||
|
||||
class Baseline(BaseModel):
|
||||
def __init__(self, cfgs, is_training):
|
||||
super().__init__(cfgs, is_training)
|
||||
|
||||
def build_network(self, model_cfg):
|
||||
self.Backbone = self.get_backbone(model_cfg)
|
||||
self.Backbone = self.get_backbone(model_cfg['backbone_cfg'])
|
||||
self.Backbone = SetBlockWrapper(self.Backbone)
|
||||
self.FCs = SeparateFCs(**model_cfg['SeparateFCs'])
|
||||
self.BNNecks = SeparateBNNecks(**model_cfg['SeparateBNNecks'])
|
||||
|
||||
@@ -88,7 +88,7 @@ class GaitPart(BaseModel):
|
||||
|
||||
def build_network(self, model_cfg):
|
||||
|
||||
self.Backbone = self.get_backbone(model_cfg)
|
||||
self.Backbone = self.get_backbone(model_cfg['backbone_cfg'])
|
||||
head_cfg = model_cfg['SeparateFCs']
|
||||
self.Head = SeparateFCs(**model_cfg['SeparateFCs'])
|
||||
self.Backbone = SetBlockWrapper(self.Backbone)
|
||||
|
||||
Reference in New Issue
Block a user