Fix the bug of fix_bn
This commit is contained in:
@@ -44,6 +44,8 @@ def run_model(cfgs, training):
|
|||||||
model = Model(cfgs, training)
|
model = Model(cfgs, training)
|
||||||
if training and cfgs['trainer_cfg']['sync_BN']:
|
if training and cfgs['trainer_cfg']['sync_BN']:
|
||||||
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||||
|
if cfgs['trainer_cfg']['fix_BN']:
|
||||||
|
model.fix_BN()
|
||||||
model = get_ddp_module(model)
|
model = get_ddp_module(model)
|
||||||
msg_mgr.log_info(params_count(model))
|
msg_mgr.log_info(params_count(model))
|
||||||
msg_mgr.log_info("Model Initialization Finished!")
|
msg_mgr.log_info("Model Initialization Finished!")
|
||||||
|
|||||||
@@ -167,10 +167,6 @@ class BaseModel(MetaModel, nn.Module):
|
|||||||
if restore_hint != 0:
|
if restore_hint != 0:
|
||||||
self.resume_ckpt(restore_hint)
|
self.resume_ckpt(restore_hint)
|
||||||
|
|
||||||
if training:
|
|
||||||
if cfgs['trainer_cfg']['fix_BN']:
|
|
||||||
self.fix_BN()
|
|
||||||
|
|
||||||
def get_backbone(self, backbone_cfg):
|
def get_backbone(self, backbone_cfg):
|
||||||
"""Get the backbone of the model."""
|
"""Get the backbone of the model."""
|
||||||
if is_dict(backbone_cfg):
|
if is_dict(backbone_cfg):
|
||||||
@@ -427,6 +423,8 @@ class BaseModel(MetaModel, nn.Module):
|
|||||||
model.eval()
|
model.eval()
|
||||||
result_dict = BaseModel.run_test(model)
|
result_dict = BaseModel.run_test(model)
|
||||||
model.train()
|
model.train()
|
||||||
|
if model.cfgs['trainer_cfg']['fix_BN']:
|
||||||
|
model.fix_BN()
|
||||||
model.msg_mgr.write_to_tensorboard(result_dict)
|
model.msg_mgr.write_to_tensorboard(result_dict)
|
||||||
model.msg_mgr.reset_time()
|
model.msg_mgr.reset_time()
|
||||||
if model.iteration >= model.engine_cfg['total_iter']:
|
if model.iteration >= model.engine_cfg['total_iter']:
|
||||||
|
|||||||
Reference in New Issue
Block a user