OpenGait release(pre-beta version).

This commit is contained in:
梁峻豪
2021-10-18 14:30:21 +08:00
commit 57ee4a448e
41 changed files with 3772 additions and 0 deletions
+17
View File
@@ -0,0 +1,17 @@
from inspect import isclass
from pkgutil import iter_modules
from pathlib import Path
from importlib import import_module
# iterate through the modules in the current package
package_dir = Path(__file__).resolve().parent
for (_, module_name, _) in iter_modules([package_dir]):
# import the module and iterate through its attributes
module = import_module(f"{__name__}.{module_name}")
for attribute_name in dir(module):
attribute = getattr(module, attribute_name)
if isclass(attribute):
# Add the class to this package's variables
globals()[attribute_name] = attribute
+13
View File
@@ -0,0 +1,13 @@
import torch.nn as nn
from utils import Odict
class BasicLoss(nn.Module):
def __init__(self, loss_term_weights=1.0):
super(BasicLoss, self).__init__()
self.loss_term_weights = loss_term_weights
self.pair_based_loss = True
self.info = Odict()
def forward(self, logits, labels):
raise NotImplementedError
+51
View File
@@ -0,0 +1,51 @@
import torch
import torch.nn.functional as F
from .base import BasicLoss
class CrossEntropyLoss(BasicLoss):
def __init__(self, scale=2**4, label_smooth=True, eps=0.1, loss_term_weights=1.0, log_accuracy=False):
super(CrossEntropyLoss, self).__init__()
self.scale = scale
self.label_smooth = label_smooth
self.eps = eps
self.log_accuracy = log_accuracy
self.loss_term_weights = loss_term_weights
self.pair_based_loss = False
def forward(self, logits, labels):
"""
logits: [n, p, c]
labels: [n]
"""
logits = logits.permute(1, 0, 2).contiguous() # [n, p, c] -> [p, n, c]
p, _, c = logits.size()
log_preds = F.log_softmax(logits * self.scale, dim=-1) # [p, n, c]
one_hot_labels = self.label2one_hot(
labels, c).unsqueeze(0).repeat(p, 1, 1) # [p, n, c]
loss = self.compute_loss(log_preds, one_hot_labels)
self.info.update({'loss': loss})
if self.log_accuracy:
pred = logits.argmax(dim=-1) # [p, n]
accu = (pred == labels.unsqueeze(0)).float().mean()
self.info.update({'accuracy': accu})
return loss, self.info
def compute_loss(self, predis, labels):
softmax_loss = -(labels * predis).sum(-1) # [p, n]
losses = softmax_loss.mean(-1)
if self.label_smooth:
smooth_loss = - predis.mean(dim=-1) # [p, n]
smooth_loss = smooth_loss.mean() # [p]
smooth_loss = smooth_loss * self.eps
losses = smooth_loss + losses * (1. - self.eps)
return losses
def label2one_hot(self, label, class_num):
label = label.unsqueeze(-1)
batch_size = label.size(0)
device = label.device
return torch.zeros(batch_size, class_num).to(device).scatter(1, label, 1)
+76
View File
@@ -0,0 +1,76 @@
import torch
import torch.nn.functional as F
from .base import BasicLoss
from utils import ddp_all_gather
class TripletLoss(BasicLoss):
def __init__(self, margin, loss_term_weights=1.0):
super(TripletLoss, self).__init__()
self.margin = margin
self.loss_term_weights = loss_term_weights
self.pair_based_loss = True
def forward(self, embeddings, labels):
# embeddings: [n, p, c], label: [n]
embeddings = ddp_all_gather(embeddings)
labels = ddp_all_gather(labels)
embeddings = embeddings.permute(
1, 0, 2).contiguous() # [n, p, c] -> [p, n, c]
embeddings = embeddings.float()
ref_embed, ref_label = embeddings, labels
dist = self.ComputeDistance(embeddings, ref_embed) # [p, n1, n2]
mean_dist = dist.mean(1).mean(1)
ap_dist, an_dist = self.Convert2Triplets(labels, ref_label, dist)
dist_diff = ap_dist - an_dist
loss = F.relu(dist_diff + self.margin)
hard_loss = torch.max(loss, -1)[0]
loss_avg, loss_num = self.AvgNonZeroReducer(loss)
self.info.update({
'loss': loss_avg,
'hard_loss': hard_loss,
'loss_num': loss_num,
'mean_dist': mean_dist})
return loss_avg, self.info
def AvgNonZeroReducer(self, loss):
eps = 1.0e-9
loss_sum = loss.sum(-1)
loss_num = (loss != 0).sum(-1).float()
loss_avg = loss_sum / (loss_num + eps)
loss_avg[loss_num == 0] = 0
return loss_avg, loss_num
def ComputeDistance(self, x, y):
"""
x: [p, n_x, c]
y: [p, n_y, c]
"""
x2 = torch.sum(x ** 2, -1).unsqueeze(2) # [p, n_x, 1]
y2 = torch.sum(y ** 2, -1).unsqueeze(1) # [p, 1, n_y]
inner = x.matmul(y.transpose(-1, -2)) # [p, n_x, n_y]
dist = x2 + y2 - 2 * inner
dist = torch.sqrt(F.relu(dist)) # [p, n_x, n_y]
return dist
def Convert2Triplets(self, row_labels, clo_label, dist):
"""
row_labels: tensor with size [n_r]
clo_label : tensor with size [n_c]
"""
matches = (row_labels.unsqueeze(1) ==
clo_label.unsqueeze(0)).byte() # [n_r, n_c]
diffenc = matches ^ 1 # [n_r, n_c]
mask = matches.unsqueeze(2) * diffenc.unsqueeze(1)
a_idx, p_idx, n_idx = torch.where(mask)
ap_dist = dist[:, a_idx, p_idx]
an_dist = dist[:, a_idx, n_idx]
return ap_dist, an_dist