fix gaitedge config and add check

for num of transform
This commit is contained in:
darkliang
2022-11-12 16:52:58 +08:00
parent a71444b967
commit 1588fde52d
4 changed files with 8 additions and 6 deletions
+3 -1
View File
@@ -301,7 +301,9 @@ class BaseModel(MetaModel, nn.Module):
seqs_batch, labs_batch, typs_batch, vies_batch, seqL_batch = inputs
trf_cfgs = self.engine_cfg['transform']
seq_trfs = get_transform(trf_cfgs)
if len(seqs_batch) != len(seq_trfs):
raise ValueError(
"The number of types of input data and transform should be same. But got {} and {}".format(len(seqs_batch), len(seq_trfs)))
requires_grad = bool(self.training)
seqs = [np2var(np.asarray([trf(fra) for fra in seq]), requires_grad=requires_grad).float()
for trf, seq in zip(seq_trfs, seqs_batch)]
+1 -1
View File
@@ -50,7 +50,7 @@ class GaitEdge(GaitGL):
self.is_edge = model_cfg['edge']
self.seg_lr = model_cfg['seg_lr']
self.kernel = torch.ones(
(model_cfg['kernel_size'], model_cfg['kernel_size'])).cuda()
(model_cfg['kernel_size'], model_cfg['kernel_size']))
def finetune_parameters(self):
fine_tune_params = list()