refactor: load the selection index for poses in each frame rather than calculating it via real-time matching.
- Add new script (`pose_index_extractor.py`) to generate pose selection index for OUMVLP dataset - Update `pretreatment.py` to load the generated index for pose extraction
This commit is contained in:
@@ -0,0 +1,195 @@
|
||||
import re
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import argparse
|
||||
import pickle as pk
|
||||
from typing import Tuple
|
||||
from pathlib import Path
|
||||
from functools import partial
|
||||
from collections import defaultdict
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import multiprocessing as mp
|
||||
|
||||
"""
|
||||
This script tries to match all the potential poses detected in a frame with the silhouette of the same frame in OUMVLP dataset,
|
||||
and selects the pose that best matches the silhouette as the final pose for that frame, save its index in a pickle file which
|
||||
is used when extracting pose pkls.
|
||||
|
||||
More info please refer to https://github.com/ShiqiYu/OpenGait/pull/280
|
||||
"""
|
||||
|
||||
def pose_silu_match_score(pose: np.ndarray, silu: np.ndarray) -> float:
|
||||
"""
|
||||
Calculate the matching score between a 2D pose and a silhouette image using the sum of all joints' pixel intensity.
|
||||
|
||||
Args:
|
||||
pose (np.ndarray): 2D pose, shape (n_joints, 3)
|
||||
silu (np.ndarray): silhouette image, shape (H, W, 3)
|
||||
|
||||
Returns:
|
||||
float: matching score
|
||||
"""
|
||||
pose_coord = pose[:,:2].astype(np.int32)
|
||||
|
||||
H, W, *_ = silu.shape
|
||||
valid_joints = (pose_coord[:, 1] >=0) & (pose_coord[:, 1] < H) & \
|
||||
(pose_coord[:, 0] >=0) & (pose_coord[:, 0] < W)
|
||||
|
||||
if np.sum(valid_joints) == len(pose_coord):
|
||||
# only calculate score for points that are inside the silu img
|
||||
# use the sum of all joints' pixel intensity as the score
|
||||
return np.sum(silu[pose_coord[:, 1], pose_coord[:, 0]])
|
||||
else:
|
||||
# if pose coord is out of bound, return -inf
|
||||
return -np.inf
|
||||
|
||||
|
||||
def perseq_pipeline(txt_groups: Tuple, rearrange_silu_root: Path, output_path: Path, verbose: bool = False) -> None:
|
||||
"""
|
||||
Generate and save the pose selection index pickle file for a given sequence.
|
||||
|
||||
Args:
|
||||
txt_groups (Tuple): Tuple of (sid, seq, view) and list of pose json paths.
|
||||
rearrange_silu_root (Path): Root dir of rearranged silu dataset.
|
||||
output_path (Path): Output path.
|
||||
verbose (bool, optional): Display debug info. Defaults to False.
|
||||
"""
|
||||
|
||||
# resolve seq info
|
||||
sinfo = txt_groups[0]
|
||||
txt_paths = txt_groups[1]
|
||||
pick_idx = dict()
|
||||
|
||||
# prepare output dir & resume last work
|
||||
dst_path = os.path.join(output_path, *sinfo)
|
||||
os.makedirs(dst_path, exist_ok=True)
|
||||
pkl_path = os.path.join(dst_path, 'pose_selection_idx.pkl')
|
||||
if os.path.exists(pkl_path):
|
||||
logging.debug(f'Pose index file {pkl_path} already exists, skipping...')
|
||||
return
|
||||
|
||||
# extract
|
||||
for txt_file in sorted(txt_paths):
|
||||
# get the frame index (digit str before extension) of current frame
|
||||
try:
|
||||
frame_idx = re.findall(r'(\d+).json', os.path.basename(txt_file))[0]
|
||||
except IndexError:
|
||||
# adapt to different name format for json files in ID 00001
|
||||
frame_idx = re.findall(r'\d{4}', os.path.basename(txt_file))[0]
|
||||
|
||||
with open(txt_file) as f:
|
||||
jsondata = json.load(f)
|
||||
|
||||
person_num = len(jsondata['people'])
|
||||
|
||||
# if no person or 1 person detected in this frame
|
||||
# we don't need to do the matching, just use the first or skip this frame when extracting pose pkl
|
||||
# see datasets/pretreatment.py#Line: 167~168 and Line: 173
|
||||
if person_num <= 1:
|
||||
continue
|
||||
|
||||
# multiple people detected in this frame
|
||||
else:
|
||||
# load the reference silu image
|
||||
img_name = f'{frame_idx}.png'
|
||||
img_path = os.path.join(rearrange_silu_root, *sinfo, img_name)
|
||||
if not os.path.exists(img_path):
|
||||
logging.warning(
|
||||
f'Pose reference silu({img_path}) of seq({'-'.join(sinfo)}) not exists, the matching for frame {frame_idx} is skipped. '
|
||||
+ 'This means that the first person in the frame will be used as the pose data, and this may cause performance degradation.'
|
||||
)
|
||||
continue
|
||||
silu_img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
|
||||
|
||||
# determine which pose has the highest matching score
|
||||
person_poses = [np.array(p["pose_keypoints_2d"]).reshape(-1,3) for p in jsondata['people']]
|
||||
max_score_idx = np.argmax([pose_silu_match_score(p, silu_img) for p in person_poses])
|
||||
|
||||
# use the pose with the highest matching score to be the pkl data
|
||||
pick_idx[frame_idx] = max_score_idx
|
||||
|
||||
# dump the index dict
|
||||
if verbose:
|
||||
logging.debug(f'Saving {pkl_path}... ')
|
||||
with open(pkl_path, 'wb') as f:
|
||||
pk.dump(pick_idx, f)
|
||||
logging.debug(f'Saved {len(pick_idx)} indexs to {pkl_path}.')
|
||||
|
||||
|
||||
def main(rearrange_pose_root: Path, rearrange_silu_root: Path, output_path: Path, workers: int = 4, verbose: bool = False) -> None:
|
||||
"""Reads a dataset and saves the data in pickle format.
|
||||
|
||||
Args:
|
||||
rearrange_pose_root (Path): Root path of the rearranged oumvlp pose dataset.
|
||||
rearrange_silu_root (Path): Root path of the rearranged oumvlp silu dataset.
|
||||
output_path (Path): The selection index output path. The final structure is: output_path/sid/seq/view/pose_selection_idx.pkl
|
||||
workers (int, optional): Number of thread workers. Defaults to 4.
|
||||
verbose (bool, optional): Display debug info. Defaults to False.
|
||||
"""
|
||||
txt_groups = defaultdict(list)
|
||||
logging.info(f'Listing {rearrange_pose_root}')
|
||||
total_files = 0
|
||||
|
||||
for json_path in rearrange_pose_root.rglob('*.json'):
|
||||
if verbose:
|
||||
logging.debug(f'Adding {json_path}')
|
||||
*_, sid, seq, view, _ = json_path.as_posix().split(os.path.sep)
|
||||
txt_groups[(sid, seq, view)].append(json_path)
|
||||
total_files += 1
|
||||
|
||||
logging.info(f'Total files listed: {total_files}')
|
||||
|
||||
progress = tqdm(total=len(txt_groups), desc='Extracting Matching Pose Index', unit='seq')
|
||||
|
||||
with mp.Pool(workers) as pool:
|
||||
logging.info(f'Start extracting pose indexes for {rearrange_pose_root}')
|
||||
for _ in pool.imap_unordered(
|
||||
partial(perseq_pipeline, rearrange_silu_root=rearrange_silu_root, output_path=output_path, verbose=verbose),
|
||||
txt_groups.items()
|
||||
):
|
||||
progress.update(1)
|
||||
|
||||
logging.info('Done')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='OUMVLP pose selection index extraction module.')
|
||||
parser.add_argument('-p', '--rearrange_pose_root', required=True, type=str, help='Root path of the rearranged oumvlp pose dataset.')
|
||||
parser.add_argument('-s', '--rearrange_silu_root', required=True, type=str, help='Root path of the rearranged oumvlp silu dataset.')
|
||||
parser.add_argument('-o', '--output_path', required=True, type=str, help='Output path of pickled dataset.')
|
||||
parser.add_argument('-l', '--log_file', default='./pretreatment.log', type=str, help='Log file path. Default: ./pretreatment.log')
|
||||
parser.add_argument('-n', '--n_workers', default=4, type=int, help='Number of thread workers. Default: 4')
|
||||
parser.add_argument('-v', '--verbose', default=False, action='store_true', help='Display debug info.')
|
||||
args = parser.parse_args()
|
||||
|
||||
# logging and verbose mode
|
||||
logging.basicConfig(level=logging.INFO, filename=args.log_file, filemode='w', format='[%(asctime)s - %(levelname)s]: %(message)s')
|
||||
|
||||
if args.verbose:
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
logging.info('Verbose mode is on.')
|
||||
for k, v in args.__dict__.items():
|
||||
logging.debug(f'{k}: {v}')
|
||||
|
||||
# arguments validation
|
||||
args.rearrange_pose_root = os.path.abspath(args.rearrange_pose_root)
|
||||
assert os.path.exists(args.rearrange_pose_root), f"The specified oumvlp pose root({args.rearrange_pose_root}) does not exist."
|
||||
|
||||
args.rearrange_silu_root = os.path.abspath(args.rearrange_silu_root)
|
||||
assert os.path.exists(args.rearrange_silu_root), f"The specified oumvlp silu root({args.rearrange_silu_root}) does not exist."
|
||||
|
||||
args.output_path = os.path.abspath(args.output_path)
|
||||
os.makedirs(args.output_path, exist_ok=True)
|
||||
|
||||
# run
|
||||
main(
|
||||
rearrange_pose_root=Path(args.rearrange_pose_root),
|
||||
rearrange_silu_root=Path(args.rearrange_silu_root),
|
||||
output_path=Path(args.output_path),
|
||||
workers=args.n_workers,
|
||||
verbose=args.verbose
|
||||
)
|
||||
+46
-40
@@ -133,58 +133,54 @@ def txts2pickle(txt_groups: Tuple, output_path: Path, verbose: bool = False, dat
|
||||
Reads a group of images and saves the data in pickle format.
|
||||
|
||||
Args:
|
||||
img_groups (Tuple): Tuple of (sid, seq, view) and list of image paths.
|
||||
txt_groups (Tuple): Tuple of (sid, seq, view) and list of image paths.
|
||||
output_path (Path): Output path.
|
||||
img_size (int, optional): Image resizing size. Defaults to 64.
|
||||
verbose (bool, optional): Display debug info. Defaults to False.
|
||||
dataset (str, optional): Dataset name. Defaults to 'CASIAB'.
|
||||
kwargs (dict, optional): Additional arguments. It receives 'oumvlp_index_dir' when dataset is 'OUMVLP'.
|
||||
"""
|
||||
def pose_silu_match_score(pose: np.ndarray, silu: np.ndarray):
|
||||
pose_coord = pose[:,:2].astype(np.int32)
|
||||
|
||||
H, W, *_ = silu.shape
|
||||
valid_joints = (pose_coord[:, 1] >=0) & (pose_coord[:, 1] < H) & \
|
||||
(pose_coord[:, 0] >=0) & (pose_coord[:, 0] < W)
|
||||
if np.sum(valid_joints) == len(pose_coord):
|
||||
# only calculate score for points that are inside the silu img
|
||||
# use the sum of all joints' pixel intensity as the score
|
||||
return np.sum(silu[pose_coord[:, 1], pose_coord[:, 0]])
|
||||
else:
|
||||
# if pose coord is out of bound, return -inf
|
||||
return -np.inf
|
||||
|
||||
sinfo = txt_groups[0]
|
||||
txt_paths = txt_groups[1]
|
||||
to_pickle = []
|
||||
if dataset == 'OUMVLP':
|
||||
oumvlp_rearrange_silu_path = kwargs.get('oumvlp_rearrange_silu_path', None)
|
||||
# load pose selection index
|
||||
idx_file = os.path.join(kwargs['oumvlp_index_dir'], *sinfo, 'pose_selection_idx.pkl')
|
||||
try:
|
||||
with open(idx_file, 'rb') as f:
|
||||
frame_wise_idx = pickle.load(f) # dict, structure is {txt_file_name(str): selected_pose_idx(int)}
|
||||
except FileNotFoundError:
|
||||
logging.warning(
|
||||
f'No pose selection index found for sequence: {sinfo}, will use the first detected pose for each frame. '
|
||||
+ 'This may cause performance degradation, see https://github.com/ShiqiYu/OpenGait/pull/280 for more details. '
|
||||
+ 'You can avoid this warning by re-get the index files following Step4-2 in datasets/OUMVLP/README.md.'
|
||||
)
|
||||
frame_wise_idx = dict()
|
||||
|
||||
# apply selection index for each frame in current sequence
|
||||
for txt_file in sorted(txt_paths):
|
||||
try:
|
||||
with open(txt_file) as f:
|
||||
jsondata = json.load(f)
|
||||
person_num = len(jsondata['people'])
|
||||
if person_num==0:
|
||||
|
||||
# no person detected in this frame
|
||||
if len(jsondata['people'])==0:
|
||||
continue
|
||||
elif person_num == 1:
|
||||
data = np.array(jsondata["people"][0]["pose_keypoints_2d"]).reshape(-1,3)
|
||||
else:
|
||||
# load the reference silu image
|
||||
img_name = re.findall(r'\d{4}', os.path.basename(txt_file))[-1] + '.png'
|
||||
img_path = os.path.join(oumvlp_rearrange_silu_path, *sinfo, img_name)
|
||||
if not os.path.exists(img_path):
|
||||
logging.warning(f'Pose reference silu({img_path}) not exists.')
|
||||
continue
|
||||
silu_img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
|
||||
|
||||
# determine which pose has the highest matching score
|
||||
person_poses = [np.array(p["pose_keypoints_2d"]).reshape(-1,3) for p in jsondata['people']]
|
||||
max_score_idx = np.argmax([pose_silu_match_score(p, silu_img) for p in person_poses])
|
||||
# get the frame index (digit str before extension) of current frame
|
||||
try:
|
||||
frame_idx = re.findall(r'(\d+).json', os.path.basename(txt_file))[0]
|
||||
except IndexError:
|
||||
# adapt to different name format for json files in ID 00001
|
||||
frame_idx = re.findall(r'\d{4}', os.path.basename(txt_file))[0]
|
||||
|
||||
# use the pose with the highest matching score to be the pkl data
|
||||
data = person_poses[max_score_idx]
|
||||
# use the first person if no index file or less than one pose in current frame
|
||||
pose_idx = frame_wise_idx.get(frame_idx, 0)
|
||||
|
||||
data = np.array(jsondata["people"][pose_idx]["pose_keypoints_2d"]).reshape(-1,3)
|
||||
to_pickle.append(data)
|
||||
except:
|
||||
print(txt_file)
|
||||
print(f"Fail to extract pkl for frame({txt_file}), seq({sinfo}).")
|
||||
else:
|
||||
for txt_file in sorted(txt_paths):
|
||||
if verbose:
|
||||
@@ -206,16 +202,16 @@ def txts2pickle(txt_groups: Tuple, output_path: Path, verbose: bool = False, dat
|
||||
logging.warning(f'{sinfo} has less than 5 valid data.')
|
||||
|
||||
|
||||
|
||||
def pretreat_pose(input_path: Path, output_path: Path, workers: int = 4, verbose: bool = False, dataset='CASIAB', **kwargs) -> None:
|
||||
"""Reads a dataset and saves the data in pickle format.
|
||||
|
||||
Args:
|
||||
input_path (Path): Dataset root path.
|
||||
output_path (Path): Output path.
|
||||
img_size (int, optional): Image resizing size. Defaults to 64.
|
||||
workers (int, optional): Number of thread workers. Defaults to 4.
|
||||
verbose (bool, optional): Display debug info. Defaults to False.
|
||||
dataset (str, optional): Dataset name. Defaults to 'CASIAB'.
|
||||
kwargs (dict, optional): Additional arguments. It receives 'oumvlp_index_dir' when dataset is 'OUMVLP'.
|
||||
"""
|
||||
txt_groups = defaultdict(list)
|
||||
logging.info(f'Listing {input_path}')
|
||||
@@ -260,8 +256,9 @@ if __name__ == '__main__':
|
||||
parser.add_argument('-d', '--dataset', default='CASIAB', type=str, help='Dataset for pretreatment.')
|
||||
parser.add_argument('-v', '--verbose', default=False, action='store_true', help='Display debug info.')
|
||||
parser.add_argument('-p', '--pose', default=False, action='store_true', help='Processing pose.')
|
||||
parser.add_argument('--oumvlp_rearrange_silu_path', default='', type=str,
|
||||
help='Root path of the rearranged oumvlp silu dataset. This argument is only used in extracting oumvlp pose pkl.')
|
||||
parser.add_argument('-oid', '--oumvlp_index_dir', default='', type=str,
|
||||
help='Path of the directory containing all index files for extracting oumvlp pose pkl, which is necessary to promise the temporal consistency of extracted pose sequence. '
|
||||
+ 'Note: this argument is only used when extracting oumvlp pose pkl, more info please refer to Step4-2 in datasets/OUMVLP/README.md. ')
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO, filename=args.log_file, filemode='w', format='[%(asctime)s - %(levelname)s]: %(message)s')
|
||||
@@ -273,14 +270,23 @@ if __name__ == '__main__':
|
||||
logging.debug(f'{k}: {v}')
|
||||
if args.pose:
|
||||
if args.dataset.lower() == "oumvlp":
|
||||
assert args.oumvlp_rearrange_silu_path, "Please specify the path to the rearranged OUMVLP dataset using `--oumvlp_rearrange_silu_path` argument."
|
||||
assert args.oumvlp_index_dir, (
|
||||
"When extracting the oumvlp pose pkl, please specify the path of the directory containing all index files using the `--oumvlp_index_dir` argument."
|
||||
+ "If you don't know what it is, please refer to Step4-2 in datasets/OUMVLP/README.md for more details."
|
||||
)
|
||||
|
||||
args.oumvlp_index_dir = os.path.abspath(args.oumvlp_index_dir)
|
||||
assert os.path.exists(args.oumvlp_index_dir), f"The specified oumvlp index files' directory({args.oumvlp_index_dir}) does not exist."
|
||||
|
||||
logging.info(f'Using the oumvlp index files in {args.oumvlp_index_dir}')
|
||||
|
||||
pretreat_pose(
|
||||
input_path=Path(args.input_path),
|
||||
output_path=Path(args.output_path),
|
||||
workers=args.n_workers,
|
||||
verbose=args.verbose,
|
||||
dataset=args.dataset,
|
||||
oumvlp_rearrange_silu_path=os.path.abspath(args.oumvlp_rearrange_silu_path)
|
||||
oumvlp_index_dir=args.oumvlp_index_dir
|
||||
)
|
||||
else:
|
||||
pretreat(input_path=Path(args.input_path), output_path=Path(args.output_path), img_size=args.img_size, workers=args.n_workers, verbose=args.verbose, dataset=args.dataset)
|
||||
|
||||
Reference in New Issue
Block a user