fix gaitedge config and add check
for num of transform
This commit is contained in:
@@ -301,7 +301,9 @@ class BaseModel(MetaModel, nn.Module):
|
||||
seqs_batch, labs_batch, typs_batch, vies_batch, seqL_batch = inputs
|
||||
trf_cfgs = self.engine_cfg['transform']
|
||||
seq_trfs = get_transform(trf_cfgs)
|
||||
|
||||
if len(seqs_batch) != len(seq_trfs):
|
||||
raise ValueError(
|
||||
"The number of types of input data and transform should be same. But got {} and {}".format(len(seqs_batch), len(seq_trfs)))
|
||||
requires_grad = bool(self.training)
|
||||
seqs = [np2var(np.asarray([trf(fra) for fra in seq]), requires_grad=requires_grad).float()
|
||||
for trf, seq in zip(seq_trfs, seqs_batch)]
|
||||
|
||||
Reference in New Issue
Block a user