enable find_unused_parameters flag for DDP
This commit is contained in:
@@ -190,13 +190,13 @@ class DDPPassthrough(DDP):
|
||||
return getattr(self.module, name)
|
||||
|
||||
|
||||
def get_ddp_module(module, **kwargs):
|
||||
def get_ddp_module(module, find_unused_parameters=False, **kwargs):
|
||||
if len(list(module.parameters())) == 0:
|
||||
# for the case that loss module has not parameters.
|
||||
return module
|
||||
device = torch.cuda.current_device()
|
||||
module = DDPPassthrough(module, device_ids=[device], output_device=device,
|
||||
find_unused_parameters=False, **kwargs)
|
||||
find_unused_parameters=find_unused_parameters, **kwargs)
|
||||
return module
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user