2c29afadf3
* pose * pose * pose * pose * 你的提交消息 * pose * pose * Delete train1.sh * pretreatment * configs * pose * reference * Update gaittr.py * naming * naming * Update transform.py * update for datasets * update README * update name and README * update * Update transform.py
20 lines
722 B
Python
20 lines
722 B
Python
'''
|
|
Modifed fromhttps://github.com/BNU-IVC/FastPoseGait/blob/main/fastposegait/modeling/losses/supconloss_Lp.py
|
|
'''
|
|
|
|
from .base import BaseLoss, gather_and_scale_wrapper
|
|
from pytorch_metric_learning import losses, distances
|
|
|
|
class SupConLoss_Lp(BaseLoss):
|
|
def __init__(self, temperature=0.01):
|
|
super(SupConLoss_Lp, self).__init__()
|
|
self.distance = distances.LpDistance()
|
|
self.train_loss = losses.SupConLoss(temperature=temperature, distance=self.distance)
|
|
@gather_and_scale_wrapper
|
|
def forward(self, features, labels=None, mask=None):
|
|
loss = self.train_loss(features,labels)
|
|
self.info.update({
|
|
'loss': loss.detach().clone()})
|
|
return loss, self.info
|
|
|