Add some clearer error messages
This commit is contained in:
+1
-1
@@ -46,7 +46,7 @@ class DataSet(tordata.Dataset):
|
|||||||
data_list.append(_)
|
data_list.append(_)
|
||||||
for data in data_list:
|
for data in data_list:
|
||||||
if len(data) != len(data_list[0]):
|
if len(data) != len(data_list[0]):
|
||||||
raise AssertionError
|
raise ValueError('Each input data should have the same length.')
|
||||||
|
|
||||||
return data_list
|
return data_list
|
||||||
|
|
||||||
|
|||||||
+1
-1
@@ -63,7 +63,7 @@ class InferenceSampler(tordata.sampler.Sampler):
|
|||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
|
|
||||||
if batch_size % world_size != 0:
|
if batch_size % world_size != 0:
|
||||||
raise AssertionError("World size({}) need be divisible by batch_size({})".format(
|
raise ValueError("World size({}) is not divisible by batch_size({})".format(
|
||||||
world_size, batch_size))
|
world_size, batch_size))
|
||||||
|
|
||||||
if batch_size != 1:
|
if batch_size != 1:
|
||||||
|
|||||||
+1
-1
@@ -57,7 +57,7 @@ def run_model(cfgs, training):
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
torch.distributed.init_process_group('nccl', init_method='env://')
|
torch.distributed.init_process_group('nccl', init_method='env://')
|
||||||
if torch.distributed.get_world_size() != torch.cuda.device_count():
|
if torch.distributed.get_world_size() != torch.cuda.device_count():
|
||||||
raise AssertionError("Expect number of availuable GPUs({}) equals to the world size({}).".format(
|
raise ValueError("Expect number of availuable GPUs({}) equals to the world size({}).".format(
|
||||||
torch.distributed.get_world_size(), torch.cuda.device_count()))
|
torch.distributed.get_world_size(), torch.cuda.device_count()))
|
||||||
cfgs = config_loader(opt.cfgs)
|
cfgs = config_loader(opt.cfgs)
|
||||||
if opt.iter != 0:
|
if opt.iter != 0:
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ class Plain(nn.Module):
|
|||||||
cfg = cfg.split('-')
|
cfg = cfg.split('-')
|
||||||
typ = cfg[0]
|
typ = cfg[0]
|
||||||
if typ not in ['BC', 'FC']:
|
if typ not in ['BC', 'FC']:
|
||||||
raise AssertionError
|
raise ValueError('Only support BC or FC, but got {}'.format(typ))
|
||||||
out_c = int(cfg[1])
|
out_c = int(cfg[1])
|
||||||
|
|
||||||
if typ == 'BC':
|
if typ == 'BC':
|
||||||
|
|||||||
Reference in New Issue
Block a user