Fix the bug of fix_bn
This commit is contained in:
@@ -44,6 +44,8 @@ def run_model(cfgs, training):
|
||||
model = Model(cfgs, training)
|
||||
if training and cfgs['trainer_cfg']['sync_BN']:
|
||||
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
if cfgs['trainer_cfg']['fix_BN']:
|
||||
model.fix_BN()
|
||||
model = get_ddp_module(model)
|
||||
msg_mgr.log_info(params_count(model))
|
||||
msg_mgr.log_info("Model Initialization Finished!")
|
||||
|
||||
@@ -167,10 +167,6 @@ class BaseModel(MetaModel, nn.Module):
|
||||
if restore_hint != 0:
|
||||
self.resume_ckpt(restore_hint)
|
||||
|
||||
if training:
|
||||
if cfgs['trainer_cfg']['fix_BN']:
|
||||
self.fix_BN()
|
||||
|
||||
def get_backbone(self, backbone_cfg):
|
||||
"""Get the backbone of the model."""
|
||||
if is_dict(backbone_cfg):
|
||||
@@ -427,6 +423,8 @@ class BaseModel(MetaModel, nn.Module):
|
||||
model.eval()
|
||||
result_dict = BaseModel.run_test(model)
|
||||
model.train()
|
||||
if model.cfgs['trainer_cfg']['fix_BN']:
|
||||
model.fix_BN()
|
||||
model.msg_mgr.write_to_tensorboard(result_dict)
|
||||
model.msg_mgr.reset_time()
|
||||
if model.iteration >= model.engine_cfg['total_iter']:
|
||||
|
||||
Reference in New Issue
Block a user