diff --git a/config/baseline.yaml b/config/baseline.yaml index f8b6a50..7254d16 100644 --- a/config/baseline.yaml +++ b/config/baseline.yaml @@ -23,11 +23,11 @@ evaluator_cfg: img_w: 64 loss_cfg: - - loss_term_weights: 1.0 + - loss_term_weight: 1.0 margin: 0.2 type: TripletLoss log_prefix: triplet - - loss_term_weights: 0.1 + - loss_term_weight: 0.1 scale: 16 type: CrossEntropyLoss log_prefix: softmax diff --git a/config/baseline_OUMVLP.yaml b/config/baseline_OUMVLP.yaml index c0883f2..6c2a6c3 100644 --- a/config/baseline_OUMVLP.yaml +++ b/config/baseline_OUMVLP.yaml @@ -23,11 +23,11 @@ evaluator_cfg: # img_w: 128 loss_cfg: - - loss_term_weights: 1.0 + - loss_term_weight: 1.0 margin: 0.2 type: TripletLoss log_prefix: triplet - - loss_term_weights: 0.1 + - loss_term_weight: 0.1 scale: 16 type: CrossEntropyLoss log_prefix: softmax diff --git a/config/default.yaml b/config/default.yaml index cffe0be..29ce250 100644 --- a/config/default.yaml +++ b/config/default.yaml @@ -23,7 +23,7 @@ evaluator_cfg: metric: euc # cos loss_cfg: - loss_term_weights: 1.0 + loss_term_weight: 1.0 margin: 0.2 type: TripletLoss log_prefix: triplet diff --git a/config/gaitgl.yaml b/config/gaitgl.yaml index 6ea78ea..b3b4a07 100644 --- a/config/gaitgl.yaml +++ b/config/gaitgl.yaml @@ -19,11 +19,11 @@ evaluator_cfg: type: InferenceSampler loss_cfg: - - loss_term_weights: 1.0 + - loss_term_weight: 1.0 margin: 0.2 type: TripletLoss log_prefix: triplet - - loss_term_weights: 1.0 + - loss_term_weight: 1.0 scale: 1 type: CrossEntropyLoss log_accuracy: true diff --git a/config/gaitgl_OUMVLP.yaml b/config/gaitgl_OUMVLP.yaml index ea3caca..8c6adb0 100644 --- a/config/gaitgl_OUMVLP.yaml +++ b/config/gaitgl_OUMVLP.yaml @@ -19,11 +19,11 @@ evaluator_cfg: type: InferenceSampler loss_cfg: - - loss_term_weights: 1.0 + - loss_term_weight: 1.0 margin: 0.2 type: TripletLoss log_prefix: triplet - - loss_term_weights: 1.0 + - loss_term_weight: 1.0 scale: 1 type: CrossEntropyLoss log_accuracy: true diff --git a/config/gaitpart.yaml b/config/gaitpart.yaml index 4d700dc..6143be3 100644 --- a/config/gaitpart.yaml +++ b/config/gaitpart.yaml @@ -17,7 +17,7 @@ evaluator_cfg: metric: euc # cos loss_cfg: - loss_term_weights: 1.0 + loss_term_weight: 1.0 margin: 0.2 type: TripletLoss log_prefix: triplet diff --git a/config/gaitpart_OUMVLP.yaml b/config/gaitpart_OUMVLP.yaml index 73512be..9209c25 100644 --- a/config/gaitpart_OUMVLP.yaml +++ b/config/gaitpart_OUMVLP.yaml @@ -18,7 +18,7 @@ evaluator_cfg: metric: euc # cos loss_cfg: - loss_term_weights: 1.0 + loss_term_weight: 1.0 margin: 0.2 type: TripletLoss log_prefix: triplet diff --git a/config/gaitset.yaml b/config/gaitset.yaml index b159d43..cc84145 100644 --- a/config/gaitset.yaml +++ b/config/gaitset.yaml @@ -17,7 +17,7 @@ evaluator_cfg: metric: euc # cos loss_cfg: - loss_term_weights: 1.0 + loss_term_weight: 1.0 margin: 0.2 type: TripletLoss log_prefix: triplet diff --git a/config/gaitset_OUMVLP.yaml b/config/gaitset_OUMVLP.yaml index 8722eb8..f79e725 100644 --- a/config/gaitset_OUMVLP.yaml +++ b/config/gaitset_OUMVLP.yaml @@ -18,7 +18,7 @@ evaluator_cfg: metric: euc # cos loss_cfg: - loss_term_weights: 1.0 + loss_term_weight: 1.0 margin: 0.2 type: TripletLoss log_prefix: triplet diff --git a/config/gln/gln_phase1.yaml b/config/gln/gln_phase1.yaml index 6acc2cb..7a20180 100644 --- a/config/gln/gln_phase1.yaml +++ b/config/gln/gln_phase1.yaml @@ -23,11 +23,11 @@ evaluator_cfg: type: BaseSilCuttingTransform loss_cfg: - - loss_term_weights: 1.0 + - loss_term_weight: 1.0 margin: 0.2 type: TripletLoss log_prefix: triplet - # - loss_term_weights: 0.1 + # - loss_term_weight: 0.1 # scale: 1 # type: CrossEntropyLoss # log_prefix: softmax diff --git a/config/gln/gln_phase2.yaml b/config/gln/gln_phase2.yaml index 4766864..3d9432c 100644 --- a/config/gln/gln_phase2.yaml +++ b/config/gln/gln_phase2.yaml @@ -22,11 +22,11 @@ evaluator_cfg: type: BaseSilCuttingTransform loss_cfg: - - loss_term_weights: 1.0 + - loss_term_weight: 1.0 margin: 0.2 type: TripletLoss log_prefix: triplet - - loss_term_weights: 0.1 + - loss_term_weight: 0.1 scale: 16 type: CrossEntropyLoss log_prefix: softmax diff --git a/docs/1.detailed_config.md b/docs/1.detailed_config.md index 43f53d3..0d6d56e 100644 --- a/docs/1.detailed_config.md +++ b/docs/1.detailed_config.md @@ -16,7 +16,7 @@ * Loss function > * Args > * type: Loss function type, support `TripletLoss` and `CrossEntropyLoss`. -> * loss_term_weights: loss weight. +> * loss_term_weight: loss weight. > * log_prefix: the prefix of loss log. ---- @@ -107,11 +107,11 @@ evaluator_cfg: img_w: 64 loss_cfg: - - loss_term_weights: 1.0 + - loss_term_weight: 1.0 margin: 0.2 type: TripletLoss log_prefix: triplet - - loss_term_weights: 0.1 + - loss_term_weight: 0.1 scale: 16 type: CrossEntropyLoss log_prefix: softmax diff --git a/lib/modeling/loss_aggregator.py b/lib/modeling/loss_aggregator.py index 7ccebb6..dd75287 100644 --- a/lib/modeling/loss_aggregator.py +++ b/lib/modeling/loss_aggregator.py @@ -59,7 +59,7 @@ class LossAggregator(): loss, info = loss_func(**v) for name, value in info.items(): loss_info['scalar/%s/%s' % (k, name)] = value - loss = loss.mean() * loss_func.loss_term_weights + loss = loss.mean() * loss_func.loss_term_weight loss_sum += loss else: diff --git a/lib/modeling/losses/base.py b/lib/modeling/losses/base.py index fbdd7d6..ba4d94f 100644 --- a/lib/modeling/losses/base.py +++ b/lib/modeling/losses/base.py @@ -29,14 +29,19 @@ class BaseLoss(nn.Module): """ Base class for all losses. - Your loss should also subclass this class. - - Attribute: - loss_term_weights: the weight of the loss. - info: the loss info. + Your loss should also subclass this class. """ - loss_term_weights = 1.0 - info = Odict() + + def __init__(self, loss_term_weight=1.0): + """ + Initialize the base class. + + Args: + loss_term_weight: the weight of the loss term. + """ + super(BaseLoss, self).__init__() + self.loss_term_weight = loss_term_weight + self.info = Odict() def forward(self, logits, labels): """ diff --git a/lib/modeling/losses/softmax.py b/lib/modeling/losses/softmax.py index ddf2293..a955a02 100644 --- a/lib/modeling/losses/softmax.py +++ b/lib/modeling/losses/softmax.py @@ -5,15 +5,13 @@ from .base import BaseLoss class CrossEntropyLoss(BaseLoss): - def __init__(self, scale=2**4, label_smooth=True, eps=0.1, loss_term_weights=1.0, log_accuracy=False): - super(CrossEntropyLoss, self).__init__() + def __init__(self, scale=2**4, label_smooth=True, eps=0.1, loss_term_weight=1.0, log_accuracy=False): + super(CrossEntropyLoss, self).__init__(loss_term_weight) self.scale = scale self.label_smooth = label_smooth self.eps = eps self.log_accuracy = log_accuracy - self.loss_term_weights = loss_term_weights - def forward(self, logits, labels): """ logits: [n, p, c] diff --git a/lib/modeling/losses/triplet.py b/lib/modeling/losses/triplet.py index 4b5dac2..3feb8ce 100644 --- a/lib/modeling/losses/triplet.py +++ b/lib/modeling/losses/triplet.py @@ -5,12 +5,10 @@ from .base import BaseLoss, gather_and_scale_wrapper class TripletLoss(BaseLoss): - def __init__(self, margin, loss_term_weights=1.0): - super(TripletLoss, self).__init__() + def __init__(self, margin, loss_term_weight=1.0): + super(TripletLoss, self).__init__(loss_term_weight) self.margin = margin - self.loss_term_weights = loss_term_weights - @gather_and_scale_wrapper def forward(self, embeddings, labels): # embeddings: [n, p, c], label: [n]