a faster triplet using mask indexing
This commit is contained in:
@@ -20,7 +20,7 @@ class TripletLoss(BaseLoss):
|
||||
dist = self.ComputeDistance(embeddings, ref_embed) # [p, n1, n2]
|
||||
mean_dist = dist.mean(1).mean(1)
|
||||
ap_dist, an_dist = self.Convert2Triplets(labels, ref_label, dist)
|
||||
dist_diff = ap_dist - an_dist
|
||||
dist_diff = (ap_dist - an_dist).view(dist.size(0), -1)
|
||||
loss = F.relu(dist_diff + self.margin)
|
||||
|
||||
hard_loss = torch.max(loss, -1)[0]
|
||||
@@ -60,12 +60,9 @@ class TripletLoss(BaseLoss):
|
||||
row_labels: tensor with size [n_r]
|
||||
clo_label : tensor with size [n_c]
|
||||
"""
|
||||
matches = (row_labels.unsqueeze(1) ==
|
||||
clo_label.unsqueeze(0)).byte() # [n_r, n_c]
|
||||
diffenc = matches ^ 1 # [n_r, n_c]
|
||||
mask = matches.unsqueeze(2) * diffenc.unsqueeze(1)
|
||||
a_idx, p_idx, n_idx = torch.where(mask)
|
||||
|
||||
ap_dist = dist[:, a_idx, p_idx]
|
||||
an_dist = dist[:, a_idx, n_idx]
|
||||
matches = (row_labels.unsqueeze(1) == clo_label.unsqueeze(0)).bool() # [n_r, n_c]
|
||||
diffenc = torch.logical_not(matches) # [n_r, n_c]
|
||||
p, n, m = dist.size()
|
||||
ap_dist = dist[:, matches].view(p, n, -1, 1)
|
||||
an_dist = dist[:, diffenc].view(p, n, 1, -1)
|
||||
return ap_dist, an_dist
|
||||
|
||||
Reference in New Issue
Block a user