bb6cd5149a
* fix bug in fix_BN * gaitgl OUMVLP support. * update ./doc/3.advance_usage.md Cross-Dataset Evalution & Data Agumentation * update config * update docs.3 * update docs.3 * add loss doc and gather input decorator * refine the create model doc * support rearrange directory of unzipped OUMVLP * fix some bugs in loss_aggregator.py * refine docs and little fix * add oumvlp pretreatment description * pretreatment dataset fix oumvlp description * add gaitgl oumvlp result * assert gaitgl input size * add pipeline * update the readme. * update pipeline and readme * Corrigendum. * add logo and remove path * update new logo * Update README.md * modify logo size Co-authored-by: 12131100 <12131100@mail.sustech.edu.cn> Co-authored-by: noahshen98 <77523610+noahshen98@users.noreply.github.com> Co-authored-by: Noah <595311942@qq.com>
55 lines
1.3 KiB
Python
55 lines
1.3 KiB
Python
from ctypes import ArgumentError
|
|
import torch.nn as nn
|
|
import torch
|
|
from utils import Odict
|
|
import functools
|
|
from utils import ddp_all_gather
|
|
|
|
|
|
def gather_and_scale_wrapper(func):
|
|
"""Internal wrapper: gather the input from multple cards to one card, and scale the loss by the number of cards.
|
|
"""
|
|
|
|
@functools.wraps(func)
|
|
def inner(*args, **kwds):
|
|
try:
|
|
|
|
for k, v in kwds.items():
|
|
kwds[k] = ddp_all_gather(v)
|
|
|
|
loss, loss_info = func(*args, **kwds)
|
|
loss *= torch.distributed.get_world_size()
|
|
return loss, loss_info
|
|
except:
|
|
raise ArgumentError
|
|
return inner
|
|
|
|
|
|
class BaseLoss(nn.Module):
|
|
"""
|
|
Base class for all losses.
|
|
|
|
Your loss should also subclass this class.
|
|
|
|
Attribute:
|
|
loss_term_weights: the weight of the loss.
|
|
info: the loss info.
|
|
"""
|
|
loss_term_weights = 1.0
|
|
info = Odict()
|
|
|
|
def forward(self, logits, labels):
|
|
"""
|
|
The default forward function.
|
|
|
|
This function should be overridden by the subclass.
|
|
|
|
Args:
|
|
logits: the logits of the model.
|
|
labels: the labels of the data.
|
|
|
|
Returns:
|
|
tuple of loss and info.
|
|
"""
|
|
return .0, self.info
|