diff --git a/opengait/modeling/base_model.py b/opengait/modeling/base_model.py index 98aa103..57fc747 100644 --- a/opengait/modeling/base_model.py +++ b/opengait/modeling/base_model.py @@ -144,7 +144,7 @@ class BaseModel(MetaModel, nn.Module): self.build_network(cfgs['model_cfg']) self.init_parameters() - self.seq_trfs = get_transform(self.engine_cfg['transform']) + self.trainer_trfs = get_transform(cfgs['trainer_cfg']['transform']) self.msg_mgr.log_info(cfgs['data_cfg']) if training: @@ -153,6 +153,8 @@ class BaseModel(MetaModel, nn.Module): if not training or self.engine_cfg['with_test']: self.test_loader = self.get_loader( cfgs['data_cfg'], train=False) + self.evaluator_trfs = get_transform( + cfgs['evaluator_cfg']['transform']) self.device = torch.distributed.get_rank() torch.cuda.set_device(self.device) @@ -300,7 +302,7 @@ class BaseModel(MetaModel, nn.Module): tuple: training data including inputs, labels, and some meta data. """ seqs_batch, labs_batch, typs_batch, vies_batch, seqL_batch = inputs - seq_trfs = self.seq_trfs + seq_trfs = self.trainer_trfs if self.training else self.evaluator_trfs if len(seqs_batch) != len(seq_trfs): raise ValueError( "The number of types of input data and transform should be same. But got {} and {}".format(len(seqs_batch), len(seq_trfs)))