60 lines
1.6 KiB
Python
60 lines
1.6 KiB
Python
import pickle
|
|
import numpy as np
|
|
import os
|
|
import json
|
|
import argparse
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='Detect cross joints')
|
|
parser.add_argument('--dataset_name', dest='dataset_name',
|
|
help='select dataset',
|
|
default='', type=str)
|
|
parser.add_argument('--output_path', dest='output_path',
|
|
help='path of output',
|
|
default=None, type=str)
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def create_dir_not_exist(path):
|
|
if not os.path.exists(path):
|
|
os.mkdir(path)
|
|
|
|
|
|
def load_Jtr(file_path):
|
|
with open(file_path, 'rb') as f:
|
|
data = pickle.load(f)
|
|
Jtr = np.array(data["Jtr"])
|
|
return Jtr
|
|
|
|
|
|
def has_cross(joints: np.ndarray):
|
|
return (joints[1][0]-joints[2][0]) * (joints[10][0]-joints[11][0]) < 0 or\
|
|
(joints[13][0]-joints[14][0]) * (joints[22][0]-joints[23][0]) < 0
|
|
|
|
|
|
def cross_frames(Jtr: np.ndarray):
|
|
ans = []
|
|
for frame in range(Jtr.shape[0]):
|
|
if has_cross(Jtr[frame]):
|
|
ans.append(frame)
|
|
return ans
|
|
|
|
|
|
def cross_detector(dir_path):
|
|
ans = {}
|
|
for root, dirs, files in os.walk(dir_path):
|
|
for file in files:
|
|
file_path = os.path.join(dir_path, file)
|
|
Jtr = load_Jtr(file_path)
|
|
ans[file] = cross_frames(Jtr)
|
|
return ans
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parse_args()
|
|
d = cross_detector(args.output_path)
|
|
json.dump(
|
|
d, open("./fit/output/cross_detection/{}.json"
|
|
.format(args.dataset_name), 'w')) |