From ff398acbc70be8395127e2524259392bb520f743 Mon Sep 17 00:00:00 2001 From: wj1tr0y Date: Sat, 16 Apr 2022 19:29:06 +0800 Subject: [PATCH] Fix the bug of fix_bn --- opengait/main.py | 2 ++ opengait/modeling/base_model.py | 6 ++---- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/opengait/main.py b/opengait/main.py index 3aff63b..35d372c 100644 --- a/opengait/main.py +++ b/opengait/main.py @@ -44,6 +44,8 @@ def run_model(cfgs, training): model = Model(cfgs, training) if training and cfgs['trainer_cfg']['sync_BN']: model = nn.SyncBatchNorm.convert_sync_batchnorm(model) + if cfgs['trainer_cfg']['fix_BN']: + model.fix_BN() model = get_ddp_module(model) msg_mgr.log_info(params_count(model)) msg_mgr.log_info("Model Initialization Finished!") diff --git a/opengait/modeling/base_model.py b/opengait/modeling/base_model.py index c945712..8a0670b 100644 --- a/opengait/modeling/base_model.py +++ b/opengait/modeling/base_model.py @@ -167,10 +167,6 @@ class BaseModel(MetaModel, nn.Module): if restore_hint != 0: self.resume_ckpt(restore_hint) - if training: - if cfgs['trainer_cfg']['fix_BN']: - self.fix_BN() - def get_backbone(self, backbone_cfg): """Get the backbone of the model.""" if is_dict(backbone_cfg): @@ -427,6 +423,8 @@ class BaseModel(MetaModel, nn.Module): model.eval() result_dict = BaseModel.run_test(model) model.train() + if model.cfgs['trainer_cfg']['fix_BN']: + model.fix_BN() model.msg_mgr.write_to_tensorboard(result_dict) model.msg_mgr.reset_time() if model.iteration >= model.engine_cfg['total_iter']: