Solve the problem of dimension misuse. (#59)
* commit for fix dimension * fix dimension for all method * restore config * clean up baseline config * add contiguous * rm comment
This commit is contained in:
@@ -10,20 +10,20 @@ def cuda_dist(x, y, metric='euc'):
|
||||
x = torch.from_numpy(x).cuda()
|
||||
y = torch.from_numpy(y).cuda()
|
||||
if metric == 'cos':
|
||||
x = F.normalize(x, p=2, dim=2) # n p c
|
||||
y = F.normalize(y, p=2, dim=2) # n p c
|
||||
num_bin = x.size(1)
|
||||
x = F.normalize(x, p=2, dim=1) # n c p
|
||||
y = F.normalize(y, p=2, dim=1) # n c p
|
||||
num_bin = x.size(2)
|
||||
n_x = x.size(0)
|
||||
n_y = y.size(0)
|
||||
dist = torch.zeros(n_x, n_y).cuda()
|
||||
for i in range(num_bin):
|
||||
_x = x[:, i, ...]
|
||||
_y = y[:, i, ...]
|
||||
_x = x[:, :, i]
|
||||
_y = y[:, :, i]
|
||||
if metric == 'cos':
|
||||
dist += torch.matmul(_x, _y.transpose(0, 1))
|
||||
else:
|
||||
_dist = torch.sum(_x ** 2, 1).unsqueeze(1) + torch.sum(_y ** 2, 1).unsqueeze(
|
||||
1).transpose(0, 1) - 2 * torch.matmul(_x, _y.transpose(0, 1))
|
||||
0) - 2 * torch.matmul(_x, _y.transpose(0, 1))
|
||||
dist += torch.sqrt(F.relu(_dist))
|
||||
return 1 - dist/num_bin if metric == 'cos' else dist / num_bin
|
||||
|
||||
|
||||
Reference in New Issue
Block a user