Fixed an issue

where the world_size is not divisible by the batch_size
This commit is contained in:
darkliang
2022-03-17 16:12:35 +08:00
parent 1e988dd8a6
commit 538c40ba87
+7 -1
View File
@@ -8,9 +8,15 @@ class TripletSampler(tordata.sampler.Sampler):
def __init__(self, dataset, batch_size, batch_shuffle=False):
self.dataset = dataset
self.batch_size = batch_size
if len(self.batch_size) != 2:
raise ValueError(
"batch_size should be (P x K) not {}".format(batch_size))
self.batch_shuffle = batch_shuffle
self.world_size = dist.get_world_size()
if (self.batch_size[0]*self.batch_size[1]) % self.world_size != 0:
raise ValueError("World size ({}) is not divisible by batch_size ({} x {})".format(
self.world_size, batch_size[0], batch_size[1]))
self.rank = dist.get_rank()
def __iter__(self):
@@ -63,7 +69,7 @@ class InferenceSampler(tordata.sampler.Sampler):
rank = dist.get_rank()
if batch_size % world_size != 0:
raise ValueError("World size({}) is not divisible by batch_size({})".format(
raise ValueError("World size ({}) is not divisible by batch_size ({})".format(
world_size, batch_size))
if batch_size != 1: