fix loss info and typo (Closes GH-22).
This commit is contained in:
@@ -59,7 +59,7 @@ class LossAggregator():
|
||||
loss, info = loss_func(**v)
|
||||
for name, value in info.items():
|
||||
loss_info['scalar/%s/%s' % (k, name)] = value
|
||||
loss = loss.mean() * loss_func.loss_term_weights
|
||||
loss = loss.mean() * loss_func.loss_term_weight
|
||||
loss_sum += loss
|
||||
|
||||
else:
|
||||
|
||||
@@ -29,14 +29,19 @@ class BaseLoss(nn.Module):
|
||||
"""
|
||||
Base class for all losses.
|
||||
|
||||
Your loss should also subclass this class.
|
||||
|
||||
Attribute:
|
||||
loss_term_weights: the weight of the loss.
|
||||
info: the loss info.
|
||||
Your loss should also subclass this class.
|
||||
"""
|
||||
loss_term_weights = 1.0
|
||||
info = Odict()
|
||||
|
||||
def __init__(self, loss_term_weight=1.0):
|
||||
"""
|
||||
Initialize the base class.
|
||||
|
||||
Args:
|
||||
loss_term_weight: the weight of the loss term.
|
||||
"""
|
||||
super(BaseLoss, self).__init__()
|
||||
self.loss_term_weight = loss_term_weight
|
||||
self.info = Odict()
|
||||
|
||||
def forward(self, logits, labels):
|
||||
"""
|
||||
|
||||
@@ -5,15 +5,13 @@ from .base import BaseLoss
|
||||
|
||||
|
||||
class CrossEntropyLoss(BaseLoss):
|
||||
def __init__(self, scale=2**4, label_smooth=True, eps=0.1, loss_term_weights=1.0, log_accuracy=False):
|
||||
super(CrossEntropyLoss, self).__init__()
|
||||
def __init__(self, scale=2**4, label_smooth=True, eps=0.1, loss_term_weight=1.0, log_accuracy=False):
|
||||
super(CrossEntropyLoss, self).__init__(loss_term_weight)
|
||||
self.scale = scale
|
||||
self.label_smooth = label_smooth
|
||||
self.eps = eps
|
||||
self.log_accuracy = log_accuracy
|
||||
|
||||
self.loss_term_weights = loss_term_weights
|
||||
|
||||
def forward(self, logits, labels):
|
||||
"""
|
||||
logits: [n, p, c]
|
||||
|
||||
@@ -5,12 +5,10 @@ from .base import BaseLoss, gather_and_scale_wrapper
|
||||
|
||||
|
||||
class TripletLoss(BaseLoss):
|
||||
def __init__(self, margin, loss_term_weights=1.0):
|
||||
super(TripletLoss, self).__init__()
|
||||
def __init__(self, margin, loss_term_weight=1.0):
|
||||
super(TripletLoss, self).__init__(loss_term_weight)
|
||||
self.margin = margin
|
||||
|
||||
self.loss_term_weights = loss_term_weights
|
||||
|
||||
@gather_and_scale_wrapper
|
||||
def forward(self, embeddings, labels):
|
||||
# embeddings: [n, p, c], label: [n]
|
||||
|
||||
Reference in New Issue
Block a user