Solve the problem of dimension misuse. (#59)

* commit for fix dimension

* fix dimension for all method

* restore config

* clean up baseline config

* add contiguous

* rm comment
This commit is contained in:
Junhao Liang
2022-06-28 12:27:16 +08:00
committed by GitHub
parent 715e7448fa
commit 14fa5212d4
14 changed files with 99 additions and 121 deletions
+6 -11
View File
@@ -19,26 +19,21 @@ class Baseline(BaseModel):
sils = ipts[0]
if len(sils.size()) == 4:
sils = sils.unsqueeze(2)
sils = sils.unsqueeze(1)
del ipts
outs = self.Backbone(sils) # [n, s, c, h, w]
outs = self.Backbone(sils) # [n, c, s, h, w]
# Temporal Pooling, TP
outs = self.TP(outs, seqL, dim=1)[0] # [n, c, h, w]
outs = self.TP(outs, seqL, options={"dim": 2})[0] # [n, c, h, w]
# Horizontal Pooling Matching, HPM
feat = self.HPP(outs) # [n, c, p]
feat = feat.permute(2, 0, 1).contiguous() # [p, n, c]
embed_1 = self.FCs(feat) # [p, n, c]
embed_2, logits = self.BNNecks(embed_1) # [p, n, c]
embed_1 = embed_1.permute(1, 0, 2).contiguous() # [n, p, c]
embed_2 = embed_2.permute(1, 0, 2).contiguous() # [n, p, c]
logits = logits.permute(1, 0, 2).contiguous() # [n, p, c]
embed_1 = self.FCs(feat) # [n, c, p]
embed_2, logits = self.BNNecks(embed_1) # [n, c, p]
embed = embed_1
n, s, _, h, w = sils.size()
n, _, s, h, w = sils.size()
retval = {
'training_feat': {
'triplet': {'embeddings': embed_1, 'labels': labs},