diff --git a/datasets/pretreatment.py b/datasets/pretreatment.py index 60cb750..9fd7162 100644 --- a/datasets/pretreatment.py +++ b/datasets/pretreatment.py @@ -3,6 +3,7 @@ import argparse import logging import multiprocessing as mp import os +import re import pickle from collections import defaultdict from functools import partial @@ -127,7 +128,7 @@ def pretreat(input_path: Path, output_path: Path, img_size: int = 64, workers: i progress.update(1) logging.info('Done') -def txts2pickle(txt_groups: Tuple, output_path: Path, verbose: bool = False, dataset='CASIAB') -> None: +def txts2pickle(txt_groups: Tuple, output_path: Path, verbose: bool = False, dataset='CASIAB', **kwargs) -> None: """ Reads a group of images and saves the data in pickle format. @@ -137,18 +138,50 @@ def txts2pickle(txt_groups: Tuple, output_path: Path, verbose: bool = False, dat img_size (int, optional): Image resizing size. Defaults to 64. verbose (bool, optional): Display debug info. Defaults to False. """ - + def pose_silu_match_score(pose: np.ndarray, silu: np.ndarray): + pose_coord = pose[:,:2].astype(np.int32) + + H, W, *_ = silu.shape + valid_joints = (pose_coord[:, 1] >=0) & (pose_coord[:, 1] < H) & \ + (pose_coord[:, 0] >=0) & (pose_coord[:, 0] < W) + if np.sum(valid_joints) == len(pose_coord): + # only calculate score for points that are inside the silu img + # use the sum of all joints' pixel intensity as the score + return np.sum(silu[pose_coord[:, 1], pose_coord[:, 0]]) + else: + # if pose coord is out of bound, return -inf + return -np.inf + sinfo = txt_groups[0] txt_paths = txt_groups[1] to_pickle = [] if dataset == 'OUMVLP': + oumvlp_rearrange_silu_path = kwargs.get('oumvlp_rearrange_silu_path', None) for txt_file in sorted(txt_paths): try: with open(txt_file) as f: jsondata = json.load(f) - if len(jsondata['people'])==0: + person_num = len(jsondata['people']) + if person_num==0: continue - data = np.array(jsondata["people"][0]["pose_keypoints_2d"]).reshape(-1,3) + elif person_num == 1: + data = np.array(jsondata["people"][0]["pose_keypoints_2d"]).reshape(-1,3) + else: + # load the reference silu image + img_name = re.findall(r'\d{4}', os.path.basename(txt_file))[-1] + '.png' + img_path = os.path.join(oumvlp_rearrange_silu_path, *sinfo, img_name) + if not os.path.exists(img_path): + logging.warning(f'Pose reference silu({img_path}) not exists.') + continue + silu_img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) + + # determine which pose has the highest matching score + person_poses = [np.array(p["pose_keypoints_2d"]).reshape(-1,3) for p in jsondata['people']] + max_score_idx = np.argmax([pose_silu_match_score(p, silu_img) for p in person_poses]) + + # use the pose with the highest matching score to be the pkl data + data = person_poses[max_score_idx] + to_pickle.append(data) except: print(txt_file) @@ -174,7 +207,7 @@ def txts2pickle(txt_groups: Tuple, output_path: Path, verbose: bool = False, dat -def pretreat_pose(input_path: Path, output_path: Path, workers: int = 4, verbose: bool = False, dataset='CASIAB') -> None: +def pretreat_pose(input_path: Path, output_path: Path, workers: int = 4, verbose: bool = False, dataset='CASIAB', **kwargs) -> None: """Reads a dataset and saves the data in pickle format. Args: @@ -208,7 +241,10 @@ def pretreat_pose(input_path: Path, output_path: Path, workers: int = 4, verbose with mp.Pool(workers) as pool: logging.info(f'Start pretreating {input_path}') - for _ in pool.imap_unordered(partial(txts2pickle, output_path=output_path, verbose=verbose, dataset=args.dataset), txt_groups.items()): + for _ in pool.imap_unordered( + partial(txts2pickle, output_path=output_path, verbose=verbose, dataset=args.dataset, **kwargs), + txt_groups.items() + ): progress.update(1) logging.info('Done') @@ -224,6 +260,8 @@ if __name__ == '__main__': parser.add_argument('-d', '--dataset', default='CASIAB', type=str, help='Dataset for pretreatment.') parser.add_argument('-v', '--verbose', default=False, action='store_true', help='Display debug info.') parser.add_argument('-p', '--pose', default=False, action='store_true', help='Processing pose.') + parser.add_argument('--oumvlp_rearrange_silu_path', default='', type=str, + help='Root path of the rearranged oumvlp silu dataset. This argument is only used in extracting oumvlp pose pkl.') args = parser.parse_args() logging.basicConfig(level=logging.INFO, filename=args.log_file, filemode='w', format='[%(asctime)s - %(levelname)s]: %(message)s') @@ -234,6 +272,15 @@ if __name__ == '__main__': for k, v in args.__dict__.items(): logging.debug(f'{k}: {v}') if args.pose: - pretreat_pose(input_path=Path(args.input_path), output_path=Path(args.output_path), workers=args.n_workers, verbose=args.verbose, dataset=args.dataset) + if args.dataset.lower() == "oumvlp": + assert args.oumvlp_rearrange_silu_path, "Please specify the path to the rearranged OUMVLP dataset using `--oumvlp_rearrange_silu_path` argument." + pretreat_pose( + input_path=Path(args.input_path), + output_path=Path(args.output_path), + workers=args.n_workers, + verbose=args.verbose, + dataset=args.dataset, + oumvlp_rearrange_silu_path=os.path.abspath(args.oumvlp_rearrange_silu_path) + ) else: pretreat(input_path=Path(args.input_path), output_path=Path(args.output_path), img_size=args.img_size, workers=args.n_workers, verbose=args.verbose, dataset=args.dataset)