From 1588fde52dc9e709e5acc3a33a7fd6b4c3142c35 Mon Sep 17 00:00:00 2001 From: darkliang <11710911@mail.sustech.edu.cn> Date: Sat, 12 Nov 2022 16:52:58 +0800 Subject: [PATCH] fix gaitedge config and add check for num of transform --- configs/gaitedge/phase2_e2e.yaml | 4 ++-- configs/gaitedge/phase2_gaitedge.yaml | 4 ++-- opengait/modeling/base_model.py | 4 +++- opengait/modeling/models/gaitedge.py | 2 +- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/configs/gaitedge/phase2_e2e.yaml b/configs/gaitedge/phase2_e2e.yaml index 7db604e..6fea2c9 100644 --- a/configs/gaitedge/phase2_e2e.yaml +++ b/configs/gaitedge/phase2_e2e.yaml @@ -70,8 +70,8 @@ trainer_cfg: scheduler_reset: true sync_BN: true restore_hint: - - /home/leeeung/workspace/OpenGait/output/CASIA-B_new/Segmentation/Segmentation/checkpoints/Segmentation-25000.pt - - /home/leeeung/OpenGait/output/CASIA-B_new/GaitGL/GaitGL/checkpoints/GaitGL-80000.pt + - Segmentation-25000.pt + - GaitGL-80000.pt save_iter: 2000 save_name: GaitGL_E2E total_iter: 20000 diff --git a/configs/gaitedge/phase2_gaitedge.yaml b/configs/gaitedge/phase2_gaitedge.yaml index 32979e1..efc5fe9 100644 --- a/configs/gaitedge/phase2_gaitedge.yaml +++ b/configs/gaitedge/phase2_gaitedge.yaml @@ -70,8 +70,8 @@ trainer_cfg: optimizer_reset: true scheduler_reset: true sync_BN: true - restore_hint: - - Segmentation-30000.pt + restore_hint: + - Segmentation-25000.pt - GaitGL-80000.pt save_iter: 2000 save_name: GaitEdge diff --git a/opengait/modeling/base_model.py b/opengait/modeling/base_model.py index 8a0670b..cf3fdea 100644 --- a/opengait/modeling/base_model.py +++ b/opengait/modeling/base_model.py @@ -301,7 +301,9 @@ class BaseModel(MetaModel, nn.Module): seqs_batch, labs_batch, typs_batch, vies_batch, seqL_batch = inputs trf_cfgs = self.engine_cfg['transform'] seq_trfs = get_transform(trf_cfgs) - + if len(seqs_batch) != len(seq_trfs): + raise ValueError( + "The number of types of input data and transform should be same. But got {} and {}".format(len(seqs_batch), len(seq_trfs))) requires_grad = bool(self.training) seqs = [np2var(np.asarray([trf(fra) for fra in seq]), requires_grad=requires_grad).float() for trf, seq in zip(seq_trfs, seqs_batch)] diff --git a/opengait/modeling/models/gaitedge.py b/opengait/modeling/models/gaitedge.py index 9d34dd7..0f52f10 100644 --- a/opengait/modeling/models/gaitedge.py +++ b/opengait/modeling/models/gaitedge.py @@ -50,7 +50,7 @@ class GaitEdge(GaitGL): self.is_edge = model_cfg['edge'] self.seg_lr = model_cfg['seg_lr'] self.kernel = torch.ones( - (model_cfg['kernel_size'], model_cfg['kernel_size'])).cuda() + (model_cfg['kernel_size'], model_cfg['kernel_size'])) def finetune_parameters(self): fine_tune_params = list()