Solve the problem of dimension misuse. (#59)

* commit for fix dimension

* fix dimension for all method

* restore config

* clean up baseline config

* add contiguous

* rm comment
This commit is contained in:
Junhao Liang
2022-06-28 12:27:16 +08:00
committed by GitHub
parent 715e7448fa
commit 14fa5212d4
14 changed files with 99 additions and 121 deletions
+6 -6
View File
@@ -10,20 +10,20 @@ def cuda_dist(x, y, metric='euc'):
x = torch.from_numpy(x).cuda()
y = torch.from_numpy(y).cuda()
if metric == 'cos':
x = F.normalize(x, p=2, dim=2) # n p c
y = F.normalize(y, p=2, dim=2) # n p c
num_bin = x.size(1)
x = F.normalize(x, p=2, dim=1) # n c p
y = F.normalize(y, p=2, dim=1) # n c p
num_bin = x.size(2)
n_x = x.size(0)
n_y = y.size(0)
dist = torch.zeros(n_x, n_y).cuda()
for i in range(num_bin):
_x = x[:, i, ...]
_y = y[:, i, ...]
_x = x[:, :, i]
_y = y[:, :, i]
if metric == 'cos':
dist += torch.matmul(_x, _y.transpose(0, 1))
else:
_dist = torch.sum(_x ** 2, 1).unsqueeze(1) + torch.sum(_y ** 2, 1).unsqueeze(
1).transpose(0, 1) - 2 * torch.matmul(_x, _y.transpose(0, 1))
0) - 2 * torch.matmul(_x, _y.transpose(0, 1))
dist += torch.sqrt(F.relu(_dist))
return 1 - dist/num_bin if metric == 'cos' else dist / num_bin