fix a bug of train with test

This commit is contained in:
darkliang
2023-10-09 18:45:26 +08:00
parent d579ca9135
commit 5f0c27a622
+4 -2
View File
@@ -144,7 +144,7 @@ class BaseModel(MetaModel, nn.Module):
self.build_network(cfgs['model_cfg'])
self.init_parameters()
self.seq_trfs = get_transform(self.engine_cfg['transform'])
self.trainer_trfs = get_transform(cfgs['trainer_cfg']['transform'])
self.msg_mgr.log_info(cfgs['data_cfg'])
if training:
@@ -153,6 +153,8 @@ class BaseModel(MetaModel, nn.Module):
if not training or self.engine_cfg['with_test']:
self.test_loader = self.get_loader(
cfgs['data_cfg'], train=False)
self.evaluator_trfs = get_transform(
cfgs['evaluator_cfg']['transform'])
self.device = torch.distributed.get_rank()
torch.cuda.set_device(self.device)
@@ -300,7 +302,7 @@ class BaseModel(MetaModel, nn.Module):
tuple: training data including inputs, labels, and some meta data.
"""
seqs_batch, labs_batch, typs_batch, vies_batch, seqL_batch = inputs
seq_trfs = self.seq_trfs
seq_trfs = self.trainer_trfs if self.training else self.evaluator_trfs
if len(seqs_batch) != len(seq_trfs):
raise ValueError(
"The number of types of input data and transform should be same. But got {} and {}".format(len(seqs_batch), len(seq_trfs)))