refactor get_backbone and a little change
This commit is contained in:
@@ -79,7 +79,7 @@ scheduler_cfg:
|
|||||||
scheduler: MultiStepLR
|
scheduler: MultiStepLR
|
||||||
trainer_cfg:
|
trainer_cfg:
|
||||||
enable_float16: true # half_percesion float for memory reduction and speedup
|
enable_float16: true # half_percesion float for memory reduction and speedup
|
||||||
fix_BN: true
|
fix_BN: false
|
||||||
log_iter: 100
|
log_iter: 100
|
||||||
restore_ckpt_strict: true
|
restore_ckpt_strict: true
|
||||||
restore_hint: 0
|
restore_hint: 0
|
||||||
|
|||||||
@@ -4,6 +4,11 @@ import numpy as np
|
|||||||
from utils import is_list, is_dict, get_valid_args
|
from utils import is_list, is_dict, get_valid_args
|
||||||
|
|
||||||
|
|
||||||
|
class NoOperation():
|
||||||
|
def __call__(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class BaseSilTransform():
|
class BaseSilTransform():
|
||||||
def __init__(self, disvor=255.0, img_shape=None):
|
def __init__(self, disvor=255.0, img_shape=None):
|
||||||
self.disvor = disvor
|
self.disvor = disvor
|
||||||
|
|||||||
+13
-21
@@ -171,29 +171,22 @@ class BaseModel(MetaModel, nn.Module):
|
|||||||
if cfgs['trainer_cfg']['fix_BN']:
|
if cfgs['trainer_cfg']['fix_BN']:
|
||||||
self.fix_BN()
|
self.fix_BN()
|
||||||
|
|
||||||
def get_backbone(self, model_cfg):
|
def get_backbone(self, backbone_cfg):
|
||||||
"""Get the backbone of the model."""
|
"""Get the backbone of the model."""
|
||||||
def _get_backbone(backbone_cfg):
|
if is_dict(backbone_cfg):
|
||||||
if is_dict(backbone_cfg):
|
Backbone = get_attr_from([backbones], backbone_cfg['type'])
|
||||||
Backbone = get_attr_from([backbones], backbone_cfg['type'])
|
valid_args = get_valid_args(Backbone, backbone_cfg, ['type'])
|
||||||
valid_args = get_valid_args(Backbone, backbone_cfg, ['type'])
|
return Backbone(**valid_args)
|
||||||
return Backbone(**valid_args)
|
if is_list(backbone_cfg):
|
||||||
if is_list(backbone_cfg):
|
Backbone = nn.ModuleList([self.get_backbone(cfg)
|
||||||
Backbone = nn.ModuleList([_get_backbone(cfg)
|
for cfg in backbone_cfg])
|
||||||
for cfg in backbone_cfg])
|
return Backbone
|
||||||
return Backbone
|
raise ValueError(
|
||||||
raise ValueError(
|
"Error type for -Backbone-Cfg-, supported: (A list of) dict.")
|
||||||
"Error type for -Backbone-Cfg-, supported: (A list of) dict.")
|
|
||||||
|
|
||||||
if 'backbone_cfg' in model_cfg.keys():
|
|
||||||
backbone_cfg = model_cfg['backbone_cfg']
|
|
||||||
Backbone = _get_backbone(backbone_cfg)
|
|
||||||
else:
|
|
||||||
Backbone = None
|
|
||||||
return Backbone
|
|
||||||
|
|
||||||
def build_network(self, model_cfg):
|
def build_network(self, model_cfg):
|
||||||
self.Backbone = self.get_backbone(model_cfg)
|
if 'backbone_cfg' in model_cfg.keys():
|
||||||
|
self.Backbone = self.get_backbone(model_cfg['backbone_cfg'])
|
||||||
|
|
||||||
def init_parameters(self):
|
def init_parameters(self):
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
@@ -280,7 +273,6 @@ class BaseModel(MetaModel, nn.Module):
|
|||||||
self.msg_mgr.log_warning(
|
self.msg_mgr.log_warning(
|
||||||
"Restore NO Scheduler from %s !!!" % save_name)
|
"Restore NO Scheduler from %s !!!" % save_name)
|
||||||
self.msg_mgr.log_info("Restore Parameters from %s !!!" % save_name)
|
self.msg_mgr.log_info("Restore Parameters from %s !!!" % save_name)
|
||||||
del checkpoint
|
|
||||||
|
|
||||||
def resume_ckpt(self, restore_hint):
|
def resume_ckpt(self, restore_hint):
|
||||||
if isinstance(restore_hint, int):
|
if isinstance(restore_hint, int):
|
||||||
|
|||||||
@@ -5,11 +5,9 @@ from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWra
|
|||||||
|
|
||||||
|
|
||||||
class Baseline(BaseModel):
|
class Baseline(BaseModel):
|
||||||
def __init__(self, cfgs, is_training):
|
|
||||||
super().__init__(cfgs, is_training)
|
|
||||||
|
|
||||||
def build_network(self, model_cfg):
|
def build_network(self, model_cfg):
|
||||||
self.Backbone = self.get_backbone(model_cfg)
|
self.Backbone = self.get_backbone(model_cfg['backbone_cfg'])
|
||||||
self.Backbone = SetBlockWrapper(self.Backbone)
|
self.Backbone = SetBlockWrapper(self.Backbone)
|
||||||
self.FCs = SeparateFCs(**model_cfg['SeparateFCs'])
|
self.FCs = SeparateFCs(**model_cfg['SeparateFCs'])
|
||||||
self.BNNecks = SeparateBNNecks(**model_cfg['SeparateBNNecks'])
|
self.BNNecks = SeparateBNNecks(**model_cfg['SeparateBNNecks'])
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ class GaitPart(BaseModel):
|
|||||||
|
|
||||||
def build_network(self, model_cfg):
|
def build_network(self, model_cfg):
|
||||||
|
|
||||||
self.Backbone = self.get_backbone(model_cfg)
|
self.Backbone = self.get_backbone(model_cfg['backbone_cfg'])
|
||||||
head_cfg = model_cfg['SeparateFCs']
|
head_cfg = model_cfg['SeparateFCs']
|
||||||
self.Head = SeparateFCs(**model_cfg['SeparateFCs'])
|
self.Head = SeparateFCs(**model_cfg['SeparateFCs'])
|
||||||
self.Backbone = SetBlockWrapper(self.Backbone)
|
self.Backbone = SetBlockWrapper(self.Backbone)
|
||||||
|
|||||||
Reference in New Issue
Block a user