diff --git a/opengait/modeling/base_model.py b/opengait/modeling/base_model.py index 57fc747..152c6cf 100644 --- a/opengait/modeling/base_model.py +++ b/opengait/modeling/base_model.py @@ -439,6 +439,9 @@ class BaseModel(MetaModel, nn.Module): def run_test(model): """Accept the instance object(model) here, and then run the test loop.""" + if torch.distributed.get_world_size() != model.engine_cfg['sampler']['batch_size']: + raise ValueError("The batch size ({}) must be equal to the number of GPUs ({}) in testing mode!".format( + model.engine_cfg['sampler']['batch_size'], torch.distributed.get_world_size())) rank = torch.distributed.get_rank() with torch.no_grad(): info_dict = model.inference(rank)