fix loss info and typo (Closes GH-22).
This commit is contained in:
@@ -5,12 +5,10 @@ from .base import BaseLoss, gather_and_scale_wrapper
|
||||
|
||||
|
||||
class TripletLoss(BaseLoss):
|
||||
def __init__(self, margin, loss_term_weights=1.0):
|
||||
super(TripletLoss, self).__init__()
|
||||
def __init__(self, margin, loss_term_weight=1.0):
|
||||
super(TripletLoss, self).__init__(loss_term_weight)
|
||||
self.margin = margin
|
||||
|
||||
self.loss_term_weights = loss_term_weights
|
||||
|
||||
@gather_and_scale_wrapper
|
||||
def forward(self, embeddings, labels):
|
||||
# embeddings: [n, p, c], label: [n]
|
||||
|
||||
Reference in New Issue
Block a user