Solve the problem of dimension misuse. (#59)

* commit for fix dimension

* fix dimension for all method

* restore config

* clean up baseline config

* add contiguous

* rm comment
This commit is contained in:
Junhao Liang
2022-06-28 12:27:16 +08:00
committed by GitHub
parent 715e7448fa
commit 14fa5212d4
14 changed files with 99 additions and 121 deletions
+25 -29
View File
@@ -38,14 +38,14 @@ class SetBlockWrapper(nn.Module):
def forward(self, x, *args, **kwargs):
"""
In x: [n, s, c, h, w]
Out x: [n, s, ...]
In x: [n, c_in, s, h_in, w_in]
Out x: [n, c_out, s, h_out, w_out]
"""
n, s, c, h, w = x.size()
x = self.forward_block(x.view(-1, c, h, w), *args, **kwargs)
input_size = x.size()
output_size = [n, s] + [*input_size[1:]]
return x.view(*output_size)
n, c, s, h, w = x.size()
x = self.forward_block(x.transpose(
1, 2).view(-1, c, h, w), *args, **kwargs)
output_size = x.size()
return x.reshape(n, s, *output_size[1:]).transpose(1, 2).contiguous()
class PackSequenceWrapper(nn.Module):
@@ -53,26 +53,20 @@ class PackSequenceWrapper(nn.Module):
super(PackSequenceWrapper, self).__init__()
self.pooling_func = pooling_func
def forward(self, seqs, seqL, seq_dim=1, **kwargs):
def forward(self, seqs, seqL, dim=2, options={}):
"""
In seqs: [n, s, ...]
In seqs: [n, c, s, ...]
Out rets: [n, ...]
"""
if seqL is None:
return self.pooling_func(seqs, **kwargs)
return self.pooling_func(seqs, **options)
seqL = seqL[0].data.cpu().numpy().tolist()
start = [0] + np.cumsum(seqL).tolist()[:-1]
rets = []
for curr_start, curr_seqL in zip(start, seqL):
narrowed_seq = seqs.narrow(seq_dim, curr_start, curr_seqL)
# save the memory
# splited_narrowed_seq = torch.split(narrowed_seq, 256, dim=1)
# ret = []
# for seq_to_pooling in splited_narrowed_seq:
# ret.append(self.pooling_func(seq_to_pooling, keepdim=True, **kwargs)
# [0] if self.is_tuple_result else self.pooling_func(seq_to_pooling, **kwargs))
rets.append(self.pooling_func(narrowed_seq, **kwargs))
narrowed_seq = seqs.narrow(dim, curr_start, curr_seqL)
rets.append(self.pooling_func(narrowed_seq, **options))
if len(rets) > 0 and is_list_or_tuple(rets[0]):
return [torch.cat([ret[j] for ret in rets])
for j in range(len(rets[0]))]
@@ -101,13 +95,15 @@ class SeparateFCs(nn.Module):
def forward(self, x):
"""
x: [p, n, c]
x: [n, c_in, p]
out: [n, c_out, p]
"""
x = x.permute(2, 0, 1).contiguous()
if self.norm:
out = x.matmul(F.normalize(self.fc_bin, dim=1))
else:
out = x.matmul(self.fc_bin)
return out
return out.permute(1, 2, 0).contiguous()
class SeparateBNNecks(nn.Module):
@@ -133,24 +129,24 @@ class SeparateBNNecks(nn.Module):
def forward(self, x):
"""
x: [p, n, c]
x: [n, c, p]
"""
if self.parallel_BN1d:
p, n, c = x.size()
x = x.transpose(0, 1).contiguous().view(n, -1) # [n, p*c]
n, c, p = x.size()
x = x.view(n, -1) # [n, c*p]
x = self.bn1d(x)
x = x.view(n, p, c).permute(1, 0, 2).contiguous()
x = x.view(n, c, p)
else:
x = torch.cat([bn(_.squeeze(0)).unsqueeze(0)
for _, bn in zip(x.split(1, 0), self.bn1d)], 0) # [p, n, c]
x = torch.cat([bn(_x) for _x, bn in zip(
x.split(1, 2), self.bn1d)], 2) # [p, n, c]
feature = x.permute(2, 0, 1).contiguous()
if self.norm:
feature = F.normalize(x, dim=-1) # [p, n, c]
feature = F.normalize(feature, dim=-1) # [p, n, c]
logits = feature.matmul(F.normalize(
self.fc_bin, dim=1)) # [p, n, c]
else:
feature = x
logits = feature.matmul(self.fc_bin)
return feature, logits
return feature.permute(1, 2, 0).contiguous(), logits.permute(1, 2, 0).contiguous()
class FocalConv2d(nn.Module):