OpenGait release(pre-beta version).
This commit is contained in:
@@ -0,0 +1,113 @@
|
||||
import math
|
||||
import random
|
||||
import numpy as np
|
||||
from utils import get_msg_mgr
|
||||
|
||||
class CollateFn(object):
|
||||
def __init__(self, label_set, sample_config):
|
||||
self.label_set = label_set
|
||||
sample_type = sample_config['sample_type']
|
||||
sample_type = sample_type.split('_')
|
||||
self.sampler = sample_type[0]
|
||||
self.ordered = sample_type[1]
|
||||
if self.sampler not in ['fixed', 'unfixed', 'all']:
|
||||
raise ValueError
|
||||
if self.ordered not in ['ordered', 'unordered']:
|
||||
raise ValueError
|
||||
self.ordered = sample_type[1] == 'ordered'
|
||||
|
||||
# fixed cases
|
||||
if self.sampler == 'fixed':
|
||||
self.frames_num_fixed = sample_config['frames_num_fixed']
|
||||
|
||||
# unfixed cases
|
||||
if self.sampler == 'unfixed':
|
||||
self.frames_num_max = sample_config['frames_num_max']
|
||||
self.frames_num_min = sample_config['frames_num_min']
|
||||
|
||||
if self.sampler != 'all' and self.ordered:
|
||||
self.frames_skip_num = sample_config['frames_skip_num']
|
||||
|
||||
self.frames_all_limit = -1
|
||||
if self.sampler == 'all' and 'frames_all_limit' in sample_config:
|
||||
self.frames_all_limit = sample_config['frames_all_limit']
|
||||
|
||||
def __call__(self, batch):
|
||||
batch_size = len(batch)
|
||||
feature_num = len(batch[0][0])
|
||||
seqs_batch, labs_batch, typs_batch, vies_batch = [], [], [], []
|
||||
|
||||
for bt in batch:
|
||||
seqs_batch.append(bt[0])
|
||||
labs_batch.append(self.label_set.index(bt[1][0]))
|
||||
typs_batch.append(bt[1][1])
|
||||
vies_batch.append(bt[1][2])
|
||||
|
||||
global count
|
||||
count = 0
|
||||
|
||||
def sample_frames(seqs):
|
||||
global count
|
||||
sampled_fras = [[] for i in range(feature_num)]
|
||||
seq_len = len(seqs[0])
|
||||
indices = list(range(seq_len))
|
||||
|
||||
if self.sampler in ['fixed', 'unfixed']:
|
||||
if self.sampler == 'fixed':
|
||||
frames_num = self.frames_num_fixed
|
||||
else:
|
||||
frames_num = random.choice(
|
||||
list(range(self.frames_num_min, self.frames_num_max+1)))
|
||||
|
||||
if self.ordered:
|
||||
fs_n = frames_num + self.frames_skip_num
|
||||
if seq_len < fs_n:
|
||||
it = math.ceil(fs_n / seq_len)
|
||||
seq_len = seq_len * it
|
||||
indices = indices * it
|
||||
|
||||
start = random.choice(list(range(0, seq_len - fs_n + 1)))
|
||||
end = start + fs_n
|
||||
idx_lst = list(range(seq_len))
|
||||
idx_lst = idx_lst[start:end]
|
||||
idx_lst = sorted(np.random.choice(
|
||||
idx_lst, frames_num, replace=False))
|
||||
indices = [indices[i] for i in idx_lst]
|
||||
else:
|
||||
replace = seq_len < frames_num
|
||||
|
||||
if seq_len == 0:
|
||||
get_msg_mgr().log_debug('Find no frames in the sequence %s-%s-%s.'
|
||||
%(str(labs_batch[count]), str(typs_batch[count]), str(vies_batch[count])))
|
||||
|
||||
count += 1
|
||||
indices = np.random.choice(
|
||||
indices, frames_num, replace=replace)
|
||||
|
||||
for i in range(feature_num):
|
||||
for j in indices[:self.frames_all_limit] if self.frames_all_limit > -1 and len(indices) > self.frames_all_limit else indices:
|
||||
sampled_fras[i].append(seqs[i][j])
|
||||
return sampled_fras
|
||||
|
||||
# f: feature_num
|
||||
# b: batch_size
|
||||
# p: batch_size_per_gpu
|
||||
# g: gpus_num
|
||||
fras_batch = [sample_frames(seqs) for seqs in seqs_batch] # [b, f]
|
||||
batch = [fras_batch, labs_batch, typs_batch, vies_batch, None]
|
||||
|
||||
if self.sampler == "fixed":
|
||||
fras_batch = [[np.asarray(fras_batch[i][j]) for i in range(batch_size)]
|
||||
for j in range(feature_num)] # [f, b]
|
||||
else:
|
||||
seqL_batch = [[len(fras_batch[i][0])
|
||||
for i in range(batch_size)]] # [1, p]
|
||||
|
||||
def my_cat(k): return np.concatenate(
|
||||
[fras_batch[i][k] for i in range(batch_size)], 0)
|
||||
fras_batch = [[my_cat(k)] for k in range(feature_num)] # [f, g]
|
||||
|
||||
batch[-1] = np.asarray(seqL_batch)
|
||||
|
||||
batch[0] = fras_batch
|
||||
return batch
|
||||
@@ -0,0 +1,123 @@
|
||||
import os
|
||||
import pickle
|
||||
import os.path as osp
|
||||
import torch.utils.data as tordata
|
||||
import json
|
||||
from utils import get_msg_mgr
|
||||
|
||||
|
||||
class DataSet(tordata.Dataset):
|
||||
def __init__(self, data_cfg, training):
|
||||
"""
|
||||
seqs_info: the list with each element indicating
|
||||
a certain gait sequence presented as [label, type, view, paths];
|
||||
"""
|
||||
self.__dataset_parser(data_cfg, training)
|
||||
self.cache = data_cfg['cache']
|
||||
self.label_list = [seq_info[0] for seq_info in self.seqs_info]
|
||||
self.types_list = [seq_info[1] for seq_info in self.seqs_info]
|
||||
self.views_list = [seq_info[2] for seq_info in self.seqs_info]
|
||||
|
||||
self.label_set = sorted(list(set(self.label_list)))
|
||||
self.types_set = sorted(list(set(self.types_list)))
|
||||
self.views_set = sorted(list(set(self.views_list)))
|
||||
self.seqs_data = [None] * len(self)
|
||||
self.indices_dict = {label: [] for label in self.label_set}
|
||||
for i, seq_info in enumerate(self.seqs_info):
|
||||
self.indices_dict[seq_info[0]].append(i)
|
||||
if self.cache:
|
||||
self.__load_all_data()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.seqs_info)
|
||||
|
||||
def __loader__(self, paths):
|
||||
paths = sorted(paths)
|
||||
data_list = []
|
||||
for pth in paths:
|
||||
if pth.endswith('.pkl'):
|
||||
with open(pth, 'rb') as f:
|
||||
_ = pickle.load(f)
|
||||
f.close()
|
||||
else:
|
||||
raise ValueError('- Loader - just support .pkl !!!')
|
||||
# if len(_) >= 200:
|
||||
# _ = _[:200]
|
||||
data_list.append(_)
|
||||
for data in data_list:
|
||||
if len(data) != len(data_list[0]):
|
||||
raise AssertionError
|
||||
|
||||
return data_list
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if not self.cache:
|
||||
data_lst = self.__loader__(self.seqs_info[idx][-1])
|
||||
elif self.seqs_data[idx] is None:
|
||||
data_lst = self.__loader__(self.seqs_info[idx][-1])
|
||||
self.seqs_data[idx] = data_lst
|
||||
else:
|
||||
data_lst = self.seqs_data[idx]
|
||||
seq_info = self.seqs_info[idx]
|
||||
return data_lst, seq_info
|
||||
|
||||
def __load_all_data(self):
|
||||
for idx in range(len(self)):
|
||||
self.__getitem__(idx)
|
||||
|
||||
def __dataset_parser(self, data_config, training):
|
||||
dataset_root = data_config['dataset_root']
|
||||
try:
|
||||
data_in_use = data_config['data_in_use'] # [n], true or false
|
||||
except:
|
||||
data_in_use = None
|
||||
|
||||
with open(data_config['dataset_partition'], "rb") as f:
|
||||
partition = json.load(f)
|
||||
train_set = partition["TRAIN_SET"]
|
||||
test_set = partition["TEST_SET"]
|
||||
label_list = os.listdir(dataset_root)
|
||||
train_set = [label for label in train_set if label in label_list]
|
||||
test_set = [label for label in test_set if label in label_list]
|
||||
miss_pids = [label for label in label_list if label not in (
|
||||
train_set + test_set)]
|
||||
msg_mgr = get_msg_mgr()
|
||||
|
||||
def log_pid_list(pid_list):
|
||||
if len(pid_list) >= 3:
|
||||
msg_mgr.log_info('[%s, %s, ..., %s]' %
|
||||
(pid_list[0], pid_list[1], pid_list[-1]))
|
||||
else:
|
||||
msg_mgr.log_info(pid_list)
|
||||
|
||||
if len(miss_pids) > 0:
|
||||
msg_mgr.log_debug('-------- Miss Pid List --------')
|
||||
msg_mgr.log_debug(miss_pids)
|
||||
if training:
|
||||
msg_mgr.log_info("-------- Train Pid List --------")
|
||||
log_pid_list(train_set)
|
||||
else:
|
||||
msg_mgr.log_info("-------- Test Pid List --------")
|
||||
log_pid_list(test_set)
|
||||
|
||||
def get_seqs_info_list(label_set):
|
||||
seqs_info_list = []
|
||||
for lab in label_set:
|
||||
for typ in sorted(os.listdir(osp.join(dataset_root, lab))):
|
||||
for vie in sorted(os.listdir(osp.join(dataset_root, lab, typ))):
|
||||
seq_info = [lab, typ, vie]
|
||||
seq_path = osp.join(dataset_root, *seq_info)
|
||||
seq_dirs = sorted(os.listdir(seq_path))
|
||||
if seq_dirs != []:
|
||||
seq_dirs = [osp.join(seq_path, dir)
|
||||
for dir in seq_dirs]
|
||||
if data_in_use is not None:
|
||||
seq_dirs = [dir for dir, use_bl in zip(
|
||||
seq_dirs, data_in_use) if use_bl]
|
||||
seqs_info_list.append([*seq_info, seq_dirs])
|
||||
else:
|
||||
msg_mgr.log_debug('Find no .pkl file in %s-%s-%s.'%(lab, typ, vie))
|
||||
return seqs_info_list
|
||||
|
||||
self.seqs_info = get_seqs_info_list(
|
||||
train_set) if training else get_seqs_info_list(test_set)
|
||||
@@ -0,0 +1,87 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.utils.data as tordata
|
||||
|
||||
|
||||
class TripletSampler(tordata.sampler.Sampler):
|
||||
def __init__(self, dataset, batch_size, batch_shuffle=False):
|
||||
self.dataset = dataset
|
||||
self.batch_size = batch_size
|
||||
self.batch_shuffle = batch_shuffle
|
||||
|
||||
self.world_size = dist.get_world_size()
|
||||
self.rank = dist.get_rank()
|
||||
|
||||
def __iter__(self):
|
||||
while (True):
|
||||
sample_indices = []
|
||||
pid_list = sync_random_sample_list(
|
||||
self.dataset.label_set, self.batch_size[0])
|
||||
|
||||
for pid in pid_list:
|
||||
indices = self.dataset.indices_dict[pid]
|
||||
indices = sync_random_sample_list(
|
||||
indices, k=self.batch_size[1])
|
||||
sample_indices += indices
|
||||
|
||||
if self.batch_shuffle:
|
||||
sample_indices = sync_random_sample_list(
|
||||
sample_indices, len(sample_indices))
|
||||
|
||||
_ = self.batch_size[0] * self.batch_size[1]
|
||||
total_size = int(math.ceil(_ / self.world_size)
|
||||
) * self.world_size
|
||||
sample_indices += sample_indices[:(_ - len(sample_indices))]
|
||||
|
||||
sample_indices = sample_indices[self.rank:total_size:self.world_size]
|
||||
yield sample_indices
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
|
||||
def sync_random_sample_list(obj_list, k):
|
||||
idx = torch.randperm(len(obj_list))[:k]
|
||||
if torch.cuda.is_available():
|
||||
idx = idx.cuda()
|
||||
torch.distributed.broadcast(idx, src=0)
|
||||
idx = idx.tolist()
|
||||
return [obj_list[i] for i in idx]
|
||||
|
||||
|
||||
class InferenceSampler(tordata.sampler.Sampler):
|
||||
def __init__(self, dataset, batch_size):
|
||||
self.dataset = dataset
|
||||
self.batch_size = batch_size
|
||||
|
||||
self.size = len(dataset)
|
||||
indices = list(range(self.size))
|
||||
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
|
||||
if batch_size % world_size != 0:
|
||||
raise AssertionError("World size({}) need be divisible by batch_size({})".format(
|
||||
world_size, batch_size))
|
||||
|
||||
if batch_size != 1:
|
||||
_ = math.ceil(self.size / batch_size) * \
|
||||
batch_size
|
||||
indices += indices[:(_ - self.size)]
|
||||
self.size = _
|
||||
|
||||
batch_size_per_rank = int(self.batch_size / world_size)
|
||||
indx_batch_per_rank = []
|
||||
|
||||
for i in range(int(self.size / batch_size_per_rank)):
|
||||
indx_batch_per_rank.append(
|
||||
indices[i*batch_size_per_rank:(i+1)*batch_size_per_rank])
|
||||
|
||||
self.idx_batch_this_rank = indx_batch_per_rank[rank::world_size]
|
||||
|
||||
def __iter__(self):
|
||||
yield from self.idx_batch_this_rank
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
@@ -0,0 +1,58 @@
|
||||
from data import transform as base_transform
|
||||
import numpy as np
|
||||
|
||||
from utils import is_list, is_dict, get_valid_args
|
||||
|
||||
|
||||
class BaseSilTransform():
|
||||
def __init__(self, disvor=255.0, img_shape=None):
|
||||
self.disvor = disvor
|
||||
self.img_shape = img_shape
|
||||
|
||||
def __call__(self, x):
|
||||
if self.img_shape is not None:
|
||||
s = x.shape[0]
|
||||
_ = [s] + [*self.img_shape]
|
||||
x = x.reshape(*_)
|
||||
return x / self.disvor
|
||||
|
||||
|
||||
class BaseSilCuttingTransform():
|
||||
def __init__(self, img_w=64, disvor=255.0, cutting=None):
|
||||
self.img_w = img_w
|
||||
self.disvor = disvor
|
||||
self.cutting = cutting
|
||||
|
||||
def __call__(self, x):
|
||||
if self.cutting is not None:
|
||||
cutting = self.cutting
|
||||
else:
|
||||
cutting = int(self.img_w // 64) * 10
|
||||
x = x[..., cutting:-cutting]
|
||||
return x / self.disvor
|
||||
|
||||
|
||||
class BaseRgbTransform():
|
||||
def __init__(self, mean=None, std=None):
|
||||
if mean is None:
|
||||
mean = [0.485*255, 0.456*255, 0.406*255]
|
||||
if std is None:
|
||||
std = [0.229*255, 0.224*255, 0.225*255]
|
||||
self.mean = np.array(mean).reshape((1, 3, 1, 1))
|
||||
self.std = np.array(std).reshape((1, 3, 1, 1))
|
||||
|
||||
def __call__(self, x):
|
||||
return (x - self.mean) / self.std
|
||||
|
||||
|
||||
def get_transform(trf_cfg=None):
|
||||
if is_dict(trf_cfg):
|
||||
transform = getattr(base_transform, trf_cfg['type'])
|
||||
valid_trf_arg = get_valid_args(transform, trf_cfg, ['type'])
|
||||
return transform(**valid_trf_arg)
|
||||
if trf_cfg is None:
|
||||
return lambda x: x
|
||||
if is_list(trf_cfg):
|
||||
transform = [get_transform(cfg) for cfg in trf_cfg]
|
||||
return transform
|
||||
raise "Error type for -Transform-Cfg-"
|
||||
Reference in New Issue
Block a user