Support skeleton (#155)
* pose * pose * pose * pose * 你的提交消息 * pose * pose * Delete train1.sh * pretreatment * configs * pose * reference * Update gaittr.py * naming * naming * Update transform.py * update for datasets * update README * update name and README * update * Update transform.py
This commit is contained in:
@@ -144,6 +144,7 @@ class BaseModel(MetaModel, nn.Module):
|
||||
|
||||
self.build_network(cfgs['model_cfg'])
|
||||
self.init_parameters()
|
||||
self.seq_trfs = get_transform(self.engine_cfg['transform'])
|
||||
|
||||
self.msg_mgr.log_info(cfgs['data_cfg'])
|
||||
if training:
|
||||
@@ -299,8 +300,7 @@ class BaseModel(MetaModel, nn.Module):
|
||||
tuple: training data including inputs, labels, and some meta data.
|
||||
"""
|
||||
seqs_batch, labs_batch, typs_batch, vies_batch, seqL_batch = inputs
|
||||
trf_cfgs = self.engine_cfg['transform']
|
||||
seq_trfs = get_transform(trf_cfgs)
|
||||
seq_trfs = self.seq_trfs
|
||||
if len(seqs_batch) != len(seq_trfs):
|
||||
raise ValueError(
|
||||
"The number of types of input data and transform should be same. But got {} and {}".format(len(seqs_batch), len(seq_trfs)))
|
||||
|
||||
Reference in New Issue
Block a user