rename lib to opengait
This commit is contained in:
@@ -0,0 +1,125 @@
|
||||
import os
|
||||
import pickle
|
||||
import os.path as osp
|
||||
import torch.utils.data as tordata
|
||||
import json
|
||||
from utils import get_msg_mgr
|
||||
|
||||
|
||||
class DataSet(tordata.Dataset):
|
||||
def __init__(self, data_cfg, training):
|
||||
"""
|
||||
seqs_info: the list with each element indicating
|
||||
a certain gait sequence presented as [label, type, view, paths];
|
||||
"""
|
||||
self.__dataset_parser(data_cfg, training)
|
||||
self.cache = data_cfg['cache']
|
||||
self.label_list = [seq_info[0] for seq_info in self.seqs_info]
|
||||
self.types_list = [seq_info[1] for seq_info in self.seqs_info]
|
||||
self.views_list = [seq_info[2] for seq_info in self.seqs_info]
|
||||
|
||||
self.label_set = sorted(list(set(self.label_list)))
|
||||
self.types_set = sorted(list(set(self.types_list)))
|
||||
self.views_set = sorted(list(set(self.views_list)))
|
||||
self.seqs_data = [None] * len(self)
|
||||
self.indices_dict = {label: [] for label in self.label_set}
|
||||
for i, seq_info in enumerate(self.seqs_info):
|
||||
self.indices_dict[seq_info[0]].append(i)
|
||||
if self.cache:
|
||||
self.__load_all_data()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.seqs_info)
|
||||
|
||||
def __loader__(self, paths):
|
||||
paths = sorted(paths)
|
||||
data_list = []
|
||||
for pth in paths:
|
||||
if pth.endswith('.pkl'):
|
||||
with open(pth, 'rb') as f:
|
||||
_ = pickle.load(f)
|
||||
f.close()
|
||||
else:
|
||||
raise ValueError('- Loader - just support .pkl !!!')
|
||||
data_list.append(_)
|
||||
for idx, data in enumerate(data_list):
|
||||
if len(data) != len(data_list[0]):
|
||||
raise ValueError(
|
||||
'Each input data({}) should have the same length.'.format(paths[idx]))
|
||||
if len(data) == 0:
|
||||
raise ValueError(
|
||||
'Each input data({}) should have at least one element.'.format(paths[idx]))
|
||||
return data_list
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if not self.cache:
|
||||
data_list = self.__loader__(self.seqs_info[idx][-1])
|
||||
elif self.seqs_data[idx] is None:
|
||||
data_list = self.__loader__(self.seqs_info[idx][-1])
|
||||
self.seqs_data[idx] = data_list
|
||||
else:
|
||||
data_list = self.seqs_data[idx]
|
||||
seq_info = self.seqs_info[idx]
|
||||
return data_list, seq_info
|
||||
|
||||
def __load_all_data(self):
|
||||
for idx in range(len(self)):
|
||||
self.__getitem__(idx)
|
||||
|
||||
def __dataset_parser(self, data_config, training):
|
||||
dataset_root = data_config['dataset_root']
|
||||
try:
|
||||
data_in_use = data_config['data_in_use'] # [n], true or false
|
||||
except:
|
||||
data_in_use = None
|
||||
|
||||
with open(data_config['dataset_partition'], "rb") as f:
|
||||
partition = json.load(f)
|
||||
train_set = partition["TRAIN_SET"]
|
||||
test_set = partition["TEST_SET"]
|
||||
label_list = os.listdir(dataset_root)
|
||||
train_set = [label for label in train_set if label in label_list]
|
||||
test_set = [label for label in test_set if label in label_list]
|
||||
miss_pids = [label for label in label_list if label not in (
|
||||
train_set + test_set)]
|
||||
msg_mgr = get_msg_mgr()
|
||||
|
||||
def log_pid_list(pid_list):
|
||||
if len(pid_list) >= 3:
|
||||
msg_mgr.log_info('[%s, %s, ..., %s]' %
|
||||
(pid_list[0], pid_list[1], pid_list[-1]))
|
||||
else:
|
||||
msg_mgr.log_info(pid_list)
|
||||
|
||||
if len(miss_pids) > 0:
|
||||
msg_mgr.log_debug('-------- Miss Pid List --------')
|
||||
msg_mgr.log_debug(miss_pids)
|
||||
if training:
|
||||
msg_mgr.log_info("-------- Train Pid List --------")
|
||||
log_pid_list(train_set)
|
||||
else:
|
||||
msg_mgr.log_info("-------- Test Pid List --------")
|
||||
log_pid_list(test_set)
|
||||
|
||||
def get_seqs_info_list(label_set):
|
||||
seqs_info_list = []
|
||||
for lab in label_set:
|
||||
for typ in sorted(os.listdir(osp.join(dataset_root, lab))):
|
||||
for vie in sorted(os.listdir(osp.join(dataset_root, lab, typ))):
|
||||
seq_info = [lab, typ, vie]
|
||||
seq_path = osp.join(dataset_root, *seq_info)
|
||||
seq_dirs = sorted(os.listdir(seq_path))
|
||||
if seq_dirs != []:
|
||||
seq_dirs = [osp.join(seq_path, dir)
|
||||
for dir in seq_dirs]
|
||||
if data_in_use is not None:
|
||||
seq_dirs = [dir for dir, use_bl in zip(
|
||||
seq_dirs, data_in_use) if use_bl]
|
||||
seqs_info_list.append([*seq_info, seq_dirs])
|
||||
else:
|
||||
msg_mgr.log_debug(
|
||||
'Find no .pkl file in %s-%s-%s.' % (lab, typ, vie))
|
||||
return seqs_info_list
|
||||
|
||||
self.seqs_info = get_seqs_info_list(
|
||||
train_set) if training else get_seqs_info_list(test_set)
|
||||
Reference in New Issue
Block a user