fix loss info and typo (Closes GH-22).

This commit is contained in:
darkliang
2021-12-09 15:11:06 +08:00
parent 0d869161c2
commit 49cbc44069
16 changed files with 37 additions and 36 deletions
+2 -2
View File
@@ -23,11 +23,11 @@ evaluator_cfg:
img_w: 64 img_w: 64
loss_cfg: loss_cfg:
- loss_term_weights: 1.0 - loss_term_weight: 1.0
margin: 0.2 margin: 0.2
type: TripletLoss type: TripletLoss
log_prefix: triplet log_prefix: triplet
- loss_term_weights: 0.1 - loss_term_weight: 0.1
scale: 16 scale: 16
type: CrossEntropyLoss type: CrossEntropyLoss
log_prefix: softmax log_prefix: softmax
+2 -2
View File
@@ -23,11 +23,11 @@ evaluator_cfg:
# img_w: 128 # img_w: 128
loss_cfg: loss_cfg:
- loss_term_weights: 1.0 - loss_term_weight: 1.0
margin: 0.2 margin: 0.2
type: TripletLoss type: TripletLoss
log_prefix: triplet log_prefix: triplet
- loss_term_weights: 0.1 - loss_term_weight: 0.1
scale: 16 scale: 16
type: CrossEntropyLoss type: CrossEntropyLoss
log_prefix: softmax log_prefix: softmax
+1 -1
View File
@@ -23,7 +23,7 @@ evaluator_cfg:
metric: euc # cos metric: euc # cos
loss_cfg: loss_cfg:
loss_term_weights: 1.0 loss_term_weight: 1.0
margin: 0.2 margin: 0.2
type: TripletLoss type: TripletLoss
log_prefix: triplet log_prefix: triplet
+2 -2
View File
@@ -19,11 +19,11 @@ evaluator_cfg:
type: InferenceSampler type: InferenceSampler
loss_cfg: loss_cfg:
- loss_term_weights: 1.0 - loss_term_weight: 1.0
margin: 0.2 margin: 0.2
type: TripletLoss type: TripletLoss
log_prefix: triplet log_prefix: triplet
- loss_term_weights: 1.0 - loss_term_weight: 1.0
scale: 1 scale: 1
type: CrossEntropyLoss type: CrossEntropyLoss
log_accuracy: true log_accuracy: true
+2 -2
View File
@@ -19,11 +19,11 @@ evaluator_cfg:
type: InferenceSampler type: InferenceSampler
loss_cfg: loss_cfg:
- loss_term_weights: 1.0 - loss_term_weight: 1.0
margin: 0.2 margin: 0.2
type: TripletLoss type: TripletLoss
log_prefix: triplet log_prefix: triplet
- loss_term_weights: 1.0 - loss_term_weight: 1.0
scale: 1 scale: 1
type: CrossEntropyLoss type: CrossEntropyLoss
log_accuracy: true log_accuracy: true
+1 -1
View File
@@ -17,7 +17,7 @@ evaluator_cfg:
metric: euc # cos metric: euc # cos
loss_cfg: loss_cfg:
loss_term_weights: 1.0 loss_term_weight: 1.0
margin: 0.2 margin: 0.2
type: TripletLoss type: TripletLoss
log_prefix: triplet log_prefix: triplet
+1 -1
View File
@@ -18,7 +18,7 @@ evaluator_cfg:
metric: euc # cos metric: euc # cos
loss_cfg: loss_cfg:
loss_term_weights: 1.0 loss_term_weight: 1.0
margin: 0.2 margin: 0.2
type: TripletLoss type: TripletLoss
log_prefix: triplet log_prefix: triplet
+1 -1
View File
@@ -17,7 +17,7 @@ evaluator_cfg:
metric: euc # cos metric: euc # cos
loss_cfg: loss_cfg:
loss_term_weights: 1.0 loss_term_weight: 1.0
margin: 0.2 margin: 0.2
type: TripletLoss type: TripletLoss
log_prefix: triplet log_prefix: triplet
+1 -1
View File
@@ -18,7 +18,7 @@ evaluator_cfg:
metric: euc # cos metric: euc # cos
loss_cfg: loss_cfg:
loss_term_weights: 1.0 loss_term_weight: 1.0
margin: 0.2 margin: 0.2
type: TripletLoss type: TripletLoss
log_prefix: triplet log_prefix: triplet
+2 -2
View File
@@ -23,11 +23,11 @@ evaluator_cfg:
type: BaseSilCuttingTransform type: BaseSilCuttingTransform
loss_cfg: loss_cfg:
- loss_term_weights: 1.0 - loss_term_weight: 1.0
margin: 0.2 margin: 0.2
type: TripletLoss type: TripletLoss
log_prefix: triplet log_prefix: triplet
# - loss_term_weights: 0.1 # - loss_term_weight: 0.1
# scale: 1 # scale: 1
# type: CrossEntropyLoss # type: CrossEntropyLoss
# log_prefix: softmax # log_prefix: softmax
+2 -2
View File
@@ -22,11 +22,11 @@ evaluator_cfg:
type: BaseSilCuttingTransform type: BaseSilCuttingTransform
loss_cfg: loss_cfg:
- loss_term_weights: 1.0 - loss_term_weight: 1.0
margin: 0.2 margin: 0.2
type: TripletLoss type: TripletLoss
log_prefix: triplet log_prefix: triplet
- loss_term_weights: 0.1 - loss_term_weight: 0.1
scale: 16 scale: 16
type: CrossEntropyLoss type: CrossEntropyLoss
log_prefix: softmax log_prefix: softmax
+3 -3
View File
@@ -16,7 +16,7 @@
* Loss function * Loss function
> * Args > * Args
> * type: Loss function type, support `TripletLoss` and `CrossEntropyLoss`. > * type: Loss function type, support `TripletLoss` and `CrossEntropyLoss`.
> * loss_term_weights: loss weight. > * loss_term_weight: loss weight.
> * log_prefix: the prefix of loss log. > * log_prefix: the prefix of loss log.
---- ----
@@ -107,11 +107,11 @@ evaluator_cfg:
img_w: 64 img_w: 64
loss_cfg: loss_cfg:
- loss_term_weights: 1.0 - loss_term_weight: 1.0
margin: 0.2 margin: 0.2
type: TripletLoss type: TripletLoss
log_prefix: triplet log_prefix: triplet
- loss_term_weights: 0.1 - loss_term_weight: 0.1
scale: 16 scale: 16
type: CrossEntropyLoss type: CrossEntropyLoss
log_prefix: softmax log_prefix: softmax
+1 -1
View File
@@ -59,7 +59,7 @@ class LossAggregator():
loss, info = loss_func(**v) loss, info = loss_func(**v)
for name, value in info.items(): for name, value in info.items():
loss_info['scalar/%s/%s' % (k, name)] = value loss_info['scalar/%s/%s' % (k, name)] = value
loss = loss.mean() * loss_func.loss_term_weights loss = loss.mean() * loss_func.loss_term_weight
loss_sum += loss loss_sum += loss
else: else:
+12 -7
View File
@@ -29,14 +29,19 @@ class BaseLoss(nn.Module):
""" """
Base class for all losses. Base class for all losses.
Your loss should also subclass this class. Your loss should also subclass this class.
Attribute:
loss_term_weights: the weight of the loss.
info: the loss info.
""" """
loss_term_weights = 1.0
info = Odict() def __init__(self, loss_term_weight=1.0):
"""
Initialize the base class.
Args:
loss_term_weight: the weight of the loss term.
"""
super(BaseLoss, self).__init__()
self.loss_term_weight = loss_term_weight
self.info = Odict()
def forward(self, logits, labels): def forward(self, logits, labels):
""" """
+2 -4
View File
@@ -5,15 +5,13 @@ from .base import BaseLoss
class CrossEntropyLoss(BaseLoss): class CrossEntropyLoss(BaseLoss):
def __init__(self, scale=2**4, label_smooth=True, eps=0.1, loss_term_weights=1.0, log_accuracy=False): def __init__(self, scale=2**4, label_smooth=True, eps=0.1, loss_term_weight=1.0, log_accuracy=False):
super(CrossEntropyLoss, self).__init__() super(CrossEntropyLoss, self).__init__(loss_term_weight)
self.scale = scale self.scale = scale
self.label_smooth = label_smooth self.label_smooth = label_smooth
self.eps = eps self.eps = eps
self.log_accuracy = log_accuracy self.log_accuracy = log_accuracy
self.loss_term_weights = loss_term_weights
def forward(self, logits, labels): def forward(self, logits, labels):
""" """
logits: [n, p, c] logits: [n, p, c]
+2 -4
View File
@@ -5,12 +5,10 @@ from .base import BaseLoss, gather_and_scale_wrapper
class TripletLoss(BaseLoss): class TripletLoss(BaseLoss):
def __init__(self, margin, loss_term_weights=1.0): def __init__(self, margin, loss_term_weight=1.0):
super(TripletLoss, self).__init__() super(TripletLoss, self).__init__(loss_term_weight)
self.margin = margin self.margin = margin
self.loss_term_weights = loss_term_weights
@gather_and_scale_wrapper @gather_and_scale_wrapper
def forward(self, embeddings, labels): def forward(self, embeddings, labels):
# embeddings: [n, p, c], label: [n] # embeddings: [n, p, c], label: [n]