Solve the problem of dimension misuse. (#59)
* commit for fix dimension * fix dimension for all method * restore config * clean up baseline config * add contiguous * rm comment
This commit is contained in:
@@ -89,12 +89,12 @@ class GLN(BaseModel):
|
||||
sils = ipts[0] # [n, s, h, w]
|
||||
del ipts
|
||||
if len(sils.size()) == 4:
|
||||
sils = sils.unsqueeze(2)
|
||||
n, s, _, h, w = sils.size()
|
||||
sils = sils.unsqueeze(1)
|
||||
n, _, s, h, w = sils.size()
|
||||
|
||||
### stage 0 sil ###
|
||||
sil_0_outs = self.sil_stage_0(sils)
|
||||
stage_0_sil_set = self.set_pooling(sil_0_outs, seqL, dim=1)[0]
|
||||
stage_0_sil_set = self.set_pooling(sil_0_outs, seqL, options={"dim": 2})[0]
|
||||
|
||||
### stage 1 sil ###
|
||||
sil_1_ipts = self.MaxP_sil(sil_0_outs)
|
||||
@@ -105,13 +105,13 @@ class GLN(BaseModel):
|
||||
sil_2_outs = self.sil_stage_2(sil_2_ipts)
|
||||
|
||||
### stage 1 set ###
|
||||
set_1_ipts = self.set_pooling(sil_1_ipts, seqL, dim=1)[0]
|
||||
stage_1_sil_set = self.set_pooling(sil_1_outs, seqL, dim=1)[0]
|
||||
set_1_ipts = self.set_pooling(sil_1_ipts, seqL, options={"dim": 2})[0]
|
||||
stage_1_sil_set = self.set_pooling(sil_1_outs, seqL, options={"dim": 2})[0]
|
||||
set_1_outs = self.set_stage_1(set_1_ipts) + stage_1_sil_set
|
||||
|
||||
### stage 2 set ###
|
||||
set_2_ipts = self.MaxP_set(set_1_outs)
|
||||
stage_2_sil_set = self.set_pooling(sil_2_outs, seqL, dim=1)[0]
|
||||
stage_2_sil_set = self.set_pooling(sil_2_outs, seqL, options={"dim": 2})[0]
|
||||
set_2_outs = self.set_stage_2(set_2_ipts) + stage_2_sil_set
|
||||
|
||||
set1 = torch.cat((stage_0_sil_set, stage_0_sil_set), dim=1)
|
||||
@@ -133,11 +133,9 @@ class GLN(BaseModel):
|
||||
set2 = self.HPP(set2)
|
||||
set3 = self.HPP(set3)
|
||||
|
||||
feature = torch.cat([set1, set2, set3], -
|
||||
1).permute(2, 0, 1).contiguous()
|
||||
feature = torch.cat([set1, set2, set3], -1)
|
||||
|
||||
feature = self.Head(feature)
|
||||
feature = feature.permute(1, 0, 2).contiguous() # n p c
|
||||
|
||||
# compact_bloack
|
||||
if not self.pretrain:
|
||||
|
||||
Reference in New Issue
Block a user