From 388974ab2a1e95cf0b9c6cb9508c0e14c6fec3cf Mon Sep 17 00:00:00 2001 From: jdyjjj <1410234026@qq.com> Date: Tue, 21 Nov 2023 19:26:42 +0800 Subject: [PATCH] enable find_unused_parameters flag for DDP --- configs/default.yaml | 1 + opengait/main.py | 2 +- opengait/utils/common.py | 4 ++-- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/configs/default.yaml b/configs/default.yaml index 960fabf..abd357a 100644 --- a/configs/default.yaml +++ b/configs/default.yaml @@ -46,6 +46,7 @@ scheduler_cfg: scheduler: MultiStepLR trainer_cfg: + find_unused_parameters: false enable_float16: true with_test: false fix_BN: false diff --git a/opengait/main.py b/opengait/main.py index 598dcaa..8358d6d 100644 --- a/opengait/main.py +++ b/opengait/main.py @@ -46,7 +46,7 @@ def run_model(cfgs, training): model = nn.SyncBatchNorm.convert_sync_batchnorm(model) if cfgs['trainer_cfg']['fix_BN']: model.fix_BN() - model = get_ddp_module(model) + model = get_ddp_module(model, cfgs['trainer_cfg']['find_unused_parameters']) msg_mgr.log_info(params_count(model)) msg_mgr.log_info("Model Initialization Finished!") diff --git a/opengait/utils/common.py b/opengait/utils/common.py index 4429174..dfd39e0 100644 --- a/opengait/utils/common.py +++ b/opengait/utils/common.py @@ -190,13 +190,13 @@ class DDPPassthrough(DDP): return getattr(self.module, name) -def get_ddp_module(module, **kwargs): +def get_ddp_module(module, find_unused_parameters=False, **kwargs): if len(list(module.parameters())) == 0: # for the case that loss module has not parameters. return module device = torch.cuda.current_device() module = DDPPassthrough(module, device_ids=[device], output_device=device, - find_unused_parameters=False, **kwargs) + find_unused_parameters=find_unused_parameters, **kwargs) return module