enable find_unused_parameters flag for DDP

This commit is contained in:
jdyjjj
2023-11-21 19:26:42 +08:00
parent 112208ef74
commit 388974ab2a
3 changed files with 4 additions and 3 deletions
+1
View File
@@ -46,6 +46,7 @@ scheduler_cfg:
scheduler: MultiStepLR
trainer_cfg:
find_unused_parameters: false
enable_float16: true
with_test: false
fix_BN: false
+1 -1
View File
@@ -46,7 +46,7 @@ def run_model(cfgs, training):
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
if cfgs['trainer_cfg']['fix_BN']:
model.fix_BN()
model = get_ddp_module(model)
model = get_ddp_module(model, cfgs['trainer_cfg']['find_unused_parameters'])
msg_mgr.log_info(params_count(model))
msg_mgr.log_info("Model Initialization Finished!")
+2 -2
View File
@@ -190,13 +190,13 @@ class DDPPassthrough(DDP):
return getattr(self.module, name)
def get_ddp_module(module, **kwargs):
def get_ddp_module(module, find_unused_parameters=False, **kwargs):
if len(list(module.parameters())) == 0:
# for the case that loss module has not parameters.
return module
device = torch.cuda.current_device()
module = DDPPassthrough(module, device_ids=[device], output_device=device,
find_unused_parameters=False, **kwargs)
find_unused_parameters=find_unused_parameters, **kwargs)
return module