From 6ac10e22e051df851f06254f4c35863c8df9cbc6 Mon Sep 17 00:00:00 2001 From: Yiming Dou Date: Sat, 1 Oct 2022 21:42:12 +0800 Subject: [PATCH] Support HAA4D --- README.md | 1 + demo.py | 5 +- display_utils.py | 4 +- fit/configs/HAA4D.json | 24 ++++ fit/configs/NTU.json | 29 +++-- fit/tools/label.py | 242 ++++++++++++++++++++--------------------- fit/tools/load.py | 2 + fit/tools/main.py | 43 ++++---- fit/tools/save.py | 22 +++- fit/tools/train.py | 21 +++- fit/tools/transform.py | 3 +- make_gif.py | 6 +- 12 files changed, 232 insertions(+), 170 deletions(-) create mode 100644 fit/configs/HAA4D.json diff --git a/README.md b/README.md index 904c272..5b8ac30 100644 --- a/README.md +++ b/README.md @@ -47,6 +47,7 @@ The SMPL human body layer for Pytorch is from the [smplpytorch](https://github.c - [UTD-MHAD](https://personal.utdallas.edu/~kehtar/UTD-MHAD.html) - [Human3.6M](http://vision.imar.ro/human3.6m/description.php) - [NTU](https://rose1.ntu.edu.sg/dataset/actionRecognition/) + - [HAA4D](https://cse.hkust.edu.hk/haa4d/dataset.html) - Set the **DATASET.PATH** in the corresponding configuration file to the location of dataset. diff --git a/demo.py b/demo.py index d36bce4..1b4c964 100644 --- a/demo.py +++ b/demo.py @@ -1,4 +1,6 @@ import torch +import random +import numpy as np from smplpytorch.pytorch.smpl_layer import SMPL_Layer from display_utils import display_model @@ -15,7 +17,7 @@ if __name__ == '__main__': model_root='smplpytorch/native/models') # Generate random pose and shape parameters - pose_params = torch.rand(batch_size, 72) * 0.2 + pose_params = torch.rand(batch_size, 72) * 0.01 shape_params = torch.rand(batch_size, 10) * 0.03 # GPU mode @@ -26,7 +28,6 @@ if __name__ == '__main__': # Forward from the SMPL layer verts, Jtr = smpl_layer(pose_params, th_betas=shape_params) - print(Jtr) # Draw output vertices and joints display_model( diff --git a/display_utils.py b/display_utils.py index ce06d41..e7d857a 100644 --- a/display_utils.py +++ b/display_utils.py @@ -1,3 +1,4 @@ +from xml.parsers.expat import model from matplotlib import pyplot as plt from mpl_toolkits.mplot3d import Axes3D from mpl_toolkits.mplot3d.art3d import Poly3DCollection @@ -21,7 +22,8 @@ def display_model( if ax is None: fig = plt.figure() ax = fig.add_subplot(111, projection='3d') - verts, joints = model_info['verts'][batch_idx], model_info['joints'][batch_idx] + verts = model_info['verts'][batch_idx] + joints = model_info['joints'][batch_idx] if model_faces is None: ax.scatter(verts[:, 0], verts[:, 1], verts[:, 2], alpha=0.2) elif not only_joint: diff --git a/fit/configs/HAA4D.json b/fit/configs/HAA4D.json new file mode 100644 index 0000000..702ce06 --- /dev/null +++ b/fit/configs/HAA4D.json @@ -0,0 +1,24 @@ +{ + "MODEL": { + "GENDER": "neutral" + }, + "TRAIN": { + "LEARNING_RATE": 1e-2, + "MAX_EPOCH": 1000, + "WRITE": 10, + "OPTIMIZE_SCALE":1, + "OPTIMIZE_SHAPE":1 + }, + "USE_GPU": 1, + "DATASET": { + "NAME": "NTU", + "PATH": "../NTU RGB+D/result", + "TARGET_PATH": "", + "DATA_MAP": [ + [0,0],[1,4],[2,1],[4,5],[5,2],[7,6],[8,3], + [12,9],[18,12],[19,15],[20,13],[21,16], + [15,10],[6,1] + ] + }, + "DEBUG": 0 +} \ No newline at end of file diff --git a/fit/configs/NTU.json b/fit/configs/NTU.json index 32a73b5..457659c 100644 --- a/fit/configs/NTU.json +++ b/fit/configs/NTU.json @@ -20,19 +20,19 @@ 0 ], [ - 1, + 2, 12 ], [ - 2, + 1, 16 ], [ - 4, + 5, 13 ], [ - 5, + 4, 17 ], [ @@ -40,52 +40,51 @@ 1 ], [ - 7, + 8, 14 ], [ - 8, + 7, 18 ], [ 9, 20 ], - [ 12, 2 ], [ - 13, + 14, 4 ], [ - 14, + 13, 8 ], [ - 18, + 19, 5 ], [ - 19, + 18, 9 ], [ - 20, + 21, 6 ], [ - 21, + 20, 10 ], [ - 22, + 23, 22 ], [ - 23, + 22, 24 ] ] diff --git a/fit/tools/label.py b/fit/tools/label.py index 5672c01..40f7985 100644 --- a/fit/tools/label.py +++ b/fit/tools/label.py @@ -1156,127 +1156,127 @@ CMU_Mocap = { "41_09": "Climb" } -NTU={ - "1":" drink water", - "2":" eat meal/snack", - "3":" brushing teeth", - "4":" brushing hair", - "5":" drop", - "6":" pickup", - "7":" throw", - "8":" sitting down", - "9":" standing up (from sitting position)", - "10":" clapping", - "11":" reading", - "12":" writing", - "13":" tear up paper", - "14":" wear jacket", - "15":" take off jacket", - "16":" wear a shoe", - "17":" take off a shoe", - "18":" wear on glasses", - "19":" take off glasses", - "20":" put on a hat/cap", - "21":" take off a hat/cap", - "22":" cheer up", - "23":" hand waving", - "24":" kicking something", - "25":" reach into pocket", - "26":" hopping (one foot jumping)", - "27":" jump up", - "28":" make a phone call/answer phone", - "29":" playing with phone/tablet", - "30":" typing on a keyboard", - "31":" pointing to something with finger", - "32":" taking a selfie", - "33":" check time (from watch)", - "34":" rub two hands together", - "35":" nod head/bow", - "36":" shake head", - "37":" wipe face", - "38":" salute", - "39":" put the palms together", - "40":" cross hands in front (say stop)", - "41":" sneeze/cough", - "42":" staggering", - "43":" falling", - "44":" touch head (headache)", - "45":" touch chest (stomachache/heart pain)", - "46":" touch back (backache)", - "47":" touch neck (neckache)", - "48":" nausea or vomiting condition", - "49":" use a fan (with hand or paper)/feeling warm", - "50":" punching/slapping other person", - "51":" kicking other person", - "52":" pushing other person", - "53":" pat on back of other person", - "54":" point finger at the other person", - "55":" hugging other person", - "56":" giving something to other person", - "57":" touch other person's pocket", - "58":" handshaking", - "59":" walking towards each other", - "60":" walking apart from each other", - "61":" put on headphone", - "62":" take off headphone", - "63":" shoot at the basket", - "64":" bounce ball", - "65":" tennis bat swing", - "66":" juggling table tennis balls", - "67":" hush (quite)", - "68":" flick hair", - "69":" thumb up", - "70":" thumb down", - "71":" make ok sign", - "72":" make victory sign", - "73":" staple book", - "74":" counting money", - "75":" cutting nails", - "76":" cutting paper (using scissors)", - "77":" snapping fingers", - "78":" open bottle", - "79":" sniff (smell)", - "80":" squat down", - "81":" toss a coin", - "82":" fold paper", - "83":" ball up paper", - "84":" play magic cube", - "85":" apply cream on face", - "86":" apply cream on hand back", - "87":" put on bag", - "88":" take off bag", - "89":" put something into a bag", - "90":" take something out of a bag", - "91":" open a box", - "92":" move heavy objects", - "93":" shake fist", - "94":" throw up cap/hat", - "95":" hands up (both hands)", - "96":" cross arms", - "97":" arm circles", - "98":" arm swings", - "99":" running on the spot", - "100":" butt kicks (kick backward)", - "101":" cross toe touch", - "102":" side kick", - "103":" yawn", - "104":" stretch oneself", - "105":" blow nose", - "106":" hit other person with something", - "107":" wield knife towards other person", - "108":" knock over other person (hit with body)", - "109":" grab other person’s stuff", - "110":" shoot at other person with a gun", - "111":" step on foot", - "112":" high-five", - "113":" cheers and drink", - "114":" carry something with other person", - "115":" take a photo of other person", - "116":" follow other person", - "117":" whisper in other person’s ear", - "118":" exchange things with other person", - "119":" support somebody with hand", - "120":" finger-guessing game (playing rock-paper-scissors)", +NTU = { + "1":"drink water", + "2":"eat meal/snack", + "3":"brushing teeth", + "4":"brushing hair", + "5":"drop", + "6":"pickup", + "7":"throw", + "8":"sitting down", + "9":"standing up (from sitting position)", + "10":"clapping", + "11":"reading", + "12":"writing", + "13":"tear up paper", + "14":"wear jacket", + "15":"take off jacket", + "16":"wear a shoe", + "17":"take off a shoe", + "18":"wear on glasses", + "19":"take off glasses", + "20":"put on a hat/cap", + "21":"take off a hat/cap", + "22":"cheer up", + "23":"hand waving", + "24":"kicking something", + "25":"reach into pocket", + "26":"hopping (one foot jumping)", + "27":"jump up", + "28":"make a phone call/answer phone", + "29":"playing with phone/tablet", + "30":"typing on a keyboard", + "31":"pointing to something with finger", + "32":"taking a selfie", + "33":"check time (from watch)", + "34":"rub two hands together", + "35":"nod head/bow", + "36":"shake head", + "37":"wipe face", + "38":"salute", + "39":"put the palms together", + "40":"cross hands in front (say stop)", + "41":"sneeze/cough", + "42":"staggering", + "43":"falling", + "44":"touch head (headache)", + "45":"touch chest (stomachache/heart pain)", + "46":"touch back (backache)", + "47":"touch neck (neckache)", + "48":"nausea or vomiting condition", + "49":"use a fan (with hand or paper)/feeling warm", + "50":"punching/slapping other person", + "51":"kicking other person", + "52":"pushing other person", + "53":"pat on back of other person", + "54":"point finger at the other person", + "55":"hugging other person", + "56":"giving something to other person", + "57":"touch other person's pocket", + "58":"handshaking", + "59":"walking towards each other", + "60":"walking apart from each other", + "61":"put on headphone", + "62":"take off headphone", + "63":"shoot at the basket", + "64":"bounce ball", + "65":"tennis bat swing", + "66":"juggling table tennis balls", + "67":"hush (quite)", + "68":"flick hair", + "69":"thumb up", + "70":"thumb down", + "71":"make ok sign", + "72":"make victory sign", + "73":"staple book", + "74":"counting money", + "75":"cutting nails", + "76":"cutting paper (using scissors)", + "77":"snapping fingers", + "78":"open bottle", + "79":"sniff (smell)", + "80":"squat down", + "81":"toss a coin", + "82":"fold paper", + "83":"ball up paper", + "84":"play magic cube", + "85":"apply cream on face", + "86":"apply cream on hand back", + "87":"put on bag", + "88":"take off bag", + "89":"put something into a bag", + "90":"take something out of a bag", + "91":"open a box", + "92":"move heavy objects", + "93":"shake fist", + "94":"throw up cap/hat", + "95":"hands up (both hands)", + "96":"cross arms", + "97":"arm circles", + "98":"arm swings", + "99":"running on the spot", + "100":"butt kicks (kick backward)", + "101":"cross toe touch", + "102":"side kick", + "103":"yawn", + "104":"stretch oneself", + "105":"blow nose", + "106":"hit other person with something", + "107":"wield knife towards other person", + "108":"knock over other person (hit with body)", + "109":"grab other person’s stuff", + "110":"shoot at other person with a gun", + "111":"step on foot", + "112":"high-five", + "113":"cheers and drink", + "114":"carry something with other person", + "115":"take a photo of other person", + "116":"follow other person", + "117":"whisper in other person’s ear", + "118":"exchange things with other person", + "119":"support somebody with hand", + "120":"finger-guessing game (playing rock-paper-scissors)", } def get_label(file_name, dataset_name): diff --git a/fit/tools/load.py b/fit/tools/load.py index b546515..9d34aee 100644 --- a/fit/tools/load.py +++ b/fit/tools/load.py @@ -19,5 +19,7 @@ def load(name, path): elif name == "Human3.6M": return np.load(path, allow_pickle=True)[0::5] # down_sample elif name == "NTU": + return np.load(path, allow_pickle=True)[0::2] + elif name == "HAA4D": return np.load(path, allow_pickle=True) diff --git a/fit/tools/main.py b/fit/tools/main.py index 8eadb3a..3a4715a 100644 --- a/fit/tools/main.py +++ b/fit/tools/main.py @@ -1,32 +1,33 @@ +import os +import sys +sys.path.append(os.getcwd()) +from meters import Meters +from smplpytorch.pytorch.smpl_layer import SMPL_Layer +from train import train +from transform import transform +from save import save_pic, save_params +from load import load import torch import numpy as np from tensorboardX import SummaryWriter from easydict import EasyDict as edict import time -import sys -import os import logging import argparse import json -sys.path.append(os.getcwd()) -from load import load -from save import save_pic, save_params -from transform import transform -from train import train -from smplpytorch.pytorch.smpl_layer import SMPL_Layer -from meters import Meters torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True + def parse_args(): parser = argparse.ArgumentParser(description='Fit SMPL') parser.add_argument('--exp', dest='exp', help='Define exp name', default=time.strftime('%Y-%m-%d %H-%M-%S', time.localtime(time.time())), type=str) - parser.add_argument('--dataset_name', dest='dataset_name', + parser.add_argument('--dataset_name', '-n', dest='dataset_name', help='select dataset', default='', type=str) parser.add_argument('--dataset_path', dest='dataset_path', @@ -97,15 +98,18 @@ if __name__ == "__main__": center_idx=0, gender=cfg.MODEL.GENDER, model_root='smplpytorch/native/models') - - meters=Meters() + + meters = Meters() file_num = 0 for root, dirs, files in os.walk(cfg.DATASET.PATH): - for file in files: + for file in sorted(files): + if not 'baseball_swing' in file: + continue file_num += 1 - logger.info('Processing file: {} [{} / {}]'.format(file, file_num, len(files))) - target = torch.from_numpy(transform(args.dataset_name,load(args.dataset_name, - os.path.join(root, file)))).float() + logger.info( + 'Processing file: {} [{} / {}]'.format(file, file_num, len(files))) + target = torch.from_numpy(transform(args.dataset_name, load(args.dataset_name, + os.path.join(root, file)))).float() logger.info("target shape:{}".format(target.shape)) res = train(smpl_layer, target, logger, writer, device, @@ -115,7 +119,8 @@ if __name__ == "__main__": logger.info("avg_loss:{:.4f}".format(meters.avg)) save_params(res, file, logger, args.dataset_name) - # save_pic(res,smpl_layer,file,logger,args.dataset_name,target) - + save_pic(res, smpl_layer, file, logger, args.dataset_name, target) + torch.cuda.empty_cache() - logger.info("Fitting finished! Average loss: {:.9f}".format(meters.avg)) + logger.info( + "Fitting finished! Average loss: {:.9f}".format(meters.avg)) diff --git a/fit/tools/save.py b/fit/tools/save.py index 41e7103..05d97a1 100644 --- a/fit/tools/save.py +++ b/fit/tools/save.py @@ -19,7 +19,7 @@ def save_pic(res, smpl_layer, file, logger, dataset_name, target): _, _, verts, Jtr = res file_name = re.split('[/.]', file)[-2] fit_path = "fit/output/{}/picture/{}".format(dataset_name, file_name) - create_dir_not_exist(fit_path) + os.makedirs(fit_path,exist_ok=True) logger.info('Saving pictures at {}'.format(fit_path)) for i in tqdm(range(Jtr.shape[0])): display_model( @@ -28,7 +28,7 @@ def save_pic(res, smpl_layer, file, logger, dataset_name, target): model_faces=smpl_layer.th_faces, with_joints=True, kintree_table=smpl_layer.kintree_table, - savepath=os.path.join(fit_path+"/frame_{}".format(i)), + savepath=os.path.join(fit_path+"/frame_{:0>4d}".format(i)), batch_idx=i, show=False, only_joint=True) @@ -55,3 +55,21 @@ def save_params(res, file, logger, dataset_name): with open(os.path.join((fit_path), "{}_params.pkl".format(file_name)), 'wb') as f: pickle.dump(params, f) + + +def save_single_pic(res, smpl_layer, epoch, logger, dataset_name, target): + _, _, verts, Jtr = res + fit_path = "fit/output/{}/picture".format(dataset_name) + create_dir_not_exist(fit_path) + logger.info('Saving pictures at {}'.format(fit_path)) + display_model( + {'verts': verts.cpu().detach(), + 'joints': Jtr.cpu().detach()}, + model_faces=smpl_layer.th_faces, + with_joints=True, + kintree_table=smpl_layer.kintree_table, + savepath=fit_path+"/epoch_{:0>4d}".format(epoch), + batch_idx=60, + show=False, + only_joint=False) + logger.info('Picture saved') \ No newline at end of file diff --git a/fit/tools/train.py b/fit/tools/train.py index 68f27a3..827c7db 100644 --- a/fit/tools/train.py +++ b/fit/tools/train.py @@ -6,6 +6,7 @@ import os from tqdm import tqdm sys.path.append(os.getcwd()) +from save import save_single_pic @@ -25,8 +26,10 @@ def init(smpl_layer, target, device, cfg): params["shape_params"].requires_grad = bool(cfg.TRAIN.OPTIMIZE_SHAPE) params["scale"].requires_grad = bool(cfg.TRAIN.OPTIMIZE_SCALE) - optimizer = optim.Adam([params["pose_params"], params["shape_params"], params["scale"]], - lr=cfg.TRAIN.LEARNING_RATE) + optim_params = [{'params': params["pose_params"], 'lr': cfg.TRAIN.LEARNING_RATE}, + {'params': params["shape_params"], 'lr': cfg.TRAIN.LEARNING_RATE}, + {'params': params["scale"], 'lr': cfg.TRAIN.LEARNING_RATE*10},] + optimizer = optim.Adam(optim_params) index = {} smpl_index = [] @@ -50,12 +53,15 @@ def train(smpl_layer, target, pose_params = params["pose_params"] shape_params = params["shape_params"] scale = params["scale"] + + with torch.no_grad(): + verts, Jtr = smpl_layer(pose_params, th_betas=shape_params) + params["scale"]*=(torch.max(torch.abs(target))/torch.max(torch.abs(Jtr))) for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)): - # for epoch in range(cfg.TRAIN.MAX_EPOCH): verts, Jtr = smpl_layer(pose_params, th_betas=shape_params) - loss = F.smooth_l1_loss(Jtr.index_select(1, index["smpl_index"]) * 100, - target.index_select(1, index["dataset_index"]) * 100 * scale) + loss = F.smooth_l1_loss(scale*Jtr.index_select(1, index["smpl_index"]), + target.index_select(1, index["dataset_index"])) optimizer.zero_grad() loss.backward() optimizer.step() @@ -67,12 +73,15 @@ def train(smpl_layer, target, logger.info("Early stop at epoch {} !".format(epoch)) break - if epoch % cfg.TRAIN.WRITE == 0: + if epoch % cfg.TRAIN.WRITE == 0 or epoch<10: # logger.info("Epoch {}, lossPerBatch={:.6f}, scale={:.4f}".format( # epoch, float(loss),float(scale))) + print("Epoch {}, lossPerBatch={:.6f}, scale={:.4f}".format( + epoch, float(loss),float(scale))) writer.add_scalar('loss', float(loss), epoch) writer.add_scalar('learning_rate', float( optimizer.state_dict()['param_groups'][0]['lr']), epoch) + # save_single_pic(res,smpl_layer,epoch,logger,args.dataset_name,target) logger.info('Train ended, min_loss = {:.4f}'.format( float(meters.min_loss))) diff --git a/fit/tools/transform.py b/fit/tools/transform.py index d4e5c6b..d4c0975 100644 --- a/fit/tools/transform.py +++ b/fit/tools/transform.py @@ -5,7 +5,8 @@ rotate = { 'CMU_Mocap': [0.05, 0.05, 0.05], 'UTD_MHAD': [-1., 1., -1.], 'Human3.6M': [-0.001, -0.001, 0.001], - 'NTU': [-1., 1., -1.] + 'NTU': [1., 1., -1.], + 'HAA4D': [1., -1., -1.], } diff --git a/make_gif.py b/make_gif.py index 2d022d9..b7f8766 100644 --- a/make_gif.py +++ b/make_gif.py @@ -1,7 +1,7 @@ import matplotlib.pyplot as plt import imageio, os images = [] -filenames = sorted(fn for fn in os.listdir('./fit/output/Human3.6M/picture/fit/s_01_act_09_subact_02_ca_02') ) +filenames = sorted(fn for fn in os.listdir('D:/OneDrive - sjtu.edu.cn/MVIG/Action-Dataset/Pose_to_SMPL/fit/output/NTU/picture') ) for filename in filenames: - images.append(imageio.imread('./fit/output/Human3.6M/picture/fit/s_01_act_09_subact_02_ca_02/'+filename)) -imageio.mimsave('fit_mesh.gif', images, duration=0.2) \ No newline at end of file + images.append(imageio.imread('D:/OneDrive - sjtu.edu.cn/MVIG/Action-Dataset/Pose_to_SMPL/fit/output/NTU/picture/'+filename)) +imageio.mimsave('clapping_example.gif', images, duration=0.2) \ No newline at end of file