From c3dc13d2c845b6adfe77e96d4d7a107d3fb806d8 Mon Sep 17 00:00:00 2001 From: Iridoudou <2534936416@qq.com> Date: Mon, 9 Aug 2021 11:00:12 +0800 Subject: [PATCH] saving format changed to pkl --- fit/tools/label.py | 75 ++++++++++++++++++++++++++++++++++++++++++++++ fit/tools/main.py | 6 +++- fit/tools/save.py | 15 +++++----- 3 files changed, 88 insertions(+), 8 deletions(-) create mode 100644 fit/tools/label.py diff --git a/fit/tools/label.py b/fit/tools/label.py new file mode 100644 index 0000000..08920bb --- /dev/null +++ b/fit/tools/label.py @@ -0,0 +1,75 @@ +HumanAct12 = { + "A0101": "warm_up_wristankle", + "A0102": "warm_up_pectoral", + "A0103": "warm_up_eblowback", + "A0104": "warm_up_bodylean_right_arm", + "A0105": "warm_up_bodylean_left_arm", + "A0106": "warm_up_bow_right", + "A0107": "warm_up_bow_left", + "A0201": "walk", + "A0301": "run", + "A0401": "jump_handsup", + "A0402": "jump_vertical", + "A0501": "drink_bottle_righthand", + "A0502": "drink_bottle_lefthand", + "A0503": "drink_cup_righthand", + "A0504": "drink_cup_lefthand", + "A0505": "drink_both_hands", + "A0601": "lift_dumbbell with _right hand", + "A0602": "lift_dumbbell with _left hand", + "A0603": "Lift dumbells with both hands", + "A0604": "lift_dumbbell over head", + "A0605": "lift_dumbells with both hands and bend legs", + "A0701": "sit", + "A0801": "eat_finger_right", + "A0802": "eat_pie/hamburger", + "A0803": "Eat with left hand", + "A0901": "Turn steering wheel", + "A1001": "Take out phone, call and put phone back", + "A1002": "Call with left hand", + "A1101": "boxing_left_right", + "A1102": "boxing_left_upwards", + "A1103": "boxing_right_upwards", + "A1104": "boxing_right_left", + "A1201": "throw_right_hand", + "A1202": "throw_both_hands" +} + +UTD_MHAD = { + "1": " right arm swipe to the left(swipt_left)", + "2": " right arm swipe to the right(swipt_right)", + "3": " right hand wave(wave)", + "4": " two hand front clap(clap)", + "5": " right arm throw(throw)", + "6": " cross arms in the chest(arm_cross)", + "7": " basketball shooting(basketball_shoot)", + "8": " draw x(draw_x)", + "9": " draw circle(clockwise)(draw_circle_CW)", + "10": " draw circle(counter clockwise)(draw_circle_CCW)", + "11": " draw triangle(draw_triangle)", + "12": " bowling(right hand)(bowling)", + "13": " front boxing(boxing)", + "14": " baseball swing from right(baseball_swing)", + "15": " tennis forehand swing(tennis_swing)", + "16": " arm curl(two arms)(arm_curl)", + "17": " tennis serve(tennis_serve)", + "18": " two hand push(push)", + "19": " knock on door(knock)", + "20": " hand catch(catch)", + "21": " pick up and throw(pickup_throw)", + "22": " jogging(jog)", + "23": " walking(walk)", + "24": " sit to stand(sit2stand)", + "25": " stand to sit(stand2sit)", + "26": " forward lunge(left foot forward)(lunge)", + "27": " squat(squat)" +} + + +def get_label(file_name, dataset_name): + if dataset_name == 'HumanAct12': + key = file_name[-5:] + return HumanAct12[key] + elif dataset_name == 'UTD_MHAD': + key = file_name.split('_')[0][1:] + return UTD_MHAD[key] \ No newline at end of file diff --git a/fit/tools/main.py b/fit/tools/main.py index 3147ce1..c8527e3 100644 --- a/fit/tools/main.py +++ b/fit/tools/main.py @@ -9,7 +9,6 @@ import logging import argparse import json sys.path.append(os.getcwd()) -from display_utils import display_model from smplpytorch.pytorch.smpl_layer import SMPL_Layer from train import train from transform import transform @@ -25,6 +24,9 @@ def parse_args(): parser.add_argument('--dataset_name', dest='dataset_name', help='select dataset', default='', type=str) + parser.add_argument('--dataset_path', dest='dataset_path', + help='path of dataset', + default=None, type=str) args = parser.parse_args() return args @@ -33,6 +35,8 @@ def get_config(args): with open(config_path, 'r') as f: data = json.load(f) cfg = edict(data.copy()) + if not args.dataset_path == None: + cfg.DATASET.PATH = args.dataset_path return cfg def set_device(USE_GPU): diff --git a/fit/tools/save.py b/fit/tools/save.py index 471f791..b8a017d 100644 --- a/fit/tools/save.py +++ b/fit/tools/save.py @@ -3,11 +3,11 @@ import os import re from tqdm import tqdm import numpy as np -import json - +import pickle sys.path.append(os.getcwd()) from display_utils import display_model +from label import get_label def create_dir_not_exist(path): @@ -38,18 +38,19 @@ def save_pic(res, smpl_layer, file, logger, dataset_name): def save_params(res, file, logger, dataset_name): pose_params, shape_params, verts, Jtr = res file_name = re.split('[/.]', file)[-2] - fit_path = "fit/output/{}/params/".format(dataset_name) + fit_path = "fit/output/{}/".format(dataset_name) create_dir_not_exist(fit_path) logger.info('Saving params at {}'.format(fit_path)) + label=get_label(file_name, dataset_name) pose_params = (pose_params.cpu().detach()).numpy().tolist() shape_params = (shape_params.cpu().detach()).numpy().tolist() Jtr = (Jtr.cpu().detach()).numpy().tolist() verts = (verts.cpu().detach()).numpy().tolist() params = {} + params["label"] = label params["pose_params"] = pose_params params["shape_params"] = shape_params params["Jtr"] = Jtr - params["mesh"] = verts - f = open(os.path.join((fit_path), - "{}_params.json".format(file_name)), 'w') - json.dump(params, f) + with open(os.path.join((fit_path), + "{}_params.pkl".format(file_name)), 'wb') as f: + pickle.dump(params, f)