Add some clearer error messages

This commit is contained in:
darkliang
2021-12-29 17:14:03 +08:00
parent bddd552907
commit 19ed960b0a
4 changed files with 4 additions and 4 deletions
+1 -1
View File
@@ -46,7 +46,7 @@ class DataSet(tordata.Dataset):
data_list.append(_) data_list.append(_)
for data in data_list: for data in data_list:
if len(data) != len(data_list[0]): if len(data) != len(data_list[0]):
raise AssertionError raise ValueError('Each input data should have the same length.')
return data_list return data_list
+1 -1
View File
@@ -63,7 +63,7 @@ class InferenceSampler(tordata.sampler.Sampler):
rank = dist.get_rank() rank = dist.get_rank()
if batch_size % world_size != 0: if batch_size % world_size != 0:
raise AssertionError("World size({}) need be divisible by batch_size({})".format( raise ValueError("World size({}) is not divisible by batch_size({})".format(
world_size, batch_size)) world_size, batch_size))
if batch_size != 1: if batch_size != 1:
+1 -1
View File
@@ -57,7 +57,7 @@ def run_model(cfgs, training):
if __name__ == '__main__': if __name__ == '__main__':
torch.distributed.init_process_group('nccl', init_method='env://') torch.distributed.init_process_group('nccl', init_method='env://')
if torch.distributed.get_world_size() != torch.cuda.device_count(): if torch.distributed.get_world_size() != torch.cuda.device_count():
raise AssertionError("Expect number of availuable GPUs({}) equals to the world size({}).".format( raise ValueError("Expect number of availuable GPUs({}) equals to the world size({}).".format(
torch.distributed.get_world_size(), torch.cuda.device_count())) torch.distributed.get_world_size(), torch.cuda.device_count()))
cfgs = config_loader(opt.cfgs) cfgs = config_loader(opt.cfgs)
if opt.iter != 0: if opt.iter != 0:
+1 -1
View File
@@ -43,7 +43,7 @@ class Plain(nn.Module):
cfg = cfg.split('-') cfg = cfg.split('-')
typ = cfg[0] typ = cfg[0]
if typ not in ['BC', 'FC']: if typ not in ['BC', 'FC']:
raise AssertionError raise ValueError('Only support BC or FC, but got {}'.format(typ))
out_c = int(cfg[1]) out_c = int(cfg[1])
if typ == 'BC': if typ == 'BC':