diff --git a/config/baseline/baseline.yaml b/config/baseline/baseline.yaml index 415bd8a..f03335b 100644 --- a/config/baseline/baseline.yaml +++ b/config/baseline/baseline.yaml @@ -46,9 +46,6 @@ model_cfg: - M - BC-256 - BC-256 - # - M - # - BC-512 - # - BC-512 type: Plain SeparateFCs: in_channels: 256 @@ -81,6 +78,7 @@ trainer_cfg: enable_float16: true # half_percesion float for memory reduction and speedup fix_BN: false log_iter: 100 + with_test: true restore_ckpt_strict: true restore_hint: 0 save_iter: 10000 diff --git a/config/gaitgl/gaitgl.yaml b/config/gaitgl/gaitgl.yaml index 31e1ffc..d5365d9 100644 --- a/config/gaitgl/gaitgl.yaml +++ b/config/gaitgl/gaitgl.yaml @@ -47,8 +47,8 @@ scheduler_cfg: scheduler: MultiStepLR trainer_cfg: - enable_distributed: true - enable_float16: false + enable_float16: true + with_test: true log_iter: 100 restore_ckpt_strict: true restore_hint: 0 diff --git a/config/gaitpart/gaitpart.yaml b/config/gaitpart/gaitpart.yaml index 6143be3..4af787b 100644 --- a/config/gaitpart/gaitpart.yaml +++ b/config/gaitpart/gaitpart.yaml @@ -59,6 +59,7 @@ scheduler_cfg: trainer_cfg: enable_float16: true log_iter: 100 + with_test: true restore_ckpt_strict: true restore_hint: 0 save_iter: 10000 diff --git a/config/gaitset/gaitset.yaml b/config/gaitset/gaitset.yaml index cc84145..da7a0b7 100644 --- a/config/gaitset/gaitset.yaml +++ b/config/gaitset/gaitset.yaml @@ -57,12 +57,13 @@ scheduler_cfg: trainer_cfg: enable_float16: true log_iter: 100 + with_test: true restore_ckpt_strict: true restore_hint: 0 save_iter: 10000 save_name: GaitSet sync_BN: false - total_iter: 42000 + total_iter: 40000 sampler: batch_shuffle: false batch_size: diff --git a/config/gln/gln_phase1.yaml b/config/gln/gln_phase1.yaml index e2e81f5..db36c5f 100644 --- a/config/gln/gln_phase1.yaml +++ b/config/gln/gln_phase1.yaml @@ -74,10 +74,9 @@ scheduler_cfg: scheduler: MultiStepLR trainer_cfg: - enable_distributed: true enable_float16: true fix_layers: false - with_test: false + with_test: true log_iter: 100 optimizer_reset: false restore_ckpt_strict: true diff --git a/opengait/modeling/losses/softmax.py b/opengait/modeling/losses/softmax.py index a955a02..1d9502a 100644 --- a/opengait/modeling/losses/softmax.py +++ b/opengait/modeling/losses/softmax.py @@ -14,31 +14,29 @@ class CrossEntropyLoss(BaseLoss): def forward(self, logits, labels): """ - logits: [n, p, c] + logits: [n, c, p] labels: [n] """ - logits = logits.permute(1, 0, 2).contiguous() # [n, p, c] -> [p, n, c] - p, _, c = logits.size() - log_preds = F.log_softmax(logits * self.scale, dim=-1) # [p, n, c] + n, c, p = logits.size() + log_preds = F.log_softmax(logits * self.scale, dim=1) # [n, c, p] one_hot_labels = self.label2one_hot( - labels, c).unsqueeze(0).repeat(p, 1, 1) # [p, n, c] + labels, c).unsqueeze(2).repeat(1, 1, p) # [n, c, p] loss = self.compute_loss(log_preds, one_hot_labels) self.info.update({'loss': loss.detach().clone()}) if self.log_accuracy: - pred = logits.argmax(dim=-1) # [p, n] - accu = (pred == labels.unsqueeze(0)).float().mean() + pred = logits.argmax(dim=1) # [n, p] + accu = (pred == labels.unsqueeze(1)).float().mean() self.info.update({'accuracy': accu}) return loss, self.info def compute_loss(self, predis, labels): - softmax_loss = -(labels * predis).sum(-1) # [p, n] - losses = softmax_loss.mean(-1) + softmax_loss = -(labels * predis).sum(1) # [n, p] + losses = softmax_loss.mean(0) # [p] if self.label_smooth: - smooth_loss = - predis.mean(dim=-1) # [p, n] - smooth_loss = smooth_loss.mean() # [p] - smooth_loss = smooth_loss * self.eps - losses = smooth_loss + losses * (1. - self.eps) + smooth_loss = - predis.mean(dim=1) # [n, p] + smooth_loss = smooth_loss.mean(0) # [p] + losses = smooth_loss * self.eps + losses * (1. - self.eps) return losses def label2one_hot(self, label, class_num): diff --git a/opengait/modeling/losses/triplet.py b/opengait/modeling/losses/triplet.py index 60348f9..8445d9a 100644 --- a/opengait/modeling/losses/triplet.py +++ b/opengait/modeling/losses/triplet.py @@ -11,14 +11,13 @@ class TripletLoss(BaseLoss): @gather_and_scale_wrapper def forward(self, embeddings, labels): - # embeddings: [n, p, c], label: [n] + # embeddings: [n, c, p], label: [n] embeddings = embeddings.permute( - 1, 0, 2).contiguous() # [n, p, c] -> [p, n, c] - embeddings = embeddings.float() + 2, 0, 1).contiguous().float() # [n, c, p] -> [p, n, c] ref_embed, ref_label = embeddings, labels dist = self.ComputeDistance(embeddings, ref_embed) # [p, n1, n2] - mean_dist = dist.mean(1).mean(1) + mean_dist = dist.mean((1, 2)) # [p] ap_dist, an_dist = self.Convert2Triplets(labels, ref_label, dist) dist_diff = (ap_dist - an_dist).view(dist.size(0), -1) loss = F.relu(dist_diff + self.margin) @@ -50,7 +49,7 @@ class TripletLoss(BaseLoss): """ x2 = torch.sum(x ** 2, -1).unsqueeze(2) # [p, n_x, 1] y2 = torch.sum(y ** 2, -1).unsqueeze(1) # [p, 1, n_y] - inner = x.matmul(y.transpose(-1, -2)) # [p, n_x, n_y] + inner = x.matmul(y.transpose(1, 2)) # [p, n_x, n_y] dist = x2 + y2 - 2 * inner dist = torch.sqrt(F.relu(dist)) # [p, n_x, n_y] return dist @@ -60,9 +59,10 @@ class TripletLoss(BaseLoss): row_labels: tensor with size [n_r] clo_label : tensor with size [n_c] """ - matches = (row_labels.unsqueeze(1) == clo_label.unsqueeze(0)).bool() # [n_r, n_c] - diffenc = torch.logical_not(matches) # [n_r, n_c] - p, n, m = dist.size() + matches = (row_labels.unsqueeze(1) == + clo_label.unsqueeze(0)).bool() # [n_r, n_c] + diffenc = torch.logical_not(matches) # [n_r, n_c] + p, n, _ = dist.size() ap_dist = dist[:, matches].view(p, n, -1, 1) an_dist = dist[:, diffenc].view(p, n, 1, -1) return ap_dist, an_dist diff --git a/opengait/modeling/models/baseline.py b/opengait/modeling/models/baseline.py index febcfeb..4e1c72f 100644 --- a/opengait/modeling/models/baseline.py +++ b/opengait/modeling/models/baseline.py @@ -19,26 +19,21 @@ class Baseline(BaseModel): sils = ipts[0] if len(sils.size()) == 4: - sils = sils.unsqueeze(2) + sils = sils.unsqueeze(1) del ipts - outs = self.Backbone(sils) # [n, s, c, h, w] + outs = self.Backbone(sils) # [n, c, s, h, w] # Temporal Pooling, TP - outs = self.TP(outs, seqL, dim=1)[0] # [n, c, h, w] + outs = self.TP(outs, seqL, options={"dim": 2})[0] # [n, c, h, w] # Horizontal Pooling Matching, HPM feat = self.HPP(outs) # [n, c, p] - feat = feat.permute(2, 0, 1).contiguous() # [p, n, c] - embed_1 = self.FCs(feat) # [p, n, c] - embed_2, logits = self.BNNecks(embed_1) # [p, n, c] - - embed_1 = embed_1.permute(1, 0, 2).contiguous() # [n, p, c] - embed_2 = embed_2.permute(1, 0, 2).contiguous() # [n, p, c] - logits = logits.permute(1, 0, 2).contiguous() # [n, p, c] + embed_1 = self.FCs(feat) # [n, c, p] + embed_2, logits = self.BNNecks(embed_1) # [n, c, p] embed = embed_1 - n, s, _, h, w = sils.size() + n, _, s, h, w = sils.size() retval = { 'training_feat': { 'triplet': {'embeddings': embed_1, 'labels': labs}, diff --git a/opengait/modeling/models/gaitgl.py b/opengait/modeling/models/gaitgl.py index 75e6ddc..19953f5 100644 --- a/opengait/modeling/models/gaitgl.py +++ b/opengait/modeling/models/gaitgl.py @@ -75,7 +75,7 @@ class GaitGL(BaseModel): class_num = model_cfg['class_num'] dataset_name = self.cfgs['data_cfg']['dataset_name'] - if dataset_name in ['OUMVLP','GREW']: + if dataset_name in ['OUMVLP', 'GREW']: # For OUMVLP and GREW self.conv3d = nn.Sequential( BasicConv3d(1, in_c[0], kernel_size=(3, 3, 3), @@ -135,12 +135,11 @@ class GaitGL(BaseModel): self.GLConvB2 = GLConv(in_c[2], in_c[2], halving=3, fm_sign=True, kernel_size=( 3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)) - self.TP = PackSequenceWrapper(torch.max) self.HPP = GeMHPP() - + self.Head0 = SeparateFCs(64, in_c[-1], in_c[-1]) - + if 'SeparateBNNecks' in model_cfg.keys(): self.BNNecks = SeparateBNNecks(**model_cfg['SeparateBNNecks']) self.Bn_head = False @@ -171,22 +170,18 @@ class GaitGL(BaseModel): outs = self.GLConvA1(outs) outs = self.GLConvB2(outs) # [n, c, s, h, w] - outs = self.TP(outs, dim=2, seq_dim=2, seqL=seqL)[0] # [n, c, h, w] + outs = self.TP(outs, seqL=seqL, options={"dim": 2})[0] # [n, c, h, w] outs = self.HPP(outs) # [n, c, p] - outs = outs.permute(2, 0, 1).contiguous() # [p, n, c] - gait = self.Head0(outs) # [p, n, c] - - if self.Bn_head: # Original GaitGL Head - gait = gait.permute(1, 2, 0).contiguous() # [n, c, p] + gait = self.Head0(outs) # [n, c, p] + + if self.Bn_head: # Original GaitGL Head bnft = self.Bn(gait) # [n, c, p] - logi = self.Head1(bnft.permute(2, 0, 1).contiguous()) # [p, n, c] - embed = bnft.permute(0, 2, 1).contiguous() # [n, p, c] - else: # BNNechk as Head - bnft, logi = self.BNNecks(gait) # [p, n, c] - embed = gait.permute(1, 0, 2).contiguous() # [n, p, c] - - logi = logi.permute(1, 0, 2).contiguous() # [n, p, c] + logi = self.Head1(bnft) # [n, c, p] + embed = bnft + else: # BNNechk as Head + bnft, logi = self.BNNecks(gait) # [n, c, p] + embed = gait n, _, s, h, w = sils.size() retval = { diff --git a/opengait/modeling/models/gaitpart.py b/opengait/modeling/models/gaitpart.py index 242dcaa..3dcef4e 100644 --- a/opengait/modeling/models/gaitpart.py +++ b/opengait/modeling/models/gaitpart.py @@ -45,12 +45,12 @@ class TemporalFeatureAggregator(nn.Module): def forward(self, x): """ - Input: x, [n, s, c, p] - Output: ret, [n, p, c] + Input: x, [n, c, s, p] + Output: ret, [n, c, p] """ - n, s, c, p = x.size() - x = x.permute(3, 0, 2, 1).contiguous() # [p, n, c, s] - feature = x.split(1, 0) # [[n, c, s], ...] + n, c, s, p = x.size() + x = x.permute(3, 0, 1, 2).contiguous() # [p, n, c, s] + feature = x.split(1, 0) # [[1, n, c, s], ...] x = x.view(-1, c, s) # MTB1: ConvNet1d & Sigmoid @@ -73,7 +73,7 @@ class TemporalFeatureAggregator(nn.Module): # Temporal Pooling ret = self.TP(feature3x1 + feature3x3, dim=-1)[0] # [p, n, c] - ret = ret.permute(1, 0, 2).contiguous() # [n, p, c] + ret = ret.permute(1, 2, 0).contiguous() # [n, p, c] return ret @@ -102,17 +102,16 @@ class GaitPart(BaseModel): sils = ipts[0] if len(sils.size()) == 4: - sils = sils.unsqueeze(2) + sils = sils.unsqueeze(1) del ipts - out = self.Backbone(sils) # [n, s, c, h, w] - out = self.HPP(out) # [n, s, c, p] - out = self.TFA(out, seqL) # [n, p, c] + out = self.Backbone(sils) # [n, c, s, h, w] + out = self.HPP(out) # [n, c, s, p] + out = self.TFA(out, seqL) # [n, c, p] - embs = self.Head(out.permute(1, 0, 2).contiguous()) # [p, n, c] - embs = embs.permute(1, 0, 2).contiguous() # [n, p, c] + embs = self.Head(out) # [n, c, p] - n, s, _, h, w = sils.size() + n, _, s, h, w = sils.size() retval = { 'training_feat': { 'triplet': {'embeddings': embs, 'labels': labs} diff --git a/opengait/modeling/models/gaitset.py b/opengait/modeling/models/gaitset.py index 9613bef..3ef45bc 100644 --- a/opengait/modeling/models/gaitset.py +++ b/opengait/modeling/models/gaitset.py @@ -49,30 +49,28 @@ class GaitSet(BaseModel): ipts, labs, _, _, seqL = inputs sils = ipts[0] # [n, s, h, w] if len(sils.size()) == 4: - sils = sils.unsqueeze(2) + sils = sils.unsqueeze(1) del ipts outs = self.set_block1(sils) - gl = self.set_pooling(outs, seqL, dim=1)[0] + gl = self.set_pooling(outs, seqL, options={"dim": 2})[0] gl = self.gl_block2(gl) outs = self.set_block2(outs) - gl = gl + self.set_pooling(outs, seqL, dim=1)[0] + gl = gl + self.set_pooling(outs, seqL, options={"dim": 2})[0] gl = self.gl_block3(gl) outs = self.set_block3(outs) - outs = self.set_pooling(outs, seqL, dim=1)[0] + outs = self.set_pooling(outs, seqL, options={"dim": 2})[0] gl = gl + outs # Horizontal Pooling Matching, HPM feature1 = self.HPP(outs) # [n, c, p] feature2 = self.HPP(gl) # [n, c, p] feature = torch.cat([feature1, feature2], -1) # [n, c, p] - feature = feature.permute(2, 0, 1).contiguous() # [p, n, c] embs = self.Head(feature) - embs = embs.permute(1, 0, 2).contiguous() # [n, p, c] - n, s, _, h, w = sils.size() + n, _, s, h, w = sils.size() retval = { 'training_feat': { 'triplet': {'embeddings': embs, 'labels': labs} diff --git a/opengait/modeling/models/gln.py b/opengait/modeling/models/gln.py index 8e660a0..3f57f01 100644 --- a/opengait/modeling/models/gln.py +++ b/opengait/modeling/models/gln.py @@ -89,12 +89,12 @@ class GLN(BaseModel): sils = ipts[0] # [n, s, h, w] del ipts if len(sils.size()) == 4: - sils = sils.unsqueeze(2) - n, s, _, h, w = sils.size() + sils = sils.unsqueeze(1) + n, _, s, h, w = sils.size() ### stage 0 sil ### sil_0_outs = self.sil_stage_0(sils) - stage_0_sil_set = self.set_pooling(sil_0_outs, seqL, dim=1)[0] + stage_0_sil_set = self.set_pooling(sil_0_outs, seqL, options={"dim": 2})[0] ### stage 1 sil ### sil_1_ipts = self.MaxP_sil(sil_0_outs) @@ -105,13 +105,13 @@ class GLN(BaseModel): sil_2_outs = self.sil_stage_2(sil_2_ipts) ### stage 1 set ### - set_1_ipts = self.set_pooling(sil_1_ipts, seqL, dim=1)[0] - stage_1_sil_set = self.set_pooling(sil_1_outs, seqL, dim=1)[0] + set_1_ipts = self.set_pooling(sil_1_ipts, seqL, options={"dim": 2})[0] + stage_1_sil_set = self.set_pooling(sil_1_outs, seqL, options={"dim": 2})[0] set_1_outs = self.set_stage_1(set_1_ipts) + stage_1_sil_set ### stage 2 set ### set_2_ipts = self.MaxP_set(set_1_outs) - stage_2_sil_set = self.set_pooling(sil_2_outs, seqL, dim=1)[0] + stage_2_sil_set = self.set_pooling(sil_2_outs, seqL, options={"dim": 2})[0] set_2_outs = self.set_stage_2(set_2_ipts) + stage_2_sil_set set1 = torch.cat((stage_0_sil_set, stage_0_sil_set), dim=1) @@ -133,11 +133,9 @@ class GLN(BaseModel): set2 = self.HPP(set2) set3 = self.HPP(set3) - feature = torch.cat([set1, set2, set3], - - 1).permute(2, 0, 1).contiguous() + feature = torch.cat([set1, set2, set3], -1) feature = self.Head(feature) - feature = feature.permute(1, 0, 2).contiguous() # n p c # compact_bloack if not self.pretrain: diff --git a/opengait/modeling/modules.py b/opengait/modeling/modules.py index a5d2b6b..38e81c2 100644 --- a/opengait/modeling/modules.py +++ b/opengait/modeling/modules.py @@ -38,14 +38,14 @@ class SetBlockWrapper(nn.Module): def forward(self, x, *args, **kwargs): """ - In x: [n, s, c, h, w] - Out x: [n, s, ...] + In x: [n, c_in, s, h_in, w_in] + Out x: [n, c_out, s, h_out, w_out] """ - n, s, c, h, w = x.size() - x = self.forward_block(x.view(-1, c, h, w), *args, **kwargs) - input_size = x.size() - output_size = [n, s] + [*input_size[1:]] - return x.view(*output_size) + n, c, s, h, w = x.size() + x = self.forward_block(x.transpose( + 1, 2).view(-1, c, h, w), *args, **kwargs) + output_size = x.size() + return x.reshape(n, s, *output_size[1:]).transpose(1, 2).contiguous() class PackSequenceWrapper(nn.Module): @@ -53,26 +53,20 @@ class PackSequenceWrapper(nn.Module): super(PackSequenceWrapper, self).__init__() self.pooling_func = pooling_func - def forward(self, seqs, seqL, seq_dim=1, **kwargs): + def forward(self, seqs, seqL, dim=2, options={}): """ - In seqs: [n, s, ...] + In seqs: [n, c, s, ...] Out rets: [n, ...] """ if seqL is None: - return self.pooling_func(seqs, **kwargs) + return self.pooling_func(seqs, **options) seqL = seqL[0].data.cpu().numpy().tolist() start = [0] + np.cumsum(seqL).tolist()[:-1] rets = [] for curr_start, curr_seqL in zip(start, seqL): - narrowed_seq = seqs.narrow(seq_dim, curr_start, curr_seqL) - # save the memory - # splited_narrowed_seq = torch.split(narrowed_seq, 256, dim=1) - # ret = [] - # for seq_to_pooling in splited_narrowed_seq: - # ret.append(self.pooling_func(seq_to_pooling, keepdim=True, **kwargs) - # [0] if self.is_tuple_result else self.pooling_func(seq_to_pooling, **kwargs)) - rets.append(self.pooling_func(narrowed_seq, **kwargs)) + narrowed_seq = seqs.narrow(dim, curr_start, curr_seqL) + rets.append(self.pooling_func(narrowed_seq, **options)) if len(rets) > 0 and is_list_or_tuple(rets[0]): return [torch.cat([ret[j] for ret in rets]) for j in range(len(rets[0]))] @@ -101,13 +95,15 @@ class SeparateFCs(nn.Module): def forward(self, x): """ - x: [p, n, c] + x: [n, c_in, p] + out: [n, c_out, p] """ + x = x.permute(2, 0, 1).contiguous() if self.norm: out = x.matmul(F.normalize(self.fc_bin, dim=1)) else: out = x.matmul(self.fc_bin) - return out + return out.permute(1, 2, 0).contiguous() class SeparateBNNecks(nn.Module): @@ -133,24 +129,24 @@ class SeparateBNNecks(nn.Module): def forward(self, x): """ - x: [p, n, c] + x: [n, c, p] """ if self.parallel_BN1d: - p, n, c = x.size() - x = x.transpose(0, 1).contiguous().view(n, -1) # [n, p*c] + n, c, p = x.size() + x = x.view(n, -1) # [n, c*p] x = self.bn1d(x) - x = x.view(n, p, c).permute(1, 0, 2).contiguous() + x = x.view(n, c, p) else: - x = torch.cat([bn(_.squeeze(0)).unsqueeze(0) - for _, bn in zip(x.split(1, 0), self.bn1d)], 0) # [p, n, c] + x = torch.cat([bn(_x) for _x, bn in zip( + x.split(1, 2), self.bn1d)], 2) # [p, n, c] + feature = x.permute(2, 0, 1).contiguous() if self.norm: - feature = F.normalize(x, dim=-1) # [p, n, c] + feature = F.normalize(feature, dim=-1) # [p, n, c] logits = feature.matmul(F.normalize( self.fc_bin, dim=1)) # [p, n, c] else: - feature = x logits = feature.matmul(self.fc_bin) - return feature, logits + return feature.permute(1, 2, 0).contiguous(), logits.permute(1, 2, 0).contiguous() class FocalConv2d(nn.Module): diff --git a/opengait/utils/evaluation.py b/opengait/utils/evaluation.py index fde4d17..3005ad7 100644 --- a/opengait/utils/evaluation.py +++ b/opengait/utils/evaluation.py @@ -10,20 +10,20 @@ def cuda_dist(x, y, metric='euc'): x = torch.from_numpy(x).cuda() y = torch.from_numpy(y).cuda() if metric == 'cos': - x = F.normalize(x, p=2, dim=2) # n p c - y = F.normalize(y, p=2, dim=2) # n p c - num_bin = x.size(1) + x = F.normalize(x, p=2, dim=1) # n c p + y = F.normalize(y, p=2, dim=1) # n c p + num_bin = x.size(2) n_x = x.size(0) n_y = y.size(0) dist = torch.zeros(n_x, n_y).cuda() for i in range(num_bin): - _x = x[:, i, ...] - _y = y[:, i, ...] + _x = x[:, :, i] + _y = y[:, :, i] if metric == 'cos': dist += torch.matmul(_x, _y.transpose(0, 1)) else: _dist = torch.sum(_x ** 2, 1).unsqueeze(1) + torch.sum(_y ** 2, 1).unsqueeze( - 1).transpose(0, 1) - 2 * torch.matmul(_x, _y.transpose(0, 1)) + 0) - 2 * torch.matmul(_x, _y.transpose(0, 1)) dist += torch.sqrt(F.relu(_dist)) return 1 - dist/num_bin if metric == 'cos' else dist / num_bin