Solve the problem of dimension misuse. (#59)
* commit for fix dimension * fix dimension for all method * restore config * clean up baseline config * add contiguous * rm comment
This commit is contained in:
@@ -45,12 +45,12 @@ class TemporalFeatureAggregator(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Input: x, [n, s, c, p]
|
||||
Output: ret, [n, p, c]
|
||||
Input: x, [n, c, s, p]
|
||||
Output: ret, [n, c, p]
|
||||
"""
|
||||
n, s, c, p = x.size()
|
||||
x = x.permute(3, 0, 2, 1).contiguous() # [p, n, c, s]
|
||||
feature = x.split(1, 0) # [[n, c, s], ...]
|
||||
n, c, s, p = x.size()
|
||||
x = x.permute(3, 0, 1, 2).contiguous() # [p, n, c, s]
|
||||
feature = x.split(1, 0) # [[1, n, c, s], ...]
|
||||
x = x.view(-1, c, s)
|
||||
|
||||
# MTB1: ConvNet1d & Sigmoid
|
||||
@@ -73,7 +73,7 @@ class TemporalFeatureAggregator(nn.Module):
|
||||
|
||||
# Temporal Pooling
|
||||
ret = self.TP(feature3x1 + feature3x3, dim=-1)[0] # [p, n, c]
|
||||
ret = ret.permute(1, 0, 2).contiguous() # [n, p, c]
|
||||
ret = ret.permute(1, 2, 0).contiguous() # [n, p, c]
|
||||
return ret
|
||||
|
||||
|
||||
@@ -102,17 +102,16 @@ class GaitPart(BaseModel):
|
||||
|
||||
sils = ipts[0]
|
||||
if len(sils.size()) == 4:
|
||||
sils = sils.unsqueeze(2)
|
||||
sils = sils.unsqueeze(1)
|
||||
|
||||
del ipts
|
||||
out = self.Backbone(sils) # [n, s, c, h, w]
|
||||
out = self.HPP(out) # [n, s, c, p]
|
||||
out = self.TFA(out, seqL) # [n, p, c]
|
||||
out = self.Backbone(sils) # [n, c, s, h, w]
|
||||
out = self.HPP(out) # [n, c, s, p]
|
||||
out = self.TFA(out, seqL) # [n, c, p]
|
||||
|
||||
embs = self.Head(out.permute(1, 0, 2).contiguous()) # [p, n, c]
|
||||
embs = embs.permute(1, 0, 2).contiguous() # [n, p, c]
|
||||
embs = self.Head(out) # [n, c, p]
|
||||
|
||||
n, s, _, h, w = sils.size()
|
||||
n, _, s, h, w = sils.size()
|
||||
retval = {
|
||||
'training_feat': {
|
||||
'triplet': {'embeddings': embs, 'labels': labs}
|
||||
|
||||
Reference in New Issue
Block a user