fix bug within gaitgl model

This commit is contained in:
Noah
2022-04-20 17:26:01 +08:00
parent 3c71141b3e
commit 7172b53af1
+4 -4
View File
@@ -181,17 +181,17 @@ class GaitGL(BaseModel):
gait = gait.permute(1, 2, 0).contiguous() # [n, c, p]
bnft = self.Bn(gait) # [n, c, p]
logi = self.Head1(bnft.permute(2, 0, 1).contiguous()) # [p, n, c]
bnft = bnft.permute(0, 2, 1).contiguous() # [n, p, c]
embed = bnft.permute(0, 2, 1).contiguous() # [n, p, c]
else: # BNNechk as Head
bnft, logi = self.BNNecks(gait) # [p, n, c]
bnft = bnft.permute(1, 0, 2).contiguous() # [n, p, c]
bnft, logi = self.BNNecks(gait) # [p, n, c]
embed = gait.permute(1, 0, 2).contiguous() # [n, p, c]
logi = logi.permute(1, 0, 2).contiguous() # [n, p, c]
n, _, s, h, w = sils.size()
retval = {
'training_feat': {
'triplet': {'embeddings': bnft, 'labels': labs},
'triplet': {'embeddings': embed, 'labels': labs},
'softmax': {'logits': logi, 'labels': labs}
},
'visual_summary': {