fix: silu pose mismatch in oumvlp pose pkl extraction
This commit is contained in:
@@ -3,6 +3,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import pickle
|
import pickle
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@@ -127,7 +128,7 @@ def pretreat(input_path: Path, output_path: Path, img_size: int = 64, workers: i
|
|||||||
progress.update(1)
|
progress.update(1)
|
||||||
logging.info('Done')
|
logging.info('Done')
|
||||||
|
|
||||||
def txts2pickle(txt_groups: Tuple, output_path: Path, verbose: bool = False, dataset='CASIAB') -> None:
|
def txts2pickle(txt_groups: Tuple, output_path: Path, verbose: bool = False, dataset='CASIAB', **kwargs) -> None:
|
||||||
"""
|
"""
|
||||||
Reads a group of images and saves the data in pickle format.
|
Reads a group of images and saves the data in pickle format.
|
||||||
|
|
||||||
@@ -137,18 +138,50 @@ def txts2pickle(txt_groups: Tuple, output_path: Path, verbose: bool = False, dat
|
|||||||
img_size (int, optional): Image resizing size. Defaults to 64.
|
img_size (int, optional): Image resizing size. Defaults to 64.
|
||||||
verbose (bool, optional): Display debug info. Defaults to False.
|
verbose (bool, optional): Display debug info. Defaults to False.
|
||||||
"""
|
"""
|
||||||
|
def pose_silu_match_score(pose: np.ndarray, silu: np.ndarray):
|
||||||
|
pose_coord = pose[:,:2].astype(np.int32)
|
||||||
|
|
||||||
|
H, W, *_ = silu.shape
|
||||||
|
valid_joints = (pose_coord[:, 1] >=0) & (pose_coord[:, 1] < H) & \
|
||||||
|
(pose_coord[:, 0] >=0) & (pose_coord[:, 0] < W)
|
||||||
|
if np.sum(valid_joints) == len(pose_coord):
|
||||||
|
# only calculate score for points that are inside the silu img
|
||||||
|
# use the sum of all joints' pixel intensity as the score
|
||||||
|
return np.sum(silu[pose_coord[:, 1], pose_coord[:, 0]])
|
||||||
|
else:
|
||||||
|
# if pose coord is out of bound, return -inf
|
||||||
|
return -np.inf
|
||||||
|
|
||||||
sinfo = txt_groups[0]
|
sinfo = txt_groups[0]
|
||||||
txt_paths = txt_groups[1]
|
txt_paths = txt_groups[1]
|
||||||
to_pickle = []
|
to_pickle = []
|
||||||
if dataset == 'OUMVLP':
|
if dataset == 'OUMVLP':
|
||||||
|
oumvlp_rearrange_silu_path = kwargs.get('oumvlp_rearrange_silu_path', None)
|
||||||
for txt_file in sorted(txt_paths):
|
for txt_file in sorted(txt_paths):
|
||||||
try:
|
try:
|
||||||
with open(txt_file) as f:
|
with open(txt_file) as f:
|
||||||
jsondata = json.load(f)
|
jsondata = json.load(f)
|
||||||
if len(jsondata['people'])==0:
|
person_num = len(jsondata['people'])
|
||||||
|
if person_num==0:
|
||||||
continue
|
continue
|
||||||
data = np.array(jsondata["people"][0]["pose_keypoints_2d"]).reshape(-1,3)
|
elif person_num == 1:
|
||||||
|
data = np.array(jsondata["people"][0]["pose_keypoints_2d"]).reshape(-1,3)
|
||||||
|
else:
|
||||||
|
# load the reference silu image
|
||||||
|
img_name = re.findall(r'\d{4}', os.path.basename(txt_file))[-1] + '.png'
|
||||||
|
img_path = os.path.join(oumvlp_rearrange_silu_path, *sinfo, img_name)
|
||||||
|
if not os.path.exists(img_path):
|
||||||
|
logging.warning(f'Pose reference silu({img_path}) not exists.')
|
||||||
|
continue
|
||||||
|
silu_img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
|
||||||
|
|
||||||
|
# determine which pose has the highest matching score
|
||||||
|
person_poses = [np.array(p["pose_keypoints_2d"]).reshape(-1,3) for p in jsondata['people']]
|
||||||
|
max_score_idx = np.argmax([pose_silu_match_score(p, silu_img) for p in person_poses])
|
||||||
|
|
||||||
|
# use the pose with the highest matching score to be the pkl data
|
||||||
|
data = person_poses[max_score_idx]
|
||||||
|
|
||||||
to_pickle.append(data)
|
to_pickle.append(data)
|
||||||
except:
|
except:
|
||||||
print(txt_file)
|
print(txt_file)
|
||||||
@@ -174,7 +207,7 @@ def txts2pickle(txt_groups: Tuple, output_path: Path, verbose: bool = False, dat
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def pretreat_pose(input_path: Path, output_path: Path, workers: int = 4, verbose: bool = False, dataset='CASIAB') -> None:
|
def pretreat_pose(input_path: Path, output_path: Path, workers: int = 4, verbose: bool = False, dataset='CASIAB', **kwargs) -> None:
|
||||||
"""Reads a dataset and saves the data in pickle format.
|
"""Reads a dataset and saves the data in pickle format.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -208,7 +241,10 @@ def pretreat_pose(input_path: Path, output_path: Path, workers: int = 4, verbose
|
|||||||
|
|
||||||
with mp.Pool(workers) as pool:
|
with mp.Pool(workers) as pool:
|
||||||
logging.info(f'Start pretreating {input_path}')
|
logging.info(f'Start pretreating {input_path}')
|
||||||
for _ in pool.imap_unordered(partial(txts2pickle, output_path=output_path, verbose=verbose, dataset=args.dataset), txt_groups.items()):
|
for _ in pool.imap_unordered(
|
||||||
|
partial(txts2pickle, output_path=output_path, verbose=verbose, dataset=args.dataset, **kwargs),
|
||||||
|
txt_groups.items()
|
||||||
|
):
|
||||||
progress.update(1)
|
progress.update(1)
|
||||||
logging.info('Done')
|
logging.info('Done')
|
||||||
|
|
||||||
@@ -224,6 +260,8 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('-d', '--dataset', default='CASIAB', type=str, help='Dataset for pretreatment.')
|
parser.add_argument('-d', '--dataset', default='CASIAB', type=str, help='Dataset for pretreatment.')
|
||||||
parser.add_argument('-v', '--verbose', default=False, action='store_true', help='Display debug info.')
|
parser.add_argument('-v', '--verbose', default=False, action='store_true', help='Display debug info.')
|
||||||
parser.add_argument('-p', '--pose', default=False, action='store_true', help='Processing pose.')
|
parser.add_argument('-p', '--pose', default=False, action='store_true', help='Processing pose.')
|
||||||
|
parser.add_argument('--oumvlp_rearrange_silu_path', default='', type=str,
|
||||||
|
help='Root path of the rearranged oumvlp silu dataset. This argument is only used in extracting oumvlp pose pkl.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, filename=args.log_file, filemode='w', format='[%(asctime)s - %(levelname)s]: %(message)s')
|
logging.basicConfig(level=logging.INFO, filename=args.log_file, filemode='w', format='[%(asctime)s - %(levelname)s]: %(message)s')
|
||||||
@@ -234,6 +272,15 @@ if __name__ == '__main__':
|
|||||||
for k, v in args.__dict__.items():
|
for k, v in args.__dict__.items():
|
||||||
logging.debug(f'{k}: {v}')
|
logging.debug(f'{k}: {v}')
|
||||||
if args.pose:
|
if args.pose:
|
||||||
pretreat_pose(input_path=Path(args.input_path), output_path=Path(args.output_path), workers=args.n_workers, verbose=args.verbose, dataset=args.dataset)
|
if args.dataset.lower() == "oumvlp":
|
||||||
|
assert args.oumvlp_rearrange_silu_path, "Please specify the path to the rearranged OUMVLP dataset using `--oumvlp_rearrange_silu_path` argument."
|
||||||
|
pretreat_pose(
|
||||||
|
input_path=Path(args.input_path),
|
||||||
|
output_path=Path(args.output_path),
|
||||||
|
workers=args.n_workers,
|
||||||
|
verbose=args.verbose,
|
||||||
|
dataset=args.dataset,
|
||||||
|
oumvlp_rearrange_silu_path=os.path.abspath(args.oumvlp_rearrange_silu_path)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
pretreat(input_path=Path(args.input_path), output_path=Path(args.output_path), img_size=args.img_size, workers=args.n_workers, verbose=args.verbose, dataset=args.dataset)
|
pretreat(input_path=Path(args.input_path), output_path=Path(args.output_path), img_size=args.img_size, workers=args.n_workers, verbose=args.verbose, dataset=args.dataset)
|
||||||
|
|||||||
Reference in New Issue
Block a user