From 538c40ba8724ebb942ca8ce3838a8cbb1a892f1f Mon Sep 17 00:00:00 2001 From: darkliang <11710911@mail.sustech.edu.cn> Date: Thu, 17 Mar 2022 16:12:35 +0800 Subject: [PATCH] Fixed an issue where the world_size is not divisible by the batch_size --- lib/data/sampler.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/lib/data/sampler.py b/lib/data/sampler.py index a8e4e05..6f6cc22 100644 --- a/lib/data/sampler.py +++ b/lib/data/sampler.py @@ -8,9 +8,15 @@ class TripletSampler(tordata.sampler.Sampler): def __init__(self, dataset, batch_size, batch_shuffle=False): self.dataset = dataset self.batch_size = batch_size + if len(self.batch_size) != 2: + raise ValueError( + "batch_size should be (P x K) not {}".format(batch_size)) self.batch_shuffle = batch_shuffle self.world_size = dist.get_world_size() + if (self.batch_size[0]*self.batch_size[1]) % self.world_size != 0: + raise ValueError("World size ({}) is not divisible by batch_size ({} x {})".format( + self.world_size, batch_size[0], batch_size[1])) self.rank = dist.get_rank() def __iter__(self): @@ -63,7 +69,7 @@ class InferenceSampler(tordata.sampler.Sampler): rank = dist.get_rank() if batch_size % world_size != 0: - raise ValueError("World size({}) is not divisible by batch_size({})".format( + raise ValueError("World size ({}) is not divisible by batch_size ({})".format( world_size, batch_size)) if batch_size != 1: