fix gaitedge config and add check
for num of transform
This commit is contained in:
@@ -70,8 +70,8 @@ trainer_cfg:
|
|||||||
scheduler_reset: true
|
scheduler_reset: true
|
||||||
sync_BN: true
|
sync_BN: true
|
||||||
restore_hint:
|
restore_hint:
|
||||||
- /home/leeeung/workspace/OpenGait/output/CASIA-B_new/Segmentation/Segmentation/checkpoints/Segmentation-25000.pt
|
- Segmentation-25000.pt
|
||||||
- /home/leeeung/OpenGait/output/CASIA-B_new/GaitGL/GaitGL/checkpoints/GaitGL-80000.pt
|
- GaitGL-80000.pt
|
||||||
save_iter: 2000
|
save_iter: 2000
|
||||||
save_name: GaitGL_E2E
|
save_name: GaitGL_E2E
|
||||||
total_iter: 20000
|
total_iter: 20000
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ trainer_cfg:
|
|||||||
scheduler_reset: true
|
scheduler_reset: true
|
||||||
sync_BN: true
|
sync_BN: true
|
||||||
restore_hint:
|
restore_hint:
|
||||||
- Segmentation-30000.pt
|
- Segmentation-25000.pt
|
||||||
- GaitGL-80000.pt
|
- GaitGL-80000.pt
|
||||||
save_iter: 2000
|
save_iter: 2000
|
||||||
save_name: GaitEdge
|
save_name: GaitEdge
|
||||||
|
|||||||
@@ -301,7 +301,9 @@ class BaseModel(MetaModel, nn.Module):
|
|||||||
seqs_batch, labs_batch, typs_batch, vies_batch, seqL_batch = inputs
|
seqs_batch, labs_batch, typs_batch, vies_batch, seqL_batch = inputs
|
||||||
trf_cfgs = self.engine_cfg['transform']
|
trf_cfgs = self.engine_cfg['transform']
|
||||||
seq_trfs = get_transform(trf_cfgs)
|
seq_trfs = get_transform(trf_cfgs)
|
||||||
|
if len(seqs_batch) != len(seq_trfs):
|
||||||
|
raise ValueError(
|
||||||
|
"The number of types of input data and transform should be same. But got {} and {}".format(len(seqs_batch), len(seq_trfs)))
|
||||||
requires_grad = bool(self.training)
|
requires_grad = bool(self.training)
|
||||||
seqs = [np2var(np.asarray([trf(fra) for fra in seq]), requires_grad=requires_grad).float()
|
seqs = [np2var(np.asarray([trf(fra) for fra in seq]), requires_grad=requires_grad).float()
|
||||||
for trf, seq in zip(seq_trfs, seqs_batch)]
|
for trf, seq in zip(seq_trfs, seqs_batch)]
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ class GaitEdge(GaitGL):
|
|||||||
self.is_edge = model_cfg['edge']
|
self.is_edge = model_cfg['edge']
|
||||||
self.seg_lr = model_cfg['seg_lr']
|
self.seg_lr = model_cfg['seg_lr']
|
||||||
self.kernel = torch.ones(
|
self.kernel = torch.ones(
|
||||||
(model_cfg['kernel_size'], model_cfg['kernel_size'])).cuda()
|
(model_cfg['kernel_size'], model_cfg['kernel_size']))
|
||||||
|
|
||||||
def finetune_parameters(self):
|
def finetune_parameters(self):
|
||||||
fine_tune_params = list()
|
fine_tune_params = list()
|
||||||
|
|||||||
Reference in New Issue
Block a user