OpenGait release(pre-beta version).
This commit is contained in:
@@ -0,0 +1,17 @@
|
||||
from inspect import isclass
|
||||
from pkgutil import iter_modules
|
||||
from pathlib import Path
|
||||
from importlib import import_module
|
||||
|
||||
# iterate through the modules in the current package
|
||||
package_dir = Path(__file__).resolve().parent
|
||||
for (_, module_name, _) in iter_modules([package_dir]):
|
||||
|
||||
# import the module and iterate through its attributes
|
||||
module = import_module(f"{__name__}.{module_name}")
|
||||
for attribute_name in dir(module):
|
||||
attribute = getattr(module, attribute_name)
|
||||
|
||||
if isclass(attribute):
|
||||
# Add the class to this package's variables
|
||||
globals()[attribute_name] = attribute
|
||||
@@ -0,0 +1,13 @@
|
||||
import torch.nn as nn
|
||||
from utils import Odict
|
||||
|
||||
class BasicLoss(nn.Module):
|
||||
def __init__(self, loss_term_weights=1.0):
|
||||
super(BasicLoss, self).__init__()
|
||||
|
||||
self.loss_term_weights = loss_term_weights
|
||||
self.pair_based_loss = True
|
||||
self.info = Odict()
|
||||
|
||||
def forward(self, logits, labels):
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,51 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .base import BasicLoss
|
||||
|
||||
|
||||
class CrossEntropyLoss(BasicLoss):
|
||||
def __init__(self, scale=2**4, label_smooth=True, eps=0.1, loss_term_weights=1.0, log_accuracy=False):
|
||||
super(CrossEntropyLoss, self).__init__()
|
||||
self.scale = scale
|
||||
self.label_smooth = label_smooth
|
||||
self.eps = eps
|
||||
self.log_accuracy = log_accuracy
|
||||
|
||||
self.loss_term_weights = loss_term_weights
|
||||
self.pair_based_loss = False
|
||||
|
||||
def forward(self, logits, labels):
|
||||
"""
|
||||
logits: [n, p, c]
|
||||
labels: [n]
|
||||
"""
|
||||
logits = logits.permute(1, 0, 2).contiguous() # [n, p, c] -> [p, n, c]
|
||||
p, _, c = logits.size()
|
||||
log_preds = F.log_softmax(logits * self.scale, dim=-1) # [p, n, c]
|
||||
one_hot_labels = self.label2one_hot(
|
||||
labels, c).unsqueeze(0).repeat(p, 1, 1) # [p, n, c]
|
||||
loss = self.compute_loss(log_preds, one_hot_labels)
|
||||
self.info.update({'loss': loss})
|
||||
if self.log_accuracy:
|
||||
pred = logits.argmax(dim=-1) # [p, n]
|
||||
accu = (pred == labels.unsqueeze(0)).float().mean()
|
||||
self.info.update({'accuracy': accu})
|
||||
return loss, self.info
|
||||
|
||||
def compute_loss(self, predis, labels):
|
||||
softmax_loss = -(labels * predis).sum(-1) # [p, n]
|
||||
losses = softmax_loss.mean(-1)
|
||||
|
||||
if self.label_smooth:
|
||||
smooth_loss = - predis.mean(dim=-1) # [p, n]
|
||||
smooth_loss = smooth_loss.mean() # [p]
|
||||
smooth_loss = smooth_loss * self.eps
|
||||
losses = smooth_loss + losses * (1. - self.eps)
|
||||
return losses
|
||||
|
||||
def label2one_hot(self, label, class_num):
|
||||
label = label.unsqueeze(-1)
|
||||
batch_size = label.size(0)
|
||||
device = label.device
|
||||
return torch.zeros(batch_size, class_num).to(device).scatter(1, label, 1)
|
||||
@@ -0,0 +1,76 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .base import BasicLoss
|
||||
from utils import ddp_all_gather
|
||||
|
||||
|
||||
class TripletLoss(BasicLoss):
|
||||
def __init__(self, margin, loss_term_weights=1.0):
|
||||
super(TripletLoss, self).__init__()
|
||||
self.margin = margin
|
||||
|
||||
self.loss_term_weights = loss_term_weights
|
||||
self.pair_based_loss = True
|
||||
|
||||
def forward(self, embeddings, labels):
|
||||
# embeddings: [n, p, c], label: [n]
|
||||
embeddings = ddp_all_gather(embeddings)
|
||||
labels = ddp_all_gather(labels)
|
||||
embeddings = embeddings.permute(
|
||||
1, 0, 2).contiguous() # [n, p, c] -> [p, n, c]
|
||||
embeddings = embeddings.float()
|
||||
|
||||
ref_embed, ref_label = embeddings, labels
|
||||
dist = self.ComputeDistance(embeddings, ref_embed) # [p, n1, n2]
|
||||
mean_dist = dist.mean(1).mean(1)
|
||||
ap_dist, an_dist = self.Convert2Triplets(labels, ref_label, dist)
|
||||
dist_diff = ap_dist - an_dist
|
||||
loss = F.relu(dist_diff + self.margin)
|
||||
|
||||
hard_loss = torch.max(loss, -1)[0]
|
||||
loss_avg, loss_num = self.AvgNonZeroReducer(loss)
|
||||
|
||||
self.info.update({
|
||||
'loss': loss_avg,
|
||||
'hard_loss': hard_loss,
|
||||
'loss_num': loss_num,
|
||||
'mean_dist': mean_dist})
|
||||
|
||||
return loss_avg, self.info
|
||||
|
||||
def AvgNonZeroReducer(self, loss):
|
||||
eps = 1.0e-9
|
||||
loss_sum = loss.sum(-1)
|
||||
loss_num = (loss != 0).sum(-1).float()
|
||||
|
||||
loss_avg = loss_sum / (loss_num + eps)
|
||||
loss_avg[loss_num == 0] = 0
|
||||
return loss_avg, loss_num
|
||||
|
||||
def ComputeDistance(self, x, y):
|
||||
"""
|
||||
x: [p, n_x, c]
|
||||
y: [p, n_y, c]
|
||||
"""
|
||||
x2 = torch.sum(x ** 2, -1).unsqueeze(2) # [p, n_x, 1]
|
||||
y2 = torch.sum(y ** 2, -1).unsqueeze(1) # [p, 1, n_y]
|
||||
inner = x.matmul(y.transpose(-1, -2)) # [p, n_x, n_y]
|
||||
dist = x2 + y2 - 2 * inner
|
||||
dist = torch.sqrt(F.relu(dist)) # [p, n_x, n_y]
|
||||
return dist
|
||||
|
||||
def Convert2Triplets(self, row_labels, clo_label, dist):
|
||||
"""
|
||||
row_labels: tensor with size [n_r]
|
||||
clo_label : tensor with size [n_c]
|
||||
"""
|
||||
matches = (row_labels.unsqueeze(1) ==
|
||||
clo_label.unsqueeze(0)).byte() # [n_r, n_c]
|
||||
diffenc = matches ^ 1 # [n_r, n_c]
|
||||
mask = matches.unsqueeze(2) * diffenc.unsqueeze(1)
|
||||
a_idx, p_idx, n_idx = torch.where(mask)
|
||||
|
||||
ap_dist = dist[:, a_idx, p_idx]
|
||||
an_dist = dist[:, a_idx, n_idx]
|
||||
return ap_dist, an_dist
|
||||
Reference in New Issue
Block a user