OpenGait release(pre-beta version).
This commit is contained in:
@@ -0,0 +1,113 @@
|
||||
import math
|
||||
import random
|
||||
import numpy as np
|
||||
from utils import get_msg_mgr
|
||||
|
||||
class CollateFn(object):
|
||||
def __init__(self, label_set, sample_config):
|
||||
self.label_set = label_set
|
||||
sample_type = sample_config['sample_type']
|
||||
sample_type = sample_type.split('_')
|
||||
self.sampler = sample_type[0]
|
||||
self.ordered = sample_type[1]
|
||||
if self.sampler not in ['fixed', 'unfixed', 'all']:
|
||||
raise ValueError
|
||||
if self.ordered not in ['ordered', 'unordered']:
|
||||
raise ValueError
|
||||
self.ordered = sample_type[1] == 'ordered'
|
||||
|
||||
# fixed cases
|
||||
if self.sampler == 'fixed':
|
||||
self.frames_num_fixed = sample_config['frames_num_fixed']
|
||||
|
||||
# unfixed cases
|
||||
if self.sampler == 'unfixed':
|
||||
self.frames_num_max = sample_config['frames_num_max']
|
||||
self.frames_num_min = sample_config['frames_num_min']
|
||||
|
||||
if self.sampler != 'all' and self.ordered:
|
||||
self.frames_skip_num = sample_config['frames_skip_num']
|
||||
|
||||
self.frames_all_limit = -1
|
||||
if self.sampler == 'all' and 'frames_all_limit' in sample_config:
|
||||
self.frames_all_limit = sample_config['frames_all_limit']
|
||||
|
||||
def __call__(self, batch):
|
||||
batch_size = len(batch)
|
||||
feature_num = len(batch[0][0])
|
||||
seqs_batch, labs_batch, typs_batch, vies_batch = [], [], [], []
|
||||
|
||||
for bt in batch:
|
||||
seqs_batch.append(bt[0])
|
||||
labs_batch.append(self.label_set.index(bt[1][0]))
|
||||
typs_batch.append(bt[1][1])
|
||||
vies_batch.append(bt[1][2])
|
||||
|
||||
global count
|
||||
count = 0
|
||||
|
||||
def sample_frames(seqs):
|
||||
global count
|
||||
sampled_fras = [[] for i in range(feature_num)]
|
||||
seq_len = len(seqs[0])
|
||||
indices = list(range(seq_len))
|
||||
|
||||
if self.sampler in ['fixed', 'unfixed']:
|
||||
if self.sampler == 'fixed':
|
||||
frames_num = self.frames_num_fixed
|
||||
else:
|
||||
frames_num = random.choice(
|
||||
list(range(self.frames_num_min, self.frames_num_max+1)))
|
||||
|
||||
if self.ordered:
|
||||
fs_n = frames_num + self.frames_skip_num
|
||||
if seq_len < fs_n:
|
||||
it = math.ceil(fs_n / seq_len)
|
||||
seq_len = seq_len * it
|
||||
indices = indices * it
|
||||
|
||||
start = random.choice(list(range(0, seq_len - fs_n + 1)))
|
||||
end = start + fs_n
|
||||
idx_lst = list(range(seq_len))
|
||||
idx_lst = idx_lst[start:end]
|
||||
idx_lst = sorted(np.random.choice(
|
||||
idx_lst, frames_num, replace=False))
|
||||
indices = [indices[i] for i in idx_lst]
|
||||
else:
|
||||
replace = seq_len < frames_num
|
||||
|
||||
if seq_len == 0:
|
||||
get_msg_mgr().log_debug('Find no frames in the sequence %s-%s-%s.'
|
||||
%(str(labs_batch[count]), str(typs_batch[count]), str(vies_batch[count])))
|
||||
|
||||
count += 1
|
||||
indices = np.random.choice(
|
||||
indices, frames_num, replace=replace)
|
||||
|
||||
for i in range(feature_num):
|
||||
for j in indices[:self.frames_all_limit] if self.frames_all_limit > -1 and len(indices) > self.frames_all_limit else indices:
|
||||
sampled_fras[i].append(seqs[i][j])
|
||||
return sampled_fras
|
||||
|
||||
# f: feature_num
|
||||
# b: batch_size
|
||||
# p: batch_size_per_gpu
|
||||
# g: gpus_num
|
||||
fras_batch = [sample_frames(seqs) for seqs in seqs_batch] # [b, f]
|
||||
batch = [fras_batch, labs_batch, typs_batch, vies_batch, None]
|
||||
|
||||
if self.sampler == "fixed":
|
||||
fras_batch = [[np.asarray(fras_batch[i][j]) for i in range(batch_size)]
|
||||
for j in range(feature_num)] # [f, b]
|
||||
else:
|
||||
seqL_batch = [[len(fras_batch[i][0])
|
||||
for i in range(batch_size)]] # [1, p]
|
||||
|
||||
def my_cat(k): return np.concatenate(
|
||||
[fras_batch[i][k] for i in range(batch_size)], 0)
|
||||
fras_batch = [[my_cat(k)] for k in range(feature_num)] # [f, g]
|
||||
|
||||
batch[-1] = np.asarray(seqL_batch)
|
||||
|
||||
batch[0] = fras_batch
|
||||
return batch
|
||||
@@ -0,0 +1,123 @@
|
||||
import os
|
||||
import pickle
|
||||
import os.path as osp
|
||||
import torch.utils.data as tordata
|
||||
import json
|
||||
from utils import get_msg_mgr
|
||||
|
||||
|
||||
class DataSet(tordata.Dataset):
|
||||
def __init__(self, data_cfg, training):
|
||||
"""
|
||||
seqs_info: the list with each element indicating
|
||||
a certain gait sequence presented as [label, type, view, paths];
|
||||
"""
|
||||
self.__dataset_parser(data_cfg, training)
|
||||
self.cache = data_cfg['cache']
|
||||
self.label_list = [seq_info[0] for seq_info in self.seqs_info]
|
||||
self.types_list = [seq_info[1] for seq_info in self.seqs_info]
|
||||
self.views_list = [seq_info[2] for seq_info in self.seqs_info]
|
||||
|
||||
self.label_set = sorted(list(set(self.label_list)))
|
||||
self.types_set = sorted(list(set(self.types_list)))
|
||||
self.views_set = sorted(list(set(self.views_list)))
|
||||
self.seqs_data = [None] * len(self)
|
||||
self.indices_dict = {label: [] for label in self.label_set}
|
||||
for i, seq_info in enumerate(self.seqs_info):
|
||||
self.indices_dict[seq_info[0]].append(i)
|
||||
if self.cache:
|
||||
self.__load_all_data()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.seqs_info)
|
||||
|
||||
def __loader__(self, paths):
|
||||
paths = sorted(paths)
|
||||
data_list = []
|
||||
for pth in paths:
|
||||
if pth.endswith('.pkl'):
|
||||
with open(pth, 'rb') as f:
|
||||
_ = pickle.load(f)
|
||||
f.close()
|
||||
else:
|
||||
raise ValueError('- Loader - just support .pkl !!!')
|
||||
# if len(_) >= 200:
|
||||
# _ = _[:200]
|
||||
data_list.append(_)
|
||||
for data in data_list:
|
||||
if len(data) != len(data_list[0]):
|
||||
raise AssertionError
|
||||
|
||||
return data_list
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if not self.cache:
|
||||
data_lst = self.__loader__(self.seqs_info[idx][-1])
|
||||
elif self.seqs_data[idx] is None:
|
||||
data_lst = self.__loader__(self.seqs_info[idx][-1])
|
||||
self.seqs_data[idx] = data_lst
|
||||
else:
|
||||
data_lst = self.seqs_data[idx]
|
||||
seq_info = self.seqs_info[idx]
|
||||
return data_lst, seq_info
|
||||
|
||||
def __load_all_data(self):
|
||||
for idx in range(len(self)):
|
||||
self.__getitem__(idx)
|
||||
|
||||
def __dataset_parser(self, data_config, training):
|
||||
dataset_root = data_config['dataset_root']
|
||||
try:
|
||||
data_in_use = data_config['data_in_use'] # [n], true or false
|
||||
except:
|
||||
data_in_use = None
|
||||
|
||||
with open(data_config['dataset_partition'], "rb") as f:
|
||||
partition = json.load(f)
|
||||
train_set = partition["TRAIN_SET"]
|
||||
test_set = partition["TEST_SET"]
|
||||
label_list = os.listdir(dataset_root)
|
||||
train_set = [label for label in train_set if label in label_list]
|
||||
test_set = [label for label in test_set if label in label_list]
|
||||
miss_pids = [label for label in label_list if label not in (
|
||||
train_set + test_set)]
|
||||
msg_mgr = get_msg_mgr()
|
||||
|
||||
def log_pid_list(pid_list):
|
||||
if len(pid_list) >= 3:
|
||||
msg_mgr.log_info('[%s, %s, ..., %s]' %
|
||||
(pid_list[0], pid_list[1], pid_list[-1]))
|
||||
else:
|
||||
msg_mgr.log_info(pid_list)
|
||||
|
||||
if len(miss_pids) > 0:
|
||||
msg_mgr.log_debug('-------- Miss Pid List --------')
|
||||
msg_mgr.log_debug(miss_pids)
|
||||
if training:
|
||||
msg_mgr.log_info("-------- Train Pid List --------")
|
||||
log_pid_list(train_set)
|
||||
else:
|
||||
msg_mgr.log_info("-------- Test Pid List --------")
|
||||
log_pid_list(test_set)
|
||||
|
||||
def get_seqs_info_list(label_set):
|
||||
seqs_info_list = []
|
||||
for lab in label_set:
|
||||
for typ in sorted(os.listdir(osp.join(dataset_root, lab))):
|
||||
for vie in sorted(os.listdir(osp.join(dataset_root, lab, typ))):
|
||||
seq_info = [lab, typ, vie]
|
||||
seq_path = osp.join(dataset_root, *seq_info)
|
||||
seq_dirs = sorted(os.listdir(seq_path))
|
||||
if seq_dirs != []:
|
||||
seq_dirs = [osp.join(seq_path, dir)
|
||||
for dir in seq_dirs]
|
||||
if data_in_use is not None:
|
||||
seq_dirs = [dir for dir, use_bl in zip(
|
||||
seq_dirs, data_in_use) if use_bl]
|
||||
seqs_info_list.append([*seq_info, seq_dirs])
|
||||
else:
|
||||
msg_mgr.log_debug('Find no .pkl file in %s-%s-%s.'%(lab, typ, vie))
|
||||
return seqs_info_list
|
||||
|
||||
self.seqs_info = get_seqs_info_list(
|
||||
train_set) if training else get_seqs_info_list(test_set)
|
||||
@@ -0,0 +1,87 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.utils.data as tordata
|
||||
|
||||
|
||||
class TripletSampler(tordata.sampler.Sampler):
|
||||
def __init__(self, dataset, batch_size, batch_shuffle=False):
|
||||
self.dataset = dataset
|
||||
self.batch_size = batch_size
|
||||
self.batch_shuffle = batch_shuffle
|
||||
|
||||
self.world_size = dist.get_world_size()
|
||||
self.rank = dist.get_rank()
|
||||
|
||||
def __iter__(self):
|
||||
while (True):
|
||||
sample_indices = []
|
||||
pid_list = sync_random_sample_list(
|
||||
self.dataset.label_set, self.batch_size[0])
|
||||
|
||||
for pid in pid_list:
|
||||
indices = self.dataset.indices_dict[pid]
|
||||
indices = sync_random_sample_list(
|
||||
indices, k=self.batch_size[1])
|
||||
sample_indices += indices
|
||||
|
||||
if self.batch_shuffle:
|
||||
sample_indices = sync_random_sample_list(
|
||||
sample_indices, len(sample_indices))
|
||||
|
||||
_ = self.batch_size[0] * self.batch_size[1]
|
||||
total_size = int(math.ceil(_ / self.world_size)
|
||||
) * self.world_size
|
||||
sample_indices += sample_indices[:(_ - len(sample_indices))]
|
||||
|
||||
sample_indices = sample_indices[self.rank:total_size:self.world_size]
|
||||
yield sample_indices
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
|
||||
def sync_random_sample_list(obj_list, k):
|
||||
idx = torch.randperm(len(obj_list))[:k]
|
||||
if torch.cuda.is_available():
|
||||
idx = idx.cuda()
|
||||
torch.distributed.broadcast(idx, src=0)
|
||||
idx = idx.tolist()
|
||||
return [obj_list[i] for i in idx]
|
||||
|
||||
|
||||
class InferenceSampler(tordata.sampler.Sampler):
|
||||
def __init__(self, dataset, batch_size):
|
||||
self.dataset = dataset
|
||||
self.batch_size = batch_size
|
||||
|
||||
self.size = len(dataset)
|
||||
indices = list(range(self.size))
|
||||
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
|
||||
if batch_size % world_size != 0:
|
||||
raise AssertionError("World size({}) need be divisible by batch_size({})".format(
|
||||
world_size, batch_size))
|
||||
|
||||
if batch_size != 1:
|
||||
_ = math.ceil(self.size / batch_size) * \
|
||||
batch_size
|
||||
indices += indices[:(_ - self.size)]
|
||||
self.size = _
|
||||
|
||||
batch_size_per_rank = int(self.batch_size / world_size)
|
||||
indx_batch_per_rank = []
|
||||
|
||||
for i in range(int(self.size / batch_size_per_rank)):
|
||||
indx_batch_per_rank.append(
|
||||
indices[i*batch_size_per_rank:(i+1)*batch_size_per_rank])
|
||||
|
||||
self.idx_batch_this_rank = indx_batch_per_rank[rank::world_size]
|
||||
|
||||
def __iter__(self):
|
||||
yield from self.idx_batch_this_rank
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
@@ -0,0 +1,58 @@
|
||||
from data import transform as base_transform
|
||||
import numpy as np
|
||||
|
||||
from utils import is_list, is_dict, get_valid_args
|
||||
|
||||
|
||||
class BaseSilTransform():
|
||||
def __init__(self, disvor=255.0, img_shape=None):
|
||||
self.disvor = disvor
|
||||
self.img_shape = img_shape
|
||||
|
||||
def __call__(self, x):
|
||||
if self.img_shape is not None:
|
||||
s = x.shape[0]
|
||||
_ = [s] + [*self.img_shape]
|
||||
x = x.reshape(*_)
|
||||
return x / self.disvor
|
||||
|
||||
|
||||
class BaseSilCuttingTransform():
|
||||
def __init__(self, img_w=64, disvor=255.0, cutting=None):
|
||||
self.img_w = img_w
|
||||
self.disvor = disvor
|
||||
self.cutting = cutting
|
||||
|
||||
def __call__(self, x):
|
||||
if self.cutting is not None:
|
||||
cutting = self.cutting
|
||||
else:
|
||||
cutting = int(self.img_w // 64) * 10
|
||||
x = x[..., cutting:-cutting]
|
||||
return x / self.disvor
|
||||
|
||||
|
||||
class BaseRgbTransform():
|
||||
def __init__(self, mean=None, std=None):
|
||||
if mean is None:
|
||||
mean = [0.485*255, 0.456*255, 0.406*255]
|
||||
if std is None:
|
||||
std = [0.229*255, 0.224*255, 0.225*255]
|
||||
self.mean = np.array(mean).reshape((1, 3, 1, 1))
|
||||
self.std = np.array(std).reshape((1, 3, 1, 1))
|
||||
|
||||
def __call__(self, x):
|
||||
return (x - self.mean) / self.std
|
||||
|
||||
|
||||
def get_transform(trf_cfg=None):
|
||||
if is_dict(trf_cfg):
|
||||
transform = getattr(base_transform, trf_cfg['type'])
|
||||
valid_trf_arg = get_valid_args(transform, trf_cfg, ['type'])
|
||||
return transform(**valid_trf_arg)
|
||||
if trf_cfg is None:
|
||||
return lambda x: x
|
||||
if is_list(trf_cfg):
|
||||
transform = [get_transform(cfg) for cfg in trf_cfg]
|
||||
return transform
|
||||
raise "Error type for -Transform-Cfg-"
|
||||
+66
@@ -0,0 +1,66 @@
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from modeling import models
|
||||
from utils import config_loader, get_ddp_module, init_seeds, params_count, get_msg_mgr
|
||||
|
||||
parser = argparse.ArgumentParser(description='Main program for opengait.')
|
||||
parser.add_argument('--local_rank', type=int, default=0,
|
||||
help="passed by torch.distributed.launch module")
|
||||
parser.add_argument('--cfgs', type=str,
|
||||
default='config/default.yaml', help="path of config file")
|
||||
parser.add_argument('--phase', default='train',
|
||||
choices=['train', 'test'], help="choose train or test phase")
|
||||
parser.add_argument('--log_to_file', action='store_true',
|
||||
help="log to file, default path is: output/<dataset>/<model>/<save_name>/<logs>/<Datetime>.txt")
|
||||
parser.add_argument('--iter', default=0, help="iter to restore")
|
||||
opt = parser.parse_args()
|
||||
|
||||
|
||||
def initialization(cfgs, training):
|
||||
msg_mgr = get_msg_mgr()
|
||||
engine_cfg = cfgs['trainer_cfg'] if training else cfgs['evaluator_cfg']
|
||||
output_path = os.path.join('output/', cfgs['data_cfg']['dataset_name'],
|
||||
cfgs['model_cfg']['model'], engine_cfg['save_name'])
|
||||
msg_mgr.init_manager(output_path, opt.log_to_file, engine_cfg['log_iter'] if training else 0,
|
||||
engine_cfg['restore_hint'] if isinstance(engine_cfg['restore_hint'], (int)) else 0)
|
||||
|
||||
msg_mgr.log_info(engine_cfg)
|
||||
|
||||
seed = torch.distributed.get_rank()
|
||||
init_seeds(seed)
|
||||
|
||||
|
||||
def run_model(cfgs, training):
|
||||
msg_mgr = get_msg_mgr()
|
||||
model_cfg = cfgs['model_cfg']
|
||||
msg_mgr.log_info(model_cfg)
|
||||
Model = getattr(models, model_cfg['model'])
|
||||
model = Model(cfgs, training)
|
||||
if training and cfgs['trainer_cfg']['sync_BN']:
|
||||
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
model = get_ddp_module(model)
|
||||
msg_mgr.log_info(params_count(model))
|
||||
msg_mgr.log_info("Model Initialization Finished!")
|
||||
|
||||
if training:
|
||||
Model.run_train(model)
|
||||
else:
|
||||
Model.run_test(model)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
torch.distributed.init_process_group('nccl', init_method='env://')
|
||||
if torch.distributed.get_world_size() != torch.cuda.device_count():
|
||||
raise AssertionError("Expect number of availuable GPUs({}) equals to the world size({}).".format(
|
||||
torch.distributed.get_world_size(), torch.cuda.device_count()))
|
||||
cfgs = config_loader(opt.cfgs)
|
||||
if opt.iter != 0:
|
||||
cfgs['evaluator_cfg']['restore_hint'] = int(opt.iter)
|
||||
cfgs['trainer_cfg']['restore_hint'] = int(opt.iter)
|
||||
|
||||
training = (opt.phase == 'train')
|
||||
initialization(cfgs, training)
|
||||
run_model(cfgs, training)
|
||||
@@ -0,0 +1,17 @@
|
||||
from inspect import isclass
|
||||
from pkgutil import iter_modules
|
||||
from pathlib import Path
|
||||
from importlib import import_module
|
||||
|
||||
# iterate through the modules in the current package
|
||||
package_dir = Path(__file__).resolve().parent
|
||||
for (_, module_name, _) in iter_modules([package_dir]):
|
||||
|
||||
# import the module and iterate through its attributes
|
||||
module = import_module(f"{__name__}.{module_name}")
|
||||
for attribute_name in dir(module):
|
||||
attribute = getattr(module, attribute_name)
|
||||
|
||||
if isclass(attribute):
|
||||
# Add the class to this package's variables
|
||||
globals()[attribute_name] = attribute
|
||||
@@ -0,0 +1,42 @@
|
||||
import torch.nn as nn
|
||||
from ..modules import BasicConv2d, FocalConv2d
|
||||
|
||||
class Plain(nn.Module):
|
||||
|
||||
def __init__(self, layers_cfg, in_channels=1):
|
||||
super(Plain, self).__init__()
|
||||
self.layers_cfg = layers_cfg
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.feature = self.make_layers()
|
||||
|
||||
def forward(self, seqs):
|
||||
out = self.feature(seqs)
|
||||
return out
|
||||
|
||||
# torchvision/models/vgg.py
|
||||
def make_layers(self):
|
||||
def get_layer(cfg, in_c, kernel_size, stride, padding):
|
||||
cfg = cfg.split('-')
|
||||
typ = cfg[0]
|
||||
if typ not in ['BC', 'FC']:
|
||||
raise AssertionError
|
||||
out_c = int(cfg[1])
|
||||
|
||||
if typ == 'BC':
|
||||
return BasicConv2d(in_c, out_c, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
return FocalConv2d(in_c, out_c, kernel_size=kernel_size, stride=stride, padding=padding, halving=int(cfg[2]))
|
||||
|
||||
Layers = [get_layer(self.layers_cfg[0], self.in_channels, 5, 1, 2), nn.LeakyReLU(inplace=True)]
|
||||
in_c = int(self.layers_cfg[0].split('-')[1])
|
||||
for cfg in self.layers_cfg[1:]:
|
||||
if cfg == 'M':
|
||||
Layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
|
||||
else:
|
||||
conv2d = get_layer(cfg, in_c, 3, 1, 1)
|
||||
Layers += [conv2d, nn.LeakyReLU(inplace=True)]
|
||||
in_c = int(cfg.split('-')[1])
|
||||
return nn.Sequential(*Layers)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,417 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import os.path as osp
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.utils.data as tordata
|
||||
|
||||
from tqdm import tqdm
|
||||
from torch.cuda.amp import autocast
|
||||
from torch.cuda.amp import GradScaler
|
||||
from abc import ABCMeta
|
||||
from abc import abstractmethod
|
||||
|
||||
from . import backbones
|
||||
from .loss_aggregator import LossAggregator
|
||||
from modeling.modules import fix_BN
|
||||
from data.transform import get_transform
|
||||
from data.collate_fn import CollateFn
|
||||
from data.dataset import DataSet
|
||||
import data.sampler as Samplers
|
||||
from utils import Odict, mkdir, ddp_all_gather
|
||||
from utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from
|
||||
from utils import evaluation as eval_functions
|
||||
from utils import NoOp
|
||||
from utils import get_msg_mgr
|
||||
|
||||
__all__ = ['BasicModel']
|
||||
|
||||
|
||||
class MetaModel(metaclass=ABCMeta):
|
||||
|
||||
@abstractmethod
|
||||
def get_loader(self, data_cfg):
|
||||
'''
|
||||
Build your data Loader here.
|
||||
Inputs: data_cfg, dict
|
||||
Return: Loader
|
||||
'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def build_network(self, model_cfg):
|
||||
'''
|
||||
Build your Model here.
|
||||
Inputs: model_cfg, dict
|
||||
Return: Network, nn.Module(s)
|
||||
'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def init_parameters(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_optimizer(self, optimizer_cfg):
|
||||
'''
|
||||
Build your Optimizer here.
|
||||
Inputs: optimizer_cfg, dict
|
||||
Return: Optimizer, a optimizer object
|
||||
'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_scheduler(self, scheduler_cfg):
|
||||
'''
|
||||
Build your Scheduler.
|
||||
Inputs: scheduler_cfg, dict
|
||||
Optimizer, your optimizer
|
||||
Return: Scheduler, a scheduler object
|
||||
'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def save_ckpt(self, iteration):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def resume_ckpt(self, restore_hint):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def inputs_pretreament(self, inputs):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def train_step(self, loss_num):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def inference(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def run_train(model):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def run_test(model):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BaseModel(MetaModel, nn.Module):
|
||||
|
||||
def __init__(self, cfgs, training):
|
||||
super(BaseModel, self).__init__()
|
||||
self.msg_mgr = get_msg_mgr()
|
||||
self.cfgs = cfgs
|
||||
self.iteration = 0
|
||||
self.engine_cfg = cfgs['trainer_cfg'] if training else cfgs['evaluator_cfg']
|
||||
if self.engine_cfg is None:
|
||||
raise Exception("Initialize a model without -Engine-Cfgs-")
|
||||
|
||||
if training and self.engine_cfg['enable_float16']:
|
||||
self.Scaler = GradScaler()
|
||||
self.save_path = osp.join('output/', cfgs['data_cfg']['dataset_name'],
|
||||
cfgs['model_cfg']['model'], self.engine_cfg['save_name'])
|
||||
|
||||
self.build_network(cfgs['model_cfg'])
|
||||
self.init_parameters()
|
||||
|
||||
self.msg_mgr.log_info(cfgs['data_cfg'])
|
||||
if training:
|
||||
self.train_loader = self.get_loader(
|
||||
cfgs['data_cfg'], train=True)
|
||||
if not training or self.engine_cfg['with_test']:
|
||||
self.test_loader = self.get_loader(
|
||||
cfgs['data_cfg'], train=False)
|
||||
|
||||
self.device = torch.distributed.get_rank()
|
||||
torch.cuda.set_device(self.device)
|
||||
self.to(device=torch.device(
|
||||
"cuda", self.device))
|
||||
|
||||
if training:
|
||||
if cfgs['trainer_cfg']['fix_BN']:
|
||||
fix_BN(self)
|
||||
self.loss_aggregator = LossAggregator(cfgs['loss_cfg'])
|
||||
self.optimizer = self.get_optimizer(self.cfgs['optimizer_cfg'])
|
||||
self.scheduler = self.get_scheduler(cfgs['scheduler_cfg'])
|
||||
self.train(training)
|
||||
restore_hint = self.engine_cfg['restore_hint']
|
||||
if restore_hint != 0:
|
||||
self.resume_ckpt(restore_hint)
|
||||
|
||||
def get_backbone(self, model_cfg):
|
||||
def _get_backbone(backbone_cfg):
|
||||
if is_dict(backbone_cfg):
|
||||
Backbone = get_attr_from([backbones], backbone_cfg['type'])
|
||||
valid_args = get_valid_args(Backbone, backbone_cfg, ['type'])
|
||||
return Backbone(**valid_args)
|
||||
if is_list(backbone_cfg):
|
||||
Backbone = nn.ModuleList([_get_backbone(cfg)
|
||||
for cfg in backbone_cfg])
|
||||
return Backbone
|
||||
raise ValueError(
|
||||
"Error type for -Backbone-Cfg-, supported: (A list of) dict.")
|
||||
|
||||
if 'backbone_cfg' in model_cfg.keys():
|
||||
backbone_cfg = model_cfg['backbone_cfg']
|
||||
Backbone = _get_backbone(backbone_cfg)
|
||||
else:
|
||||
Backbone = None
|
||||
return Backbone
|
||||
|
||||
def build_network(self, model_cfg):
|
||||
self.Backbone = self.get_backbone(model_cfg)
|
||||
|
||||
def init_parameters(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, (nn.Conv3d, nn.Conv2d, nn.Conv1d)):
|
||||
nn.init.xavier_uniform_(m.weight.data)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias.data, 0.0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.xavier_uniform_(m.weight.data)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias.data, 0.0)
|
||||
elif isinstance(m, (nn.BatchNorm3d, nn.BatchNorm2d, nn.BatchNorm1d)):
|
||||
if m.affine:
|
||||
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
||||
nn.init.constant_(m.bias.data, 0.0)
|
||||
|
||||
def get_loader(self, data_cfg, train=True):
|
||||
sampler_cfg = self.cfgs['trainer_cfg']['sampler'] if train else self.cfgs['evaluator_cfg']['sampler']
|
||||
dataset = DataSet(data_cfg, train)
|
||||
|
||||
Sampler = get_attr_from([Samplers], sampler_cfg['type'])
|
||||
vaild_args = get_valid_args(Sampler, sampler_cfg, free_keys=[
|
||||
'sample_type', 'type'])
|
||||
sampler = Sampler(dataset, **vaild_args)
|
||||
|
||||
loader = tordata.DataLoader(
|
||||
dataset=dataset,
|
||||
batch_sampler=sampler,
|
||||
collate_fn=CollateFn(dataset.label_set, sampler_cfg),
|
||||
num_workers=data_cfg['num_workers'])
|
||||
return loader
|
||||
|
||||
def get_optimizer(self, optimizer_cfg):
|
||||
self.msg_mgr.log_info(optimizer_cfg)
|
||||
optimizer = get_attr_from([optim], optimizer_cfg['solver'])
|
||||
valid_arg = get_valid_args(optimizer, optimizer_cfg, ['solver'])
|
||||
optimizer = optimizer(
|
||||
filter(lambda p: p.requires_grad, self.parameters()), **valid_arg)
|
||||
return optimizer
|
||||
|
||||
def get_scheduler(self, scheduler_cfg):
|
||||
self.msg_mgr.log_info(scheduler_cfg)
|
||||
Scheduler = get_attr_from(
|
||||
[optim.lr_scheduler], scheduler_cfg['scheduler'])
|
||||
valid_arg = get_valid_args(Scheduler, scheduler_cfg, ['scheduler'])
|
||||
scheduler = Scheduler(self.optimizer, **valid_arg)
|
||||
return scheduler
|
||||
|
||||
def save_ckpt(self, iteration):
|
||||
if torch.distributed.get_rank() == 0:
|
||||
mkdir(osp.join(self.save_path, "checkpoints/"))
|
||||
save_name = self.engine_cfg['save_name']
|
||||
checkpoint = {
|
||||
'model': self.state_dict(),
|
||||
'optimizer': self.optimizer.state_dict(),
|
||||
'scheduler': self.scheduler.state_dict(),
|
||||
'iteration': iteration}
|
||||
torch.save(checkpoint,
|
||||
osp.join(self.save_path, 'checkpoints/{}-{:0>5}.pt'.format(save_name, iteration)))
|
||||
|
||||
def _load_ckpt(self, save_name):
|
||||
load_ckpt_strict = self.engine_cfg['restore_ckpt_strict']
|
||||
|
||||
checkpoint = torch.load(save_name, map_location=torch.device(
|
||||
"cuda", self.device))
|
||||
model_state_dict = checkpoint['model']
|
||||
|
||||
if not load_ckpt_strict:
|
||||
self.msg_mgr.log_info("-------- Restored Params List --------")
|
||||
self.msg_mgr.log_info(sorted(set(model_state_dict.keys()).intersection(
|
||||
set(self.state_dict().keys()))))
|
||||
|
||||
self.load_state_dict(model_state_dict, strict=load_ckpt_strict)
|
||||
if self.training:
|
||||
if not self.engine_cfg["optimizer_reset"] and 'optimizer' in checkpoint:
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
else:
|
||||
self.msg_mgr.log_warning(
|
||||
"Restore NO Optimizer from %s !!!" % save_name)
|
||||
if not self.engine_cfg["scheduler_reset"] and 'scheduler' in checkpoint:
|
||||
self.scheduler.load_state_dict(
|
||||
checkpoint['scheduler'])
|
||||
else:
|
||||
self.msg_mgr.log_warning(
|
||||
"Restore NO Scheduler from %s !!!" % save_name)
|
||||
self.msg_mgr.log_info("Restore Parameters from %s !!!" % save_name)
|
||||
del checkpoint
|
||||
|
||||
def resume_ckpt(self, restore_hint):
|
||||
if isinstance(restore_hint, int):
|
||||
save_name = self.engine_cfg['save_name']
|
||||
save_name = osp.join(
|
||||
self.save_path, 'checkpoints/{}-{:0>5}.pt'.format(save_name, restore_hint))
|
||||
self.iteration = restore_hint
|
||||
elif isinstance(restore_hint, str):
|
||||
save_name = restore_hint
|
||||
self.iteration = 0
|
||||
else:
|
||||
raise ValueError(
|
||||
"Error type for -Restore_Hint-, supported: int or string.")
|
||||
self._load_ckpt(save_name)
|
||||
|
||||
def inputs_pretreament(self, inputs):
|
||||
seqs_batch, labs_batch, typs_batch, vies_batch, seqL_batch = inputs
|
||||
trf_cfgs = self.engine_cfg['transform']
|
||||
seq_trfs = get_transform(trf_cfgs)
|
||||
|
||||
requires_grad = bool(self.training)
|
||||
seqs = [np2var(np.asarray([trf(fra) for fra in seq]), requires_grad=requires_grad).float()
|
||||
for trf, seq in zip(seq_trfs, seqs_batch)]
|
||||
|
||||
typs = typs_batch
|
||||
vies = vies_batch
|
||||
|
||||
labs = list2var(labs_batch).long()
|
||||
|
||||
if seqL_batch is not None:
|
||||
seqL_batch = np2var(seqL_batch).int()
|
||||
seqL = seqL_batch
|
||||
|
||||
if seqL is not None:
|
||||
seqL_sum = int(seqL.sum().data.cpu().numpy())
|
||||
ipts = [_[:, :seqL_sum] for _ in seqs]
|
||||
else:
|
||||
ipts = seqs
|
||||
del seqs
|
||||
return ipts, labs, typs, vies, seqL
|
||||
|
||||
def train_step(self, loss_sum):
|
||||
'''
|
||||
Conduct loss_sum.backward(), self.optimizer.step() and self.scheduler.step().
|
||||
'''
|
||||
|
||||
skip_lr_sched = False
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
if loss_sum <= 1e-9:
|
||||
self.msg_mgr.log_warning("Find the loss sum less than 1e-9 but the training process will continue!)")
|
||||
|
||||
if self.engine_cfg['enable_float16']:
|
||||
self.Scaler.scale(loss_sum).backward()
|
||||
self.Scaler.step(self.optimizer)
|
||||
scale = self.Scaler.get_scale()
|
||||
self.Scaler.update()
|
||||
skip_lr_sched = (scale != self.Scaler.get_scale())
|
||||
# Warning caused by optimizer skip when NaN
|
||||
# https://discuss.pytorch.org/t/optimizer-step-before-lr-scheduler-step-error-using-gradscaler/92930/5
|
||||
|
||||
#for debug
|
||||
# for name, param in self.named_parameters():
|
||||
# if param.grad is None:
|
||||
# print(name)
|
||||
else:
|
||||
loss_sum.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
if not skip_lr_sched:
|
||||
self.iteration += 1
|
||||
self.scheduler.step()
|
||||
|
||||
def inference(self, rank):
|
||||
total_size = len(self.test_loader)
|
||||
if rank == 0:
|
||||
pbar = tqdm(total=total_size, desc='Transforming')
|
||||
else:
|
||||
pbar = NoOp()
|
||||
batch_size = self.test_loader.batch_sampler.batch_size
|
||||
rest_size = total_size
|
||||
info_dict = Odict()
|
||||
for inputs in self.test_loader:
|
||||
ipts = self.inputs_pretreament(inputs)
|
||||
with autocast(enabled=self.engine_cfg['enable_float16']):
|
||||
retval = self.forward(ipts)
|
||||
inference_feat = retval['inference_feat']
|
||||
for k, v in inference_feat.items():
|
||||
inference_feat[k] = ddp_all_gather(v, requires_grad=False)
|
||||
del retval
|
||||
for k, v in inference_feat.items():
|
||||
inference_feat[k] = ts2np(v)
|
||||
info_dict.append(inference_feat)
|
||||
rest_size -= batch_size
|
||||
if rest_size >= 0:
|
||||
update_size = batch_size
|
||||
else:
|
||||
update_size = total_size % batch_size
|
||||
pbar.update(update_size)
|
||||
pbar.close()
|
||||
for k, v in info_dict.items():
|
||||
v = np.concatenate(v)[:total_size]
|
||||
info_dict[k] = v
|
||||
return info_dict
|
||||
|
||||
@ staticmethod
|
||||
def run_train(model):
|
||||
'''
|
||||
Accept the instace object(model) here, and then run the train loop handler.
|
||||
'''
|
||||
for inputs in model.train_loader:
|
||||
ipts = model.inputs_pretreament(inputs)
|
||||
with autocast(enabled=model.engine_cfg['enable_float16']):
|
||||
retval = model(ipts)
|
||||
training_feat, visual_summary = retval['training_feat'], retval['visual_summary']
|
||||
del retval
|
||||
loss_sum, loss_info = model.loss_aggregator(training_feat)
|
||||
model.train_step(loss_sum)
|
||||
|
||||
visual_summary.update(loss_info)
|
||||
visual_summary['scalar/learning_rate'] = model.optimizer.param_groups[0]['lr']
|
||||
|
||||
model.msg_mgr.train_step(loss_info, visual_summary)
|
||||
if model.iteration % model.engine_cfg['save_iter'] == 0:
|
||||
# save the checkpoint
|
||||
model.save_ckpt(model.iteration)
|
||||
|
||||
# run test if with_test = true
|
||||
if model.engine_cfg['with_test']:
|
||||
model.msg_mgr.log_info("Running test...")
|
||||
model.eval()
|
||||
result_dict = BaseModel.run_test(model)
|
||||
model.train()
|
||||
model.msg_mgr.write_to_tensorboard(result_dict)
|
||||
model.msg_mgr.reset_time()
|
||||
if model.iteration >= model.engine_cfg['total_iter']:
|
||||
break
|
||||
|
||||
@ staticmethod
|
||||
def run_test(model):
|
||||
rank = torch.distributed.get_rank()
|
||||
with torch.no_grad():
|
||||
info_dict = model.inference(rank)
|
||||
if rank == 0:
|
||||
loader = model.test_loader
|
||||
label_list = loader.dataset.label_list
|
||||
types_list = loader.dataset.types_list
|
||||
views_list = loader.dataset.views_list
|
||||
|
||||
info_dict.update({
|
||||
'labels': label_list, 'types': types_list, 'views': views_list})
|
||||
|
||||
if 'eval_func' in model.engine_cfg.keys():
|
||||
eval_func = model.engine_cfg['eval_func']
|
||||
else:
|
||||
eval_func = 'identification'
|
||||
eval_func = getattr(eval_functions, eval_func)
|
||||
valid_args = get_valid_args(
|
||||
eval_func, model.engine_cfg, ['metric'])
|
||||
try:
|
||||
dataset_name = model.cfgs['data_cfg']['test_dataset_name']
|
||||
except:
|
||||
dataset_name = model.cfgs['data_cfg']['dataset_name']
|
||||
return eval_func(info_dict, dataset_name, **valid_args)
|
||||
@@ -0,0 +1,50 @@
|
||||
import torch
|
||||
from . import losses
|
||||
from utils import is_dict, get_attr_from, get_valid_args, is_tensor, get_ddp_module
|
||||
from utils import Odict
|
||||
from utils import get_msg_mgr
|
||||
|
||||
|
||||
class LossAggregator():
|
||||
def __init__(self, loss_cfg) -> None:
|
||||
self.losses = {loss_cfg['log_prefix']: self._build_loss_(loss_cfg)} if is_dict(loss_cfg) \
|
||||
else {cfg['log_prefix']: self._build_loss_(cfg) for cfg in loss_cfg}
|
||||
|
||||
def _build_loss_(self, loss_cfg):
|
||||
Loss = get_attr_from([losses], loss_cfg['type'])
|
||||
valid_loss_arg = get_valid_args(
|
||||
Loss, loss_cfg, ['type', 'pair_based_loss'])
|
||||
loss = get_ddp_module(Loss(**valid_loss_arg))
|
||||
return loss
|
||||
|
||||
def __call__(self, training_feats):
|
||||
loss_sum = .0
|
||||
loss_info = Odict()
|
||||
|
||||
for k, v in training_feats.items():
|
||||
if k in self.losses:
|
||||
loss_func = self.losses[k]
|
||||
loss, info = loss_func(**v)
|
||||
for name, value in info.items():
|
||||
loss_info['scalar/%s/%s' % (k, name)] = value
|
||||
loss = loss.mean() * loss_func.loss_term_weights
|
||||
if loss_func.pair_based_loss:
|
||||
loss = loss * torch.distributed.get_world_size()
|
||||
loss_sum += loss
|
||||
|
||||
else:
|
||||
if isinstance(v, dict):
|
||||
raise ValueError(
|
||||
"The key %s in -Trainng-Feat- should be stated as the log_prefix of a certain loss defined in your loss_cfg."
|
||||
)
|
||||
elif is_tensor(v):
|
||||
_ = v.mean()
|
||||
loss_info['scalar/%s' % k] = _
|
||||
loss_sum += _
|
||||
get_msg_mgr().log_debug(
|
||||
"Please check whether %s needed in training." % k)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Error type for -Trainng-Feat-, supported: A feature dict or loss tensor.")
|
||||
|
||||
return loss_sum, loss_info
|
||||
@@ -0,0 +1,17 @@
|
||||
from inspect import isclass
|
||||
from pkgutil import iter_modules
|
||||
from pathlib import Path
|
||||
from importlib import import_module
|
||||
|
||||
# iterate through the modules in the current package
|
||||
package_dir = Path(__file__).resolve().parent
|
||||
for (_, module_name, _) in iter_modules([package_dir]):
|
||||
|
||||
# import the module and iterate through its attributes
|
||||
module = import_module(f"{__name__}.{module_name}")
|
||||
for attribute_name in dir(module):
|
||||
attribute = getattr(module, attribute_name)
|
||||
|
||||
if isclass(attribute):
|
||||
# Add the class to this package's variables
|
||||
globals()[attribute_name] = attribute
|
||||
@@ -0,0 +1,13 @@
|
||||
import torch.nn as nn
|
||||
from utils import Odict
|
||||
|
||||
class BasicLoss(nn.Module):
|
||||
def __init__(self, loss_term_weights=1.0):
|
||||
super(BasicLoss, self).__init__()
|
||||
|
||||
self.loss_term_weights = loss_term_weights
|
||||
self.pair_based_loss = True
|
||||
self.info = Odict()
|
||||
|
||||
def forward(self, logits, labels):
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,51 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .base import BasicLoss
|
||||
|
||||
|
||||
class CrossEntropyLoss(BasicLoss):
|
||||
def __init__(self, scale=2**4, label_smooth=True, eps=0.1, loss_term_weights=1.0, log_accuracy=False):
|
||||
super(CrossEntropyLoss, self).__init__()
|
||||
self.scale = scale
|
||||
self.label_smooth = label_smooth
|
||||
self.eps = eps
|
||||
self.log_accuracy = log_accuracy
|
||||
|
||||
self.loss_term_weights = loss_term_weights
|
||||
self.pair_based_loss = False
|
||||
|
||||
def forward(self, logits, labels):
|
||||
"""
|
||||
logits: [n, p, c]
|
||||
labels: [n]
|
||||
"""
|
||||
logits = logits.permute(1, 0, 2).contiguous() # [n, p, c] -> [p, n, c]
|
||||
p, _, c = logits.size()
|
||||
log_preds = F.log_softmax(logits * self.scale, dim=-1) # [p, n, c]
|
||||
one_hot_labels = self.label2one_hot(
|
||||
labels, c).unsqueeze(0).repeat(p, 1, 1) # [p, n, c]
|
||||
loss = self.compute_loss(log_preds, one_hot_labels)
|
||||
self.info.update({'loss': loss})
|
||||
if self.log_accuracy:
|
||||
pred = logits.argmax(dim=-1) # [p, n]
|
||||
accu = (pred == labels.unsqueeze(0)).float().mean()
|
||||
self.info.update({'accuracy': accu})
|
||||
return loss, self.info
|
||||
|
||||
def compute_loss(self, predis, labels):
|
||||
softmax_loss = -(labels * predis).sum(-1) # [p, n]
|
||||
losses = softmax_loss.mean(-1)
|
||||
|
||||
if self.label_smooth:
|
||||
smooth_loss = - predis.mean(dim=-1) # [p, n]
|
||||
smooth_loss = smooth_loss.mean() # [p]
|
||||
smooth_loss = smooth_loss * self.eps
|
||||
losses = smooth_loss + losses * (1. - self.eps)
|
||||
return losses
|
||||
|
||||
def label2one_hot(self, label, class_num):
|
||||
label = label.unsqueeze(-1)
|
||||
batch_size = label.size(0)
|
||||
device = label.device
|
||||
return torch.zeros(batch_size, class_num).to(device).scatter(1, label, 1)
|
||||
@@ -0,0 +1,76 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .base import BasicLoss
|
||||
from utils import ddp_all_gather
|
||||
|
||||
|
||||
class TripletLoss(BasicLoss):
|
||||
def __init__(self, margin, loss_term_weights=1.0):
|
||||
super(TripletLoss, self).__init__()
|
||||
self.margin = margin
|
||||
|
||||
self.loss_term_weights = loss_term_weights
|
||||
self.pair_based_loss = True
|
||||
|
||||
def forward(self, embeddings, labels):
|
||||
# embeddings: [n, p, c], label: [n]
|
||||
embeddings = ddp_all_gather(embeddings)
|
||||
labels = ddp_all_gather(labels)
|
||||
embeddings = embeddings.permute(
|
||||
1, 0, 2).contiguous() # [n, p, c] -> [p, n, c]
|
||||
embeddings = embeddings.float()
|
||||
|
||||
ref_embed, ref_label = embeddings, labels
|
||||
dist = self.ComputeDistance(embeddings, ref_embed) # [p, n1, n2]
|
||||
mean_dist = dist.mean(1).mean(1)
|
||||
ap_dist, an_dist = self.Convert2Triplets(labels, ref_label, dist)
|
||||
dist_diff = ap_dist - an_dist
|
||||
loss = F.relu(dist_diff + self.margin)
|
||||
|
||||
hard_loss = torch.max(loss, -1)[0]
|
||||
loss_avg, loss_num = self.AvgNonZeroReducer(loss)
|
||||
|
||||
self.info.update({
|
||||
'loss': loss_avg,
|
||||
'hard_loss': hard_loss,
|
||||
'loss_num': loss_num,
|
||||
'mean_dist': mean_dist})
|
||||
|
||||
return loss_avg, self.info
|
||||
|
||||
def AvgNonZeroReducer(self, loss):
|
||||
eps = 1.0e-9
|
||||
loss_sum = loss.sum(-1)
|
||||
loss_num = (loss != 0).sum(-1).float()
|
||||
|
||||
loss_avg = loss_sum / (loss_num + eps)
|
||||
loss_avg[loss_num == 0] = 0
|
||||
return loss_avg, loss_num
|
||||
|
||||
def ComputeDistance(self, x, y):
|
||||
"""
|
||||
x: [p, n_x, c]
|
||||
y: [p, n_y, c]
|
||||
"""
|
||||
x2 = torch.sum(x ** 2, -1).unsqueeze(2) # [p, n_x, 1]
|
||||
y2 = torch.sum(y ** 2, -1).unsqueeze(1) # [p, 1, n_y]
|
||||
inner = x.matmul(y.transpose(-1, -2)) # [p, n_x, n_y]
|
||||
dist = x2 + y2 - 2 * inner
|
||||
dist = torch.sqrt(F.relu(dist)) # [p, n_x, n_y]
|
||||
return dist
|
||||
|
||||
def Convert2Triplets(self, row_labels, clo_label, dist):
|
||||
"""
|
||||
row_labels: tensor with size [n_r]
|
||||
clo_label : tensor with size [n_c]
|
||||
"""
|
||||
matches = (row_labels.unsqueeze(1) ==
|
||||
clo_label.unsqueeze(0)).byte() # [n_r, n_c]
|
||||
diffenc = matches ^ 1 # [n_r, n_c]
|
||||
mask = matches.unsqueeze(2) * diffenc.unsqueeze(1)
|
||||
a_idx, p_idx, n_idx = torch.where(mask)
|
||||
|
||||
ap_dist = dist[:, a_idx, p_idx]
|
||||
an_dist = dist[:, a_idx, n_idx]
|
||||
return ap_dist, an_dist
|
||||
@@ -0,0 +1,17 @@
|
||||
from inspect import isclass
|
||||
from pkgutil import iter_modules
|
||||
from pathlib import Path
|
||||
from importlib import import_module
|
||||
|
||||
# iterate through the modules in the current package
|
||||
package_dir = Path(__file__).resolve().parent
|
||||
for (_, module_name, _) in iter_modules([package_dir]):
|
||||
|
||||
# import the module and iterate through its attributes
|
||||
module = import_module(f"{__name__}.{module_name}")
|
||||
for attribute_name in dir(module):
|
||||
attribute = getattr(module, attribute_name)
|
||||
|
||||
if isclass(attribute):
|
||||
# Add the class to this package's variables
|
||||
globals()[attribute_name] = attribute
|
||||
@@ -0,0 +1,56 @@
|
||||
import torch
|
||||
|
||||
from ..base_model import BaseModel
|
||||
from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper, SeparateFCs, SeparateBNNecks
|
||||
|
||||
|
||||
class Baseline(BaseModel):
|
||||
def __init__(self, cfgs, is_training):
|
||||
super().__init__(cfgs, is_training)
|
||||
|
||||
def build_network(self, model_cfg):
|
||||
self.Backbone = self.get_backbone(model_cfg)
|
||||
self.Backbone = SetBlockWrapper(self.Backbone)
|
||||
self.FCs = SeparateFCs(**model_cfg['SeparateFCs'])
|
||||
self.BNNecks = SeparateBNNecks(**model_cfg['SeparateBNNecks'])
|
||||
self.TP = PackSequenceWrapper(torch.max)
|
||||
self.HPP = HorizontalPoolingPyramid(bin_num=model_cfg['bin_num'])
|
||||
|
||||
def forward(self, inputs):
|
||||
ipts, labs, _, _, seqL = inputs
|
||||
|
||||
sils = ipts[0]
|
||||
if len(sils.size()) == 4:
|
||||
sils = sils.unsqueeze(2)
|
||||
|
||||
del ipts
|
||||
outs = self.Backbone(sils) # [n, s, c, h, w]
|
||||
|
||||
# Temporal Pooling, TP
|
||||
outs = self.TP(outs, seqL, dim=1)[0] # [n, c, h, w]
|
||||
# Horizontal Pooling Matching, HPM
|
||||
feat = self.HPP(outs) # [n, c, p]
|
||||
feat = feat.permute(2, 0, 1).contiguous() # [p, n, c]
|
||||
|
||||
embed_1 = self.FCs(feat) # [p, n, c]
|
||||
embed_2, logits = self.BNNecks(embed_1) # [p, n, c]
|
||||
|
||||
embed_1 = embed_1.permute(1, 0, 2).contiguous() # [n, p, c]
|
||||
embed_2 = embed_2.permute(1, 0, 2).contiguous() # [n, p, c]
|
||||
logits = logits.permute(1, 0, 2).contiguous() # [n, p, c]
|
||||
embed = embed_1
|
||||
|
||||
n, s, _, h, w = sils.size()
|
||||
retval = {
|
||||
'training_feat': {
|
||||
'triplet': {'embeddings': embed_1, 'labels': labs},
|
||||
'softmax': {'logits': logits, 'labels': labs}
|
||||
},
|
||||
'visual_summary': {
|
||||
'image/sils': sils.view(n*s, 1, h, w)
|
||||
},
|
||||
'inference_feat': {
|
||||
'embeddings': embed
|
||||
}
|
||||
}
|
||||
return retval
|
||||
@@ -0,0 +1,151 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..base_model import BaseModel
|
||||
from ..modules import SeparateFCs, BasicConv3d, PackSequenceWrapper
|
||||
|
||||
|
||||
class GLConv(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, halving, fm_sign=False, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False, **kwargs):
|
||||
super(GLConv, self).__init__()
|
||||
self.halving = halving
|
||||
self.fm_sign = fm_sign
|
||||
self.global_conv3d = BasicConv3d(
|
||||
in_channels, out_channels, kernel_size, stride, padding, bias, **kwargs)
|
||||
self.local_conv3d = BasicConv3d(
|
||||
in_channels, out_channels, kernel_size, stride, padding, bias, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
'''
|
||||
x: [n, c, s, h, w]
|
||||
'''
|
||||
gob_feat = self.global_conv3d(x)
|
||||
if self.halving == 0:
|
||||
lcl_feat = self.local_conv3d(x)
|
||||
else:
|
||||
h = x.size(3)
|
||||
split_size = int(h // 2**self.halving)
|
||||
lcl_feat = x.split(split_size, 3)
|
||||
lcl_feat = torch.cat([self.local_conv3d(_) for _ in lcl_feat], 3)
|
||||
|
||||
if not self.fm_sign:
|
||||
feat = F.leaky_relu(gob_feat) + F.leaky_relu(lcl_feat)
|
||||
else:
|
||||
feat = F.leaky_relu(torch.cat([gob_feat, lcl_feat], dim=3))
|
||||
return feat
|
||||
|
||||
|
||||
class GeMHPP(nn.Module):
|
||||
def __init__(self, bin_num=[64], p=6.5, eps=1.0e-6):
|
||||
super(GeMHPP, self).__init__()
|
||||
self.bin_num = bin_num
|
||||
self.p = nn.Parameter(
|
||||
torch.ones(1)*p)
|
||||
self.eps = eps
|
||||
|
||||
def gem(self, ipts):
|
||||
return F.avg_pool2d(ipts.clamp(min=self.eps).pow(self.p), (1, ipts.size(-1))).pow(1. / self.p)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x : [n, c, h, w]
|
||||
ret: [n, c, p]
|
||||
"""
|
||||
n, c = x.size()[:2]
|
||||
features = []
|
||||
for b in self.bin_num:
|
||||
z = x.view(n, c, b, -1)
|
||||
z = self.gem(z).squeeze(-1)
|
||||
features.append(z)
|
||||
return torch.cat(features, -1)
|
||||
|
||||
|
||||
class GaitGL(BaseModel):
|
||||
"""
|
||||
GaitGL: Gait Recognition via Effective Global-Local Feature Representation and Local Temporal Aggregation
|
||||
Arxiv : https://arxiv.org/pdf/2011.01461.pdf
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kargs):
|
||||
super(GaitGL, self).__init__(*args, **kargs)
|
||||
|
||||
def build_network(self, model_cfg):
|
||||
in_c = model_cfg['channels']
|
||||
class_num = model_cfg['class_num']
|
||||
|
||||
# For CASIA-B
|
||||
self.conv3d = nn.Sequential(
|
||||
BasicConv3d(1, in_c[0], kernel_size=(3, 3, 3),
|
||||
stride=(1, 1, 1), padding=(1, 1, 1)),
|
||||
nn.LeakyReLU(inplace=True)
|
||||
)
|
||||
self.LTA = nn.Sequential(
|
||||
BasicConv3d(in_c[0], in_c[0], kernel_size=(
|
||||
3, 1, 1), stride=(3, 1, 1), padding=(0, 0, 0)),
|
||||
nn.LeakyReLU(inplace=True)
|
||||
)
|
||||
|
||||
self.GLConvA0 = GLConv(in_c[0], in_c[1], halving=3, fm_sign=False, kernel_size=(
|
||||
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
|
||||
self.MaxPool0 = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))
|
||||
|
||||
self.GLConvA1 = GLConv(in_c[1], in_c[2], halving=3, fm_sign=False, kernel_size=(
|
||||
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
|
||||
self.GLConvB2 = GLConv(in_c[2], in_c[2], halving=3, fm_sign=True, kernel_size=(
|
||||
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
|
||||
|
||||
self.Head0 = SeparateFCs(64, in_c[2], in_c[2])
|
||||
self.Bn = nn.BatchNorm1d(in_c[2])
|
||||
self.Head1 = SeparateFCs(64, in_c[2], class_num)
|
||||
|
||||
self.TP = PackSequenceWrapper(torch.max)
|
||||
self.HPP = GeMHPP()
|
||||
|
||||
def forward(self, inputs):
|
||||
ipts, labs, _, _, seqL = inputs
|
||||
seqL = None if not self.training else seqL
|
||||
|
||||
sils = ipts[0].unsqueeze(1)
|
||||
del ipts
|
||||
n, _, s, h, w = sils.size()
|
||||
if s < 3:
|
||||
repeat = 3 if s == 1 else 2
|
||||
sils = sils.repeat(1, 1, repeat, 1, 1)
|
||||
|
||||
outs = self.conv3d(sils)
|
||||
outs = self.LTA(outs)
|
||||
|
||||
outs = self.GLConvA0(outs)
|
||||
outs = self.MaxPool0(outs)
|
||||
|
||||
outs = self.GLConvA1(outs)
|
||||
outs = self.GLConvB2(outs) # [n, c, s, h, w]
|
||||
|
||||
outs = self.TP(outs, dim=2, seq_dim=2, seqL=seqL)[0] # [n, c, h, w]
|
||||
outs = self.HPP(outs) # [n, c, p]
|
||||
outs = outs.permute(2, 0, 1).contiguous() # [p, n, c]
|
||||
|
||||
gait = self.Head0(outs) # [p, n, c]
|
||||
gait = gait.permute(1, 2, 0).contiguous() # [n, c, p]
|
||||
bnft = self.Bn(gait) # [n, c, p]
|
||||
logi = self.Head1(bnft.permute(2, 0, 1).contiguous()) # [p, n, c]
|
||||
|
||||
gait = gait.permute(0, 2, 1).contiguous() # [n, p, c]
|
||||
bnft = bnft.permute(0, 2, 1).contiguous() # [n, p, c]
|
||||
logi = logi.permute(1, 0, 2).contiguous() # [n, p, c]
|
||||
|
||||
n, _, s, h, w = sils.size()
|
||||
retval = {
|
||||
'training_feat': {
|
||||
'triplet': {'embeddings': bnft, 'labels': labs},
|
||||
'softmax': {'logits': logi, 'labels': labs}
|
||||
},
|
||||
'visual_summary': {
|
||||
'image/sils': sils.view(n*s, 1, h, w)
|
||||
},
|
||||
'inference_feat': {
|
||||
'embeddings': bnft
|
||||
}
|
||||
}
|
||||
return retval
|
||||
@@ -0,0 +1,127 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ..base_model import BaseModel
|
||||
from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper, SeparateFCs
|
||||
from utils import clones
|
||||
|
||||
|
||||
class BasicConv1d(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
|
||||
super(BasicConv1d, self).__init__()
|
||||
self.conv = nn.Conv1d(in_channels, out_channels,
|
||||
kernel_size, bias=False, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
ret = self.conv(x)
|
||||
return ret
|
||||
|
||||
|
||||
class TemporalFeatureAggregator(nn.Module):
|
||||
def __init__(self, in_channels, squeeze=4, parts_num=16):
|
||||
super(TemporalFeatureAggregator, self).__init__()
|
||||
hidden_dim = int(in_channels // squeeze)
|
||||
self.parts_num = parts_num
|
||||
|
||||
# MTB1
|
||||
conv3x1 = nn.Sequential(
|
||||
BasicConv1d(in_channels, hidden_dim, 3, padding=1),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
BasicConv1d(hidden_dim, in_channels, 1))
|
||||
self.conv1d3x1 = clones(conv3x1, parts_num)
|
||||
self.avg_pool3x1 = nn.AvgPool1d(3, stride=1, padding=1)
|
||||
self.max_pool3x1 = nn.MaxPool1d(3, stride=1, padding=1)
|
||||
|
||||
# MTB1
|
||||
conv3x3 = nn.Sequential(
|
||||
BasicConv1d(in_channels, hidden_dim, 3, padding=1),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
BasicConv1d(hidden_dim, in_channels, 3, padding=1))
|
||||
self.conv1d3x3 = clones(conv3x3, parts_num)
|
||||
self.avg_pool3x3 = nn.AvgPool1d(5, stride=1, padding=2)
|
||||
self.max_pool3x3 = nn.MaxPool1d(5, stride=1, padding=2)
|
||||
|
||||
# Temporal Pooling, TP
|
||||
self.TP = torch.max
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Input: x, [n, s, c, p]
|
||||
Output: ret, [n, p, c]
|
||||
"""
|
||||
n, s, c, p = x.size()
|
||||
x = x.permute(3, 0, 2, 1).contiguous() # [p, n, c, s]
|
||||
feature = x.split(1, 0) # [[n, c, s], ...]
|
||||
x = x.view(-1, c, s)
|
||||
|
||||
# MTB1: ConvNet1d & Sigmoid
|
||||
logits3x1 = torch.cat([conv(_.squeeze(0)).unsqueeze(0)
|
||||
for conv, _ in zip(self.conv1d3x1, feature)], 0)
|
||||
scores3x1 = torch.sigmoid(logits3x1)
|
||||
# MTB1: Template Function
|
||||
feature3x1 = self.avg_pool3x1(x) + self.max_pool3x1(x)
|
||||
feature3x1 = feature3x1.view(p, n, c, s)
|
||||
feature3x1 = feature3x1 * scores3x1
|
||||
|
||||
# MTB2: ConvNet1d & Sigmoid
|
||||
logits3x3 = torch.cat([conv(_.squeeze(0)).unsqueeze(0)
|
||||
for conv, _ in zip(self.conv1d3x3, feature)], 0)
|
||||
scores3x3 = torch.sigmoid(logits3x3)
|
||||
# MTB2: Template Function
|
||||
feature3x3 = self.avg_pool3x3(x) + self.max_pool3x3(x)
|
||||
feature3x3 = feature3x3.view(p, n, c, s)
|
||||
feature3x3 = feature3x3 * scores3x3
|
||||
|
||||
# Temporal Pooling
|
||||
ret = self.TP(feature3x1 + feature3x3, dim=-1)[0] # [p, n, c]
|
||||
ret = ret.permute(1, 0, 2).contiguous() # [n, p, c]
|
||||
return ret
|
||||
|
||||
|
||||
class GaitPart(BaseModel):
|
||||
def __init__(self, *args, **kargs):
|
||||
super(GaitPart, self).__init__(*args, **kargs)
|
||||
"""
|
||||
GaitPart: Temporal Part-based Model for Gait Recognition
|
||||
Paper: https://openaccess.thecvf.com/content_CVPR_2020/papers/Fan_GaitPart_Temporal_Part-Based_Model_for_Gait_Recognition_CVPR_2020_paper.pdf
|
||||
Github: https://github.com/ChaoFan96/GaitPart
|
||||
"""
|
||||
|
||||
def build_network(self, model_cfg):
|
||||
|
||||
self.Backbone = self.get_backbone(model_cfg)
|
||||
head_cfg = model_cfg['SeparateFCs']
|
||||
self.Head = SeparateFCs(**model_cfg['SeparateFCs'])
|
||||
self.Backbone = SetBlockWrapper(self.Backbone)
|
||||
self.HPP = SetBlockWrapper(
|
||||
HorizontalPoolingPyramid(bin_num=model_cfg['bin_num']))
|
||||
self.TFA = PackSequenceWrapper(TemporalFeatureAggregator(
|
||||
in_channels=head_cfg['in_channels'], parts_num=head_cfg['parts_num']))
|
||||
|
||||
def forward(self, inputs):
|
||||
ipts, labs, _, _, seqL = inputs
|
||||
|
||||
sils = ipts[0]
|
||||
if len(sils.size()) == 4:
|
||||
sils = sils.unsqueeze(2)
|
||||
|
||||
del ipts
|
||||
out = self.Backbone(sils) # [n, s, c, h, w]
|
||||
out = self.HPP(out) # [n, s, c, p]
|
||||
out = self.TFA(out, seqL) # [n, p, c]
|
||||
|
||||
embs = self.Head(out.permute(1, 0, 2).contiguous()) # [p, n, c]
|
||||
embs = embs.permute(1, 0, 2).contiguous() # [n, p, c]
|
||||
|
||||
n, s, _, h, w = sils.size()
|
||||
retval = {
|
||||
'training_feat': {
|
||||
'triplet': {'embeddings': embs, 'labels': labs}
|
||||
},
|
||||
'visual_summary': {
|
||||
'image/sils': sils.view(n*s, 1, h, w)
|
||||
},
|
||||
'inference_feat': {
|
||||
'embeddings': embs
|
||||
}
|
||||
}
|
||||
return retval
|
||||
@@ -0,0 +1,87 @@
|
||||
import torch
|
||||
import copy
|
||||
import torch.nn as nn
|
||||
|
||||
from ..base_model import BaseModel
|
||||
from ..modules import SeparateFCs, BasicConv2d, SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper
|
||||
|
||||
|
||||
class GaitSet(BaseModel):
|
||||
"""
|
||||
GaitSet: Regarding Gait as a Set for Cross-View Gait Recognition
|
||||
Arxiv: https://arxiv.org/abs/1811.06186
|
||||
Github: https://github.com/AbnerHqC/GaitSet
|
||||
"""
|
||||
|
||||
def build_network(self, model_cfg):
|
||||
in_c = model_cfg['in_channels']
|
||||
self.set_block1 = nn.Sequential(BasicConv2d(in_c[0], in_c[1], 5, 1, 2),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
BasicConv2d(in_c[1], in_c[1], 3, 1, 1),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2))
|
||||
|
||||
self.set_block2 = nn.Sequential(BasicConv2d(in_c[1], in_c[2], 3, 1, 1),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
BasicConv2d(in_c[2], in_c[2], 3, 1, 1),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2))
|
||||
|
||||
self.set_block3 = nn.Sequential(BasicConv2d(in_c[2], in_c[3], 3, 1, 1),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
BasicConv2d(in_c[3], in_c[3], 3, 1, 1),
|
||||
nn.LeakyReLU(inplace=True))
|
||||
|
||||
self.gl_block2 = copy.deepcopy(self.set_block2)
|
||||
self.gl_block3 = copy.deepcopy(self.set_block3)
|
||||
|
||||
self.set_block1 = SetBlockWrapper(self.set_block1)
|
||||
self.set_block2 = SetBlockWrapper(self.set_block2)
|
||||
self.set_block3 = SetBlockWrapper(self.set_block3)
|
||||
|
||||
self.set_pooling = PackSequenceWrapper(torch.max)
|
||||
|
||||
self.Head = SeparateFCs(**model_cfg['SeparateFCs'])
|
||||
|
||||
self.HPP = HorizontalPoolingPyramid(bin_num=model_cfg['bin_num'])
|
||||
|
||||
def forward(self, inputs):
|
||||
ipts, labs, _, _, seqL = inputs
|
||||
sils = ipts[0] # [n, s, h, w]
|
||||
if len(sils.size()) == 4:
|
||||
sils = sils.unsqueeze(2)
|
||||
|
||||
del ipts
|
||||
outs = self.set_block1(sils)
|
||||
gl = self.set_pooling(outs, seqL, dim=1)[0]
|
||||
gl = self.gl_block2(gl)
|
||||
|
||||
outs = self.set_block2(outs)
|
||||
gl = gl + self.set_pooling(outs, seqL, dim=1)[0]
|
||||
gl = self.gl_block3(gl)
|
||||
|
||||
outs = self.set_block3(outs)
|
||||
outs = self.set_pooling(outs, seqL, dim=1)[0]
|
||||
gl = gl + outs
|
||||
|
||||
# Horizontal Pooling Matching, HPM
|
||||
feature1 = self.HPP(outs) # [n, c, p]
|
||||
feature2 = self.HPP(gl) # [n, c, p]
|
||||
feature = torch.cat([feature1, feature2], -1) # [n, c, p]
|
||||
feature = feature.permute(2, 0, 1).contiguous() # [p, n, c]
|
||||
embs = self.Head(feature)
|
||||
embs = embs.permute(1, 0, 2).contiguous() # [n, p, c]
|
||||
|
||||
n, s, _, h, w = sils.size()
|
||||
retval = {
|
||||
'training_feat': {
|
||||
'triplet': {'embeddings': embs, 'labels': labs}
|
||||
},
|
||||
'visual_summary': {
|
||||
'image/sils': sils.view(n*s, 1, h, w)
|
||||
},
|
||||
'inference_feat': {
|
||||
'embeddings': embs
|
||||
}
|
||||
}
|
||||
return retval
|
||||
@@ -0,0 +1,172 @@
|
||||
import torch
|
||||
import copy
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..base_model import BaseModel
|
||||
from ..modules import SeparateFCs, BasicConv2d, SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper
|
||||
|
||||
|
||||
class GLN(BaseModel):
|
||||
"""
|
||||
http://home.ustc.edu.cn/~saihui/papers/eccv2020_gln.pdf
|
||||
Gait Lateral Network: Learning Discriminative and Compact Representations for Gait Recognition
|
||||
"""
|
||||
|
||||
def build_network(self, model_cfg):
|
||||
in_channels = model_cfg['in_channels']
|
||||
self.bin_num = model_cfg['bin_num']
|
||||
self.hidden_dim = model_cfg['hidden_dim']
|
||||
lateral_dim = model_cfg['lateral_dim']
|
||||
reduce_dim = self.hidden_dim
|
||||
self.pretrain = model_cfg['Lateral_pretraining']
|
||||
|
||||
self.sil_stage_0 = nn.Sequential(BasicConv2d(in_channels[0], in_channels[1], 5, 1, 2),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
BasicConv2d(
|
||||
in_channels[1], in_channels[1], 3, 1, 1),
|
||||
nn.LeakyReLU(inplace=True))
|
||||
|
||||
self.sil_stage_1 = nn.Sequential(BasicConv2d(in_channels[1], in_channels[2], 3, 1, 1),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
BasicConv2d(
|
||||
in_channels[2], in_channels[2], 3, 1, 1),
|
||||
nn.LeakyReLU(inplace=True))
|
||||
|
||||
self.sil_stage_2 = nn.Sequential(BasicConv2d(in_channels[2], in_channels[3], 3, 1, 1),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
BasicConv2d(
|
||||
in_channels[3], in_channels[3], 3, 1, 1),
|
||||
nn.LeakyReLU(inplace=True))
|
||||
|
||||
self.set_stage_1 = copy.deepcopy(self.sil_stage_1)
|
||||
self.set_stage_2 = copy.deepcopy(self.sil_stage_2)
|
||||
|
||||
self.set_pooling = PackSequenceWrapper(torch.max)
|
||||
|
||||
self.MaxP_sil = SetBlockWrapper(nn.MaxPool2d(kernel_size=2, stride=2))
|
||||
self.MaxP_set = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
|
||||
self.sil_stage_0 = SetBlockWrapper(self.sil_stage_0)
|
||||
self.sil_stage_1 = SetBlockWrapper(self.sil_stage_1)
|
||||
self.sil_stage_2 = SetBlockWrapper(self.sil_stage_2)
|
||||
|
||||
self.lateral_layer1 = nn.Conv2d(
|
||||
in_channels[1]*2, lateral_dim, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.lateral_layer2 = nn.Conv2d(
|
||||
in_channels[2]*2, lateral_dim, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.lateral_layer3 = nn.Conv2d(
|
||||
in_channels[3]*2, lateral_dim, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
|
||||
self.smooth_layer1 = nn.Conv2d(
|
||||
lateral_dim, lateral_dim, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.smooth_layer2 = nn.Conv2d(
|
||||
lateral_dim, lateral_dim, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.smooth_layer3 = nn.Conv2d(
|
||||
lateral_dim, lateral_dim, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
|
||||
self.HPP = HorizontalPoolingPyramid()
|
||||
self.Head = SeparateFCs(**model_cfg['SeparateFCs'])
|
||||
|
||||
if not self.pretrain:
|
||||
self.encoder_bn = nn.BatchNorm1d(sum(self.bin_num)*3*self.hidden_dim)
|
||||
self.encoder_bn.bias.requires_grad_(False)
|
||||
|
||||
self.reduce_dp = nn.Dropout(p=model_cfg['dropout'])
|
||||
self.reduce_ac = nn.ReLU(inplace=True)
|
||||
self.reduce_fc = nn.Linear(sum(self.bin_num)*3*self.hidden_dim, reduce_dim, bias=False)
|
||||
|
||||
self.reduce_bn = nn.BatchNorm1d(reduce_dim)
|
||||
self.reduce_bn.bias.requires_grad_(False)
|
||||
|
||||
self.reduce_cls = nn.Linear(reduce_dim, model_cfg['class_num'], bias=False)
|
||||
|
||||
def upsample_add(self, x, y):
|
||||
return F.interpolate(x, scale_factor=2, mode='nearest') + y
|
||||
|
||||
def forward(self, inputs):
|
||||
ipts, labs, _, _, seqL = inputs
|
||||
seqL = None if not self.training else seqL
|
||||
sils = ipts[0] # [n, s, h, w]
|
||||
del ipts
|
||||
if len(sils.size()) == 4:
|
||||
sils = sils.unsqueeze(2)
|
||||
n, s, _, h, w = sils.size()
|
||||
|
||||
### stage 0 sil ###
|
||||
sil_0_outs = self.sil_stage_0(sils)
|
||||
stage_0_sil_set = self.set_pooling(sil_0_outs, seqL, dim=1)[0]
|
||||
|
||||
### stage 1 sil ###
|
||||
sil_1_ipts = self.MaxP_sil(sil_0_outs)
|
||||
sil_1_outs = self.sil_stage_1(sil_1_ipts)
|
||||
|
||||
### stage 2 sil ###
|
||||
sil_2_ipts = self.MaxP_sil(sil_1_outs)
|
||||
sil_2_outs = self.sil_stage_2(sil_2_ipts)
|
||||
|
||||
### stage 1 set ###
|
||||
set_1_ipts = self.set_pooling(sil_1_ipts, seqL, dim=1)[0]
|
||||
stage_1_sil_set = self.set_pooling(sil_1_outs, seqL, dim=1)[0]
|
||||
set_1_outs = self.set_stage_1(set_1_ipts) + stage_1_sil_set
|
||||
|
||||
### stage 2 set ###
|
||||
set_2_ipts = self.MaxP_set(set_1_outs)
|
||||
stage_2_sil_set = self.set_pooling(sil_2_outs, seqL, dim=1)[0]
|
||||
set_2_outs = self.set_stage_2(set_2_ipts) + stage_2_sil_set
|
||||
|
||||
set1 = torch.cat((stage_0_sil_set, stage_0_sil_set), dim=1)
|
||||
set2 = torch.cat((stage_1_sil_set, set_1_outs), dim=1)
|
||||
set3 = torch.cat((stage_2_sil_set, set_2_outs), dim=1)
|
||||
|
||||
# print(set1.shape,set2.shape,set3.shape,"***\n")
|
||||
|
||||
# lateral
|
||||
set3 = self.lateral_layer3(set3)
|
||||
set2 = self.upsample_add(set3, self.lateral_layer2(set2))
|
||||
set1 = self.upsample_add(set2, self.lateral_layer1(set1))
|
||||
|
||||
set3 = self.smooth_layer3(set3)
|
||||
set2 = self.smooth_layer2(set2)
|
||||
set1 = self.smooth_layer1(set1)
|
||||
|
||||
set1 = self.HPP(set1)
|
||||
set2 = self.HPP(set2)
|
||||
set3 = self.HPP(set3)
|
||||
|
||||
feature = torch.cat([set1, set2, set3], -
|
||||
1).permute(2, 0, 1).contiguous()
|
||||
|
||||
feature = self.Head(feature)
|
||||
feature = feature.permute(1, 0, 2).contiguous() # n p c
|
||||
|
||||
# compact_bloack
|
||||
if not self.pretrain:
|
||||
bn_feature = self.encoder_bn(feature.view(n, -1))
|
||||
bn_feature = bn_feature.view(*feature.shape).contiguous()
|
||||
|
||||
reduce_feature = self.reduce_dp(bn_feature)
|
||||
reduce_feature = self.reduce_ac(reduce_feature)
|
||||
reduce_feature = self.reduce_fc(reduce_feature.view(n, -1))
|
||||
|
||||
bn_reduce_feature = self.reduce_bn(reduce_feature)
|
||||
logits = self.reduce_cls(bn_reduce_feature).unsqueeze(1) # n c
|
||||
|
||||
reduce_feature = reduce_feature.unsqueeze(1).contiguous()
|
||||
bn_reduce_feature = bn_reduce_feature.unsqueeze(1).contiguous()
|
||||
|
||||
retval = {
|
||||
'training_feat': {},
|
||||
'visual_summary': {
|
||||
'image/sils': sils.view(n*s, 1, h, w)
|
||||
},
|
||||
'inference_feat': {
|
||||
'embeddings': feature # reduce_feature # bn_reduce_feature
|
||||
}
|
||||
}
|
||||
if self.pretrain:
|
||||
retval['training_feat']['triplet'] = {'embeddings': feature, 'labels': labs}
|
||||
else:
|
||||
retval['training_feat']['triplet'] = {'embeddings': feature, 'labels': labs}
|
||||
retval['training_feat']['softmax'] = {'logits': logits, 'labels': labs}
|
||||
return retval
|
||||
@@ -0,0 +1,200 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from utils import clones, is_list_or_tuple
|
||||
|
||||
|
||||
class HorizontalPoolingPyramid():
|
||||
"""
|
||||
Horizontal Pyramid Matching for Person Re-identification
|
||||
Arxiv: https://arxiv.org/abs/1804.05275
|
||||
Github: https://github.com/SHI-Labs/Horizontal-Pyramid-Matching
|
||||
"""
|
||||
|
||||
def __init__(self, bin_num=None):
|
||||
if bin_num is None:
|
||||
bin_num = [16, 8, 4, 2, 1]
|
||||
self.bin_num = bin_num
|
||||
|
||||
def __call__(self, x):
|
||||
"""
|
||||
x : [n, c, h, w]
|
||||
ret: [n, c, p]
|
||||
"""
|
||||
n, c = x.size()[:2]
|
||||
features = []
|
||||
for b in self.bin_num:
|
||||
z = x.view(n, c, b, -1)
|
||||
z = z.mean(-1) + z.max(-1)[0]
|
||||
features.append(z)
|
||||
return torch.cat(features, -1)
|
||||
|
||||
|
||||
class SetBlockWrapper(nn.Module):
|
||||
def __init__(self, forward_block):
|
||||
super(SetBlockWrapper, self).__init__()
|
||||
self.forward_block = forward_block
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
"""
|
||||
In x: [n, s, c, h, w]
|
||||
Out x: [n, s, ...]
|
||||
"""
|
||||
n, s, c, h, w = x.size()
|
||||
x = self.forward_block(x.view(-1, c, h, w), *args, **kwargs)
|
||||
_ = x.size()
|
||||
_ = [n, s] + [*_[1:]]
|
||||
return x.view(*_)
|
||||
|
||||
|
||||
class PackSequenceWrapper(nn.Module):
|
||||
def __init__(self, pooling_func):
|
||||
super(PackSequenceWrapper, self).__init__()
|
||||
self.pooling_func = pooling_func
|
||||
|
||||
def forward(self, seqs, seqL, seq_dim=1, **kwargs):
|
||||
"""
|
||||
In seqs: [n, s, ...]
|
||||
Out rets: [n, ...]
|
||||
"""
|
||||
if seqL is None:
|
||||
return self.pooling_func(seqs, **kwargs)
|
||||
seqL = seqL[0].data.cpu().numpy().tolist()
|
||||
start = [0] + np.cumsum(seqL).tolist()[:-1]
|
||||
|
||||
rets = []
|
||||
for curr_start, curr_seqL in zip(start, seqL):
|
||||
narrowed_seq = seqs.narrow(seq_dim, curr_start, curr_seqL)
|
||||
# save the memory
|
||||
# splited_narrowed_seq = torch.split(narrowed_seq, 256, dim=1)
|
||||
# ret = []
|
||||
# for seq_to_pooling in splited_narrowed_seq:
|
||||
# ret.append(self.pooling_func(seq_to_pooling, keepdim=True, **kwargs)
|
||||
# [0] if self.is_tuple_result else self.pooling_func(seq_to_pooling, **kwargs))
|
||||
rets.append(self.pooling_func(narrowed_seq, **kwargs))
|
||||
if len(rets) > 0 and is_list_or_tuple(rets[0]):
|
||||
return [torch.cat([ret[j] for ret in rets])
|
||||
for j in range(len(rets[0]))]
|
||||
return torch.cat(rets)
|
||||
|
||||
|
||||
class BasicConv2d(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, **kwargs):
|
||||
super(BasicConv2d, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size,
|
||||
stride=stride, padding=padding, bias=False, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class SeparateFCs(nn.Module):
|
||||
def __init__(self, parts_num, in_channels, out_channels, norm=False):
|
||||
super(SeparateFCs, self).__init__()
|
||||
self.p = parts_num
|
||||
self.fc_bin = nn.Parameter(
|
||||
nn.init.xavier_uniform_(
|
||||
torch.zeros(parts_num, in_channels, out_channels)))
|
||||
self.norm = norm
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: [p, n, c]
|
||||
"""
|
||||
if self.norm:
|
||||
out = x.matmul(F.normalize(self.fc_bin, dim=1))
|
||||
else:
|
||||
out = x.matmul(self.fc_bin)
|
||||
return out
|
||||
|
||||
|
||||
class SeparateBNNecks(nn.Module):
|
||||
"""
|
||||
GaitSet: Bag of Tricks and a Strong Baseline for Deep Person Re-Identification
|
||||
CVPR Workshop: https://openaccess.thecvf.com/content_CVPRW_2019/papers/TRMTMCT/Luo_Bag_of_Tricks_and_a_Strong_Baseline_for_Deep_Person_CVPRW_2019_paper.pdf
|
||||
Github: https://github.com/michuanhaohao/reid-strong-baseline
|
||||
"""
|
||||
|
||||
def __init__(self, parts_num, in_channels, class_num, norm=True, parallel_BN1d=True):
|
||||
super(SeparateBNNecks, self).__init__()
|
||||
self.p = parts_num
|
||||
self.class_num = class_num
|
||||
self.norm = norm
|
||||
self.fc_bin = nn.Parameter(
|
||||
nn.init.xavier_uniform_(
|
||||
torch.zeros(parts_num, in_channels, class_num)))
|
||||
if parallel_BN1d:
|
||||
self.bn1d = nn.BatchNorm1d(in_channels * parts_num)
|
||||
else:
|
||||
self.bn1d = clones(nn.BatchNorm1d(in_channels), parts_num)
|
||||
self.parallel_BN1d = parallel_BN1d
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: [p, n, c]
|
||||
"""
|
||||
if self.parallel_BN1d:
|
||||
p, n, c = x.size()
|
||||
x = x.transpose(0, 1).contiguous().view(n, -1) # [n, p*c]
|
||||
x = self.bn1d(x)
|
||||
x = x.view(n, p, c).permute(1, 0, 2).contiguous()
|
||||
else:
|
||||
x = torch.cat([bn(_.squeeze(0)).unsqueeze(0)
|
||||
for _, bn in zip(x.split(1, 0), self.bn1d)], 0) # [p, n, c]
|
||||
if self.norm:
|
||||
feature = F.normalize(x, dim=-1) # [p, n, c]
|
||||
logits = feature.matmul(F.normalize(
|
||||
self.fc_bin, dim=1)) # [p, n, c]
|
||||
else:
|
||||
feature = x
|
||||
logits = feature.matmul(self.fc_bin)
|
||||
return feature, logits
|
||||
|
||||
|
||||
class FocalConv2d(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, halving, **kwargs):
|
||||
super(FocalConv2d, self).__init__()
|
||||
self.halving = halving
|
||||
self.conv = nn.Conv2d(in_channels, out_channels,
|
||||
kernel_size, bias=False, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
if self.halving == 0:
|
||||
z = self.conv(x)
|
||||
else:
|
||||
h = x.size(2)
|
||||
split_size = int(h // 2**self.halving)
|
||||
z = x.split(split_size, 2)
|
||||
z = torch.cat([self.conv(_) for _ in z], 2)
|
||||
return z
|
||||
|
||||
|
||||
class BasicConv3d(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False, **kwargs):
|
||||
super(BasicConv3d, self).__init__()
|
||||
self.conv3d = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size,
|
||||
stride=stride, padding=padding, bias=bias, **kwargs)
|
||||
|
||||
def forward(self, ipts):
|
||||
'''
|
||||
ipts: [n, c, s, h, w]
|
||||
outs: [n, c, s, h, w]
|
||||
'''
|
||||
outs = self.conv3d(ipts)
|
||||
return outs
|
||||
|
||||
|
||||
def RmBN2dAffine(model):
|
||||
for m in model.modules():
|
||||
if isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.requires_grad = False
|
||||
m.bias.requires_grad = False
|
||||
|
||||
|
||||
def fix_BN(model):
|
||||
for module in model.modules():
|
||||
classname = module.__class__.__name__
|
||||
if classname.find('BatchNorm2d') != -1:
|
||||
module.eval()
|
||||
@@ -0,0 +1,10 @@
|
||||
from .common import get_ddp_module, ddp_all_gather
|
||||
from .common import Odict, Ntuple
|
||||
from .common import get_valid_args
|
||||
from .common import is_list_or_tuple, is_str, is_list, is_dict, is_tensor, is_array, config_loader, init_seeds, handler, params_count
|
||||
from .common import ts2np, ts2var, np2var, list2var
|
||||
from .common import mkdir, clones
|
||||
from .common import MergeCfgsDict
|
||||
from .common import get_attr_from
|
||||
from .common import NoOp
|
||||
from .msg_manager import get_msg_mgr
|
||||
@@ -0,0 +1,201 @@
|
||||
import copy
|
||||
import os
|
||||
import inspect
|
||||
import logging
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch.autograd as autograd
|
||||
import yaml
|
||||
import random
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from collections import OrderedDict, namedtuple
|
||||
|
||||
|
||||
class NoOp:
|
||||
def __getattr__(self, *args):
|
||||
def no_op(*args, **kwargs): pass
|
||||
return no_op
|
||||
|
||||
|
||||
class Odict(OrderedDict):
|
||||
def append(self, odict):
|
||||
dst_keys = self.keys()
|
||||
for k, v in odict.items():
|
||||
if not is_list(v):
|
||||
v = [v]
|
||||
if k in dst_keys:
|
||||
if is_list(self[k]):
|
||||
self[k] += v
|
||||
else:
|
||||
self[k] = [self[k]] + v
|
||||
else:
|
||||
self[k] = v
|
||||
|
||||
|
||||
def Ntuple(description, keys, values):
|
||||
if not is_list_or_tuple(keys):
|
||||
keys = [keys]
|
||||
values = [values]
|
||||
Tuple = namedtuple(description, keys)
|
||||
return Tuple._make(values)
|
||||
|
||||
|
||||
def get_valid_args(obj, input_args, free_keys=[]):
|
||||
if inspect.isfunction(obj):
|
||||
expected_keys = inspect.getargspec(obj)[0]
|
||||
elif inspect.isclass(obj):
|
||||
expected_keys = inspect.getargspec(obj.__init__)[0]
|
||||
else:
|
||||
raise ValueError('Just support function and class object!')
|
||||
unexpect_keys = list()
|
||||
expected_args = {}
|
||||
for k, v in input_args.items():
|
||||
if k in expected_keys:
|
||||
expected_args[k] = v
|
||||
elif k in free_keys:
|
||||
pass
|
||||
else:
|
||||
unexpect_keys.append(k)
|
||||
if unexpect_keys != []:
|
||||
logging.info("Find Unexpected Args(%s) in the Configuration of - %s -" %
|
||||
(', '.join(unexpect_keys), obj.__name__))
|
||||
return expected_args
|
||||
|
||||
|
||||
def get_attr_from(sources, name):
|
||||
try:
|
||||
return getattr(sources[0], name)
|
||||
except:
|
||||
return get_attr_from(sources[1:], name) if len(sources) > 1 else getattr(sources[0], name)
|
||||
|
||||
|
||||
def is_list_or_tuple(x):
|
||||
return isinstance(x, (list, tuple))
|
||||
|
||||
|
||||
def is_str(x):
|
||||
return isinstance(x, str)
|
||||
|
||||
|
||||
def is_list(x):
|
||||
return isinstance(x, list) or isinstance(x, nn.ModuleList)
|
||||
|
||||
|
||||
def is_dict(x):
|
||||
return isinstance(x, dict) or isinstance(x, OrderedDict) or isinstance(x, Odict)
|
||||
|
||||
|
||||
def is_tensor(x):
|
||||
return isinstance(x, torch.Tensor)
|
||||
|
||||
|
||||
def is_array(x):
|
||||
return isinstance(x, np.ndarray)
|
||||
|
||||
|
||||
def ts2np(x):
|
||||
return x.cpu().data.numpy()
|
||||
|
||||
|
||||
def ts2var(x, **kwargs):
|
||||
return autograd.Variable(x, **kwargs).cuda()
|
||||
|
||||
|
||||
def np2var(x, **kwargs):
|
||||
return ts2var(torch.from_numpy(x), **kwargs)
|
||||
|
||||
|
||||
def list2var(x, **kwargs):
|
||||
return np2var(np.array(x), **kwargs)
|
||||
|
||||
|
||||
def mkdir(path):
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
|
||||
|
||||
def MergeCfgsDict(src, dst):
|
||||
for k, v in src.items():
|
||||
if (k not in dst.keys()) or (type(v) != type(dict())):
|
||||
dst[k] = v
|
||||
else:
|
||||
if is_dict(src[k]) and is_dict(dst[k]):
|
||||
MergeCfgsDict(src[k], dst[k])
|
||||
else:
|
||||
dst[k] = v
|
||||
|
||||
|
||||
def clones(module, N):
|
||||
"Produce N identical layers."
|
||||
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
|
||||
|
||||
|
||||
def config_loader(path):
|
||||
with open(path, 'r') as stream:
|
||||
src_cfgs = yaml.safe_load(stream)
|
||||
with open("./config/default.yaml", 'r') as stream:
|
||||
dst_cfgs = yaml.safe_load(stream)
|
||||
MergeCfgsDict(src_cfgs, dst_cfgs)
|
||||
return dst_cfgs
|
||||
|
||||
|
||||
def init_seeds(seed=0, cuda_deterministic=True):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
# Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
|
||||
if cuda_deterministic: # slower, more reproducible
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
else: # faster, less reproducible
|
||||
torch.backends.cudnn.deterministic = False
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
|
||||
def handler(signum, frame):
|
||||
logging.info('Ctrl+c/z pressed')
|
||||
os.system(
|
||||
"kill $(ps aux | grep main.py | grep -v grep | awk '{print $2}') ")
|
||||
logging.info('process group flush!')
|
||||
|
||||
|
||||
def ddp_all_gather(features, dim=0, requires_grad=True):
|
||||
'''
|
||||
inputs: [n, ...]
|
||||
'''
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
rank = torch.distributed.get_rank()
|
||||
feature_list = [torch.ones_like(features) for _ in range(world_size)]
|
||||
torch.distributed.all_gather(feature_list, features.contiguous())
|
||||
|
||||
if requires_grad:
|
||||
feature_list[rank] = features
|
||||
feature = torch.cat(feature_list, dim=dim)
|
||||
return feature
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/16885
|
||||
class DDPPassthrough(DDP):
|
||||
def __getattr__(self, name):
|
||||
try:
|
||||
return super().__getattr__(name)
|
||||
except AttributeError:
|
||||
return getattr(self.module, name)
|
||||
|
||||
|
||||
def get_ddp_module(module, **kwargs):
|
||||
if len(list(module.parameters())) == 0:
|
||||
# for the case that loss module has not parameters.
|
||||
return module
|
||||
device = torch.cuda.current_device()
|
||||
module = DDPPassthrough(module, device_ids=[device], output_device=device,
|
||||
find_unused_parameters=False, **kwargs)
|
||||
return module
|
||||
|
||||
|
||||
def params_count(net):
|
||||
n_parameters = sum(p.numel() for p in net.parameters())
|
||||
return 'Parameters Count: {:.5f}M'.format(n_parameters / 1e6)
|
||||
@@ -0,0 +1,143 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from utils import get_msg_mgr
|
||||
|
||||
|
||||
def cuda_dist(x, y, metric='euc'):
|
||||
x = torch.from_numpy(x).cuda()
|
||||
y = torch.from_numpy(y).cuda()
|
||||
if metric == 'cos':
|
||||
x = F.normalize(x, p=2, dim=2) # n p c
|
||||
y = F.normalize(y, p=2, dim=2) # n p c
|
||||
num_bin = x.size(1)
|
||||
n_x = x.size(0)
|
||||
n_y = y.size(0)
|
||||
dist = torch.zeros(n_x, n_y).cuda()
|
||||
for i in range(num_bin):
|
||||
_x = x[:, i, ...]
|
||||
_y = y[:, i, ...]
|
||||
if metric == 'cos':
|
||||
dist += torch.matmul(_x, _y.transpose(0, 1))
|
||||
else:
|
||||
_dist = torch.sum(_x ** 2, 1).unsqueeze(1) + torch.sum(_y ** 2, 1).unsqueeze(
|
||||
1).transpose(0, 1) - 2 * torch.matmul(_x, _y.transpose(0, 1))
|
||||
dist += torch.sqrt(F.relu(_dist))
|
||||
return 1 - dist/num_bin if metric == 'cos' else dist / num_bin
|
||||
|
||||
# Exclude identical-view cases
|
||||
|
||||
|
||||
def de_diag(acc, each_angle=False):
|
||||
dividend = acc.shape[1] - 1.
|
||||
result = np.sum(acc - np.diag(np.diag(acc)), 1) / dividend
|
||||
if not each_angle:
|
||||
result = np.mean(result)
|
||||
return result
|
||||
|
||||
# Modified From https://github.com/AbnerHqC/GaitSet/blob/master/model/utils/evaluator.py
|
||||
|
||||
|
||||
def identification(data, dataset, metric='euc'):
|
||||
msg_mgr = get_msg_mgr()
|
||||
|
||||
feature, label, seq_type, view = data['embeddings'], data['labels'], data['types'], data['views']
|
||||
label = np.array(label)
|
||||
view_list = list(set(view))
|
||||
view_list.sort()
|
||||
view_num = len(view_list)
|
||||
# sample_num = len(feature)
|
||||
|
||||
probe_seq_dict = {'CASIA-B': [['nm-05', 'nm-06'], ['bg-01', 'bg-02'], ['cl-01', 'cl-02']],
|
||||
'OUMVLP': [['00']]}
|
||||
|
||||
gallery_seq_dict = {'CASIA-B': [['nm-01', 'nm-02', 'nm-03', 'nm-04']],
|
||||
'OUMVLP': [['01']]}
|
||||
if dataset not in (probe_seq_dict or gallery_seq_dict):
|
||||
raise KeyError("DataSet %s hasn't been supported !" % dataset)
|
||||
num_rank = 5
|
||||
acc = np.zeros([len(probe_seq_dict[dataset]),
|
||||
view_num, view_num, num_rank]) - 1.
|
||||
for (p, probe_seq) in enumerate(probe_seq_dict[dataset]):
|
||||
for gallery_seq in gallery_seq_dict[dataset]:
|
||||
for (v1, probe_view) in enumerate(view_list):
|
||||
for (v2, gallery_view) in enumerate(view_list):
|
||||
gseq_mask = np.isin(seq_type, gallery_seq) & np.isin(
|
||||
view, [gallery_view])
|
||||
gallery_x = feature[gseq_mask, :]
|
||||
gallery_y = label[gseq_mask]
|
||||
|
||||
pseq_mask = np.isin(seq_type, probe_seq) & np.isin(
|
||||
view, [probe_view])
|
||||
probe_x = feature[pseq_mask, :]
|
||||
probe_y = label[pseq_mask]
|
||||
|
||||
dist = cuda_dist(probe_x, gallery_x, metric)
|
||||
idx = dist.sort(1)[1].cpu().numpy()
|
||||
acc[p, v1, v2, :] = np.round(
|
||||
np.sum(np.cumsum(np.reshape(probe_y, [-1, 1]) == gallery_y[idx[:, 0:num_rank]], 1) > 0,
|
||||
0) * 100 / dist.shape[0], 2)
|
||||
result_dict = {}
|
||||
if 'OUMVLP' not in dataset:
|
||||
for i in range(1):
|
||||
msg_mgr.log_info(
|
||||
'===Rank-%d (Include identical-view cases)===' % (i + 1))
|
||||
msg_mgr.log_info('NM: %.3f,\tBG: %.3f,\tCL: %.3f' % (
|
||||
np.mean(acc[0, :, :, i]),
|
||||
np.mean(acc[1, :, :, i]),
|
||||
np.mean(acc[2, :, :, i])))
|
||||
for i in range(1):
|
||||
msg_mgr.log_info(
|
||||
'===Rank-%d (Exclude identical-view cases)===' % (i + 1))
|
||||
msg_mgr.log_info('NM: %.3f,\tBG: %.3f,\tCL: %.3f' % (
|
||||
de_diag(acc[0, :, :, i]),
|
||||
de_diag(acc[1, :, :, i]),
|
||||
de_diag(acc[2, :, :, i])))
|
||||
result_dict["scalar/test_accuracy/NM"] = acc[0, :, :, i]
|
||||
result_dict["scalar/test_accuracy/BG"] = acc[0, :, :, i]
|
||||
result_dict["scalar/test_accuracy/CL"] = acc[2, :, :, i]
|
||||
np.set_printoptions(precision=2, floatmode='fixed')
|
||||
for i in range(1):
|
||||
msg_mgr.log_info(
|
||||
'===Rank-%d of each angle (Exclude identical-view cases)===' % (i + 1))
|
||||
msg_mgr.log_info('NM: {}'.format(de_diag(acc[0, :, :, i], True)))
|
||||
msg_mgr.log_info('BG: {}'.format(de_diag(acc[1, :, :, i], True)))
|
||||
msg_mgr.log_info('CL: {}'.format(de_diag(acc[2, :, :, i], True)))
|
||||
else:
|
||||
msg_mgr.log_info('===Rank-1 (Include identical-view cases)===')
|
||||
msg_mgr.log_info('NM: %.3f ' % (np.mean(acc[0, :, :, 0])))
|
||||
msg_mgr.log_info('===Rank-1 (Exclude identical-view cases)===')
|
||||
msg_mgr.log_info('NM: %.3f ' % (np.mean(de_diag(acc[0, :, :, 0]))))
|
||||
result_dict["scalar/test_accuracy/NM"] = np.mean(
|
||||
de_diag(acc[0, :, :, 0]))
|
||||
return result_dict
|
||||
|
||||
|
||||
def identification_real_scene(data, dataset, metric='euc'):
|
||||
msg_mgr = get_msg_mgr()
|
||||
feature, label, seq_type = data['embeddings'], data['labels'], data['types']
|
||||
label = np.array(label)
|
||||
|
||||
gallery_seq_type = {'0001-1000': ['1', '2'],
|
||||
"HID2021": ['0'], '0001-1000-test': ['0']}
|
||||
probe_seq_type = {'0001-1000': ['3', '4', '5', '6'],
|
||||
"HID2021": ['1'], '0001-1000-test': ['1']}
|
||||
|
||||
num_rank = 5
|
||||
acc = np.zeros([num_rank]) - 1.
|
||||
gseq_mask = np.isin(seq_type, gallery_seq_type[dataset])
|
||||
gallery_x = feature[gseq_mask, :]
|
||||
gallery_y = label[gseq_mask]
|
||||
pseq_mask = np.isin(seq_type, probe_seq_type[dataset])
|
||||
probe_x = feature[pseq_mask, :]
|
||||
probe_y = label[pseq_mask]
|
||||
|
||||
dist = cuda_dist(probe_x, gallery_x, metric)
|
||||
idx = dist.cpu().sort(1)[1].numpy()
|
||||
acc = np.round(np.sum(np.cumsum(np.reshape(probe_y, [-1, 1]) == gallery_y[idx[:, 0:num_rank]], 1) > 0,
|
||||
0) * 100 / dist.shape[0], 2)
|
||||
msg_mgr.log_info('==Rank-1==')
|
||||
msg_mgr.log_info('%.3f' % (np.mean(acc[0])))
|
||||
msg_mgr.log_info('==Rank-5==')
|
||||
msg_mgr.log_info('%.3f' % (np.mean(acc[4])))
|
||||
return {"scalar/test_accuracy/Rank-1": np.mean(acc[0]), "scalar/test_accuracy/Rank-5": np.mean(acc[4])}
|
||||
@@ -0,0 +1,119 @@
|
||||
import time
|
||||
import torch
|
||||
|
||||
import numpy as np
|
||||
import torchvision.utils as vutils
|
||||
import os.path as osp
|
||||
from time import strftime, localtime
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from .common import is_list, is_tensor, ts2np, mkdir, Odict, NoOp
|
||||
import logging
|
||||
|
||||
|
||||
class MessageManager:
|
||||
def __init__(self):
|
||||
self.info_dict = Odict()
|
||||
self.writer_hparams = ['image', 'scalar']
|
||||
self.time = time.time()
|
||||
|
||||
def init_manager(self, save_path, log_to_file, log_iter, iteration=0):
|
||||
self.iteration = iteration
|
||||
self.log_iter = log_iter
|
||||
mkdir(osp.join(save_path, "summary/"))
|
||||
self.writer = SummaryWriter(
|
||||
osp.join(save_path, "summary/"), purge_step=self.iteration)
|
||||
|
||||
# init logger
|
||||
self.logger = logging.getLogger('opengait')
|
||||
self.logger.setLevel(logging.INFO)
|
||||
self.logger.propagate = False
|
||||
formatter = logging.Formatter(
|
||||
fmt='[%(asctime)s] [%(levelname)s]: %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
|
||||
if log_to_file:
|
||||
mkdir(osp.join(save_path, "logs/"))
|
||||
vlog = logging.FileHandler(
|
||||
osp.join(save_path, "logs/", strftime('%Y-%m-%d-%H-%M-%S', localtime())+'.txt'))
|
||||
vlog.setLevel(logging.INFO)
|
||||
vlog.setFormatter(formatter)
|
||||
self.logger.addHandler(vlog)
|
||||
|
||||
console = logging.StreamHandler()
|
||||
console.setFormatter(formatter)
|
||||
console.setLevel(logging.DEBUG)
|
||||
self.logger.addHandler(console)
|
||||
|
||||
def append(self, info):
|
||||
for k, v in info.items():
|
||||
v = [v] if not is_list(v) else v
|
||||
v = [ts2np(_) if is_tensor(_) else _ for _ in v]
|
||||
info[k] = v
|
||||
self.info_dict.append(info)
|
||||
|
||||
def flush(self):
|
||||
self.info_dict.clear()
|
||||
self.writer.flush()
|
||||
|
||||
def write_to_tensorboard(self, summary):
|
||||
|
||||
for k, v in summary.items():
|
||||
module_name = k.split('/')[0]
|
||||
if module_name not in self.writer_hparams:
|
||||
self.log_warning(
|
||||
'Not Expected --Summary-- type [{}] appear!!!{}'.format(k, self.writer_hparams))
|
||||
continue
|
||||
board_name = k.replace(module_name + "/", '')
|
||||
writer_module = getattr(self.writer, 'add_' + module_name)
|
||||
v = v.detach() if is_tensor(v) else v
|
||||
v = vutils.make_grid(
|
||||
v, normalize=True, scale_each=True) if 'image' in module_name else v
|
||||
if module_name == 'scalar':
|
||||
try:
|
||||
v = v.mean()
|
||||
except:
|
||||
v = v
|
||||
writer_module(board_name, v, self.iteration)
|
||||
|
||||
def log_training_info(self):
|
||||
now = time.time()
|
||||
string = "Iteration {:0>5}, Cost {:.2f}s".format(
|
||||
self.iteration, now-self.time, end="")
|
||||
for i, (k, v) in enumerate(self.info_dict.items()):
|
||||
if 'scalar' not in k:
|
||||
continue
|
||||
k = k.replace('scalar/', '').replace('/', '_')
|
||||
end = "\n" if i == len(self.info_dict)-1 else ""
|
||||
string += ", {0}={1:.4f}".format(k, np.mean(v), end=end)
|
||||
self.log_info(string)
|
||||
self.reset_time()
|
||||
|
||||
def reset_time(self):
|
||||
self.time = time.time()
|
||||
|
||||
def train_step(self, info, summary):
|
||||
self.iteration += 1
|
||||
self.append(info)
|
||||
if self.iteration % self.log_iter == 0:
|
||||
self.log_training_info()
|
||||
self.flush()
|
||||
self.write_to_tensorboard(summary)
|
||||
|
||||
def log_debug(self, *args, **kwargs):
|
||||
self.logger.debug(*args, **kwargs)
|
||||
|
||||
def log_info(self, *args, **kwargs):
|
||||
self.logger.info(*args, **kwargs)
|
||||
|
||||
def log_warning(self, *args, **kwargs):
|
||||
self.logger.warning(*args, **kwargs)
|
||||
|
||||
|
||||
msg_mgr = MessageManager()
|
||||
noop = NoOp()
|
||||
|
||||
|
||||
def get_msg_mgr():
|
||||
if torch.distributed.get_rank() > 0:
|
||||
return noop
|
||||
else:
|
||||
return msg_mgr
|
||||
Reference in New Issue
Block a user