Fixed an issue
where the world_size is not divisible by the batch_size
This commit is contained in:
+7
-1
@@ -8,9 +8,15 @@ class TripletSampler(tordata.sampler.Sampler):
|
||||
def __init__(self, dataset, batch_size, batch_shuffle=False):
|
||||
self.dataset = dataset
|
||||
self.batch_size = batch_size
|
||||
if len(self.batch_size) != 2:
|
||||
raise ValueError(
|
||||
"batch_size should be (P x K) not {}".format(batch_size))
|
||||
self.batch_shuffle = batch_shuffle
|
||||
|
||||
self.world_size = dist.get_world_size()
|
||||
if (self.batch_size[0]*self.batch_size[1]) % self.world_size != 0:
|
||||
raise ValueError("World size ({}) is not divisible by batch_size ({} x {})".format(
|
||||
self.world_size, batch_size[0], batch_size[1]))
|
||||
self.rank = dist.get_rank()
|
||||
|
||||
def __iter__(self):
|
||||
@@ -63,7 +69,7 @@ class InferenceSampler(tordata.sampler.Sampler):
|
||||
rank = dist.get_rank()
|
||||
|
||||
if batch_size % world_size != 0:
|
||||
raise ValueError("World size({}) is not divisible by batch_size({})".format(
|
||||
raise ValueError("World size ({}) is not divisible by batch_size ({})".format(
|
||||
world_size, batch_size))
|
||||
|
||||
if batch_size != 1:
|
||||
|
||||
Reference in New Issue
Block a user