Fixed an issue

where the world_size is not divisible by the batch_size
This commit is contained in:
darkliang
2022-03-17 16:12:35 +08:00
parent 1e988dd8a6
commit 538c40ba87
+6
View File
@@ -8,9 +8,15 @@ class TripletSampler(tordata.sampler.Sampler):
def __init__(self, dataset, batch_size, batch_shuffle=False): def __init__(self, dataset, batch_size, batch_shuffle=False):
self.dataset = dataset self.dataset = dataset
self.batch_size = batch_size self.batch_size = batch_size
if len(self.batch_size) != 2:
raise ValueError(
"batch_size should be (P x K) not {}".format(batch_size))
self.batch_shuffle = batch_shuffle self.batch_shuffle = batch_shuffle
self.world_size = dist.get_world_size() self.world_size = dist.get_world_size()
if (self.batch_size[0]*self.batch_size[1]) % self.world_size != 0:
raise ValueError("World size ({}) is not divisible by batch_size ({} x {})".format(
self.world_size, batch_size[0], batch_size[1]))
self.rank = dist.get_rank() self.rank = dist.get_rank()
def __iter__(self): def __iter__(self):