From d579ca9135d17e761b42ed04e02972cc16b601e2 Mon Sep 17 00:00:00 2001 From: darkliang <12132342@mail.sustech.edu.cn> Date: Mon, 9 Oct 2023 18:32:23 +0800 Subject: [PATCH] remove dependency of pytorch_metric_learning --- opengait/modeling/losses/supconloss.py | 52 +++++++++++++++++------ opengait/modeling/losses/supconloss_Lp.py | 19 --------- 2 files changed, 40 insertions(+), 31 deletions(-) delete mode 100644 opengait/modeling/losses/supconloss_Lp.py diff --git a/opengait/modeling/losses/supconloss.py b/opengait/modeling/losses/supconloss.py index f35bced..5ba0130 100644 --- a/opengait/modeling/losses/supconloss.py +++ b/opengait/modeling/losses/supconloss.py @@ -6,13 +6,29 @@ import torch.nn as nn import torch from .base import BaseLoss, gather_and_scale_wrapper + class SupConLoss_Re(BaseLoss): def __init__(self, temperature=0.01): super(SupConLoss_Re, self).__init__() self.train_loss = SupConLoss(temperature=temperature) + @gather_and_scale_wrapper def forward(self, features, labels=None, mask=None): - loss = self.train_loss(features,labels) + loss = self.train_loss(features, labels) + self.info.update({ + 'loss': loss.detach().clone()}) + return loss, self.info + + +class SupConLoss_Lp(BaseLoss): + def __init__(self, temperature=0.01): + super(SupConLoss_Lp, self).__init__() + self.train_loss = SupConLoss( + temperature=temperature, base_temperature=temperature, reduce_zero=True, p=2) + + @gather_and_scale_wrapper + def forward(self, features, labels=None, mask=None): + loss = self.train_loss(features.unsqueeze(1), labels) self.info.update({ 'loss': loss.detach().clone()}) return loss, self.info @@ -21,12 +37,15 @@ class SupConLoss_Re(BaseLoss): class SupConLoss(nn.Module): """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. It also supports the unsupervised contrastive loss in SimCLR""" + def __init__(self, temperature=0.01, contrast_mode='all', - base_temperature=0.07): + base_temperature=0.07, reduce_zero=False, p=None): super(SupConLoss, self).__init__() self.temperature = temperature self.contrast_mode = contrast_mode self.base_temperature = base_temperature + self.reduce_zero = reduce_zero + self.p = p def forward(self, features, labels=None, mask=None): """Compute loss for model. If both `labels` and `mask` are None, @@ -74,13 +93,21 @@ class SupConLoss(nn.Module): else: raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) - # compute logits - anchor_dot_contrast = torch.div( - torch.matmul(anchor_feature, contrast_feature.T), - self.temperature) + # compute distance mat + if self.p is None: + mat = torch.matmul( + anchor_feature, contrast_feature.T) + else: + anchor_feature = torch.nn.functional.normalize( + anchor_feature, p=self.p, dim=1) + contrast_feature = torch.nn.functional.normalize( + contrast_feature, p=self.p, dim=1) + mat = -torch.cdist( + anchor_feature, contrast_feature, p=self.p) + mat = mat/self.temperature # for numerical stability - logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) - logits = anchor_dot_contrast - logits_max.detach() + logits_max, _ = torch.max(mat, dim=1, keepdim=True) + logits = mat - logits_max.detach() # tile mask mask = mask.repeat(anchor_count, contrast_count) @@ -98,10 +125,11 @@ class SupConLoss(nn.Module): log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) # compute mean of log-likelihood over positive - mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) - + mean_log_prob_pos = (mask * log_prob).sum(1) / \ + (mask.sum(1)+torch.finfo(mat.dtype).tiny) # loss loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos - loss = loss.view(anchor_count, batch_size).mean() + if self.reduce_zero: + loss = loss[loss > 0] - return loss \ No newline at end of file + return loss.mean() diff --git a/opengait/modeling/losses/supconloss_Lp.py b/opengait/modeling/losses/supconloss_Lp.py deleted file mode 100644 index e7a5669..0000000 --- a/opengait/modeling/losses/supconloss_Lp.py +++ /dev/null @@ -1,19 +0,0 @@ -''' -Modifed fromhttps://github.com/BNU-IVC/FastPoseGait/blob/main/fastposegait/modeling/losses/supconloss_Lp.py -''' - -from .base import BaseLoss, gather_and_scale_wrapper -from pytorch_metric_learning import losses, distances - -class SupConLoss_Lp(BaseLoss): - def __init__(self, temperature=0.01): - super(SupConLoss_Lp, self).__init__() - self.distance = distances.LpDistance() - self.train_loss = losses.SupConLoss(temperature=temperature, distance=self.distance) - @gather_and_scale_wrapper - def forward(self, features, labels=None, mask=None): - loss = self.train_loss(features,labels) - self.info.update({ - 'loss': loss.detach().clone()}) - return loss, self.info -