a faster triplet using mask indexing
This commit is contained in:
@@ -20,7 +20,7 @@ class TripletLoss(BaseLoss):
|
|||||||
dist = self.ComputeDistance(embeddings, ref_embed) # [p, n1, n2]
|
dist = self.ComputeDistance(embeddings, ref_embed) # [p, n1, n2]
|
||||||
mean_dist = dist.mean(1).mean(1)
|
mean_dist = dist.mean(1).mean(1)
|
||||||
ap_dist, an_dist = self.Convert2Triplets(labels, ref_label, dist)
|
ap_dist, an_dist = self.Convert2Triplets(labels, ref_label, dist)
|
||||||
dist_diff = ap_dist - an_dist
|
dist_diff = (ap_dist - an_dist).view(dist.size(0), -1)
|
||||||
loss = F.relu(dist_diff + self.margin)
|
loss = F.relu(dist_diff + self.margin)
|
||||||
|
|
||||||
hard_loss = torch.max(loss, -1)[0]
|
hard_loss = torch.max(loss, -1)[0]
|
||||||
@@ -60,12 +60,9 @@ class TripletLoss(BaseLoss):
|
|||||||
row_labels: tensor with size [n_r]
|
row_labels: tensor with size [n_r]
|
||||||
clo_label : tensor with size [n_c]
|
clo_label : tensor with size [n_c]
|
||||||
"""
|
"""
|
||||||
matches = (row_labels.unsqueeze(1) ==
|
matches = (row_labels.unsqueeze(1) == clo_label.unsqueeze(0)).bool() # [n_r, n_c]
|
||||||
clo_label.unsqueeze(0)).byte() # [n_r, n_c]
|
diffenc = torch.logical_not(matches) # [n_r, n_c]
|
||||||
diffenc = matches ^ 1 # [n_r, n_c]
|
p, n, m = dist.size()
|
||||||
mask = matches.unsqueeze(2) * diffenc.unsqueeze(1)
|
ap_dist = dist[:, matches].view(p, n, -1, 1)
|
||||||
a_idx, p_idx, n_idx = torch.where(mask)
|
an_dist = dist[:, diffenc].view(p, n, 1, -1)
|
||||||
|
|
||||||
ap_dist = dist[:, a_idx, p_idx]
|
|
||||||
an_dist = dist[:, a_idx, n_idx]
|
|
||||||
return ap_dist, an_dist
|
return ap_dist, an_dist
|
||||||
|
|||||||
Reference in New Issue
Block a user