fix a bug of train with test
This commit is contained in:
@@ -144,7 +144,7 @@ class BaseModel(MetaModel, nn.Module):
|
||||
|
||||
self.build_network(cfgs['model_cfg'])
|
||||
self.init_parameters()
|
||||
self.seq_trfs = get_transform(self.engine_cfg['transform'])
|
||||
self.trainer_trfs = get_transform(cfgs['trainer_cfg']['transform'])
|
||||
|
||||
self.msg_mgr.log_info(cfgs['data_cfg'])
|
||||
if training:
|
||||
@@ -153,6 +153,8 @@ class BaseModel(MetaModel, nn.Module):
|
||||
if not training or self.engine_cfg['with_test']:
|
||||
self.test_loader = self.get_loader(
|
||||
cfgs['data_cfg'], train=False)
|
||||
self.evaluator_trfs = get_transform(
|
||||
cfgs['evaluator_cfg']['transform'])
|
||||
|
||||
self.device = torch.distributed.get_rank()
|
||||
torch.cuda.set_device(self.device)
|
||||
@@ -300,7 +302,7 @@ class BaseModel(MetaModel, nn.Module):
|
||||
tuple: training data including inputs, labels, and some meta data.
|
||||
"""
|
||||
seqs_batch, labs_batch, typs_batch, vies_batch, seqL_batch = inputs
|
||||
seq_trfs = self.seq_trfs
|
||||
seq_trfs = self.trainer_trfs if self.training else self.evaluator_trfs
|
||||
if len(seqs_batch) != len(seq_trfs):
|
||||
raise ValueError(
|
||||
"The number of types of input data and transform should be same. But got {} and {}".format(len(seqs_batch), len(seq_trfs)))
|
||||
|
||||
Reference in New Issue
Block a user