From 16387a7afe28d382cf18f02aecbb4f5a89c85929 Mon Sep 17 00:00:00 2001 From: Iridoudou <2534936416@qq.com> Date: Fri, 3 Sep 2021 11:06:26 +0800 Subject: [PATCH] support NTU --- README.md | 1 + fit/configs/NTU.json | 94 ++++++++++++++++++++++++++++++ fit/tools/label.py | 129 ++++++++++++++++++++++++++++++++++++++++- fit/tools/load.py | 2 + fit/tools/main.py | 3 +- fit/tools/save.py | 21 ++----- fit/tools/train.py | 4 +- fit/tools/transform.py | 5 +- 8 files changed, 236 insertions(+), 23 deletions(-) create mode 100644 fit/configs/NTU.json diff --git a/README.md b/README.md index 593f4eb..d0695f2 100644 --- a/README.md +++ b/README.md @@ -46,6 +46,7 @@ The SMPL human body layer for Pytorch is from the [smplpytorch](https://github.c - [CMU Mocap](https://ericguo5513.github.io/action-to-motion/) - [UTD-MHAD](https://personal.utdallas.edu/~kehtar/UTD-MHAD.html) - [Human3.6M](http://vision.imar.ro/human3.6m/description.php) + - [NTU]([ROSE Lab (ntu.edu.sg)](https://rose1.ntu.edu.sg/dataset/actionRecognition/)) - Set the **DATASET.PATH** in the corresponding configuration file to the location of dataset. diff --git a/fit/configs/NTU.json b/fit/configs/NTU.json new file mode 100644 index 0000000..ca5bc81 --- /dev/null +++ b/fit/configs/NTU.json @@ -0,0 +1,94 @@ +{ + "MODEL": { + "GENDER": "neutral" + }, + "TRAIN": { + "LEARNING_RATE": 5e-2, + "MAX_EPOCH": 1000, + "WRITE": 10, + "OPTIMIZE_SCALE":0, + "OPTIMIZE_SHAPE":1 + }, + "USE_GPU": 1, + "DATASET": { + "NAME": "NTU", + "PATH": "../NTU RGB+D/skeleton_npy", + "TARGET_PATH": "", + "DATA_MAP": [ + [ + 0, + 0 + ], + [ + 1, + 12 + ], + [ + 2, + 16 + ], + [ + 4, + 13 + ], + [ + 5, + 17 + ], + [ + 6, + 1 + ], + [ + 7, + 14 + ], + [ + 8, + 18 + ], + [ + 9, + 20 + ], + + [ + 12, + 2 + ], + [ + 13, + 4 + ], + [ + 14, + 8 + ], + [ + 18, + 5 + ], + [ + 19, + 9 + ], + [ + 20, + 6 + ], + [ + 21, + 10 + ], + [ + 22, + 22 + ], + [ + 23, + 24 + ] + ] + }, + "DEBUG": 0 +} \ No newline at end of file diff --git a/fit/tools/label.py b/fit/tools/label.py index 6b842e3..5672c01 100644 --- a/fit/tools/label.py +++ b/fit/tools/label.py @@ -1156,6 +1156,128 @@ CMU_Mocap = { "41_09": "Climb" } +NTU={ + "1":" drink water", + "2":" eat meal/snack", + "3":" brushing teeth", + "4":" brushing hair", + "5":" drop", + "6":" pickup", + "7":" throw", + "8":" sitting down", + "9":" standing up (from sitting position)", + "10":" clapping", + "11":" reading", + "12":" writing", + "13":" tear up paper", + "14":" wear jacket", + "15":" take off jacket", + "16":" wear a shoe", + "17":" take off a shoe", + "18":" wear on glasses", + "19":" take off glasses", + "20":" put on a hat/cap", + "21":" take off a hat/cap", + "22":" cheer up", + "23":" hand waving", + "24":" kicking something", + "25":" reach into pocket", + "26":" hopping (one foot jumping)", + "27":" jump up", + "28":" make a phone call/answer phone", + "29":" playing with phone/tablet", + "30":" typing on a keyboard", + "31":" pointing to something with finger", + "32":" taking a selfie", + "33":" check time (from watch)", + "34":" rub two hands together", + "35":" nod head/bow", + "36":" shake head", + "37":" wipe face", + "38":" salute", + "39":" put the palms together", + "40":" cross hands in front (say stop)", + "41":" sneeze/cough", + "42":" staggering", + "43":" falling", + "44":" touch head (headache)", + "45":" touch chest (stomachache/heart pain)", + "46":" touch back (backache)", + "47":" touch neck (neckache)", + "48":" nausea or vomiting condition", + "49":" use a fan (with hand or paper)/feeling warm", + "50":" punching/slapping other person", + "51":" kicking other person", + "52":" pushing other person", + "53":" pat on back of other person", + "54":" point finger at the other person", + "55":" hugging other person", + "56":" giving something to other person", + "57":" touch other person's pocket", + "58":" handshaking", + "59":" walking towards each other", + "60":" walking apart from each other", + "61":" put on headphone", + "62":" take off headphone", + "63":" shoot at the basket", + "64":" bounce ball", + "65":" tennis bat swing", + "66":" juggling table tennis balls", + "67":" hush (quite)", + "68":" flick hair", + "69":" thumb up", + "70":" thumb down", + "71":" make ok sign", + "72":" make victory sign", + "73":" staple book", + "74":" counting money", + "75":" cutting nails", + "76":" cutting paper (using scissors)", + "77":" snapping fingers", + "78":" open bottle", + "79":" sniff (smell)", + "80":" squat down", + "81":" toss a coin", + "82":" fold paper", + "83":" ball up paper", + "84":" play magic cube", + "85":" apply cream on face", + "86":" apply cream on hand back", + "87":" put on bag", + "88":" take off bag", + "89":" put something into a bag", + "90":" take something out of a bag", + "91":" open a box", + "92":" move heavy objects", + "93":" shake fist", + "94":" throw up cap/hat", + "95":" hands up (both hands)", + "96":" cross arms", + "97":" arm circles", + "98":" arm swings", + "99":" running on the spot", + "100":" butt kicks (kick backward)", + "101":" cross toe touch", + "102":" side kick", + "103":" yawn", + "104":" stretch oneself", + "105":" blow nose", + "106":" hit other person with something", + "107":" wield knife towards other person", + "108":" knock over other person (hit with body)", + "109":" grab other person’s stuff", + "110":" shoot at other person with a gun", + "111":" step on foot", + "112":" high-five", + "113":" cheers and drink", + "114":" carry something with other person", + "115":" take a photo of other person", + "116":" follow other person", + "117":" whisper in other person’s ear", + "118":" exchange things with other person", + "119":" support somebody with hand", + "120":" finger-guessing game (playing rock-paper-scissors)", +} def get_label(file_name, dataset_name): if dataset_name == 'HumanAct12': @@ -1165,5 +1287,8 @@ def get_label(file_name, dataset_name): key = file_name.split('_')[0][1:] return UTD_MHAD[key] elif dataset_name == 'CMU_Mocap': - key = file_name.split('.')[0] - return CMU_Mocap[key] if key in CMU_Mocap.keys() else "" \ No newline at end of file + key = file_name.split(':')[0] + return CMU_Mocap[key] if key in CMU_Mocap.keys() else "" + elif dataset_name == 'NTU': + key = str(int(file_name[-3:])) + return NTU[key] \ No newline at end of file diff --git a/fit/tools/load.py b/fit/tools/load.py index 4d72af0..b546515 100644 --- a/fit/tools/load.py +++ b/fit/tools/load.py @@ -18,4 +18,6 @@ def load(name, path): return np.load(path, allow_pickle=True) elif name == "Human3.6M": return np.load(path, allow_pickle=True)[0::5] # down_sample + elif name == "NTU": + return np.load(path, allow_pickle=True) diff --git a/fit/tools/main.py b/fit/tools/main.py index 644b2e2..d76e175 100644 --- a/fit/tools/main.py +++ b/fit/tools/main.py @@ -114,7 +114,8 @@ if __name__ == "__main__": meters.reset_early_stop() logger.info("avg_loss:{:.4f}".format(meters.avg)) - # save_pic(res,smpl_layer,file,logger,args.dataset_name,target) save_params(res, file, logger, args.dataset_name) + save_pic(res,smpl_layer,file,logger,args.dataset_name,target) + torch.cuda.empty_cache() logger.info("Fitting finished! Average loss: {:.9f}".format(meters.avg)) diff --git a/fit/tools/save.py b/fit/tools/save.py index 97894aa..41e7103 100644 --- a/fit/tools/save.py +++ b/fit/tools/save.py @@ -1,3 +1,5 @@ +from display_utils import display_model +from label import get_label import sys import os import re @@ -6,8 +8,6 @@ import numpy as np import pickle sys.path.append(os.getcwd()) -from label import get_label -from display_utils import display_model def create_dir_not_exist(path): @@ -18,10 +18,8 @@ def create_dir_not_exist(path): def save_pic(res, smpl_layer, file, logger, dataset_name, target): _, _, verts, Jtr = res file_name = re.split('[/.]', file)[-2] - fit_path = "fit/output/{}/picture/fit/{}".format(dataset_name, file_name) - # gt_path = "fit/output/{}/picture/gt/{}".format(dataset_name, file_name) + fit_path = "fit/output/{}/picture/{}".format(dataset_name, file_name) create_dir_not_exist(fit_path) - # create_dir_not_exist(gt_path) logger.info('Saving pictures at {}'.format(fit_path)) for i in tqdm(range(Jtr.shape[0])): display_model( @@ -32,18 +30,8 @@ def save_pic(res, smpl_layer, file, logger, dataset_name, target): kintree_table=smpl_layer.kintree_table, savepath=os.path.join(fit_path+"/frame_{}".format(i)), batch_idx=i, - show=True, + show=False, only_joint=True) - # display_model( - # {'verts': verts.cpu().detach(), - # 'joints': target.cpu().detach()}, - # model_faces=smpl_layer.th_faces, - # with_joints=True, - # kintree_table=smpl_layer.kintree_table, - # savepath=os.path.join(gt_path+"/frame_{}".format(i)), - # batch_idx=i, - # show=False, - # only_joint=True) logger.info('Pictures saved') @@ -63,6 +51,7 @@ def save_params(res, file, logger, dataset_name): params["pose_params"] = pose_params params["shape_params"] = shape_params params["Jtr"] = Jtr + print("label:{}".format(label)) with open(os.path.join((fit_path), "{}_params.pkl".format(file_name)), 'wb') as f: pickle.dump(params, f) diff --git a/fit/tools/train.py b/fit/tools/train.py index 1e8f855..68f27a3 100644 --- a/fit/tools/train.py +++ b/fit/tools/train.py @@ -68,8 +68,8 @@ def train(smpl_layer, target, break if epoch % cfg.TRAIN.WRITE == 0: - # logger.info("Epoch {}, lossPerBatch={:.6f}, scale={:.4f} EarlyStopSatis: {}".format( - # epoch, float(loss),float(scale), early_stop.satis_num)) + # logger.info("Epoch {}, lossPerBatch={:.6f}, scale={:.4f}".format( + # epoch, float(loss),float(scale))) writer.add_scalar('loss', float(loss), epoch) writer.add_scalar('learning_rate', float( optimizer.state_dict()['param_groups'][0]['lr']), epoch) diff --git a/fit/tools/transform.py b/fit/tools/transform.py index c36317b..d4e5c6b 100644 --- a/fit/tools/transform.py +++ b/fit/tools/transform.py @@ -4,7 +4,8 @@ rotate = { 'HumanAct12': [1., -1., -1.], 'CMU_Mocap': [0.05, 0.05, 0.05], 'UTD_MHAD': [-1., 1., -1.], - 'Human3.6M': [-0.001, -0.001, 0.001] + 'Human3.6M': [-0.001, -0.001, 0.001], + 'NTU': [-1., 1., -1.] } @@ -15,4 +16,4 @@ def transform(name, arr: np.ndarray): arr[i][j] -= origin for k in range(3): arr[i][j][k] *= rotate[name][k] - return arr \ No newline at end of file + return arr