fix a bug of train with test
This commit is contained in:
@@ -144,7 +144,7 @@ class BaseModel(MetaModel, nn.Module):
|
|||||||
|
|
||||||
self.build_network(cfgs['model_cfg'])
|
self.build_network(cfgs['model_cfg'])
|
||||||
self.init_parameters()
|
self.init_parameters()
|
||||||
self.seq_trfs = get_transform(self.engine_cfg['transform'])
|
self.trainer_trfs = get_transform(cfgs['trainer_cfg']['transform'])
|
||||||
|
|
||||||
self.msg_mgr.log_info(cfgs['data_cfg'])
|
self.msg_mgr.log_info(cfgs['data_cfg'])
|
||||||
if training:
|
if training:
|
||||||
@@ -153,6 +153,8 @@ class BaseModel(MetaModel, nn.Module):
|
|||||||
if not training or self.engine_cfg['with_test']:
|
if not training or self.engine_cfg['with_test']:
|
||||||
self.test_loader = self.get_loader(
|
self.test_loader = self.get_loader(
|
||||||
cfgs['data_cfg'], train=False)
|
cfgs['data_cfg'], train=False)
|
||||||
|
self.evaluator_trfs = get_transform(
|
||||||
|
cfgs['evaluator_cfg']['transform'])
|
||||||
|
|
||||||
self.device = torch.distributed.get_rank()
|
self.device = torch.distributed.get_rank()
|
||||||
torch.cuda.set_device(self.device)
|
torch.cuda.set_device(self.device)
|
||||||
@@ -300,7 +302,7 @@ class BaseModel(MetaModel, nn.Module):
|
|||||||
tuple: training data including inputs, labels, and some meta data.
|
tuple: training data including inputs, labels, and some meta data.
|
||||||
"""
|
"""
|
||||||
seqs_batch, labs_batch, typs_batch, vies_batch, seqL_batch = inputs
|
seqs_batch, labs_batch, typs_batch, vies_batch, seqL_batch = inputs
|
||||||
seq_trfs = self.seq_trfs
|
seq_trfs = self.trainer_trfs if self.training else self.evaluator_trfs
|
||||||
if len(seqs_batch) != len(seq_trfs):
|
if len(seqs_batch) != len(seq_trfs):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The number of types of input data and transform should be same. But got {} and {}".format(len(seqs_batch), len(seq_trfs)))
|
"The number of types of input data and transform should be same. But got {} and {}".format(len(seqs_batch), len(seq_trfs)))
|
||||||
|
|||||||
Reference in New Issue
Block a user