Force the input batch size to be equal to the number of GPUs during testing

This commit is contained in:
darkliang
2024-03-19 15:22:00 +08:00
parent da65481b66
commit c9f61e4808
+3
View File
@@ -439,6 +439,9 @@ class BaseModel(MetaModel, nn.Module):
def run_test(model):
"""Accept the instance object(model) here, and then run the test loop."""
if torch.distributed.get_world_size() != model.engine_cfg['sampler']['batch_size']:
raise ValueError("The batch size ({}) must be equal to the number of GPUs ({}) in testing mode!".format(
model.engine_cfg['sampler']['batch_size'], torch.distributed.get_world_size()))
rank = torch.distributed.get_rank()
with torch.no_grad():
info_dict = model.inference(rank)