enable find_unused_parameters flag for DDP

This commit is contained in:
jdyjjj
2023-11-21 19:26:42 +08:00
parent 112208ef74
commit 388974ab2a
3 changed files with 4 additions and 3 deletions
+2 -2
View File
@@ -190,13 +190,13 @@ class DDPPassthrough(DDP):
return getattr(self.module, name)
def get_ddp_module(module, **kwargs):
def get_ddp_module(module, find_unused_parameters=False, **kwargs):
if len(list(module.parameters())) == 0:
# for the case that loss module has not parameters.
return module
device = torch.cuda.current_device()
module = DDPPassthrough(module, device_ids=[device], output_device=device,
find_unused_parameters=False, **kwargs)
find_unused_parameters=find_unused_parameters, **kwargs)
return module