enable find_unused_parameters flag for DDP
This commit is contained in:
@@ -46,6 +46,7 @@ scheduler_cfg:
|
|||||||
scheduler: MultiStepLR
|
scheduler: MultiStepLR
|
||||||
|
|
||||||
trainer_cfg:
|
trainer_cfg:
|
||||||
|
find_unused_parameters: false
|
||||||
enable_float16: true
|
enable_float16: true
|
||||||
with_test: false
|
with_test: false
|
||||||
fix_BN: false
|
fix_BN: false
|
||||||
|
|||||||
+1
-1
@@ -46,7 +46,7 @@ def run_model(cfgs, training):
|
|||||||
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||||
if cfgs['trainer_cfg']['fix_BN']:
|
if cfgs['trainer_cfg']['fix_BN']:
|
||||||
model.fix_BN()
|
model.fix_BN()
|
||||||
model = get_ddp_module(model)
|
model = get_ddp_module(model, cfgs['trainer_cfg']['find_unused_parameters'])
|
||||||
msg_mgr.log_info(params_count(model))
|
msg_mgr.log_info(params_count(model))
|
||||||
msg_mgr.log_info("Model Initialization Finished!")
|
msg_mgr.log_info("Model Initialization Finished!")
|
||||||
|
|
||||||
|
|||||||
@@ -190,13 +190,13 @@ class DDPPassthrough(DDP):
|
|||||||
return getattr(self.module, name)
|
return getattr(self.module, name)
|
||||||
|
|
||||||
|
|
||||||
def get_ddp_module(module, **kwargs):
|
def get_ddp_module(module, find_unused_parameters=False, **kwargs):
|
||||||
if len(list(module.parameters())) == 0:
|
if len(list(module.parameters())) == 0:
|
||||||
# for the case that loss module has not parameters.
|
# for the case that loss module has not parameters.
|
||||||
return module
|
return module
|
||||||
device = torch.cuda.current_device()
|
device = torch.cuda.current_device()
|
||||||
module = DDPPassthrough(module, device_ids=[device], output_device=device,
|
module = DDPPassthrough(module, device_ids=[device], output_device=device,
|
||||||
find_unused_parameters=False, **kwargs)
|
find_unused_parameters=find_unused_parameters, **kwargs)
|
||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user