import json import os import os.path as osp import pickle import random from typing import TypeVar import torch.utils.data as tordata from opengait.utils import get_msg_mgr T = TypeVar("T") class DataSet(tordata.Dataset): def __init__(self, data_cfg, training): """ seqs_info: the list with each element indicating a certain gait sequence presented as [label, type, view, paths]; """ self.__dataset_parser(data_cfg, training) self.cache = data_cfg['cache'] self.label_list = [seq_info[0] for seq_info in self.seqs_info] self.types_list = [seq_info[1] for seq_info in self.seqs_info] self.views_list = [seq_info[2] for seq_info in self.seqs_info] self.label_set = sorted(list(set(self.label_list))) self.types_set = sorted(list(set(self.types_list))) self.views_set = sorted(list(set(self.views_list))) self.seqs_data = [None] * len(self) self.indices_dict = {label: [] for label in self.label_set} for i, seq_info in enumerate(self.seqs_info): self.indices_dict[seq_info[0]].append(i) if self.cache: self.__load_all_data() def __len__(self): return len(self.seqs_info) def __loader__(self, paths): paths = sorted(paths) data_list = [] for pth in paths: if pth.endswith('.pkl'): with open(pth, 'rb') as f: _ = pickle.load(f) f.close() else: raise ValueError('- Loader - just support .pkl !!!') data_list.append(_) for idx, data in enumerate(data_list): if len(data) != len(data_list[0]): raise ValueError( 'Each input data({}) should have the same length.'.format(paths[idx])) if len(data) == 0: raise ValueError( 'Each input data({}) should have at least one element.'.format(paths[idx])) return data_list def __getitem__(self, idx): if not self.cache: data_list = self.__loader__(self.seqs_info[idx][-1]) elif self.seqs_data[idx] is None: data_list = self.__loader__(self.seqs_info[idx][-1]) self.seqs_data[idx] = data_list else: data_list = self.seqs_data[idx] seq_info = self.seqs_info[idx] return data_list, seq_info def __load_all_data(self): for idx in range(len(self)): self.__getitem__(idx) @staticmethod def _sample_items( items: list[T], subset_size: int | None, subset_seed: int, subset_name: str, msg_mgr=None, ) -> list[T]: if subset_size is None: return items if subset_size <= 0: raise ValueError(f"{subset_name} must be positive, got {subset_size}") if subset_size >= len(items): return items sampled_items = random.Random(subset_seed).sample(items, subset_size) sampled_items.sort() if msg_mgr is not None: msg_mgr.log_info( "Using %s subset: %d / %d items (seed=%d)", subset_name, len(sampled_items), len(items), subset_seed, ) return sampled_items def __dataset_parser(self, data_config, training): dataset_root = data_config['dataset_root'] try: data_in_use = data_config['data_in_use'] # [n], true or false except: data_in_use = None with open(data_config['dataset_partition'], "rb") as f: partition = json.load(f) train_set = partition["TRAIN_SET"] test_set = partition["TEST_SET"] label_list = os.listdir(dataset_root) train_set = [label for label in train_set if label in label_list] test_set = [label for label in test_set if label in label_list] msg_mgr = get_msg_mgr() train_set = self._sample_items( train_set, data_config.get("train_pid_subset_size"), int(data_config.get("train_pid_subset_seed", 0)), "train pid", msg_mgr, ) miss_pids = [label for label in label_list if label not in ( train_set + test_set)] def log_pid_list(pid_list): if len(pid_list) >= 3: msg_mgr.log_info('[%s, %s, ..., %s]' % (pid_list[0], pid_list[1], pid_list[-1])) else: msg_mgr.log_info(pid_list) if len(miss_pids) > 0: msg_mgr.log_debug('-------- Miss Pid List --------') msg_mgr.log_debug(miss_pids) if training: msg_mgr.log_info("-------- Train Pid List --------") log_pid_list(train_set) else: msg_mgr.log_info("-------- Test Pid List --------") log_pid_list(test_set) def get_seqs_info_list(label_set): seqs_info_list = [] for lab in label_set: for typ in sorted(os.listdir(osp.join(dataset_root, lab))): for vie in sorted(os.listdir(osp.join(dataset_root, lab, typ))): seq_info = [lab, typ, vie] seq_path = osp.join(dataset_root, *seq_info) seq_dirs = sorted(os.listdir(seq_path)) if seq_dirs != []: seq_dirs = [osp.join(seq_path, dir) for dir in seq_dirs] if data_in_use is not None: seq_dirs = [dir for dir, use_bl in zip( seq_dirs, data_in_use) if use_bl] seqs_info_list.append([*seq_info, seq_dirs]) else: msg_mgr.log_debug( 'Find no .pkl file in %s-%s-%s.' % (lab, typ, vie)) return seqs_info_list if training: self.seqs_info = get_seqs_info_list(train_set) else: self.seqs_info = get_seqs_info_list(test_set) self.seqs_info = self._sample_items( self.seqs_info, data_config.get("test_seq_subset_size"), int(data_config.get("test_seq_subset_seed", 0)), "test sequence", msg_mgr, )