From c9f61e4808f0a62e184f2bbc3b36f41f942886f4 Mon Sep 17 00:00:00 2001 From: darkliang <12132342@mail.sustech.edu.cn> Date: Tue, 19 Mar 2024 15:22:00 +0800 Subject: [PATCH] Force the input batch size to be equal to the number of GPUs during testing --- opengait/modeling/base_model.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/opengait/modeling/base_model.py b/opengait/modeling/base_model.py index 57fc747..152c6cf 100644 --- a/opengait/modeling/base_model.py +++ b/opengait/modeling/base_model.py @@ -439,6 +439,9 @@ class BaseModel(MetaModel, nn.Module): def run_test(model): """Accept the instance object(model) here, and then run the test loop.""" + if torch.distributed.get_world_size() != model.engine_cfg['sampler']['batch_size']: + raise ValueError("The batch size ({}) must be equal to the number of GPUs ({}) in testing mode!".format( + model.engine_cfg['sampler']['batch_size'], torch.distributed.get_world_size())) rank = torch.distributed.get_rank() with torch.no_grad(): info_dict = model.inference(rank)