Add some clearer error messages
This commit is contained in:
+1
-1
@@ -46,7 +46,7 @@ class DataSet(tordata.Dataset):
|
||||
data_list.append(_)
|
||||
for data in data_list:
|
||||
if len(data) != len(data_list[0]):
|
||||
raise AssertionError
|
||||
raise ValueError('Each input data should have the same length.')
|
||||
|
||||
return data_list
|
||||
|
||||
|
||||
+1
-1
@@ -63,7 +63,7 @@ class InferenceSampler(tordata.sampler.Sampler):
|
||||
rank = dist.get_rank()
|
||||
|
||||
if batch_size % world_size != 0:
|
||||
raise AssertionError("World size({}) need be divisible by batch_size({})".format(
|
||||
raise ValueError("World size({}) is not divisible by batch_size({})".format(
|
||||
world_size, batch_size))
|
||||
|
||||
if batch_size != 1:
|
||||
|
||||
Reference in New Issue
Block a user