a faster triplet using mask indexing

This commit is contained in:
wj1tr0y
2022-04-16 19:34:36 +08:00
committed by Junhao Liang
parent ff398acbc7
commit e714fa9075
+6 -9
View File
@@ -20,7 +20,7 @@ class TripletLoss(BaseLoss):
dist = self.ComputeDistance(embeddings, ref_embed) # [p, n1, n2]
mean_dist = dist.mean(1).mean(1)
ap_dist, an_dist = self.Convert2Triplets(labels, ref_label, dist)
dist_diff = ap_dist - an_dist
dist_diff = (ap_dist - an_dist).view(dist.size(0), -1)
loss = F.relu(dist_diff + self.margin)
hard_loss = torch.max(loss, -1)[0]
@@ -60,12 +60,9 @@ class TripletLoss(BaseLoss):
row_labels: tensor with size [n_r]
clo_label : tensor with size [n_c]
"""
matches = (row_labels.unsqueeze(1) ==
clo_label.unsqueeze(0)).byte() # [n_r, n_c]
diffenc = matches ^ 1 # [n_r, n_c]
mask = matches.unsqueeze(2) * diffenc.unsqueeze(1)
a_idx, p_idx, n_idx = torch.where(mask)
ap_dist = dist[:, a_idx, p_idx]
an_dist = dist[:, a_idx, n_idx]
matches = (row_labels.unsqueeze(1) == clo_label.unsqueeze(0)).bool() # [n_r, n_c]
diffenc = torch.logical_not(matches) # [n_r, n_c]
p, n, m = dist.size()
ap_dist = dist[:, matches].view(p, n, -1, 1)
an_dist = dist[:, diffenc].view(p, n, 1, -1)
return ap_dist, an_dist