refactor get_backbone and a little change

This commit is contained in:
darkliang
2021-12-22 22:16:56 +08:00
parent b74d6f7959
commit 51bb8b72b4
5 changed files with 21 additions and 26 deletions
+1 -3
View File
@@ -5,11 +5,9 @@ from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWra
class Baseline(BaseModel):
def __init__(self, cfgs, is_training):
super().__init__(cfgs, is_training)
def build_network(self, model_cfg):
self.Backbone = self.get_backbone(model_cfg)
self.Backbone = self.get_backbone(model_cfg['backbone_cfg'])
self.Backbone = SetBlockWrapper(self.Backbone)
self.FCs = SeparateFCs(**model_cfg['SeparateFCs'])
self.BNNecks = SeparateBNNecks(**model_cfg['SeparateBNNecks'])
+1 -1
View File
@@ -88,7 +88,7 @@ class GaitPart(BaseModel):
def build_network(self, model_cfg):
self.Backbone = self.get_backbone(model_cfg)
self.Backbone = self.get_backbone(model_cfg['backbone_cfg'])
head_cfg = model_cfg['SeparateFCs']
self.Head = SeparateFCs(**model_cfg['SeparateFCs'])
self.Backbone = SetBlockWrapper(self.Backbone)