From 19ed960b0a4c681add8da839ee2e8154aa795ebc Mon Sep 17 00:00:00 2001 From: darkliang <11710911@mail.sustech.edu.cn> Date: Wed, 29 Dec 2021 17:14:03 +0800 Subject: [PATCH] Add some clearer error messages --- lib/data/dataset.py | 2 +- lib/data/sampler.py | 2 +- lib/main.py | 2 +- lib/modeling/backbones/plain.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/data/dataset.py b/lib/data/dataset.py index a9ceca8..2ecd759 100644 --- a/lib/data/dataset.py +++ b/lib/data/dataset.py @@ -46,7 +46,7 @@ class DataSet(tordata.Dataset): data_list.append(_) for data in data_list: if len(data) != len(data_list[0]): - raise AssertionError + raise ValueError('Each input data should have the same length.') return data_list diff --git a/lib/data/sampler.py b/lib/data/sampler.py index d91d3b6..a8e4e05 100644 --- a/lib/data/sampler.py +++ b/lib/data/sampler.py @@ -63,7 +63,7 @@ class InferenceSampler(tordata.sampler.Sampler): rank = dist.get_rank() if batch_size % world_size != 0: - raise AssertionError("World size({}) need be divisible by batch_size({})".format( + raise ValueError("World size({}) is not divisible by batch_size({})".format( world_size, batch_size)) if batch_size != 1: diff --git a/lib/main.py b/lib/main.py index e8b033f..027ff2b 100644 --- a/lib/main.py +++ b/lib/main.py @@ -57,7 +57,7 @@ def run_model(cfgs, training): if __name__ == '__main__': torch.distributed.init_process_group('nccl', init_method='env://') if torch.distributed.get_world_size() != torch.cuda.device_count(): - raise AssertionError("Expect number of availuable GPUs({}) equals to the world size({}).".format( + raise ValueError("Expect number of availuable GPUs({}) equals to the world size({}).".format( torch.distributed.get_world_size(), torch.cuda.device_count())) cfgs = config_loader(opt.cfgs) if opt.iter != 0: diff --git a/lib/modeling/backbones/plain.py b/lib/modeling/backbones/plain.py index e66da97..17a4bcb 100644 --- a/lib/modeling/backbones/plain.py +++ b/lib/modeling/backbones/plain.py @@ -43,7 +43,7 @@ class Plain(nn.Module): cfg = cfg.split('-') typ = cfg[0] if typ not in ['BC', 'FC']: - raise AssertionError + raise ValueError('Only support BC or FC, but got {}'.format(typ)) out_c = int(cfg[1]) if typ == 'BC':