Fix the bug of fix_bn

This commit is contained in:
wj1tr0y
2022-04-16 19:29:06 +08:00
committed by Junhao Liang
parent a4ead0b40d
commit ff398acbc7
2 changed files with 4 additions and 4 deletions
+2
View File
@@ -44,6 +44,8 @@ def run_model(cfgs, training):
model = Model(cfgs, training)
if training and cfgs['trainer_cfg']['sync_BN']:
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
if cfgs['trainer_cfg']['fix_BN']:
model.fix_BN()
model = get_ddp_module(model)
msg_mgr.log_info(params_count(model))
msg_mgr.log_info("Model Initialization Finished!")
+2 -4
View File
@@ -167,10 +167,6 @@ class BaseModel(MetaModel, nn.Module):
if restore_hint != 0:
self.resume_ckpt(restore_hint)
if training:
if cfgs['trainer_cfg']['fix_BN']:
self.fix_BN()
def get_backbone(self, backbone_cfg):
"""Get the backbone of the model."""
if is_dict(backbone_cfg):
@@ -427,6 +423,8 @@ class BaseModel(MetaModel, nn.Module):
model.eval()
result_dict = BaseModel.run_test(model)
model.train()
if model.cfgs['trainer_cfg']['fix_BN']:
model.fix_BN()
model.msg_mgr.write_to_tensorboard(result_dict)
model.msg_mgr.reset_time()
if model.iteration >= model.engine_cfg['total_iter']: