diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..3852cc6 --- /dev/null +++ b/environment.yml @@ -0,0 +1,22 @@ +name: opengait +channels: + - pytorch + - defaults + - anaconda + - conda-forge +dependencies: + - python + - pip + - tqdm + - numpy + - pandas + - matplotlib + - scikit-learn + - pytorch + - torchvision + - cudatoolkit=10.2 + - ipykernel + - pip: + - tensorboard + - seaborn + - py7zr diff --git a/misc/extractor.py b/misc/extractor.py new file mode 100644 index 0000000..9451e34 --- /dev/null +++ b/misc/extractor.py @@ -0,0 +1,42 @@ +import argparse +import os +from pathlib import Path + +import py7zr +from tqdm import tqdm + + +def extractall(base_path: Path, output_path: Path) -> None: + """Extract all archives in base_path to output_path. + + Args: + base_path (Path): Path to the directory containing the archives. + output_path (Path): Path to the directory to extract the archives to. + """ + + os.makedirs(output_path, exist_ok=True) + for file_path in tqdm(list(base_path.rglob('Silhouette_*.7z'))): + if output_path.joinpath(file_path.stem).exists(): + continue + with py7zr.SevenZipFile(file_path, password='OUMVLP_20180214') as archive: + total_items = len( + [f for f in archive.getnames() if f.endswith('.png')] + ) + archive.extractall(output_path) + + extracted_files = len( + list(output_path.joinpath(file_path.stem).rglob('*.png'))) + + assert extracted_files == total_items, f'{extracted_files} != {total_items}' + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='OUMVLP extractor') + parser.add_argument('-b', '--base_path', type=str, + required=True, help='Base path to OUMVLP .7z files') + parser.add_argument('-o', '--output_path', type=str, + required=True, help='Output path for extracted files') + + args = parser.parse_args() + + extractall(Path(args.base_path), Path(args.output_path)) diff --git a/misc/pretreatment.py b/misc/pretreatment.py index 3cf8ce2..15bcbf2 100644 --- a/misc/pretreatment.py +++ b/misc/pretreatment.py @@ -1,201 +1,141 @@ -# modified from https://github.com/AbnerHqC/GaitSet/blob/master/pretreatment.py - +# This source is based on https://github.com/AbnerHqC/GaitSet/blob/master/pretreatment.py +import argparse +import logging +import multiprocessing as mp import os +import pickle +from collections import defaultdict +from functools import partial +from pathlib import Path +from typing import Tuple + import cv2 import numpy as np -from warnings import warn -from time import sleep -import argparse -import pickle - -from multiprocessing import Pool -from multiprocessing import TimeoutError as MP_TimeoutError - -START = "START" -FINISH = "FINISH" -WARNING = "WARNING" -FAIL = "FAIL" +from tqdm import tqdm -def boolean_string(s): - if s.upper() not in {'FALSE', 'TRUE'}: - raise ValueError('Not a valid boolean string') - return s.upper() == 'TRUE' +def imgs2pickle(img_groups: Tuple, output_path: Path, img_size: int = 64, verbose: bool = False) -> None: + """Reads a group of images and saves the data in pickle format. + + Args: + img_groups (Tuple): Tuple of (sid, seq, view) and list of image paths. + output_path (Path): Output path. + img_size (int, optional): Image resizing size. Defaults to 64. + verbose (bool, optional): Display debug info. Defaults to False. + """ + sinfo = img_groups[0] + img_paths = img_groups[1] + to_pickle = [] + for img_file in sorted(img_paths): + if verbose: + logging.debug(f'Reading sid {sinfo[0]}, seq {sinfo[1]}, view {sinfo[2]} from {img_file}') + + img = cv2.imread(str(img_file), cv2.IMREAD_GRAYSCALE) + + if img.sum() <= 10000: + if verbose: + logging.debug(f'Image sum: {img.sum()}') + logging.warning(f'{img_file} has no data.') + continue + + # Get the upper and lower points + y_sum = img.sum(axis=1) + y_top = (y_sum != 0).argmax(axis=0) + y_btm = (y_sum != 0).cumsum(axis=0).argmax(axis=0) + img = img[y_top: y_btm + 1, :] + + # As the height of a person is larger than the width, + # use the height to calculate resize ratio. + ratio = img.shape[1] / img.shape[0] + img = cv2.resize(img, (int(img_size * ratio), img_size), interpolation=cv2.INTER_CUBIC) + + # Get the median of the x-axis and take it as the person's x-center. + x_csum = img.sum(axis=0).cumsum() + x_center = None + for idx, csum in enumerate(x_csum): + if csum > img.sum() / 2: + x_center = idx + break + + if not x_center: + logging.warning(f'{img_file} has no center.') + continue + + # Get the left and right points + half_width = img_size // 2 + left = x_center - half_width + right = x_center + half_width + if left <= 0 or right >= img.shape[1]: + left += half_width + right += half_width + _ = np.zeros((img.shape[0], half_width)) + img = np.concatenate([_, img, _], axis=1) + + to_pickle.append(img[:, left: right].astype('uint8')) + + if to_pickle: + to_pickle = np.asarray(to_pickle) + dst_path = os.path.join(output_path, *sinfo) + os.makedirs(dst_path, exist_ok=True) + pkl_path = os.path.join(dst_path, f'{sinfo[2]}.pkl') + if verbose: + logging.debug(f'Saving {pkl_path}...') + pickle.dump(to_pickle, open(pkl_path, 'wb')) + logging.info(f'Saved {len(to_pickle)} valid frames to {pkl_path}.') -parser = argparse.ArgumentParser(description='Test') -parser.add_argument('--input_path', default='', type=str, - help='Root path of raw dataset.') -parser.add_argument('--output_path', default='', type=str, - help='Root path for output.') -parser.add_argument('--log_file', default='./pretreatment.log', type=str, - help='Log file path. Default: ./pretreatment.log') -parser.add_argument('--log', default=False, type=boolean_string, - help='If set as True, all logs will be saved. ' - 'Otherwise, only warnings and errors will be saved.' - 'Default: False') -parser.add_argument('--worker_num', default=1, type=int, - help='How many subprocesses to use for data pretreatment. ' - 'Default: 1') -parser.add_argument('--img_size', default=64, type=int, - help='image size') -opt = parser.parse_args() - -INPUT_PATH = opt.input_path -OUTPUT_PATH = opt.output_path -IF_LOG = opt.log -LOG_PATH = opt.log_file -WORKERS = opt.worker_num - -T_H = opt.img_size -T_W = opt.img_size - -def log2str(pid, comment, logs): - str_log = '' - if type(logs) is str: - logs = [logs] - for log in logs: - str_log += "# JOB %d : --%s-- %s\n" % ( - pid, comment, log) - return str_log - -def log_print(pid, comment, logs): - str_log = log2str(pid, comment, logs) - if comment in [WARNING, FAIL]: - with open(LOG_PATH, 'a') as log_f: - log_f.write(str_log) - if comment in [START, FINISH]: - if pid % 500 != 0: - return - print(str_log, end='') - -def cut_img(img, seq_info, frame_name, pid): - # A silhouette contains too little white pixels - # might be not valid for identification. - if img.sum() <= 10000: - message = 'seq:%s, frame:%s, no data, %d.' % ( - '-'.join(seq_info), frame_name, img.sum()) - warn(message) - log_print(pid, WARNING, message) - return None - # Get the top and bottom point - y = img.sum(axis=1) - y_top = (y != 0).argmax(axis=0) - y_btm = (y != 0).cumsum(axis=0).argmax(axis=0) - img = img[y_top:y_btm + 1, :] - # As the height of a person is larger than the width, - # use the height to calculate resize ratio. - _r = img.shape[1] / img.shape[0] - _t_w = int(T_H * _r) - img = cv2.resize(img, (_t_w, T_H), interpolation=cv2.INTER_CUBIC) - # Get the median of x axis and regard it as the x center of the person. - sum_point = img.sum() - sum_column = img.sum(axis=0).cumsum() - x_center = -1 - for i in range(sum_column.size): - if sum_column[i] > sum_point / 2: - x_center = i - break - if x_center < 0: - message = 'seq:%s, frame:%s, no center.' % ( - '-'.join(seq_info), frame_name) - warn(message) - log_print(pid, WARNING, message) - return None - h_T_W = int(T_W / 2) - left = x_center - h_T_W - right = x_center + h_T_W - if left <= 0 or right >= img.shape[1]: - left += h_T_W - right += h_T_W - _ = np.zeros((img.shape[0], h_T_W)) - img = np.concatenate([_, img, _], axis=1) - img = img[:, left:right] - return img.astype('uint8') + if len(to_pickle) < 5: + logging.warning(f'{sinfo} has less than 5 valid data.') -def cut_pickle(seq_info, pid): - seq_name = '-'.join(seq_info) - log_print(pid, START, seq_name) - seq_path = os.path.join(INPUT_PATH, *seq_info) - out_dir = os.path.join(OUTPUT_PATH, *seq_info) - frame_list = os.listdir(seq_path) - frame_list.sort() - count_frame = 0 - all_imgs = [] - view = seq_info[-1] - for _frame_name in frame_list: - frame_path = os.path.join(seq_path, _frame_name) - img = cv2.imread(frame_path)[:, :, 0] - img = cut_img(img, seq_info, _frame_name, pid) - if img is not None: - # Save the cut img - all_imgs.append(img) - count_frame += 1 - all_imgs = np.asarray(all_imgs) +def pretreat(input_path: Path, output_path: Path, img_size: int = 64, workers: int = 4, verbose: bool = False) -> None: + """Reads a dataset and saves the data in pickle format. - if count_frame > 0: - os.makedirs(out_dir) - all_imgs_pkl = os.path.join(out_dir, '{}.pkl'.format(view)) - pickle.dump(all_imgs, open(all_imgs_pkl, 'wb')) + Args: + input_path (Path): Dataset root path. + output_path (Path): Output path. + img_size (int, optional): Image resizing size. Defaults to 64. + workers (int, optional): Number of thread workers. Defaults to 4. + verbose (bool, optional): Display debug info. Defaults to False. + """ + img_groups = defaultdict(list) + logging.info(f'Listing {input_path}') + total_files = 0 + for img_path in input_path.rglob('*.png'): + if verbose: + logging.debug(f'Adding {img_path}') + *_, sid, seq, view, _ = img_path.as_posix().split(os.sep) + img_groups[(sid, seq, view)].append(img_path) + total_files += 1 - # Warn if the sequence contains less than 5 frames - if count_frame < 5: - message = 'seq:%s, less than 5 valid data.' % ( - '-'.join(seq_info)) - warn(message) - log_print(pid, WARNING, message) + logging.info(f'Total files listed: {total_files}') - log_print(pid, FINISH, - 'Contain %d valid frames. Saved to %s.' - % (count_frame, out_dir)) + progress = tqdm(total=len(img_groups), desc='Pretreating', unit='folder') + + with mp.Pool(workers) as pool: + logging.info(f'Start pretreating {input_path}') + for _ in pool.imap_unordered(partial(imgs2pickle, output_path=output_path, img_size=img_size, verbose=verbose), img_groups.items()): + progress.update(1) + logging.info('Done') -pool = Pool(WORKERS) -results = list() -pid = 0 +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='OpenGait dataset pretreatment module.') + parser.add_argument('-r', '--input_path', default='', type=str, help='Root path of raw dataset.') + parser.add_argument('-o', '--output_path', default='', type=str, help='Output path of pickled dataset.') + parser.add_argument('-l', '--log_file', default='./pretreatment.log', type=str, help='Log file path. Default: ./pretreatment.log') + parser.add_argument('-n', '--n_workers', default=4, type=int, help='Number of thread workers. Default: 4') + parser.add_argument('-i', '--img_size', default=64, type=int, help='Image resizing size. Default 64') + parser.add_argument('-v', '--verbose', default=False, action='store_true', help='Display debug info.') + args = parser.parse_args() -print('Pretreatment Start.\n' - 'Input path: %s\n' - 'Output path: %s\n' - 'Log file: %s\n' - 'Worker num: %d' % ( - INPUT_PATH, OUTPUT_PATH, LOG_PATH, WORKERS)) + logging.basicConfig(level=logging.INFO, filename=args.log_file, filemode='w', format='[%(asctime)s - %(levelname)s]: %(message)s') + + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG) + logging.info('Verbose mode is on.') + for k, v in args.__dict__.items(): + logging.debug(f'{k}: {v}') -id_list = os.listdir(INPUT_PATH) -id_list.sort() -# Walk the input path -for _id in id_list: - seq_type = os.listdir(os.path.join(INPUT_PATH, _id)) - seq_type.sort() - for _seq_type in seq_type: - view = os.listdir(os.path.join(INPUT_PATH, _id, _seq_type)) - view.sort() - for _view in view: - seq_info = [_id, _seq_type, _view] - out_dir = os.path.join(OUTPUT_PATH, *seq_info) - # os.makedirs(out_dir) - results.append( - pool.apply_async( - cut_pickle, - args=(seq_info, pid))) - sleep(0.02) - pid += 1 - -pool.close() -unfinish = 1 -while unfinish > 0: - unfinish = 0 - for i, res in enumerate(results): - try: - res.get(timeout=0.1) - except Exception as e: - if type(e) == MP_TimeoutError: - unfinish += 1 - continue - else: - print('\n\n\nERROR OCCUR: PID ##%d##, ERRORTYPE: %s\n\n\n', - i, type(e)) - raise e -pool.join() + pretreat(input_path=Path(args.root_path), output_path=Path(args.output_path), img_size=args.img_size, workers=args.n_workers, verbose=args.verbose) diff --git a/misc/rearrange_OUMVLP.py b/misc/rearrange_OUMVLP.py index 6e155ff..d9ba412 100644 --- a/misc/rearrange_OUMVLP.py +++ b/misc/rearrange_OUMVLP.py @@ -1,59 +1,45 @@ +import argparse import os import shutil +from pathlib import Path + from tqdm import tqdm -import argparse -parser = argparse.ArgumentParser(description='Test') -parser.add_argument('--input_path', default='/home1/data/OUMVLP_raw', type=str, - help='Root path of raw dataset.') -parser.add_argument('--output_path', default='/home1/data/OUMVLP_rearranged', type=str, - help='Root path for output.') +TOTAL_SUBJECTS = 10307 -opt = parser.parse_args() - -INPUT_PATH = opt.input_path -OUTPUT_PATH = opt.output_path +def sanitize(name: str) -> (str, str): + return name.split('_')[1].split('-') -def mv_dir(src, dst): - shutil.copytree(src, dst) - print(src, dst) +def rearrange(input_path: Path, output_path: Path) -> None: + os.makedirs(output_path, exist_ok=True) + + for folder in input_path.iterdir(): + print(f'Rearranging {folder}') + view, seq = sanitize(folder.name) + progress = tqdm(total=TOTAL_SUBJECTS) + for sid in folder.iterdir(): + src = os.path.join(input_path, f'Silhouette_{view}-{seq}', sid.name) + dst = os.path.join(output_path, sid.name, seq, view) + os.makedirs(dst, exist_ok=True) + for subfile in os.listdir(src): + if subfile not in os.listdir(dst) and subfile.endswith('.png'): + os.symlink(os.path.join(src, subfile), + os.path.join(dst, subfile)) + # else: + # os.remove(os.path.join(src, subfile)) + progress.update(1) -sils_name_list = os.listdir(INPUT_PATH) -name_space = 'Silhouette_' -views = sorted(list( - set([each.replace(name_space, '').split('-')[0] for each in sils_name_list]))) -seqs = sorted(list( - set([each.replace(name_space, '').split('-')[1] for each in sils_name_list]))) -ids = list() -for each in sils_name_list: - ids.extend(os.listdir(os.path.join(INPUT_PATH, each))) +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='OUMVLP rearrange tool') + parser.add_argument('-i', '--input_path', required=True, type=str, + help='Root path of raw dataset.') + parser.add_argument('-o', '--output_path', default='OUMVLP_rearranged', type=str, + help='Root path for output.') + args = parser.parse_args() -progress = tqdm(total=len(set(ids))) - - -results = list() -pid = 0 -for _id in sorted(set(ids)): - progress.update(1) - for _view in views: - for _seq in seqs: - seq_info = [_id, _seq, _view] - name = name_space + _view + '-' + _seq + '/' + _id - src = os.path.join(INPUT_PATH, name) - dst = os.path.join(OUTPUT_PATH, *seq_info) - if os.path.exists(src): - try: - if os.path.exists(dst): - pass - else: - os.makedirs(dst) - for subfile in os.listdir(src): - os.symlink(os.path.join(src, subfile), - os.path.join(dst, subfile)) - except OSError as err: - print(err) + rearrange(Path(args.input_path), Path(args.output_path))