fix loss info and typo (Closes GH-22).

This commit is contained in:
darkliang
2021-12-09 15:11:06 +08:00
parent 0d869161c2
commit 49cbc44069
16 changed files with 37 additions and 36 deletions
+2 -4
View File
@@ -5,12 +5,10 @@ from .base import BaseLoss, gather_and_scale_wrapper
class TripletLoss(BaseLoss):
def __init__(self, margin, loss_term_weights=1.0):
super(TripletLoss, self).__init__()
def __init__(self, margin, loss_term_weight=1.0):
super(TripletLoss, self).__init__(loss_term_weight)
self.margin = margin
self.loss_term_weights = loss_term_weights
@gather_and_scale_wrapper
def forward(self, embeddings, labels):
# embeddings: [n, p, c], label: [n]