diff --git a/opengait/modeling/models/gaitgl.py b/opengait/modeling/models/gaitgl.py index d4218b5..75e6ddc 100644 --- a/opengait/modeling/models/gaitgl.py +++ b/opengait/modeling/models/gaitgl.py @@ -181,17 +181,17 @@ class GaitGL(BaseModel): gait = gait.permute(1, 2, 0).contiguous() # [n, c, p] bnft = self.Bn(gait) # [n, c, p] logi = self.Head1(bnft.permute(2, 0, 1).contiguous()) # [p, n, c] - bnft = bnft.permute(0, 2, 1).contiguous() # [n, p, c] + embed = bnft.permute(0, 2, 1).contiguous() # [n, p, c] else: # BNNechk as Head - bnft, logi = self.BNNecks(gait) # [p, n, c] - bnft = bnft.permute(1, 0, 2).contiguous() # [n, p, c] + bnft, logi = self.BNNecks(gait) # [p, n, c] + embed = gait.permute(1, 0, 2).contiguous() # [n, p, c] logi = logi.permute(1, 0, 2).contiguous() # [n, p, c] n, _, s, h, w = sils.size() retval = { 'training_feat': { - 'triplet': {'embeddings': bnft, 'labels': labs}, + 'triplet': {'embeddings': embed, 'labels': labs}, 'softmax': {'logits': logi, 'labels': labs} }, 'visual_summary': {