fix bug within gaitgl model

This commit is contained in:
Noah
2022-04-20 17:26:01 +08:00
parent 3c71141b3e
commit 7172b53af1
+3 -3
View File
@@ -181,17 +181,17 @@ class GaitGL(BaseModel):
gait = gait.permute(1, 2, 0).contiguous() # [n, c, p] gait = gait.permute(1, 2, 0).contiguous() # [n, c, p]
bnft = self.Bn(gait) # [n, c, p] bnft = self.Bn(gait) # [n, c, p]
logi = self.Head1(bnft.permute(2, 0, 1).contiguous()) # [p, n, c] logi = self.Head1(bnft.permute(2, 0, 1).contiguous()) # [p, n, c]
bnft = bnft.permute(0, 2, 1).contiguous() # [n, p, c] embed = bnft.permute(0, 2, 1).contiguous() # [n, p, c]
else: # BNNechk as Head else: # BNNechk as Head
bnft, logi = self.BNNecks(gait) # [p, n, c] bnft, logi = self.BNNecks(gait) # [p, n, c]
bnft = bnft.permute(1, 0, 2).contiguous() # [n, p, c] embed = gait.permute(1, 0, 2).contiguous() # [n, p, c]
logi = logi.permute(1, 0, 2).contiguous() # [n, p, c] logi = logi.permute(1, 0, 2).contiguous() # [n, p, c]
n, _, s, h, w = sils.size() n, _, s, h, w = sils.size()
retval = { retval = {
'training_feat': { 'training_feat': {
'triplet': {'embeddings': bnft, 'labels': labs}, 'triplet': {'embeddings': embed, 'labels': labs},
'softmax': {'logits': logi, 'labels': labs} 'softmax': {'logits': logi, 'labels': labs}
}, },
'visual_summary': { 'visual_summary': {