Fixed an issue
where the world_size is not divisible by the batch_size
This commit is contained in:
+7
-1
@@ -8,9 +8,15 @@ class TripletSampler(tordata.sampler.Sampler):
|
|||||||
def __init__(self, dataset, batch_size, batch_shuffle=False):
|
def __init__(self, dataset, batch_size, batch_shuffle=False):
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
if len(self.batch_size) != 2:
|
||||||
|
raise ValueError(
|
||||||
|
"batch_size should be (P x K) not {}".format(batch_size))
|
||||||
self.batch_shuffle = batch_shuffle
|
self.batch_shuffle = batch_shuffle
|
||||||
|
|
||||||
self.world_size = dist.get_world_size()
|
self.world_size = dist.get_world_size()
|
||||||
|
if (self.batch_size[0]*self.batch_size[1]) % self.world_size != 0:
|
||||||
|
raise ValueError("World size ({}) is not divisible by batch_size ({} x {})".format(
|
||||||
|
self.world_size, batch_size[0], batch_size[1]))
|
||||||
self.rank = dist.get_rank()
|
self.rank = dist.get_rank()
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
@@ -63,7 +69,7 @@ class InferenceSampler(tordata.sampler.Sampler):
|
|||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
|
|
||||||
if batch_size % world_size != 0:
|
if batch_size % world_size != 0:
|
||||||
raise ValueError("World size({}) is not divisible by batch_size({})".format(
|
raise ValueError("World size ({}) is not divisible by batch_size ({})".format(
|
||||||
world_size, batch_size))
|
world_size, batch_size))
|
||||||
|
|
||||||
if batch_size != 1:
|
if batch_size != 1:
|
||||||
|
|||||||
Reference in New Issue
Block a user