rename lib to opengait

This commit is contained in:
darkliang
2022-04-12 11:28:09 +08:00
parent ecab4bf380
commit 213b3a658f
33 changed files with 39 additions and 39 deletions
+17
View File
@@ -0,0 +1,17 @@
from inspect import isclass
from pkgutil import iter_modules
from pathlib import Path
from importlib import import_module
# iterate through the modules in the current package
package_dir = Path(__file__).resolve().parent
for (_, module_name, _) in iter_modules([package_dir]):
# import the module and iterate through its attributes
module = import_module(f"{__name__}.{module_name}")
for attribute_name in dir(module):
attribute = getattr(module, attribute_name)
if isclass(attribute):
# Add the class to this package's variables
globals()[attribute_name] = attribute
+63
View File
@@ -0,0 +1,63 @@
"""The plain backbone.
The plain backbone only contains the BasicConv2d, FocalConv2d and MaxPool2d and LeakyReLU layers.
"""
import torch.nn as nn
from ..modules import BasicConv2d, FocalConv2d
class Plain(nn.Module):
"""
The Plain backbone class.
An implicit LeakyRelu appended to each layer except maxPooling.
The kernel size, stride and padding of the first convolution layer are 5, 1, 2, the ones of other layers are 3, 1, 1.
Typical usage:
- BC-64: Basic conv2d with output channel 64. The input channel is the output channel of previous layer.
- M: nn.MaxPool2d(kernel_size=2, stride=2)].
- FC-128-1: Focal conv2d with output channel 64 and halving 1(divided to 2^1=2 parts).
Use it in your configuration file.
"""
def __init__(self, layers_cfg, in_channels=1):
super(Plain, self).__init__()
self.layers_cfg = layers_cfg
self.in_channels = in_channels
self.feature = self.make_layers()
def forward(self, seqs):
out = self.feature(seqs)
return out
def make_layers(self):
"""
Reference: torchvision/models/vgg.py
"""
def get_layer(cfg, in_c, kernel_size, stride, padding):
cfg = cfg.split('-')
typ = cfg[0]
if typ not in ['BC', 'FC']:
raise ValueError('Only support BC or FC, but got {}'.format(typ))
out_c = int(cfg[1])
if typ == 'BC':
return BasicConv2d(in_c, out_c, kernel_size=kernel_size, stride=stride, padding=padding)
return FocalConv2d(in_c, out_c, kernel_size=kernel_size, stride=stride, padding=padding, halving=int(cfg[2]))
Layers = [get_layer(self.layers_cfg[0], self.in_channels,
5, 1, 2), nn.LeakyReLU(inplace=True)]
in_c = int(self.layers_cfg[0].split('-')[1])
for cfg in self.layers_cfg[1:]:
if cfg == 'M':
Layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = get_layer(cfg, in_c, 3, 1, 1)
Layers += [conv2d, nn.LeakyReLU(inplace=True)]
in_c = int(cfg.split('-')[1])
return nn.Sequential(*Layers)
+462
View File
@@ -0,0 +1,462 @@
"""The base model definition.
This module defines the abstract meta model class and base model class. In the base model,
we define the basic model functions, like get_loader, build_network, and run_train, etc.
The api of the base model is run_train and run_test, they are used in `opengait/main.py`.
Typical usage:
BaseModel.run_train(model)
BaseModel.run_test(model)
"""
import torch
import numpy as np
import os.path as osp
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as tordata
from tqdm import tqdm
from torch.cuda.amp import autocast
from torch.cuda.amp import GradScaler
from abc import ABCMeta
from abc import abstractmethod
from . import backbones
from .loss_aggregator import LossAggregator
from data.transform import get_transform
from data.collate_fn import CollateFn
from data.dataset import DataSet
import data.sampler as Samplers
from utils import Odict, mkdir, ddp_all_gather
from utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from
from utils import evaluation as eval_functions
from utils import NoOp
from utils import get_msg_mgr
__all__ = ['BaseModel']
class MetaModel(metaclass=ABCMeta):
"""The necessary functions for the base model.
This class defines the necessary functions for the base model, in the base model, we have implemented them.
"""
@abstractmethod
def get_loader(self, data_cfg):
"""Based on the given data_cfg, we get the data loader."""
raise NotImplementedError
@abstractmethod
def build_network(self, model_cfg):
"""Build your network here."""
raise NotImplementedError
@abstractmethod
def init_parameters(self):
"""Initialize the parameters of your network."""
raise NotImplementedError
@abstractmethod
def get_optimizer(self, optimizer_cfg):
"""Based on the given optimizer_cfg, we get the optimizer."""
raise NotImplementedError
@abstractmethod
def get_scheduler(self, scheduler_cfg):
"""Based on the given scheduler_cfg, we get the scheduler."""
raise NotImplementedError
@abstractmethod
def save_ckpt(self, iteration):
"""Save the checkpoint, including model parameter, optimizer and scheduler."""
raise NotImplementedError
@abstractmethod
def resume_ckpt(self, restore_hint):
"""Resume the model from the checkpoint, including model parameter, optimizer and scheduler."""
raise NotImplementedError
@abstractmethod
def inputs_pretreament(self, inputs):
"""Transform the input data based on transform setting."""
raise NotImplementedError
@abstractmethod
def train_step(self, loss_num) -> bool:
"""Do one training step."""
raise NotImplementedError
@abstractmethod
def inference(self):
"""Do inference (calculate features.)."""
raise NotImplementedError
@abstractmethod
def run_train(model):
"""Run a whole train schedule."""
raise NotImplementedError
@abstractmethod
def run_test(model):
"""Run a whole test schedule."""
raise NotImplementedError
class BaseModel(MetaModel, nn.Module):
"""Base model.
This class inherites the MetaModel class, and implements the basic model functions, like get_loader, build_network, etc.
Attributes:
msg_mgr: the massage manager.
cfgs: the configs.
iteration: the current iteration of the model.
engine_cfg: the configs of the engine(train or test).
save_path: the path to save the checkpoints.
"""
def __init__(self, cfgs, training):
"""Initialize the base model.
Complete the model initialization, including the data loader, the network, the optimizer, the scheduler, the loss.
Args:
cfgs:
All of the configs.
training:
Whether the model is in training mode.
"""
super(BaseModel, self).__init__()
self.msg_mgr = get_msg_mgr()
self.cfgs = cfgs
self.iteration = 0
self.engine_cfg = cfgs['trainer_cfg'] if training else cfgs['evaluator_cfg']
if self.engine_cfg is None:
raise Exception("Initialize a model without -Engine-Cfgs-")
if training and self.engine_cfg['enable_float16']:
self.Scaler = GradScaler()
self.save_path = osp.join('output/', cfgs['data_cfg']['dataset_name'],
cfgs['model_cfg']['model'], self.engine_cfg['save_name'])
self.build_network(cfgs['model_cfg'])
self.init_parameters()
self.msg_mgr.log_info(cfgs['data_cfg'])
if training:
self.train_loader = self.get_loader(
cfgs['data_cfg'], train=True)
if not training or self.engine_cfg['with_test']:
self.test_loader = self.get_loader(
cfgs['data_cfg'], train=False)
self.device = torch.distributed.get_rank()
torch.cuda.set_device(self.device)
self.to(device=torch.device(
"cuda", self.device))
if training:
self.loss_aggregator = LossAggregator(cfgs['loss_cfg'])
self.optimizer = self.get_optimizer(self.cfgs['optimizer_cfg'])
self.scheduler = self.get_scheduler(cfgs['scheduler_cfg'])
self.train(training)
restore_hint = self.engine_cfg['restore_hint']
if restore_hint != 0:
self.resume_ckpt(restore_hint)
if training:
if cfgs['trainer_cfg']['fix_BN']:
self.fix_BN()
def get_backbone(self, backbone_cfg):
"""Get the backbone of the model."""
if is_dict(backbone_cfg):
Backbone = get_attr_from([backbones], backbone_cfg['type'])
valid_args = get_valid_args(Backbone, backbone_cfg, ['type'])
return Backbone(**valid_args)
if is_list(backbone_cfg):
Backbone = nn.ModuleList([self.get_backbone(cfg)
for cfg in backbone_cfg])
return Backbone
raise ValueError(
"Error type for -Backbone-Cfg-, supported: (A list of) dict.")
def build_network(self, model_cfg):
if 'backbone_cfg' in model_cfg.keys():
self.Backbone = self.get_backbone(model_cfg['backbone_cfg'])
def init_parameters(self):
for m in self.modules():
if isinstance(m, (nn.Conv3d, nn.Conv2d, nn.Conv1d)):
nn.init.xavier_uniform_(m.weight.data)
if m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight.data)
if m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
elif isinstance(m, (nn.BatchNorm3d, nn.BatchNorm2d, nn.BatchNorm1d)):
if m.affine:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0.0)
def get_loader(self, data_cfg, train=True):
sampler_cfg = self.cfgs['trainer_cfg']['sampler'] if train else self.cfgs['evaluator_cfg']['sampler']
dataset = DataSet(data_cfg, train)
Sampler = get_attr_from([Samplers], sampler_cfg['type'])
vaild_args = get_valid_args(Sampler, sampler_cfg, free_keys=[
'sample_type', 'type'])
sampler = Sampler(dataset, **vaild_args)
loader = tordata.DataLoader(
dataset=dataset,
batch_sampler=sampler,
collate_fn=CollateFn(dataset.label_set, sampler_cfg),
num_workers=data_cfg['num_workers'])
return loader
def get_optimizer(self, optimizer_cfg):
self.msg_mgr.log_info(optimizer_cfg)
optimizer = get_attr_from([optim], optimizer_cfg['solver'])
valid_arg = get_valid_args(optimizer, optimizer_cfg, ['solver'])
optimizer = optimizer(
filter(lambda p: p.requires_grad, self.parameters()), **valid_arg)
return optimizer
def get_scheduler(self, scheduler_cfg):
self.msg_mgr.log_info(scheduler_cfg)
Scheduler = get_attr_from(
[optim.lr_scheduler], scheduler_cfg['scheduler'])
valid_arg = get_valid_args(Scheduler, scheduler_cfg, ['scheduler'])
scheduler = Scheduler(self.optimizer, **valid_arg)
return scheduler
def save_ckpt(self, iteration):
if torch.distributed.get_rank() == 0:
mkdir(osp.join(self.save_path, "checkpoints/"))
save_name = self.engine_cfg['save_name']
checkpoint = {
'model': self.state_dict(),
'optimizer': self.optimizer.state_dict(),
'scheduler': self.scheduler.state_dict(),
'iteration': iteration}
torch.save(checkpoint,
osp.join(self.save_path, 'checkpoints/{}-{:0>5}.pt'.format(save_name, iteration)))
def _load_ckpt(self, save_name):
load_ckpt_strict = self.engine_cfg['restore_ckpt_strict']
checkpoint = torch.load(save_name, map_location=torch.device(
"cuda", self.device))
model_state_dict = checkpoint['model']
if not load_ckpt_strict:
self.msg_mgr.log_info("-------- Restored Params List --------")
self.msg_mgr.log_info(sorted(set(model_state_dict.keys()).intersection(
set(self.state_dict().keys()))))
self.load_state_dict(model_state_dict, strict=load_ckpt_strict)
if self.training:
if not self.engine_cfg["optimizer_reset"] and 'optimizer' in checkpoint:
self.optimizer.load_state_dict(checkpoint['optimizer'])
else:
self.msg_mgr.log_warning(
"Restore NO Optimizer from %s !!!" % save_name)
if not self.engine_cfg["scheduler_reset"] and 'scheduler' in checkpoint:
self.scheduler.load_state_dict(
checkpoint['scheduler'])
else:
self.msg_mgr.log_warning(
"Restore NO Scheduler from %s !!!" % save_name)
self.msg_mgr.log_info("Restore Parameters from %s !!!" % save_name)
def resume_ckpt(self, restore_hint):
if isinstance(restore_hint, int):
save_name = self.engine_cfg['save_name']
save_name = osp.join(
self.save_path, 'checkpoints/{}-{:0>5}.pt'.format(save_name, restore_hint))
self.iteration = restore_hint
elif isinstance(restore_hint, str):
save_name = restore_hint
self.iteration = 0
else:
raise ValueError(
"Error type for -Restore_Hint-, supported: int or string.")
self._load_ckpt(save_name)
def fix_BN(self):
for module in self.modules():
classname = module.__class__.__name__
if classname.find('BatchNorm') != -1:
module.eval()
def inputs_pretreament(self, inputs):
"""Conduct transforms on input data.
Args:
inputs: the input data.
Returns:
tuple: training data including inputs, labels, and some meta data.
"""
seqs_batch, labs_batch, typs_batch, vies_batch, seqL_batch = inputs
trf_cfgs = self.engine_cfg['transform']
seq_trfs = get_transform(trf_cfgs)
requires_grad = bool(self.training)
seqs = [np2var(np.asarray([trf(fra) for fra in seq]), requires_grad=requires_grad).float()
for trf, seq in zip(seq_trfs, seqs_batch)]
typs = typs_batch
vies = vies_batch
labs = list2var(labs_batch).long()
if seqL_batch is not None:
seqL_batch = np2var(seqL_batch).int()
seqL = seqL_batch
if seqL is not None:
seqL_sum = int(seqL.sum().data.cpu().numpy())
ipts = [_[:, :seqL_sum] for _ in seqs]
else:
ipts = seqs
del seqs
return ipts, labs, typs, vies, seqL
def train_step(self, loss_sum) -> bool:
"""Conduct loss_sum.backward(), self.optimizer.step() and self.scheduler.step().
Args:
loss_sum:The loss of the current batch.
Returns:
bool: True if the training is finished, False otherwise.
"""
self.optimizer.zero_grad()
if loss_sum <= 1e-9:
self.msg_mgr.log_warning(
"Find the loss sum less than 1e-9 but the training process will continue!")
if self.engine_cfg['enable_float16']:
self.Scaler.scale(loss_sum).backward()
self.Scaler.step(self.optimizer)
scale = self.Scaler.get_scale()
self.Scaler.update()
# Warning caused by optimizer skip when NaN
# https://discuss.pytorch.org/t/optimizer-step-before-lr-scheduler-step-error-using-gradscaler/92930/5
if scale != self.Scaler.get_scale():
self.msg_mgr.log_debug("Training step skip. Expected the former scale equals to the present, got {} and {}".format(
scale, self.Scaler.get_scale()))
return False
else:
loss_sum.backward()
self.optimizer.step()
self.iteration += 1
self.scheduler.step()
return True
def inference(self, rank):
"""Inference all the test data.
Args:
rank: the rank of the current process.Transform
Returns:
Odict: contains the inference results.
"""
total_size = len(self.test_loader)
if rank == 0:
pbar = tqdm(total=total_size, desc='Transforming')
else:
pbar = NoOp()
batch_size = self.test_loader.batch_sampler.batch_size
rest_size = total_size
info_dict = Odict()
for inputs in self.test_loader:
ipts = self.inputs_pretreament(inputs)
with autocast(enabled=self.engine_cfg['enable_float16']):
retval = self.forward(ipts)
inference_feat = retval['inference_feat']
for k, v in inference_feat.items():
inference_feat[k] = ddp_all_gather(v, requires_grad=False)
del retval
for k, v in inference_feat.items():
inference_feat[k] = ts2np(v)
info_dict.append(inference_feat)
rest_size -= batch_size
if rest_size >= 0:
update_size = batch_size
else:
update_size = total_size % batch_size
pbar.update(update_size)
pbar.close()
for k, v in info_dict.items():
v = np.concatenate(v)[:total_size]
info_dict[k] = v
return info_dict
@ staticmethod
def run_train(model):
"""Accept the instance object(model) here, and then run the train loop."""
for inputs in model.train_loader:
ipts = model.inputs_pretreament(inputs)
with autocast(enabled=model.engine_cfg['enable_float16']):
retval = model(ipts)
training_feat, visual_summary = retval['training_feat'], retval['visual_summary']
del retval
loss_sum, loss_info = model.loss_aggregator(training_feat)
ok = model.train_step(loss_sum)
if not ok:
continue
visual_summary.update(loss_info)
visual_summary['scalar/learning_rate'] = model.optimizer.param_groups[0]['lr']
model.msg_mgr.train_step(loss_info, visual_summary)
if model.iteration % model.engine_cfg['save_iter'] == 0:
# save the checkpoint
model.save_ckpt(model.iteration)
# run test if with_test = true
if model.engine_cfg['with_test']:
model.msg_mgr.log_info("Running test...")
model.eval()
result_dict = BaseModel.run_test(model)
model.train()
model.msg_mgr.write_to_tensorboard(result_dict)
model.msg_mgr.reset_time()
if model.iteration >= model.engine_cfg['total_iter']:
break
@ staticmethod
def run_test(model):
"""Accept the instance object(model) here, and then run the test loop."""
rank = torch.distributed.get_rank()
with torch.no_grad():
info_dict = model.inference(rank)
if rank == 0:
loader = model.test_loader
label_list = loader.dataset.label_list
types_list = loader.dataset.types_list
views_list = loader.dataset.views_list
info_dict.update({
'labels': label_list, 'types': types_list, 'views': views_list})
if 'eval_func' in model.cfgs["evaluator_cfg"].keys():
eval_func = model.cfgs['evaluator_cfg']["eval_func"]
else:
eval_func = 'identification'
eval_func = getattr(eval_functions, eval_func)
valid_args = get_valid_args(
eval_func, model.cfgs["evaluator_cfg"], ['metric'])
try:
dataset_name = model.cfgs['data_cfg']['test_dataset_name']
except:
dataset_name = model.cfgs['data_cfg']['dataset_name']
return eval_func(info_dict, dataset_name, **valid_args)
+80
View File
@@ -0,0 +1,80 @@
"""The loss aggregator."""
import torch
from . import losses
from utils import is_dict, get_attr_from, get_valid_args, is_tensor, get_ddp_module
from utils import Odict
from utils import get_msg_mgr
class LossAggregator():
"""The loss aggregator.
This class is used to aggregate the losses.
For example, if you have two losses, one is triplet loss, the other is cross entropy loss,
you can aggregate them as follows:
loss_num = tripley_loss + cross_entropy_loss
Attributes:
losses: A dict of losses.
"""
def __init__(self, loss_cfg) -> None:
"""
Initialize the loss aggregator.
Args:
loss_cfg: Config of losses. List for multiple losses.
"""
self.losses = {loss_cfg['log_prefix']: self._build_loss_(loss_cfg)} if is_dict(loss_cfg) \
else {cfg['log_prefix']: self._build_loss_(cfg) for cfg in loss_cfg}
def _build_loss_(self, loss_cfg):
"""Build the losses from loss_cfg.
Args:
loss_cfg: Config of loss.
"""
Loss = get_attr_from([losses], loss_cfg['type'])
valid_loss_arg = get_valid_args(
Loss, loss_cfg, ['type', 'gather_and_scale'])
loss = get_ddp_module(Loss(**valid_loss_arg).cuda())
return loss
def __call__(self, training_feats):
"""Compute the sum of all losses.
The input is a dict of features. The key is the name of loss and the value is the feature and label. If the key not in
built losses and the value is torch.Tensor, then it is the computed loss to be added loss_sum.
Args:
training_feats: A dict of features. The same as the output["training_feat"] of the model.
"""
loss_sum = .0
loss_info = Odict()
for k, v in training_feats.items():
if k in self.losses:
loss_func = self.losses[k]
loss, info = loss_func(**v)
for name, value in info.items():
loss_info['scalar/%s/%s' % (k, name)] = value
loss = loss.mean() * loss_func.loss_term_weight
loss_sum += loss
else:
if isinstance(v, dict):
raise ValueError(
"The key %s in -Trainng-Feat- should be stated as the log_prefix of a certain loss defined in your loss_cfg."%v
)
elif is_tensor(v):
_ = v.mean()
loss_info['scalar/%s' % k] = _
loss_sum += _
get_msg_mgr().log_debug(
"Please check whether %s needed in training." % k)
else:
raise ValueError(
"Error type for -Trainng-Feat-, supported: A feature dict or loss tensor.")
return loss_sum, loss_info
+17
View File
@@ -0,0 +1,17 @@
from inspect import isclass
from pkgutil import iter_modules
from pathlib import Path
from importlib import import_module
# iterate through the modules in the current package
package_dir = Path(__file__).resolve().parent
for (_, module_name, _) in iter_modules([package_dir]):
# import the module and iterate through its attributes
module = import_module(f"{__name__}.{module_name}")
for attribute_name in dir(module):
attribute = getattr(module, attribute_name)
if isclass(attribute):
# Add the class to this package's variables
globals()[attribute_name] = attribute
+59
View File
@@ -0,0 +1,59 @@
from ctypes import ArgumentError
import torch.nn as nn
import torch
from utils import Odict
import functools
from utils import ddp_all_gather
def gather_and_scale_wrapper(func):
"""Internal wrapper: gather the input from multple cards to one card, and scale the loss by the number of cards.
"""
@functools.wraps(func)
def inner(*args, **kwds):
try:
for k, v in kwds.items():
kwds[k] = ddp_all_gather(v)
loss, loss_info = func(*args, **kwds)
loss *= torch.distributed.get_world_size()
return loss, loss_info
except:
raise ArgumentError
return inner
class BaseLoss(nn.Module):
"""
Base class for all losses.
Your loss should also subclass this class.
"""
def __init__(self, loss_term_weight=1.0):
"""
Initialize the base class.
Args:
loss_term_weight: the weight of the loss term.
"""
super(BaseLoss, self).__init__()
self.loss_term_weight = loss_term_weight
self.info = Odict()
def forward(self, logits, labels):
"""
The default forward function.
This function should be overridden by the subclass.
Args:
logits: the logits of the model.
labels: the labels of the data.
Returns:
tuple of loss and info.
"""
return .0, self.info
+48
View File
@@ -0,0 +1,48 @@
import torch
import torch.nn.functional as F
from .base import BaseLoss
class CrossEntropyLoss(BaseLoss):
def __init__(self, scale=2**4, label_smooth=True, eps=0.1, loss_term_weight=1.0, log_accuracy=False):
super(CrossEntropyLoss, self).__init__(loss_term_weight)
self.scale = scale
self.label_smooth = label_smooth
self.eps = eps
self.log_accuracy = log_accuracy
def forward(self, logits, labels):
"""
logits: [n, p, c]
labels: [n]
"""
logits = logits.permute(1, 0, 2).contiguous() # [n, p, c] -> [p, n, c]
p, _, c = logits.size()
log_preds = F.log_softmax(logits * self.scale, dim=-1) # [p, n, c]
one_hot_labels = self.label2one_hot(
labels, c).unsqueeze(0).repeat(p, 1, 1) # [p, n, c]
loss = self.compute_loss(log_preds, one_hot_labels)
self.info.update({'loss': loss.detach().clone()})
if self.log_accuracy:
pred = logits.argmax(dim=-1) # [p, n]
accu = (pred == labels.unsqueeze(0)).float().mean()
self.info.update({'accuracy': accu})
return loss, self.info
def compute_loss(self, predis, labels):
softmax_loss = -(labels * predis).sum(-1) # [p, n]
losses = softmax_loss.mean(-1)
if self.label_smooth:
smooth_loss = - predis.mean(dim=-1) # [p, n]
smooth_loss = smooth_loss.mean() # [p]
smooth_loss = smooth_loss * self.eps
losses = smooth_loss + losses * (1. - self.eps)
return losses
def label2one_hot(self, label, class_num):
label = label.unsqueeze(-1)
batch_size = label.size(0)
device = label.device
return torch.zeros(batch_size, class_num).to(device).scatter(1, label, 1)
+71
View File
@@ -0,0 +1,71 @@
import torch
import torch.nn.functional as F
from .base import BaseLoss, gather_and_scale_wrapper
class TripletLoss(BaseLoss):
def __init__(self, margin, loss_term_weight=1.0):
super(TripletLoss, self).__init__(loss_term_weight)
self.margin = margin
@gather_and_scale_wrapper
def forward(self, embeddings, labels):
# embeddings: [n, p, c], label: [n]
embeddings = embeddings.permute(
1, 0, 2).contiguous() # [n, p, c] -> [p, n, c]
embeddings = embeddings.float()
ref_embed, ref_label = embeddings, labels
dist = self.ComputeDistance(embeddings, ref_embed) # [p, n1, n2]
mean_dist = dist.mean(1).mean(1)
ap_dist, an_dist = self.Convert2Triplets(labels, ref_label, dist)
dist_diff = ap_dist - an_dist
loss = F.relu(dist_diff + self.margin)
hard_loss = torch.max(loss, -1)[0]
loss_avg, loss_num = self.AvgNonZeroReducer(loss)
self.info.update({
'loss': loss_avg.detach().clone(),
'hard_loss': hard_loss.detach().clone(),
'loss_num': loss_num.detach().clone(),
'mean_dist': mean_dist.detach().clone()})
return loss_avg, self.info
def AvgNonZeroReducer(self, loss):
eps = 1.0e-9
loss_sum = loss.sum(-1)
loss_num = (loss != 0).sum(-1).float()
loss_avg = loss_sum / (loss_num + eps)
loss_avg[loss_num == 0] = 0
return loss_avg, loss_num
def ComputeDistance(self, x, y):
"""
x: [p, n_x, c]
y: [p, n_y, c]
"""
x2 = torch.sum(x ** 2, -1).unsqueeze(2) # [p, n_x, 1]
y2 = torch.sum(y ** 2, -1).unsqueeze(1) # [p, 1, n_y]
inner = x.matmul(y.transpose(-1, -2)) # [p, n_x, n_y]
dist = x2 + y2 - 2 * inner
dist = torch.sqrt(F.relu(dist)) # [p, n_x, n_y]
return dist
def Convert2Triplets(self, row_labels, clo_label, dist):
"""
row_labels: tensor with size [n_r]
clo_label : tensor with size [n_c]
"""
matches = (row_labels.unsqueeze(1) ==
clo_label.unsqueeze(0)).byte() # [n_r, n_c]
diffenc = matches ^ 1 # [n_r, n_c]
mask = matches.unsqueeze(2) * diffenc.unsqueeze(1)
a_idx, p_idx, n_idx = torch.where(mask)
ap_dist = dist[:, a_idx, p_idx]
an_dist = dist[:, a_idx, n_idx]
return ap_dist, an_dist
+17
View File
@@ -0,0 +1,17 @@
from inspect import isclass
from pkgutil import iter_modules
from pathlib import Path
from importlib import import_module
# iterate through the modules in the current package
package_dir = Path(__file__).resolve().parent
for (_, module_name, _) in iter_modules([package_dir]):
# import the module and iterate through its attributes
module = import_module(f"{__name__}.{module_name}")
for attribute_name in dir(module):
attribute = getattr(module, attribute_name)
if isclass(attribute):
# Add the class to this package's variables
globals()[attribute_name] = attribute
+54
View File
@@ -0,0 +1,54 @@
import torch
from ..base_model import BaseModel
from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper, SeparateFCs, SeparateBNNecks
class Baseline(BaseModel):
def build_network(self, model_cfg):
self.Backbone = self.get_backbone(model_cfg['backbone_cfg'])
self.Backbone = SetBlockWrapper(self.Backbone)
self.FCs = SeparateFCs(**model_cfg['SeparateFCs'])
self.BNNecks = SeparateBNNecks(**model_cfg['SeparateBNNecks'])
self.TP = PackSequenceWrapper(torch.max)
self.HPP = HorizontalPoolingPyramid(bin_num=model_cfg['bin_num'])
def forward(self, inputs):
ipts, labs, _, _, seqL = inputs
sils = ipts[0]
if len(sils.size()) == 4:
sils = sils.unsqueeze(2)
del ipts
outs = self.Backbone(sils) # [n, s, c, h, w]
# Temporal Pooling, TP
outs = self.TP(outs, seqL, dim=1)[0] # [n, c, h, w]
# Horizontal Pooling Matching, HPM
feat = self.HPP(outs) # [n, c, p]
feat = feat.permute(2, 0, 1).contiguous() # [p, n, c]
embed_1 = self.FCs(feat) # [p, n, c]
embed_2, logits = self.BNNecks(embed_1) # [p, n, c]
embed_1 = embed_1.permute(1, 0, 2).contiguous() # [n, p, c]
embed_2 = embed_2.permute(1, 0, 2).contiguous() # [n, p, c]
logits = logits.permute(1, 0, 2).contiguous() # [n, p, c]
embed = embed_1
n, s, _, h, w = sils.size()
retval = {
'training_feat': {
'triplet': {'embeddings': embed_1, 'labels': labs},
'softmax': {'logits': logits, 'labels': labs}
},
'visual_summary': {
'image/sils': sils.view(n*s, 1, h, w)
},
'inference_feat': {
'embeddings': embed
}
}
return retval
+193
View File
@@ -0,0 +1,193 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..base_model import BaseModel
from ..modules import SeparateFCs, BasicConv3d, PackSequenceWrapper
class GLConv(nn.Module):
def __init__(self, in_channels, out_channels, halving, fm_sign=False, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False, **kwargs):
super(GLConv, self).__init__()
self.halving = halving
self.fm_sign = fm_sign
self.global_conv3d = BasicConv3d(
in_channels, out_channels, kernel_size, stride, padding, bias, **kwargs)
self.local_conv3d = BasicConv3d(
in_channels, out_channels, kernel_size, stride, padding, bias, **kwargs)
def forward(self, x):
'''
x: [n, c, s, h, w]
'''
gob_feat = self.global_conv3d(x)
if self.halving == 0:
lcl_feat = self.local_conv3d(x)
else:
h = x.size(3)
split_size = int(h // 2**self.halving)
lcl_feat = x.split(split_size, 3)
lcl_feat = torch.cat([self.local_conv3d(_) for _ in lcl_feat], 3)
if not self.fm_sign:
feat = F.leaky_relu(gob_feat) + F.leaky_relu(lcl_feat)
else:
feat = F.leaky_relu(torch.cat([gob_feat, lcl_feat], dim=3))
return feat
class GeMHPP(nn.Module):
def __init__(self, bin_num=[64], p=6.5, eps=1.0e-6):
super(GeMHPP, self).__init__()
self.bin_num = bin_num
self.p = nn.Parameter(
torch.ones(1)*p)
self.eps = eps
def gem(self, ipts):
return F.avg_pool2d(ipts.clamp(min=self.eps).pow(self.p), (1, ipts.size(-1))).pow(1. / self.p)
def forward(self, x):
"""
x : [n, c, h, w]
ret: [n, c, p]
"""
n, c = x.size()[:2]
features = []
for b in self.bin_num:
z = x.view(n, c, b, -1)
z = self.gem(z).squeeze(-1)
features.append(z)
return torch.cat(features, -1)
class GaitGL(BaseModel):
"""
GaitGL: Gait Recognition via Effective Global-Local Feature Representation and Local Temporal Aggregation
Arxiv : https://arxiv.org/pdf/2011.01461.pdf
"""
def __init__(self, *args, **kargs):
super(GaitGL, self).__init__(*args, **kargs)
def build_network(self, model_cfg):
in_c = model_cfg['channels']
class_num = model_cfg['class_num']
dataset_name = self.cfgs['data_cfg']['dataset_name']
if dataset_name == 'OUMVLP':
# For OUMVLP
self.conv3d = nn.Sequential(
BasicConv3d(1, in_c[0], kernel_size=(3, 3, 3),
stride=(1, 1, 1), padding=(1, 1, 1)),
nn.LeakyReLU(inplace=True),
BasicConv3d(in_c[0], in_c[0], kernel_size=(3, 3, 3),
stride=(1, 1, 1), padding=(1, 1, 1)),
nn.LeakyReLU(inplace=True),
)
self.LTA = nn.Sequential(
BasicConv3d(in_c[0], in_c[0], kernel_size=(
3, 1, 1), stride=(3, 1, 1), padding=(0, 0, 0)),
nn.LeakyReLU(inplace=True)
)
self.GLConvA0 = nn.Sequential(
GLConv(in_c[0], in_c[1], halving=1, fm_sign=False, kernel_size=(
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)),
GLConv(in_c[1], in_c[1], halving=1, fm_sign=False, kernel_size=(
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)),
)
self.MaxPool0 = nn.MaxPool3d(
kernel_size=(1, 2, 2), stride=(1, 2, 2))
self.GLConvA1 = nn.Sequential(
GLConv(in_c[1], in_c[2], halving=1, fm_sign=False, kernel_size=(
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)),
GLConv(in_c[2], in_c[2], halving=1, fm_sign=False, kernel_size=(
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)),
)
self.GLConvB2 = nn.Sequential(
GLConv(in_c[2], in_c[3], halving=1, fm_sign=False, kernel_size=(
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)),
GLConv(in_c[3], in_c[3], halving=1, fm_sign=True, kernel_size=(
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)),
)
else:
# For CASIA-B or other unstated datasets.
self.conv3d = nn.Sequential(
BasicConv3d(1, in_c[0], kernel_size=(3, 3, 3),
stride=(1, 1, 1), padding=(1, 1, 1)),
nn.LeakyReLU(inplace=True)
)
self.LTA = nn.Sequential(
BasicConv3d(in_c[0], in_c[0], kernel_size=(
3, 1, 1), stride=(3, 1, 1), padding=(0, 0, 0)),
nn.LeakyReLU(inplace=True)
)
self.GLConvA0 = GLConv(in_c[0], in_c[1], halving=3, fm_sign=False, kernel_size=(
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
self.MaxPool0 = nn.MaxPool3d(
kernel_size=(1, 2, 2), stride=(1, 2, 2))
self.GLConvA1 = GLConv(in_c[1], in_c[2], halving=3, fm_sign=False, kernel_size=(
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
self.GLConvB2 = GLConv(in_c[2], in_c[2], halving=3, fm_sign=True, kernel_size=(
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
self.Head0 = SeparateFCs(64, in_c[-1], in_c[-1])
self.Bn = nn.BatchNorm1d(in_c[-1])
self.Head1 = SeparateFCs(64, in_c[-1], class_num)
self.TP = PackSequenceWrapper(torch.max)
self.HPP = GeMHPP()
def forward(self, inputs):
ipts, labs, _, _, seqL = inputs
seqL = None if not self.training else seqL
if not self.training and len(labs) != 1:
raise ValueError(
'The input size of each GPU must be 1 in testing mode, but got {}!'.format(len(labs)))
sils = ipts[0].unsqueeze(1)
del ipts
n, _, s, h, w = sils.size()
if s < 3:
repeat = 3 if s == 1 else 2
sils = sils.repeat(1, 1, repeat, 1, 1)
outs = self.conv3d(sils)
outs = self.LTA(outs)
outs = self.GLConvA0(outs)
outs = self.MaxPool0(outs)
outs = self.GLConvA1(outs)
outs = self.GLConvB2(outs) # [n, c, s, h, w]
outs = self.TP(outs, dim=2, seq_dim=2, seqL=seqL)[0] # [n, c, h, w]
outs = self.HPP(outs) # [n, c, p]
outs = outs.permute(2, 0, 1).contiguous() # [p, n, c]
gait = self.Head0(outs) # [p, n, c]
gait = gait.permute(1, 2, 0).contiguous() # [n, c, p]
bnft = self.Bn(gait) # [n, c, p]
logi = self.Head1(bnft.permute(2, 0, 1).contiguous()) # [p, n, c]
gait = gait.permute(0, 2, 1).contiguous() # [n, p, c]
bnft = bnft.permute(0, 2, 1).contiguous() # [n, p, c]
logi = logi.permute(1, 0, 2).contiguous() # [n, p, c]
n, _, s, h, w = sils.size()
retval = {
'training_feat': {
'triplet': {'embeddings': bnft, 'labels': labs},
'softmax': {'logits': logi, 'labels': labs}
},
'visual_summary': {
'image/sils': sils.view(n*s, 1, h, w)
},
'inference_feat': {
'embeddings': bnft
}
}
return retval
+127
View File
@@ -0,0 +1,127 @@
import torch
import torch.nn as nn
from ..base_model import BaseModel
from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper, SeparateFCs
from utils import clones
class BasicConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
super(BasicConv1d, self).__init__()
self.conv = nn.Conv1d(in_channels, out_channels,
kernel_size, bias=False, **kwargs)
def forward(self, x):
ret = self.conv(x)
return ret
class TemporalFeatureAggregator(nn.Module):
def __init__(self, in_channels, squeeze=4, parts_num=16):
super(TemporalFeatureAggregator, self).__init__()
hidden_dim = int(in_channels // squeeze)
self.parts_num = parts_num
# MTB1
conv3x1 = nn.Sequential(
BasicConv1d(in_channels, hidden_dim, 3, padding=1),
nn.LeakyReLU(inplace=True),
BasicConv1d(hidden_dim, in_channels, 1))
self.conv1d3x1 = clones(conv3x1, parts_num)
self.avg_pool3x1 = nn.AvgPool1d(3, stride=1, padding=1)
self.max_pool3x1 = nn.MaxPool1d(3, stride=1, padding=1)
# MTB1
conv3x3 = nn.Sequential(
BasicConv1d(in_channels, hidden_dim, 3, padding=1),
nn.LeakyReLU(inplace=True),
BasicConv1d(hidden_dim, in_channels, 3, padding=1))
self.conv1d3x3 = clones(conv3x3, parts_num)
self.avg_pool3x3 = nn.AvgPool1d(5, stride=1, padding=2)
self.max_pool3x3 = nn.MaxPool1d(5, stride=1, padding=2)
# Temporal Pooling, TP
self.TP = torch.max
def forward(self, x):
"""
Input: x, [n, s, c, p]
Output: ret, [n, p, c]
"""
n, s, c, p = x.size()
x = x.permute(3, 0, 2, 1).contiguous() # [p, n, c, s]
feature = x.split(1, 0) # [[n, c, s], ...]
x = x.view(-1, c, s)
# MTB1: ConvNet1d & Sigmoid
logits3x1 = torch.cat([conv(_.squeeze(0)).unsqueeze(0)
for conv, _ in zip(self.conv1d3x1, feature)], 0)
scores3x1 = torch.sigmoid(logits3x1)
# MTB1: Template Function
feature3x1 = self.avg_pool3x1(x) + self.max_pool3x1(x)
feature3x1 = feature3x1.view(p, n, c, s)
feature3x1 = feature3x1 * scores3x1
# MTB2: ConvNet1d & Sigmoid
logits3x3 = torch.cat([conv(_.squeeze(0)).unsqueeze(0)
for conv, _ in zip(self.conv1d3x3, feature)], 0)
scores3x3 = torch.sigmoid(logits3x3)
# MTB2: Template Function
feature3x3 = self.avg_pool3x3(x) + self.max_pool3x3(x)
feature3x3 = feature3x3.view(p, n, c, s)
feature3x3 = feature3x3 * scores3x3
# Temporal Pooling
ret = self.TP(feature3x1 + feature3x3, dim=-1)[0] # [p, n, c]
ret = ret.permute(1, 0, 2).contiguous() # [n, p, c]
return ret
class GaitPart(BaseModel):
def __init__(self, *args, **kargs):
super(GaitPart, self).__init__(*args, **kargs)
"""
GaitPart: Temporal Part-based Model for Gait Recognition
Paper: https://openaccess.thecvf.com/content_CVPR_2020/papers/Fan_GaitPart_Temporal_Part-Based_Model_for_Gait_Recognition_CVPR_2020_paper.pdf
Github: https://github.com/ChaoFan96/GaitPart
"""
def build_network(self, model_cfg):
self.Backbone = self.get_backbone(model_cfg['backbone_cfg'])
head_cfg = model_cfg['SeparateFCs']
self.Head = SeparateFCs(**model_cfg['SeparateFCs'])
self.Backbone = SetBlockWrapper(self.Backbone)
self.HPP = SetBlockWrapper(
HorizontalPoolingPyramid(bin_num=model_cfg['bin_num']))
self.TFA = PackSequenceWrapper(TemporalFeatureAggregator(
in_channels=head_cfg['in_channels'], parts_num=head_cfg['parts_num']))
def forward(self, inputs):
ipts, labs, _, _, seqL = inputs
sils = ipts[0]
if len(sils.size()) == 4:
sils = sils.unsqueeze(2)
del ipts
out = self.Backbone(sils) # [n, s, c, h, w]
out = self.HPP(out) # [n, s, c, p]
out = self.TFA(out, seqL) # [n, p, c]
embs = self.Head(out.permute(1, 0, 2).contiguous()) # [p, n, c]
embs = embs.permute(1, 0, 2).contiguous() # [n, p, c]
n, s, _, h, w = sils.size()
retval = {
'training_feat': {
'triplet': {'embeddings': embs, 'labels': labs}
},
'visual_summary': {
'image/sils': sils.view(n*s, 1, h, w)
},
'inference_feat': {
'embeddings': embs
}
}
return retval
+87
View File
@@ -0,0 +1,87 @@
import torch
import copy
import torch.nn as nn
from ..base_model import BaseModel
from ..modules import SeparateFCs, BasicConv2d, SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper
class GaitSet(BaseModel):
"""
GaitSet: Regarding Gait as a Set for Cross-View Gait Recognition
Arxiv: https://arxiv.org/abs/1811.06186
Github: https://github.com/AbnerHqC/GaitSet
"""
def build_network(self, model_cfg):
in_c = model_cfg['in_channels']
self.set_block1 = nn.Sequential(BasicConv2d(in_c[0], in_c[1], 5, 1, 2),
nn.LeakyReLU(inplace=True),
BasicConv2d(in_c[1], in_c[1], 3, 1, 1),
nn.LeakyReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2))
self.set_block2 = nn.Sequential(BasicConv2d(in_c[1], in_c[2], 3, 1, 1),
nn.LeakyReLU(inplace=True),
BasicConv2d(in_c[2], in_c[2], 3, 1, 1),
nn.LeakyReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2))
self.set_block3 = nn.Sequential(BasicConv2d(in_c[2], in_c[3], 3, 1, 1),
nn.LeakyReLU(inplace=True),
BasicConv2d(in_c[3], in_c[3], 3, 1, 1),
nn.LeakyReLU(inplace=True))
self.gl_block2 = copy.deepcopy(self.set_block2)
self.gl_block3 = copy.deepcopy(self.set_block3)
self.set_block1 = SetBlockWrapper(self.set_block1)
self.set_block2 = SetBlockWrapper(self.set_block2)
self.set_block3 = SetBlockWrapper(self.set_block3)
self.set_pooling = PackSequenceWrapper(torch.max)
self.Head = SeparateFCs(**model_cfg['SeparateFCs'])
self.HPP = HorizontalPoolingPyramid(bin_num=model_cfg['bin_num'])
def forward(self, inputs):
ipts, labs, _, _, seqL = inputs
sils = ipts[0] # [n, s, h, w]
if len(sils.size()) == 4:
sils = sils.unsqueeze(2)
del ipts
outs = self.set_block1(sils)
gl = self.set_pooling(outs, seqL, dim=1)[0]
gl = self.gl_block2(gl)
outs = self.set_block2(outs)
gl = gl + self.set_pooling(outs, seqL, dim=1)[0]
gl = self.gl_block3(gl)
outs = self.set_block3(outs)
outs = self.set_pooling(outs, seqL, dim=1)[0]
gl = gl + outs
# Horizontal Pooling Matching, HPM
feature1 = self.HPP(outs) # [n, c, p]
feature2 = self.HPP(gl) # [n, c, p]
feature = torch.cat([feature1, feature2], -1) # [n, c, p]
feature = feature.permute(2, 0, 1).contiguous() # [p, n, c]
embs = self.Head(feature)
embs = embs.permute(1, 0, 2).contiguous() # [n, p, c]
n, s, _, h, w = sils.size()
retval = {
'training_feat': {
'triplet': {'embeddings': embs, 'labels': labs}
},
'visual_summary': {
'image/sils': sils.view(n*s, 1, h, w)
},
'inference_feat': {
'embeddings': embs
}
}
return retval
+171
View File
@@ -0,0 +1,171 @@
import torch
import copy
import torch.nn as nn
import torch.nn.functional as F
from ..base_model import BaseModel
from ..modules import SeparateFCs, BasicConv2d, SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper
class GLN(BaseModel):
"""
http://home.ustc.edu.cn/~saihui/papers/eccv2020_gln.pdf
Gait Lateral Network: Learning Discriminative and Compact Representations for Gait Recognition
"""
def build_network(self, model_cfg):
in_channels = model_cfg['in_channels']
self.bin_num = model_cfg['bin_num']
self.hidden_dim = model_cfg['hidden_dim']
lateral_dim = model_cfg['lateral_dim']
reduce_dim = self.hidden_dim
self.pretrain = model_cfg['Lateral_pretraining']
self.sil_stage_0 = nn.Sequential(BasicConv2d(in_channels[0], in_channels[1], 5, 1, 2),
nn.LeakyReLU(inplace=True),
BasicConv2d(
in_channels[1], in_channels[1], 3, 1, 1),
nn.LeakyReLU(inplace=True))
self.sil_stage_1 = nn.Sequential(BasicConv2d(in_channels[1], in_channels[2], 3, 1, 1),
nn.LeakyReLU(inplace=True),
BasicConv2d(
in_channels[2], in_channels[2], 3, 1, 1),
nn.LeakyReLU(inplace=True))
self.sil_stage_2 = nn.Sequential(BasicConv2d(in_channels[2], in_channels[3], 3, 1, 1),
nn.LeakyReLU(inplace=True),
BasicConv2d(
in_channels[3], in_channels[3], 3, 1, 1),
nn.LeakyReLU(inplace=True))
self.set_stage_1 = copy.deepcopy(self.sil_stage_1)
self.set_stage_2 = copy.deepcopy(self.sil_stage_2)
self.set_pooling = PackSequenceWrapper(torch.max)
self.MaxP_sil = SetBlockWrapper(nn.MaxPool2d(kernel_size=2, stride=2))
self.MaxP_set = nn.MaxPool2d(kernel_size=2, stride=2)
self.sil_stage_0 = SetBlockWrapper(self.sil_stage_0)
self.sil_stage_1 = SetBlockWrapper(self.sil_stage_1)
self.sil_stage_2 = SetBlockWrapper(self.sil_stage_2)
self.lateral_layer1 = nn.Conv2d(
in_channels[1]*2, lateral_dim, kernel_size=1, stride=1, padding=0, bias=False)
self.lateral_layer2 = nn.Conv2d(
in_channels[2]*2, lateral_dim, kernel_size=1, stride=1, padding=0, bias=False)
self.lateral_layer3 = nn.Conv2d(
in_channels[3]*2, lateral_dim, kernel_size=1, stride=1, padding=0, bias=False)
self.smooth_layer1 = nn.Conv2d(
lateral_dim, lateral_dim, kernel_size=3, stride=1, padding=1, bias=False)
self.smooth_layer2 = nn.Conv2d(
lateral_dim, lateral_dim, kernel_size=3, stride=1, padding=1, bias=False)
self.smooth_layer3 = nn.Conv2d(
lateral_dim, lateral_dim, kernel_size=3, stride=1, padding=1, bias=False)
self.HPP = HorizontalPoolingPyramid()
self.Head = SeparateFCs(**model_cfg['SeparateFCs'])
if not self.pretrain:
self.encoder_bn = nn.BatchNorm1d(sum(self.bin_num)*3*self.hidden_dim)
self.encoder_bn.bias.requires_grad_(False)
self.reduce_dp = nn.Dropout(p=model_cfg['dropout'])
self.reduce_ac = nn.ReLU(inplace=True)
self.reduce_fc = nn.Linear(sum(self.bin_num)*3*self.hidden_dim, reduce_dim, bias=False)
self.reduce_bn = nn.BatchNorm1d(reduce_dim)
self.reduce_bn.bias.requires_grad_(False)
self.reduce_cls = nn.Linear(reduce_dim, model_cfg['class_num'], bias=False)
def upsample_add(self, x, y):
return F.interpolate(x, scale_factor=2, mode='nearest') + y
def forward(self, inputs):
ipts, labs, _, _, seqL = inputs
sils = ipts[0] # [n, s, h, w]
del ipts
if len(sils.size()) == 4:
sils = sils.unsqueeze(2)
n, s, _, h, w = sils.size()
### stage 0 sil ###
sil_0_outs = self.sil_stage_0(sils)
stage_0_sil_set = self.set_pooling(sil_0_outs, seqL, dim=1)[0]
### stage 1 sil ###
sil_1_ipts = self.MaxP_sil(sil_0_outs)
sil_1_outs = self.sil_stage_1(sil_1_ipts)
### stage 2 sil ###
sil_2_ipts = self.MaxP_sil(sil_1_outs)
sil_2_outs = self.sil_stage_2(sil_2_ipts)
### stage 1 set ###
set_1_ipts = self.set_pooling(sil_1_ipts, seqL, dim=1)[0]
stage_1_sil_set = self.set_pooling(sil_1_outs, seqL, dim=1)[0]
set_1_outs = self.set_stage_1(set_1_ipts) + stage_1_sil_set
### stage 2 set ###
set_2_ipts = self.MaxP_set(set_1_outs)
stage_2_sil_set = self.set_pooling(sil_2_outs, seqL, dim=1)[0]
set_2_outs = self.set_stage_2(set_2_ipts) + stage_2_sil_set
set1 = torch.cat((stage_0_sil_set, stage_0_sil_set), dim=1)
set2 = torch.cat((stage_1_sil_set, set_1_outs), dim=1)
set3 = torch.cat((stage_2_sil_set, set_2_outs), dim=1)
# print(set1.shape,set2.shape,set3.shape,"***\n")
# lateral
set3 = self.lateral_layer3(set3)
set2 = self.upsample_add(set3, self.lateral_layer2(set2))
set1 = self.upsample_add(set2, self.lateral_layer1(set1))
set3 = self.smooth_layer3(set3)
set2 = self.smooth_layer2(set2)
set1 = self.smooth_layer1(set1)
set1 = self.HPP(set1)
set2 = self.HPP(set2)
set3 = self.HPP(set3)
feature = torch.cat([set1, set2, set3], -
1).permute(2, 0, 1).contiguous()
feature = self.Head(feature)
feature = feature.permute(1, 0, 2).contiguous() # n p c
# compact_bloack
if not self.pretrain:
bn_feature = self.encoder_bn(feature.view(n, -1))
bn_feature = bn_feature.view(*feature.shape).contiguous()
reduce_feature = self.reduce_dp(bn_feature)
reduce_feature = self.reduce_ac(reduce_feature)
reduce_feature = self.reduce_fc(reduce_feature.view(n, -1))
bn_reduce_feature = self.reduce_bn(reduce_feature)
logits = self.reduce_cls(bn_reduce_feature).unsqueeze(1) # n c
reduce_feature = reduce_feature.unsqueeze(1).contiguous()
bn_reduce_feature = bn_reduce_feature.unsqueeze(1).contiguous()
retval = {
'training_feat': {},
'visual_summary': {
'image/sils': sils.view(n*s, 1, h, w)
},
'inference_feat': {
'embeddings': feature # reduce_feature # bn_reduce_feature
}
}
if self.pretrain:
retval['training_feat']['triplet'] = {'embeddings': feature, 'labels': labs}
else:
retval['training_feat']['triplet'] = {'embeddings': feature, 'labels': labs}
retval['training_feat']['softmax'] = {'logits': logits, 'labels': labs}
return retval
+193
View File
@@ -0,0 +1,193 @@
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from utils import clones, is_list_or_tuple
class HorizontalPoolingPyramid():
"""
Horizontal Pyramid Matching for Person Re-identification
Arxiv: https://arxiv.org/abs/1804.05275
Github: https://github.com/SHI-Labs/Horizontal-Pyramid-Matching
"""
def __init__(self, bin_num=None):
if bin_num is None:
bin_num = [16, 8, 4, 2, 1]
self.bin_num = bin_num
def __call__(self, x):
"""
x : [n, c, h, w]
ret: [n, c, p]
"""
n, c = x.size()[:2]
features = []
for b in self.bin_num:
z = x.view(n, c, b, -1)
z = z.mean(-1) + z.max(-1)[0]
features.append(z)
return torch.cat(features, -1)
class SetBlockWrapper(nn.Module):
def __init__(self, forward_block):
super(SetBlockWrapper, self).__init__()
self.forward_block = forward_block
def forward(self, x, *args, **kwargs):
"""
In x: [n, s, c, h, w]
Out x: [n, s, ...]
"""
n, s, c, h, w = x.size()
x = self.forward_block(x.view(-1, c, h, w), *args, **kwargs)
input_size = x.size()
output_size = [n, s] + [*input_size[1:]]
return x.view(*output_size)
class PackSequenceWrapper(nn.Module):
def __init__(self, pooling_func):
super(PackSequenceWrapper, self).__init__()
self.pooling_func = pooling_func
def forward(self, seqs, seqL, seq_dim=1, **kwargs):
"""
In seqs: [n, s, ...]
Out rets: [n, ...]
"""
if seqL is None:
return self.pooling_func(seqs, **kwargs)
seqL = seqL[0].data.cpu().numpy().tolist()
start = [0] + np.cumsum(seqL).tolist()[:-1]
rets = []
for curr_start, curr_seqL in zip(start, seqL):
narrowed_seq = seqs.narrow(seq_dim, curr_start, curr_seqL)
# save the memory
# splited_narrowed_seq = torch.split(narrowed_seq, 256, dim=1)
# ret = []
# for seq_to_pooling in splited_narrowed_seq:
# ret.append(self.pooling_func(seq_to_pooling, keepdim=True, **kwargs)
# [0] if self.is_tuple_result else self.pooling_func(seq_to_pooling, **kwargs))
rets.append(self.pooling_func(narrowed_seq, **kwargs))
if len(rets) > 0 and is_list_or_tuple(rets[0]):
return [torch.cat([ret[j] for ret in rets])
for j in range(len(rets[0]))]
return torch.cat(rets)
class BasicConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, **kwargs):
super(BasicConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size,
stride=stride, padding=padding, bias=False, **kwargs)
def forward(self, x):
x = self.conv(x)
return x
class SeparateFCs(nn.Module):
def __init__(self, parts_num, in_channels, out_channels, norm=False):
super(SeparateFCs, self).__init__()
self.p = parts_num
self.fc_bin = nn.Parameter(
nn.init.xavier_uniform_(
torch.zeros(parts_num, in_channels, out_channels)))
self.norm = norm
def forward(self, x):
"""
x: [p, n, c]
"""
if self.norm:
out = x.matmul(F.normalize(self.fc_bin, dim=1))
else:
out = x.matmul(self.fc_bin)
return out
class SeparateBNNecks(nn.Module):
"""
GaitSet: Bag of Tricks and a Strong Baseline for Deep Person Re-Identification
CVPR Workshop: https://openaccess.thecvf.com/content_CVPRW_2019/papers/TRMTMCT/Luo_Bag_of_Tricks_and_a_Strong_Baseline_for_Deep_Person_CVPRW_2019_paper.pdf
Github: https://github.com/michuanhaohao/reid-strong-baseline
"""
def __init__(self, parts_num, in_channels, class_num, norm=True, parallel_BN1d=True):
super(SeparateBNNecks, self).__init__()
self.p = parts_num
self.class_num = class_num
self.norm = norm
self.fc_bin = nn.Parameter(
nn.init.xavier_uniform_(
torch.zeros(parts_num, in_channels, class_num)))
if parallel_BN1d:
self.bn1d = nn.BatchNorm1d(in_channels * parts_num)
else:
self.bn1d = clones(nn.BatchNorm1d(in_channels), parts_num)
self.parallel_BN1d = parallel_BN1d
def forward(self, x):
"""
x: [p, n, c]
"""
if self.parallel_BN1d:
p, n, c = x.size()
x = x.transpose(0, 1).contiguous().view(n, -1) # [n, p*c]
x = self.bn1d(x)
x = x.view(n, p, c).permute(1, 0, 2).contiguous()
else:
x = torch.cat([bn(_.squeeze(0)).unsqueeze(0)
for _, bn in zip(x.split(1, 0), self.bn1d)], 0) # [p, n, c]
if self.norm:
feature = F.normalize(x, dim=-1) # [p, n, c]
logits = feature.matmul(F.normalize(
self.fc_bin, dim=1)) # [p, n, c]
else:
feature = x
logits = feature.matmul(self.fc_bin)
return feature, logits
class FocalConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, halving, **kwargs):
super(FocalConv2d, self).__init__()
self.halving = halving
self.conv = nn.Conv2d(in_channels, out_channels,
kernel_size, bias=False, **kwargs)
def forward(self, x):
if self.halving == 0:
z = self.conv(x)
else:
h = x.size(2)
split_size = int(h // 2**self.halving)
z = x.split(split_size, 2)
z = torch.cat([self.conv(_) for _ in z], 2)
return z
class BasicConv3d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False, **kwargs):
super(BasicConv3d, self).__init__()
self.conv3d = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size,
stride=stride, padding=padding, bias=bias, **kwargs)
def forward(self, ipts):
'''
ipts: [n, c, s, h, w]
outs: [n, c, s, h, w]
'''
outs = self.conv3d(ipts)
return outs
def RmBN2dAffine(model):
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
m.weight.requires_grad = False
m.bias.requires_grad = False