Fixed an issue
where the world_size is not divisible by the batch_size
This commit is contained in:
@@ -8,9 +8,15 @@ class TripletSampler(tordata.sampler.Sampler):
|
|||||||
def __init__(self, dataset, batch_size, batch_shuffle=False):
|
def __init__(self, dataset, batch_size, batch_shuffle=False):
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
if len(self.batch_size) != 2:
|
||||||
|
raise ValueError(
|
||||||
|
"batch_size should be (P x K) not {}".format(batch_size))
|
||||||
self.batch_shuffle = batch_shuffle
|
self.batch_shuffle = batch_shuffle
|
||||||
|
|
||||||
self.world_size = dist.get_world_size()
|
self.world_size = dist.get_world_size()
|
||||||
|
if (self.batch_size[0]*self.batch_size[1]) % self.world_size != 0:
|
||||||
|
raise ValueError("World size ({}) is not divisible by batch_size ({} x {})".format(
|
||||||
|
self.world_size, batch_size[0], batch_size[1]))
|
||||||
self.rank = dist.get_rank()
|
self.rank = dist.get_rank()
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user