Reconstruct LossAggregator and fix some typos in config files (#100)
* fix Gait3D configs typo * Use ModuleDict to reconstruct LossAggregator * fix typo
This commit is contained in:
@@ -83,7 +83,7 @@ trainer_cfg:
|
|||||||
enable_float16: true # half_percesion float for memory reduction and speedup
|
enable_float16: true # half_percesion float for memory reduction and speedup
|
||||||
fix_BN: false
|
fix_BN: false
|
||||||
log_iter: 100
|
log_iter: 100
|
||||||
with_test: 10000
|
with_test: true
|
||||||
restore_ckpt_strict: true
|
restore_ckpt_strict: true
|
||||||
restore_hint: 0
|
restore_hint: 0
|
||||||
save_iter: 10000
|
save_iter: 10000
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ trainer_cfg:
|
|||||||
enable_float16: true # half_percesion float for memory reduction and speedup
|
enable_float16: true # half_percesion float for memory reduction and speedup
|
||||||
fix_BN: false
|
fix_BN: false
|
||||||
log_iter: 100
|
log_iter: 100
|
||||||
with_test: 10000
|
with_test: true
|
||||||
restore_ckpt_strict: true
|
restore_ckpt_strict: true
|
||||||
restore_hint: 0
|
restore_hint: 0
|
||||||
save_iter: 10000
|
save_iter: 10000
|
||||||
|
|||||||
@@ -1,13 +1,14 @@
|
|||||||
"""The loss aggregator."""
|
"""The loss aggregator."""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
from . import losses
|
from . import losses
|
||||||
from utils import is_dict, get_attr_from, get_valid_args, is_tensor, get_ddp_module
|
from utils import is_dict, get_attr_from, get_valid_args, is_tensor, get_ddp_module
|
||||||
from utils import Odict
|
from utils import Odict
|
||||||
from utils import get_msg_mgr
|
from utils import get_msg_mgr
|
||||||
|
|
||||||
|
|
||||||
class LossAggregator():
|
class LossAggregator(nn.Module):
|
||||||
"""The loss aggregator.
|
"""The loss aggregator.
|
||||||
|
|
||||||
This class is used to aggregate the losses.
|
This class is used to aggregate the losses.
|
||||||
@@ -18,16 +19,21 @@ class LossAggregator():
|
|||||||
Attributes:
|
Attributes:
|
||||||
losses: A dict of losses.
|
losses: A dict of losses.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, loss_cfg) -> None:
|
def __init__(self, loss_cfg) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize the loss aggregator.
|
Initialize the loss aggregator.
|
||||||
|
|
||||||
|
LossAggregator can be indexed like a regular Python dictionary,
|
||||||
|
but modules it contains are properly registered, and will be visible by all Module methods.
|
||||||
|
All parameters registered in losses can be accessed by the method 'self.parameters()',
|
||||||
|
thus they can be trained properly.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
loss_cfg: Config of losses. List for multiple losses.
|
loss_cfg: Config of losses. List for multiple losses.
|
||||||
"""
|
"""
|
||||||
self.losses = {loss_cfg['log_prefix']: self._build_loss_(loss_cfg)} if is_dict(loss_cfg) \
|
super().__init__()
|
||||||
else {cfg['log_prefix']: self._build_loss_(cfg) for cfg in loss_cfg}
|
self.losses = nn.ModuleDict({loss_cfg['log_prefix']: self._build_loss_(loss_cfg)} if is_dict(loss_cfg) \
|
||||||
|
else {cfg['log_prefix']: self._build_loss_(cfg) for cfg in loss_cfg})
|
||||||
|
|
||||||
def _build_loss_(self, loss_cfg):
|
def _build_loss_(self, loss_cfg):
|
||||||
"""Build the losses from loss_cfg.
|
"""Build the losses from loss_cfg.
|
||||||
@@ -41,7 +47,7 @@ class LossAggregator():
|
|||||||
loss = get_ddp_module(Loss(**valid_loss_arg).cuda())
|
loss = get_ddp_module(Loss(**valid_loss_arg).cuda())
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def __call__(self, training_feats):
|
def forward(self, training_feats):
|
||||||
"""Compute the sum of all losses.
|
"""Compute the sum of all losses.
|
||||||
|
|
||||||
The input is a dict of features. The key is the name of loss and the value is the feature and label. If the key not in
|
The input is a dict of features. The key is the name of loss and the value is the feature and label. If the key not in
|
||||||
|
|||||||
Reference in New Issue
Block a user