fix: silu pose mismatch in oumvlp pose pkl extraction

This commit is contained in:
Ahzyuan
2025-06-05 12:18:57 +08:00
parent c42f2f8c07
commit 923a410cd3
+54 -7
View File
@@ -3,6 +3,7 @@ import argparse
import logging import logging
import multiprocessing as mp import multiprocessing as mp
import os import os
import re
import pickle import pickle
from collections import defaultdict from collections import defaultdict
from functools import partial from functools import partial
@@ -127,7 +128,7 @@ def pretreat(input_path: Path, output_path: Path, img_size: int = 64, workers: i
progress.update(1) progress.update(1)
logging.info('Done') logging.info('Done')
def txts2pickle(txt_groups: Tuple, output_path: Path, verbose: bool = False, dataset='CASIAB') -> None: def txts2pickle(txt_groups: Tuple, output_path: Path, verbose: bool = False, dataset='CASIAB', **kwargs) -> None:
""" """
Reads a group of images and saves the data in pickle format. Reads a group of images and saves the data in pickle format.
@@ -137,18 +138,50 @@ def txts2pickle(txt_groups: Tuple, output_path: Path, verbose: bool = False, dat
img_size (int, optional): Image resizing size. Defaults to 64. img_size (int, optional): Image resizing size. Defaults to 64.
verbose (bool, optional): Display debug info. Defaults to False. verbose (bool, optional): Display debug info. Defaults to False.
""" """
def pose_silu_match_score(pose: np.ndarray, silu: np.ndarray):
pose_coord = pose[:,:2].astype(np.int32)
H, W, *_ = silu.shape
valid_joints = (pose_coord[:, 1] >=0) & (pose_coord[:, 1] < H) & \
(pose_coord[:, 0] >=0) & (pose_coord[:, 0] < W)
if np.sum(valid_joints) == len(pose_coord):
# only calculate score for points that are inside the silu img
# use the sum of all joints' pixel intensity as the score
return np.sum(silu[pose_coord[:, 1], pose_coord[:, 0]])
else:
# if pose coord is out of bound, return -inf
return -np.inf
sinfo = txt_groups[0] sinfo = txt_groups[0]
txt_paths = txt_groups[1] txt_paths = txt_groups[1]
to_pickle = [] to_pickle = []
if dataset == 'OUMVLP': if dataset == 'OUMVLP':
oumvlp_rearrange_silu_path = kwargs.get('oumvlp_rearrange_silu_path', None)
for txt_file in sorted(txt_paths): for txt_file in sorted(txt_paths):
try: try:
with open(txt_file) as f: with open(txt_file) as f:
jsondata = json.load(f) jsondata = json.load(f)
if len(jsondata['people'])==0: person_num = len(jsondata['people'])
if person_num==0:
continue continue
data = np.array(jsondata["people"][0]["pose_keypoints_2d"]).reshape(-1,3) elif person_num == 1:
data = np.array(jsondata["people"][0]["pose_keypoints_2d"]).reshape(-1,3)
else:
# load the reference silu image
img_name = re.findall(r'\d{4}', os.path.basename(txt_file))[-1] + '.png'
img_path = os.path.join(oumvlp_rearrange_silu_path, *sinfo, img_name)
if not os.path.exists(img_path):
logging.warning(f'Pose reference silu({img_path}) not exists.')
continue
silu_img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
# determine which pose has the highest matching score
person_poses = [np.array(p["pose_keypoints_2d"]).reshape(-1,3) for p in jsondata['people']]
max_score_idx = np.argmax([pose_silu_match_score(p, silu_img) for p in person_poses])
# use the pose with the highest matching score to be the pkl data
data = person_poses[max_score_idx]
to_pickle.append(data) to_pickle.append(data)
except: except:
print(txt_file) print(txt_file)
@@ -174,7 +207,7 @@ def txts2pickle(txt_groups: Tuple, output_path: Path, verbose: bool = False, dat
def pretreat_pose(input_path: Path, output_path: Path, workers: int = 4, verbose: bool = False, dataset='CASIAB') -> None: def pretreat_pose(input_path: Path, output_path: Path, workers: int = 4, verbose: bool = False, dataset='CASIAB', **kwargs) -> None:
"""Reads a dataset and saves the data in pickle format. """Reads a dataset and saves the data in pickle format.
Args: Args:
@@ -208,7 +241,10 @@ def pretreat_pose(input_path: Path, output_path: Path, workers: int = 4, verbose
with mp.Pool(workers) as pool: with mp.Pool(workers) as pool:
logging.info(f'Start pretreating {input_path}') logging.info(f'Start pretreating {input_path}')
for _ in pool.imap_unordered(partial(txts2pickle, output_path=output_path, verbose=verbose, dataset=args.dataset), txt_groups.items()): for _ in pool.imap_unordered(
partial(txts2pickle, output_path=output_path, verbose=verbose, dataset=args.dataset, **kwargs),
txt_groups.items()
):
progress.update(1) progress.update(1)
logging.info('Done') logging.info('Done')
@@ -224,6 +260,8 @@ if __name__ == '__main__':
parser.add_argument('-d', '--dataset', default='CASIAB', type=str, help='Dataset for pretreatment.') parser.add_argument('-d', '--dataset', default='CASIAB', type=str, help='Dataset for pretreatment.')
parser.add_argument('-v', '--verbose', default=False, action='store_true', help='Display debug info.') parser.add_argument('-v', '--verbose', default=False, action='store_true', help='Display debug info.')
parser.add_argument('-p', '--pose', default=False, action='store_true', help='Processing pose.') parser.add_argument('-p', '--pose', default=False, action='store_true', help='Processing pose.')
parser.add_argument('--oumvlp_rearrange_silu_path', default='', type=str,
help='Root path of the rearranged oumvlp silu dataset. This argument is only used in extracting oumvlp pose pkl.')
args = parser.parse_args() args = parser.parse_args()
logging.basicConfig(level=logging.INFO, filename=args.log_file, filemode='w', format='[%(asctime)s - %(levelname)s]: %(message)s') logging.basicConfig(level=logging.INFO, filename=args.log_file, filemode='w', format='[%(asctime)s - %(levelname)s]: %(message)s')
@@ -234,6 +272,15 @@ if __name__ == '__main__':
for k, v in args.__dict__.items(): for k, v in args.__dict__.items():
logging.debug(f'{k}: {v}') logging.debug(f'{k}: {v}')
if args.pose: if args.pose:
pretreat_pose(input_path=Path(args.input_path), output_path=Path(args.output_path), workers=args.n_workers, verbose=args.verbose, dataset=args.dataset) if args.dataset.lower() == "oumvlp":
assert args.oumvlp_rearrange_silu_path, "Please specify the path to the rearranged OUMVLP dataset using `--oumvlp_rearrange_silu_path` argument."
pretreat_pose(
input_path=Path(args.input_path),
output_path=Path(args.output_path),
workers=args.n_workers,
verbose=args.verbose,
dataset=args.dataset,
oumvlp_rearrange_silu_path=os.path.abspath(args.oumvlp_rearrange_silu_path)
)
else: else:
pretreat(input_path=Path(args.input_path), output_path=Path(args.output_path), img_size=args.img_size, workers=args.n_workers, verbose=args.verbose, dataset=args.dataset) pretreat(input_path=Path(args.input_path), output_path=Path(args.output_path), img_size=args.img_size, workers=args.n_workers, verbose=args.verbose, dataset=args.dataset)