From 5f0f4ad3e3d6723af8915138ad822e2589c40d89 Mon Sep 17 00:00:00 2001 From: Gustavo Siqueira Date: Sat, 29 Jan 2022 10:41:49 -0300 Subject: [PATCH] Update pretreatment.py Code refactoring to solve deadlock issue and improve code organization --- misc/pretreatment.py | 311 +++++++++++++++++-------------------------- 1 file changed, 125 insertions(+), 186 deletions(-) diff --git a/misc/pretreatment.py b/misc/pretreatment.py index 3cf8ce2..94bb803 100644 --- a/misc/pretreatment.py +++ b/misc/pretreatment.py @@ -1,201 +1,140 @@ -# modified from https://github.com/AbnerHqC/GaitSet/blob/master/pretreatment.py - +# This source is based on https://github.com/AbnerHqC/GaitSet/blob/master/pretreatment.py +import argparse +import logging +import multiprocessing as mp import os +import pickle +from collections import defaultdict +from functools import partial +from pathlib import Path +from typing import Tuple + import cv2 import numpy as np -from warnings import warn -from time import sleep -import argparse -import pickle - -from multiprocessing import Pool -from multiprocessing import TimeoutError as MP_TimeoutError - -START = "START" -FINISH = "FINISH" -WARNING = "WARNING" -FAIL = "FAIL" +from tqdm import tqdm -def boolean_string(s): - if s.upper() not in {'FALSE', 'TRUE'}: - raise ValueError('Not a valid boolean string') - return s.upper() == 'TRUE' +def imgs2pickle(img_groups: Tuple, output_path: Path, img_size: int = 64, verbose: bool = False) -> None: + """Reads a group of images and saves the data in pickle format. + + Args: + img_groups (Tuple): Tuple of (sid, seq, view) and list of image paths. + output_path (Path): Output path. + img_size (int, optional): Image resizing size. Defaults to 64. + verbose (bool, optional): Display debug info. Defaults to False. + """ + sinfo = img_groups[0] + img_paths = img_groups[1] + to_pickle = [] + for img_file in sorted(img_paths): + if verbose: + logging.debug(f'Reading sid {sinfo[0]}, seq {sinfo[1]}, view {sinfo[2]} from {img_file}') + + img = cv2.imread(str(img_file), cv2.IMREAD_GRAYSCALE) + + if img.sum() <= 10000: + if verbose: + logging.debug(f'Image sum: {img.sum()}') + logging.warning(f'{img_file} has no data.') + continue + + # Get the upper and lower points + y_sum = img.sum(axis=1) + y_top = (y_sum != 0).argmax(axis=0) + y_btm = (y_sum != 0).cumsum(axis=0).argmax(axis=0) + img = img[y_top: y_btm + 1, :] + + # As the height of a person is larger than the width, + # use the height to calculate resize ratio. + ratio = img.shape[1] // img.shape[0] + img = cv2.resize(img, (img_size * ratio, img_size), interpolation=cv2.INTER_CUBIC) + + # Get the median of the x-axis and take it as the person's x-center. + x_csum = img.sum(axis=0).cumsum() + x_center = None + for idx, csum in enumerate(x_csum): + if csum > img.sum() / 2: + x_center = idx + break + + if not x_center: + logging.warning(f'{img_file} has no center.') + continue + + # Get the left and right points + half_width = img_size // 2 + left = x_center - half_width + right = x_center + half_width + if left <= 0 or right >= img.shape[1]: + left += half_width + right += half_width + _ = np.zeros((img.shape[0], half_width)) + img = np.concatenate([_, img, _], axis=1) + + to_pickle.append(img[:, left: right].astype('uint8')) + + if to_pickle: + to_pickle = np.asarray(to_pickle) + dst_path = os.path.join(output_path, *sinfo) + os.makedirs(dst_path, exist_ok=True) + pkl_path = os.path.join(dst_path, f'{sinfo[2]}.pkl') + if verbose: + logging.debug(f'Saving {pkl_path}...') + pickle.dump(to_pickle, open(pkl_path, 'wb')) + + if len(to_pickle) < 5: + logging.warning(f'{sinfo} has less than 5 valid data.') + + logging.info(f'Saved {len(to_pickle)} valid frames to {pkl_path}.') -parser = argparse.ArgumentParser(description='Test') -parser.add_argument('--input_path', default='', type=str, - help='Root path of raw dataset.') -parser.add_argument('--output_path', default='', type=str, - help='Root path for output.') -parser.add_argument('--log_file', default='./pretreatment.log', type=str, - help='Log file path. Default: ./pretreatment.log') -parser.add_argument('--log', default=False, type=boolean_string, - help='If set as True, all logs will be saved. ' - 'Otherwise, only warnings and errors will be saved.' - 'Default: False') -parser.add_argument('--worker_num', default=1, type=int, - help='How many subprocesses to use for data pretreatment. ' - 'Default: 1') -parser.add_argument('--img_size', default=64, type=int, - help='image size') -opt = parser.parse_args() +def pretreat(input_path: Path, output_path: Path, img_size: int = 64, workers: int = 4, verbose: bool = False) -> None: + """Reads a dataset and saves the data in pickle format. -INPUT_PATH = opt.input_path -OUTPUT_PATH = opt.output_path -IF_LOG = opt.log -LOG_PATH = opt.log_file -WORKERS = opt.worker_num + Args: + input_path (Path): Dataset root path. + output_path (Path): Output path. + img_size (int, optional): Image resizing size. Defaults to 64. + workers (int, optional): Number of thread workers. Defaults to 4. + verbose (bool, optional): Display debug info. Defaults to False. + """ + img_groups = defaultdict(list) + logging.info(f'Listing {input_path}') + total_files = 0 + for img_path in input_path.rglob('*.png'): + if verbose: + logging.debug(f'Adding {img_path}') + *_, sid, seq, view, _ = img_path.as_posix().split(os.sep) + img_groups[(sid, seq, view)].append(img_path) + total_files += 1 -T_H = opt.img_size -T_W = opt.img_size + logging.info(f'Total files listed: {total_files}') -def log2str(pid, comment, logs): - str_log = '' - if type(logs) is str: - logs = [logs] - for log in logs: - str_log += "# JOB %d : --%s-- %s\n" % ( - pid, comment, log) - return str_log + progress = tqdm(total=len(img_groups), desc='Pretreating', unit='folder') -def log_print(pid, comment, logs): - str_log = log2str(pid, comment, logs) - if comment in [WARNING, FAIL]: - with open(LOG_PATH, 'a') as log_f: - log_f.write(str_log) - if comment in [START, FINISH]: - if pid % 500 != 0: - return - print(str_log, end='') - -def cut_img(img, seq_info, frame_name, pid): - # A silhouette contains too little white pixels - # might be not valid for identification. - if img.sum() <= 10000: - message = 'seq:%s, frame:%s, no data, %d.' % ( - '-'.join(seq_info), frame_name, img.sum()) - warn(message) - log_print(pid, WARNING, message) - return None - # Get the top and bottom point - y = img.sum(axis=1) - y_top = (y != 0).argmax(axis=0) - y_btm = (y != 0).cumsum(axis=0).argmax(axis=0) - img = img[y_top:y_btm + 1, :] - # As the height of a person is larger than the width, - # use the height to calculate resize ratio. - _r = img.shape[1] / img.shape[0] - _t_w = int(T_H * _r) - img = cv2.resize(img, (_t_w, T_H), interpolation=cv2.INTER_CUBIC) - # Get the median of x axis and regard it as the x center of the person. - sum_point = img.sum() - sum_column = img.sum(axis=0).cumsum() - x_center = -1 - for i in range(sum_column.size): - if sum_column[i] > sum_point / 2: - x_center = i - break - if x_center < 0: - message = 'seq:%s, frame:%s, no center.' % ( - '-'.join(seq_info), frame_name) - warn(message) - log_print(pid, WARNING, message) - return None - h_T_W = int(T_W / 2) - left = x_center - h_T_W - right = x_center + h_T_W - if left <= 0 or right >= img.shape[1]: - left += h_T_W - right += h_T_W - _ = np.zeros((img.shape[0], h_T_W)) - img = np.concatenate([_, img, _], axis=1) - img = img[:, left:right] - return img.astype('uint8') + with mp.Pool(workers) as pool: + logging.info(f'Start pretreating {input_path}') + for _ in pool.imap_unordered(partial(imgs2pickle, output_path=output_path, img_size=img_size, verbose=verbose), img_groups.items()): + progress.update(1) + logging.info('Done') -def cut_pickle(seq_info, pid): - seq_name = '-'.join(seq_info) - log_print(pid, START, seq_name) - seq_path = os.path.join(INPUT_PATH, *seq_info) - out_dir = os.path.join(OUTPUT_PATH, *seq_info) - frame_list = os.listdir(seq_path) - frame_list.sort() - count_frame = 0 - all_imgs = [] - view = seq_info[-1] - for _frame_name in frame_list: - frame_path = os.path.join(seq_path, _frame_name) - img = cv2.imread(frame_path)[:, :, 0] - img = cut_img(img, seq_info, _frame_name, pid) - if img is not None: - # Save the cut img - all_imgs.append(img) - count_frame += 1 +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='OpenGait dataset pretreatment module.') + parser.add_argument('-r', '--root_path', default='', type=str, help='Root path of raw dataset.') + parser.add_argument('-o', '--output_path', default='', type=str, help='Output path of pickled dataset.') + parser.add_argument('-l', '--log_file', default='./pretreatment.log', type=str, help='Log file path. Default: ./pretreatment.log') + parser.add_argument('-n', '--n_workers', default=4, type=int, help='Number of thread workers. Default: 4') + parser.add_argument('-i', '--img_size', default=64, type=int, help='Image resizing size. Default 64') + parser.add_argument('-v', '--verbose', default=False, action='store_true', help='Display debug info.') + args = parser.parse_args() - all_imgs = np.asarray(all_imgs) + logging.basicConfig(level=logging.INFO, filename=args.log_file, filemode='w', format='[%(asctime)s - %(levelname)s]: %(message)s') + + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG) + logging.info('Verbose mode is on.') + for k, v in args.__dict__.items(): + logging.debug(f'{k}: {v}') - if count_frame > 0: - os.makedirs(out_dir) - all_imgs_pkl = os.path.join(out_dir, '{}.pkl'.format(view)) - pickle.dump(all_imgs, open(all_imgs_pkl, 'wb')) - - # Warn if the sequence contains less than 5 frames - if count_frame < 5: - message = 'seq:%s, less than 5 valid data.' % ( - '-'.join(seq_info)) - warn(message) - log_print(pid, WARNING, message) - - log_print(pid, FINISH, - 'Contain %d valid frames. Saved to %s.' - % (count_frame, out_dir)) - - -pool = Pool(WORKERS) -results = list() -pid = 0 - -print('Pretreatment Start.\n' - 'Input path: %s\n' - 'Output path: %s\n' - 'Log file: %s\n' - 'Worker num: %d' % ( - INPUT_PATH, OUTPUT_PATH, LOG_PATH, WORKERS)) - -id_list = os.listdir(INPUT_PATH) -id_list.sort() -# Walk the input path -for _id in id_list: - seq_type = os.listdir(os.path.join(INPUT_PATH, _id)) - seq_type.sort() - for _seq_type in seq_type: - view = os.listdir(os.path.join(INPUT_PATH, _id, _seq_type)) - view.sort() - for _view in view: - seq_info = [_id, _seq_type, _view] - out_dir = os.path.join(OUTPUT_PATH, *seq_info) - # os.makedirs(out_dir) - results.append( - pool.apply_async( - cut_pickle, - args=(seq_info, pid))) - sleep(0.02) - pid += 1 - -pool.close() -unfinish = 1 -while unfinish > 0: - unfinish = 0 - for i, res in enumerate(results): - try: - res.get(timeout=0.1) - except Exception as e: - if type(e) == MP_TimeoutError: - unfinish += 1 - continue - else: - print('\n\n\nERROR OCCUR: PID ##%d##, ERRORTYPE: %s\n\n\n', - i, type(e)) - raise e -pool.join() + pretreat(Path(args.root_path), Path(args.output_path), args.n_workers, args.img_size, args.verbose)