Reconstruct LossAggregator and fix some typos in config files (#100)

* fix Gait3D configs typo

* Use ModuleDict to reconstruct LossAggregator

* fix typo
This commit is contained in:
Jilong Wang
2023-01-14 17:33:38 +08:00
committed by GitHub
parent bf8d03658e
commit 6c05509add
3 changed files with 13 additions and 7 deletions
+1 -1
View File
@@ -83,7 +83,7 @@ trainer_cfg:
enable_float16: true # half_percesion float for memory reduction and speedup enable_float16: true # half_percesion float for memory reduction and speedup
fix_BN: false fix_BN: false
log_iter: 100 log_iter: 100
with_test: 10000 with_test: true
restore_ckpt_strict: true restore_ckpt_strict: true
restore_hint: 0 restore_hint: 0
save_iter: 10000 save_iter: 10000
+1 -1
View File
@@ -85,7 +85,7 @@ trainer_cfg:
enable_float16: true # half_percesion float for memory reduction and speedup enable_float16: true # half_percesion float for memory reduction and speedup
fix_BN: false fix_BN: false
log_iter: 100 log_iter: 100
with_test: 10000 with_test: true
restore_ckpt_strict: true restore_ckpt_strict: true
restore_hint: 0 restore_hint: 0
save_iter: 10000 save_iter: 10000
+11 -5
View File
@@ -1,13 +1,14 @@
"""The loss aggregator.""" """The loss aggregator."""
import torch import torch
import torch.nn as nn
from . import losses from . import losses
from utils import is_dict, get_attr_from, get_valid_args, is_tensor, get_ddp_module from utils import is_dict, get_attr_from, get_valid_args, is_tensor, get_ddp_module
from utils import Odict from utils import Odict
from utils import get_msg_mgr from utils import get_msg_mgr
class LossAggregator(): class LossAggregator(nn.Module):
"""The loss aggregator. """The loss aggregator.
This class is used to aggregate the losses. This class is used to aggregate the losses.
@@ -18,16 +19,21 @@ class LossAggregator():
Attributes: Attributes:
losses: A dict of losses. losses: A dict of losses.
""" """
def __init__(self, loss_cfg) -> None: def __init__(self, loss_cfg) -> None:
""" """
Initialize the loss aggregator. Initialize the loss aggregator.
LossAggregator can be indexed like a regular Python dictionary,
but modules it contains are properly registered, and will be visible by all Module methods.
All parameters registered in losses can be accessed by the method 'self.parameters()',
thus they can be trained properly.
Args: Args:
loss_cfg: Config of losses. List for multiple losses. loss_cfg: Config of losses. List for multiple losses.
""" """
self.losses = {loss_cfg['log_prefix']: self._build_loss_(loss_cfg)} if is_dict(loss_cfg) \ super().__init__()
else {cfg['log_prefix']: self._build_loss_(cfg) for cfg in loss_cfg} self.losses = nn.ModuleDict({loss_cfg['log_prefix']: self._build_loss_(loss_cfg)} if is_dict(loss_cfg) \
else {cfg['log_prefix']: self._build_loss_(cfg) for cfg in loss_cfg})
def _build_loss_(self, loss_cfg): def _build_loss_(self, loss_cfg):
"""Build the losses from loss_cfg. """Build the losses from loss_cfg.
@@ -41,7 +47,7 @@ class LossAggregator():
loss = get_ddp_module(Loss(**valid_loss_arg).cuda()) loss = get_ddp_module(Loss(**valid_loss_arg).cuda())
return loss return loss
def __call__(self, training_feats): def forward(self, training_feats):
"""Compute the sum of all losses. """Compute the sum of all losses.
The input is a dict of features. The key is the name of loss and the value is the feature and label. If the key not in The input is a dict of features. The key is the name of loss and the value is the feature and label. If the key not in