OpenGait release(pre-beta version).

This commit is contained in:
梁峻豪
2021-10-18 14:30:21 +08:00
commit 57ee4a448e
41 changed files with 3772 additions and 0 deletions
+50
View File
@@ -0,0 +1,50 @@
import torch
from . import losses
from utils import is_dict, get_attr_from, get_valid_args, is_tensor, get_ddp_module
from utils import Odict
from utils import get_msg_mgr
class LossAggregator():
def __init__(self, loss_cfg) -> None:
self.losses = {loss_cfg['log_prefix']: self._build_loss_(loss_cfg)} if is_dict(loss_cfg) \
else {cfg['log_prefix']: self._build_loss_(cfg) for cfg in loss_cfg}
def _build_loss_(self, loss_cfg):
Loss = get_attr_from([losses], loss_cfg['type'])
valid_loss_arg = get_valid_args(
Loss, loss_cfg, ['type', 'pair_based_loss'])
loss = get_ddp_module(Loss(**valid_loss_arg))
return loss
def __call__(self, training_feats):
loss_sum = .0
loss_info = Odict()
for k, v in training_feats.items():
if k in self.losses:
loss_func = self.losses[k]
loss, info = loss_func(**v)
for name, value in info.items():
loss_info['scalar/%s/%s' % (k, name)] = value
loss = loss.mean() * loss_func.loss_term_weights
if loss_func.pair_based_loss:
loss = loss * torch.distributed.get_world_size()
loss_sum += loss
else:
if isinstance(v, dict):
raise ValueError(
"The key %s in -Trainng-Feat- should be stated as the log_prefix of a certain loss defined in your loss_cfg."
)
elif is_tensor(v):
_ = v.mean()
loss_info['scalar/%s' % k] = _
loss_sum += _
get_msg_mgr().log_debug(
"Please check whether %s needed in training." % k)
else:
raise ValueError(
"Error type for -Trainng-Feat-, supported: A feature dict or loss tensor.")
return loss_sum, loss_info