Solve the problem of dimension misuse. (#59)

* commit for fix dimension

* fix dimension for all method

* restore config

* clean up baseline config

* add contiguous

* rm comment
This commit is contained in:
Junhao Liang
2022-06-28 12:27:16 +08:00
committed by GitHub
parent 715e7448fa
commit 14fa5212d4
14 changed files with 99 additions and 121 deletions
+11 -13
View File
@@ -14,31 +14,29 @@ class CrossEntropyLoss(BaseLoss):
def forward(self, logits, labels):
"""
logits: [n, p, c]
logits: [n, c, p]
labels: [n]
"""
logits = logits.permute(1, 0, 2).contiguous() # [n, p, c] -> [p, n, c]
p, _, c = logits.size()
log_preds = F.log_softmax(logits * self.scale, dim=-1) # [p, n, c]
n, c, p = logits.size()
log_preds = F.log_softmax(logits * self.scale, dim=1) # [n, c, p]
one_hot_labels = self.label2one_hot(
labels, c).unsqueeze(0).repeat(p, 1, 1) # [p, n, c]
labels, c).unsqueeze(2).repeat(1, 1, p) # [n, c, p]
loss = self.compute_loss(log_preds, one_hot_labels)
self.info.update({'loss': loss.detach().clone()})
if self.log_accuracy:
pred = logits.argmax(dim=-1) # [p, n]
accu = (pred == labels.unsqueeze(0)).float().mean()
pred = logits.argmax(dim=1) # [n, p]
accu = (pred == labels.unsqueeze(1)).float().mean()
self.info.update({'accuracy': accu})
return loss, self.info
def compute_loss(self, predis, labels):
softmax_loss = -(labels * predis).sum(-1) # [p, n]
losses = softmax_loss.mean(-1)
softmax_loss = -(labels * predis).sum(1) # [n, p]
losses = softmax_loss.mean(0) # [p]
if self.label_smooth:
smooth_loss = - predis.mean(dim=-1) # [p, n]
smooth_loss = smooth_loss.mean() # [p]
smooth_loss = smooth_loss * self.eps
losses = smooth_loss + losses * (1. - self.eps)
smooth_loss = - predis.mean(dim=1) # [n, p]
smooth_loss = smooth_loss.mean(0) # [p]
losses = smooth_loss * self.eps + losses * (1. - self.eps)
return losses
def label2one_hot(self, label, class_num):
+8 -8
View File
@@ -11,14 +11,13 @@ class TripletLoss(BaseLoss):
@gather_and_scale_wrapper
def forward(self, embeddings, labels):
# embeddings: [n, p, c], label: [n]
# embeddings: [n, c, p], label: [n]
embeddings = embeddings.permute(
1, 0, 2).contiguous() # [n, p, c] -> [p, n, c]
embeddings = embeddings.float()
2, 0, 1).contiguous().float() # [n, c, p] -> [p, n, c]
ref_embed, ref_label = embeddings, labels
dist = self.ComputeDistance(embeddings, ref_embed) # [p, n1, n2]
mean_dist = dist.mean(1).mean(1)
mean_dist = dist.mean((1, 2)) # [p]
ap_dist, an_dist = self.Convert2Triplets(labels, ref_label, dist)
dist_diff = (ap_dist - an_dist).view(dist.size(0), -1)
loss = F.relu(dist_diff + self.margin)
@@ -50,7 +49,7 @@ class TripletLoss(BaseLoss):
"""
x2 = torch.sum(x ** 2, -1).unsqueeze(2) # [p, n_x, 1]
y2 = torch.sum(y ** 2, -1).unsqueeze(1) # [p, 1, n_y]
inner = x.matmul(y.transpose(-1, -2)) # [p, n_x, n_y]
inner = x.matmul(y.transpose(1, 2)) # [p, n_x, n_y]
dist = x2 + y2 - 2 * inner
dist = torch.sqrt(F.relu(dist)) # [p, n_x, n_y]
return dist
@@ -60,9 +59,10 @@ class TripletLoss(BaseLoss):
row_labels: tensor with size [n_r]
clo_label : tensor with size [n_c]
"""
matches = (row_labels.unsqueeze(1) == clo_label.unsqueeze(0)).bool() # [n_r, n_c]
diffenc = torch.logical_not(matches) # [n_r, n_c]
p, n, m = dist.size()
matches = (row_labels.unsqueeze(1) ==
clo_label.unsqueeze(0)).bool() # [n_r, n_c]
diffenc = torch.logical_not(matches) # [n_r, n_c]
p, n, _ = dist.size()
ap_dist = dist[:, matches].view(p, n, -1, 1)
an_dist = dist[:, diffenc].view(p, n, 1, -1)
return ap_dist, an_dist