refactor get_backbone and a little change
This commit is contained in:
@@ -5,11 +5,9 @@ from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWra
|
||||
|
||||
|
||||
class Baseline(BaseModel):
|
||||
def __init__(self, cfgs, is_training):
|
||||
super().__init__(cfgs, is_training)
|
||||
|
||||
def build_network(self, model_cfg):
|
||||
self.Backbone = self.get_backbone(model_cfg)
|
||||
self.Backbone = self.get_backbone(model_cfg['backbone_cfg'])
|
||||
self.Backbone = SetBlockWrapper(self.Backbone)
|
||||
self.FCs = SeparateFCs(**model_cfg['SeparateFCs'])
|
||||
self.BNNecks = SeparateBNNecks(**model_cfg['SeparateBNNecks'])
|
||||
|
||||
@@ -88,7 +88,7 @@ class GaitPart(BaseModel):
|
||||
|
||||
def build_network(self, model_cfg):
|
||||
|
||||
self.Backbone = self.get_backbone(model_cfg)
|
||||
self.Backbone = self.get_backbone(model_cfg['backbone_cfg'])
|
||||
head_cfg = model_cfg['SeparateFCs']
|
||||
self.Head = SeparateFCs(**model_cfg['SeparateFCs'])
|
||||
self.Backbone = SetBlockWrapper(self.Backbone)
|
||||
|
||||
Reference in New Issue
Block a user