Solve the problem of dimension misuse. (#59)
* commit for fix dimension * fix dimension for all method * restore config * clean up baseline config * add contiguous * rm comment
This commit is contained in:
@@ -11,14 +11,13 @@ class TripletLoss(BaseLoss):
|
||||
|
||||
@gather_and_scale_wrapper
|
||||
def forward(self, embeddings, labels):
|
||||
# embeddings: [n, p, c], label: [n]
|
||||
# embeddings: [n, c, p], label: [n]
|
||||
embeddings = embeddings.permute(
|
||||
1, 0, 2).contiguous() # [n, p, c] -> [p, n, c]
|
||||
embeddings = embeddings.float()
|
||||
2, 0, 1).contiguous().float() # [n, c, p] -> [p, n, c]
|
||||
|
||||
ref_embed, ref_label = embeddings, labels
|
||||
dist = self.ComputeDistance(embeddings, ref_embed) # [p, n1, n2]
|
||||
mean_dist = dist.mean(1).mean(1)
|
||||
mean_dist = dist.mean((1, 2)) # [p]
|
||||
ap_dist, an_dist = self.Convert2Triplets(labels, ref_label, dist)
|
||||
dist_diff = (ap_dist - an_dist).view(dist.size(0), -1)
|
||||
loss = F.relu(dist_diff + self.margin)
|
||||
@@ -50,7 +49,7 @@ class TripletLoss(BaseLoss):
|
||||
"""
|
||||
x2 = torch.sum(x ** 2, -1).unsqueeze(2) # [p, n_x, 1]
|
||||
y2 = torch.sum(y ** 2, -1).unsqueeze(1) # [p, 1, n_y]
|
||||
inner = x.matmul(y.transpose(-1, -2)) # [p, n_x, n_y]
|
||||
inner = x.matmul(y.transpose(1, 2)) # [p, n_x, n_y]
|
||||
dist = x2 + y2 - 2 * inner
|
||||
dist = torch.sqrt(F.relu(dist)) # [p, n_x, n_y]
|
||||
return dist
|
||||
@@ -60,9 +59,10 @@ class TripletLoss(BaseLoss):
|
||||
row_labels: tensor with size [n_r]
|
||||
clo_label : tensor with size [n_c]
|
||||
"""
|
||||
matches = (row_labels.unsqueeze(1) == clo_label.unsqueeze(0)).bool() # [n_r, n_c]
|
||||
diffenc = torch.logical_not(matches) # [n_r, n_c]
|
||||
p, n, m = dist.size()
|
||||
matches = (row_labels.unsqueeze(1) ==
|
||||
clo_label.unsqueeze(0)).bool() # [n_r, n_c]
|
||||
diffenc = torch.logical_not(matches) # [n_r, n_c]
|
||||
p, n, _ = dist.size()
|
||||
ap_dist = dist[:, matches].view(p, n, -1, 1)
|
||||
an_dist = dist[:, diffenc].view(p, n, 1, -1)
|
||||
return ap_dist, an_dist
|
||||
|
||||
Reference in New Issue
Block a user