1.0.0 official release (#18)
* fix bug in fix_BN * gaitgl OUMVLP support. * update ./doc/3.advance_usage.md Cross-Dataset Evalution & Data Agumentation * update config * update docs.3 * update docs.3 * add loss doc and gather input decorator * refine the create model doc * support rearrange directory of unzipped OUMVLP * fix some bugs in loss_aggregator.py * refine docs and little fix * add oumvlp pretreatment description * pretreatment dataset fix oumvlp description * add gaitgl oumvlp result * assert gaitgl input size * add pipeline * update the readme. * update pipeline and readme * Corrigendum. * add logo and remove path * update new logo * Update README.md * modify logo size Co-authored-by: 12131100 <12131100@mail.sustech.edu.cn> Co-authored-by: noahshen98 <77523610+noahshen98@users.noreply.github.com> Co-authored-by: Noah <595311942@qq.com>
This commit is contained in:
@@ -1,13 +1,54 @@
|
||||
from ctypes import ArgumentError
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
from utils import Odict
|
||||
import functools
|
||||
from utils import ddp_all_gather
|
||||
|
||||
class BasicLoss(nn.Module):
|
||||
def __init__(self, loss_term_weights=1.0):
|
||||
super(BasicLoss, self).__init__()
|
||||
|
||||
self.loss_term_weights = loss_term_weights
|
||||
self.pair_based_loss = True
|
||||
self.info = Odict()
|
||||
|
||||
def gather_and_scale_wrapper(func):
|
||||
"""Internal wrapper: gather the input from multple cards to one card, and scale the loss by the number of cards.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def inner(*args, **kwds):
|
||||
try:
|
||||
|
||||
for k, v in kwds.items():
|
||||
kwds[k] = ddp_all_gather(v)
|
||||
|
||||
loss, loss_info = func(*args, **kwds)
|
||||
loss *= torch.distributed.get_world_size()
|
||||
return loss, loss_info
|
||||
except:
|
||||
raise ArgumentError
|
||||
return inner
|
||||
|
||||
|
||||
class BaseLoss(nn.Module):
|
||||
"""
|
||||
Base class for all losses.
|
||||
|
||||
Your loss should also subclass this class.
|
||||
|
||||
Attribute:
|
||||
loss_term_weights: the weight of the loss.
|
||||
info: the loss info.
|
||||
"""
|
||||
loss_term_weights = 1.0
|
||||
info = Odict()
|
||||
|
||||
def forward(self, logits, labels):
|
||||
raise NotImplementedError
|
||||
"""
|
||||
The default forward function.
|
||||
|
||||
This function should be overridden by the subclass.
|
||||
|
||||
Args:
|
||||
logits: the logits of the model.
|
||||
labels: the labels of the data.
|
||||
|
||||
Returns:
|
||||
tuple of loss and info.
|
||||
"""
|
||||
return .0, self.info
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .base import BasicLoss
|
||||
from .base import BaseLoss
|
||||
|
||||
|
||||
class CrossEntropyLoss(BasicLoss):
|
||||
class CrossEntropyLoss(BaseLoss):
|
||||
def __init__(self, scale=2**4, label_smooth=True, eps=0.1, loss_term_weights=1.0, log_accuracy=False):
|
||||
super(CrossEntropyLoss, self).__init__()
|
||||
self.scale = scale
|
||||
@@ -13,7 +13,6 @@ class CrossEntropyLoss(BasicLoss):
|
||||
self.log_accuracy = log_accuracy
|
||||
|
||||
self.loss_term_weights = loss_term_weights
|
||||
self.pair_based_loss = False
|
||||
|
||||
def forward(self, logits, labels):
|
||||
"""
|
||||
@@ -26,7 +25,7 @@ class CrossEntropyLoss(BasicLoss):
|
||||
one_hot_labels = self.label2one_hot(
|
||||
labels, c).unsqueeze(0).repeat(p, 1, 1) # [p, n, c]
|
||||
loss = self.compute_loss(log_preds, one_hot_labels)
|
||||
self.info.update({'loss': loss})
|
||||
self.info.update({'loss': loss.detach().clone()})
|
||||
if self.log_accuracy:
|
||||
pred = logits.argmax(dim=-1) # [p, n]
|
||||
accu = (pred == labels.unsqueeze(0)).float().mean()
|
||||
|
||||
@@ -1,22 +1,19 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .base import BasicLoss
|
||||
from utils import ddp_all_gather
|
||||
from .base import BaseLoss, gather_and_scale_wrapper
|
||||
|
||||
|
||||
class TripletLoss(BasicLoss):
|
||||
class TripletLoss(BaseLoss):
|
||||
def __init__(self, margin, loss_term_weights=1.0):
|
||||
super(TripletLoss, self).__init__()
|
||||
self.margin = margin
|
||||
|
||||
self.loss_term_weights = loss_term_weights
|
||||
self.pair_based_loss = True
|
||||
|
||||
@gather_and_scale_wrapper
|
||||
def forward(self, embeddings, labels):
|
||||
# embeddings: [n, p, c], label: [n]
|
||||
embeddings = ddp_all_gather(embeddings)
|
||||
labels = ddp_all_gather(labels)
|
||||
embeddings = embeddings.permute(
|
||||
1, 0, 2).contiguous() # [n, p, c] -> [p, n, c]
|
||||
embeddings = embeddings.float()
|
||||
@@ -32,10 +29,10 @@ class TripletLoss(BasicLoss):
|
||||
loss_avg, loss_num = self.AvgNonZeroReducer(loss)
|
||||
|
||||
self.info.update({
|
||||
'loss': loss_avg,
|
||||
'hard_loss': hard_loss,
|
||||
'loss_num': loss_num,
|
||||
'mean_dist': mean_dist})
|
||||
'loss': loss_avg.detach().clone(),
|
||||
'hard_loss': hard_loss.detach().clone(),
|
||||
'loss_num': loss_num.detach().clone(),
|
||||
'mean_dist': mean_dist.detach().clone()})
|
||||
|
||||
return loss_avg, self.info
|
||||
|
||||
|
||||
Reference in New Issue
Block a user