fix loss info and typo (Closes GH-22).
This commit is contained in:
@@ -23,11 +23,11 @@ evaluator_cfg:
|
|||||||
img_w: 64
|
img_w: 64
|
||||||
|
|
||||||
loss_cfg:
|
loss_cfg:
|
||||||
- loss_term_weights: 1.0
|
- loss_term_weight: 1.0
|
||||||
margin: 0.2
|
margin: 0.2
|
||||||
type: TripletLoss
|
type: TripletLoss
|
||||||
log_prefix: triplet
|
log_prefix: triplet
|
||||||
- loss_term_weights: 0.1
|
- loss_term_weight: 0.1
|
||||||
scale: 16
|
scale: 16
|
||||||
type: CrossEntropyLoss
|
type: CrossEntropyLoss
|
||||||
log_prefix: softmax
|
log_prefix: softmax
|
||||||
|
|||||||
@@ -23,11 +23,11 @@ evaluator_cfg:
|
|||||||
# img_w: 128
|
# img_w: 128
|
||||||
|
|
||||||
loss_cfg:
|
loss_cfg:
|
||||||
- loss_term_weights: 1.0
|
- loss_term_weight: 1.0
|
||||||
margin: 0.2
|
margin: 0.2
|
||||||
type: TripletLoss
|
type: TripletLoss
|
||||||
log_prefix: triplet
|
log_prefix: triplet
|
||||||
- loss_term_weights: 0.1
|
- loss_term_weight: 0.1
|
||||||
scale: 16
|
scale: 16
|
||||||
type: CrossEntropyLoss
|
type: CrossEntropyLoss
|
||||||
log_prefix: softmax
|
log_prefix: softmax
|
||||||
|
|||||||
+1
-1
@@ -23,7 +23,7 @@ evaluator_cfg:
|
|||||||
metric: euc # cos
|
metric: euc # cos
|
||||||
|
|
||||||
loss_cfg:
|
loss_cfg:
|
||||||
loss_term_weights: 1.0
|
loss_term_weight: 1.0
|
||||||
margin: 0.2
|
margin: 0.2
|
||||||
type: TripletLoss
|
type: TripletLoss
|
||||||
log_prefix: triplet
|
log_prefix: triplet
|
||||||
|
|||||||
+2
-2
@@ -19,11 +19,11 @@ evaluator_cfg:
|
|||||||
type: InferenceSampler
|
type: InferenceSampler
|
||||||
|
|
||||||
loss_cfg:
|
loss_cfg:
|
||||||
- loss_term_weights: 1.0
|
- loss_term_weight: 1.0
|
||||||
margin: 0.2
|
margin: 0.2
|
||||||
type: TripletLoss
|
type: TripletLoss
|
||||||
log_prefix: triplet
|
log_prefix: triplet
|
||||||
- loss_term_weights: 1.0
|
- loss_term_weight: 1.0
|
||||||
scale: 1
|
scale: 1
|
||||||
type: CrossEntropyLoss
|
type: CrossEntropyLoss
|
||||||
log_accuracy: true
|
log_accuracy: true
|
||||||
|
|||||||
@@ -19,11 +19,11 @@ evaluator_cfg:
|
|||||||
type: InferenceSampler
|
type: InferenceSampler
|
||||||
|
|
||||||
loss_cfg:
|
loss_cfg:
|
||||||
- loss_term_weights: 1.0
|
- loss_term_weight: 1.0
|
||||||
margin: 0.2
|
margin: 0.2
|
||||||
type: TripletLoss
|
type: TripletLoss
|
||||||
log_prefix: triplet
|
log_prefix: triplet
|
||||||
- loss_term_weights: 1.0
|
- loss_term_weight: 1.0
|
||||||
scale: 1
|
scale: 1
|
||||||
type: CrossEntropyLoss
|
type: CrossEntropyLoss
|
||||||
log_accuracy: true
|
log_accuracy: true
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ evaluator_cfg:
|
|||||||
metric: euc # cos
|
metric: euc # cos
|
||||||
|
|
||||||
loss_cfg:
|
loss_cfg:
|
||||||
loss_term_weights: 1.0
|
loss_term_weight: 1.0
|
||||||
margin: 0.2
|
margin: 0.2
|
||||||
type: TripletLoss
|
type: TripletLoss
|
||||||
log_prefix: triplet
|
log_prefix: triplet
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ evaluator_cfg:
|
|||||||
metric: euc # cos
|
metric: euc # cos
|
||||||
|
|
||||||
loss_cfg:
|
loss_cfg:
|
||||||
loss_term_weights: 1.0
|
loss_term_weight: 1.0
|
||||||
margin: 0.2
|
margin: 0.2
|
||||||
type: TripletLoss
|
type: TripletLoss
|
||||||
log_prefix: triplet
|
log_prefix: triplet
|
||||||
|
|||||||
+1
-1
@@ -17,7 +17,7 @@ evaluator_cfg:
|
|||||||
metric: euc # cos
|
metric: euc # cos
|
||||||
|
|
||||||
loss_cfg:
|
loss_cfg:
|
||||||
loss_term_weights: 1.0
|
loss_term_weight: 1.0
|
||||||
margin: 0.2
|
margin: 0.2
|
||||||
type: TripletLoss
|
type: TripletLoss
|
||||||
log_prefix: triplet
|
log_prefix: triplet
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ evaluator_cfg:
|
|||||||
metric: euc # cos
|
metric: euc # cos
|
||||||
|
|
||||||
loss_cfg:
|
loss_cfg:
|
||||||
loss_term_weights: 1.0
|
loss_term_weight: 1.0
|
||||||
margin: 0.2
|
margin: 0.2
|
||||||
type: TripletLoss
|
type: TripletLoss
|
||||||
log_prefix: triplet
|
log_prefix: triplet
|
||||||
|
|||||||
@@ -23,11 +23,11 @@ evaluator_cfg:
|
|||||||
type: BaseSilCuttingTransform
|
type: BaseSilCuttingTransform
|
||||||
|
|
||||||
loss_cfg:
|
loss_cfg:
|
||||||
- loss_term_weights: 1.0
|
- loss_term_weight: 1.0
|
||||||
margin: 0.2
|
margin: 0.2
|
||||||
type: TripletLoss
|
type: TripletLoss
|
||||||
log_prefix: triplet
|
log_prefix: triplet
|
||||||
# - loss_term_weights: 0.1
|
# - loss_term_weight: 0.1
|
||||||
# scale: 1
|
# scale: 1
|
||||||
# type: CrossEntropyLoss
|
# type: CrossEntropyLoss
|
||||||
# log_prefix: softmax
|
# log_prefix: softmax
|
||||||
|
|||||||
@@ -22,11 +22,11 @@ evaluator_cfg:
|
|||||||
type: BaseSilCuttingTransform
|
type: BaseSilCuttingTransform
|
||||||
|
|
||||||
loss_cfg:
|
loss_cfg:
|
||||||
- loss_term_weights: 1.0
|
- loss_term_weight: 1.0
|
||||||
margin: 0.2
|
margin: 0.2
|
||||||
type: TripletLoss
|
type: TripletLoss
|
||||||
log_prefix: triplet
|
log_prefix: triplet
|
||||||
- loss_term_weights: 0.1
|
- loss_term_weight: 0.1
|
||||||
scale: 16
|
scale: 16
|
||||||
type: CrossEntropyLoss
|
type: CrossEntropyLoss
|
||||||
log_prefix: softmax
|
log_prefix: softmax
|
||||||
|
|||||||
@@ -16,7 +16,7 @@
|
|||||||
* Loss function
|
* Loss function
|
||||||
> * Args
|
> * Args
|
||||||
> * type: Loss function type, support `TripletLoss` and `CrossEntropyLoss`.
|
> * type: Loss function type, support `TripletLoss` and `CrossEntropyLoss`.
|
||||||
> * loss_term_weights: loss weight.
|
> * loss_term_weight: loss weight.
|
||||||
> * log_prefix: the prefix of loss log.
|
> * log_prefix: the prefix of loss log.
|
||||||
|
|
||||||
----
|
----
|
||||||
@@ -107,11 +107,11 @@ evaluator_cfg:
|
|||||||
img_w: 64
|
img_w: 64
|
||||||
|
|
||||||
loss_cfg:
|
loss_cfg:
|
||||||
- loss_term_weights: 1.0
|
- loss_term_weight: 1.0
|
||||||
margin: 0.2
|
margin: 0.2
|
||||||
type: TripletLoss
|
type: TripletLoss
|
||||||
log_prefix: triplet
|
log_prefix: triplet
|
||||||
- loss_term_weights: 0.1
|
- loss_term_weight: 0.1
|
||||||
scale: 16
|
scale: 16
|
||||||
type: CrossEntropyLoss
|
type: CrossEntropyLoss
|
||||||
log_prefix: softmax
|
log_prefix: softmax
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ class LossAggregator():
|
|||||||
loss, info = loss_func(**v)
|
loss, info = loss_func(**v)
|
||||||
for name, value in info.items():
|
for name, value in info.items():
|
||||||
loss_info['scalar/%s/%s' % (k, name)] = value
|
loss_info['scalar/%s/%s' % (k, name)] = value
|
||||||
loss = loss.mean() * loss_func.loss_term_weights
|
loss = loss.mean() * loss_func.loss_term_weight
|
||||||
loss_sum += loss
|
loss_sum += loss
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -29,14 +29,19 @@ class BaseLoss(nn.Module):
|
|||||||
"""
|
"""
|
||||||
Base class for all losses.
|
Base class for all losses.
|
||||||
|
|
||||||
Your loss should also subclass this class.
|
Your loss should also subclass this class.
|
||||||
|
|
||||||
Attribute:
|
|
||||||
loss_term_weights: the weight of the loss.
|
|
||||||
info: the loss info.
|
|
||||||
"""
|
"""
|
||||||
loss_term_weights = 1.0
|
|
||||||
info = Odict()
|
def __init__(self, loss_term_weight=1.0):
|
||||||
|
"""
|
||||||
|
Initialize the base class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss_term_weight: the weight of the loss term.
|
||||||
|
"""
|
||||||
|
super(BaseLoss, self).__init__()
|
||||||
|
self.loss_term_weight = loss_term_weight
|
||||||
|
self.info = Odict()
|
||||||
|
|
||||||
def forward(self, logits, labels):
|
def forward(self, logits, labels):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -5,15 +5,13 @@ from .base import BaseLoss
|
|||||||
|
|
||||||
|
|
||||||
class CrossEntropyLoss(BaseLoss):
|
class CrossEntropyLoss(BaseLoss):
|
||||||
def __init__(self, scale=2**4, label_smooth=True, eps=0.1, loss_term_weights=1.0, log_accuracy=False):
|
def __init__(self, scale=2**4, label_smooth=True, eps=0.1, loss_term_weight=1.0, log_accuracy=False):
|
||||||
super(CrossEntropyLoss, self).__init__()
|
super(CrossEntropyLoss, self).__init__(loss_term_weight)
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
self.label_smooth = label_smooth
|
self.label_smooth = label_smooth
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.log_accuracy = log_accuracy
|
self.log_accuracy = log_accuracy
|
||||||
|
|
||||||
self.loss_term_weights = loss_term_weights
|
|
||||||
|
|
||||||
def forward(self, logits, labels):
|
def forward(self, logits, labels):
|
||||||
"""
|
"""
|
||||||
logits: [n, p, c]
|
logits: [n, p, c]
|
||||||
|
|||||||
@@ -5,12 +5,10 @@ from .base import BaseLoss, gather_and_scale_wrapper
|
|||||||
|
|
||||||
|
|
||||||
class TripletLoss(BaseLoss):
|
class TripletLoss(BaseLoss):
|
||||||
def __init__(self, margin, loss_term_weights=1.0):
|
def __init__(self, margin, loss_term_weight=1.0):
|
||||||
super(TripletLoss, self).__init__()
|
super(TripletLoss, self).__init__(loss_term_weight)
|
||||||
self.margin = margin
|
self.margin = margin
|
||||||
|
|
||||||
self.loss_term_weights = loss_term_weights
|
|
||||||
|
|
||||||
@gather_and_scale_wrapper
|
@gather_and_scale_wrapper
|
||||||
def forward(self, embeddings, labels):
|
def forward(self, embeddings, labels):
|
||||||
# embeddings: [n, p, c], label: [n]
|
# embeddings: [n, p, c], label: [n]
|
||||||
|
|||||||
Reference in New Issue
Block a user