Support HAA4D

This commit is contained in:
Yiming Dou
2022-10-01 21:42:12 +08:00
parent fd2ff6616f
commit 6ac10e22e0
12 changed files with 232 additions and 170 deletions

View File

@ -47,6 +47,7 @@ The SMPL human body layer for Pytorch is from the [smplpytorch](https://github.c
- [UTD-MHAD](https://personal.utdallas.edu/~kehtar/UTD-MHAD.html) - [UTD-MHAD](https://personal.utdallas.edu/~kehtar/UTD-MHAD.html)
- [Human3.6M](http://vision.imar.ro/human3.6m/description.php) - [Human3.6M](http://vision.imar.ro/human3.6m/description.php)
- [NTU](https://rose1.ntu.edu.sg/dataset/actionRecognition/) - [NTU](https://rose1.ntu.edu.sg/dataset/actionRecognition/)
- [HAA4D](https://cse.hkust.edu.hk/haa4d/dataset.html)
- Set the **DATASET.PATH** in the corresponding configuration file to the location of dataset. - Set the **DATASET.PATH** in the corresponding configuration file to the location of dataset.

View File

@ -1,4 +1,6 @@
import torch import torch
import random
import numpy as np
from smplpytorch.pytorch.smpl_layer import SMPL_Layer from smplpytorch.pytorch.smpl_layer import SMPL_Layer
from display_utils import display_model from display_utils import display_model
@ -15,7 +17,7 @@ if __name__ == '__main__':
model_root='smplpytorch/native/models') model_root='smplpytorch/native/models')
# Generate random pose and shape parameters # Generate random pose and shape parameters
pose_params = torch.rand(batch_size, 72) * 0.2 pose_params = torch.rand(batch_size, 72) * 0.01
shape_params = torch.rand(batch_size, 10) * 0.03 shape_params = torch.rand(batch_size, 10) * 0.03
# GPU mode # GPU mode
@ -26,7 +28,6 @@ if __name__ == '__main__':
# Forward from the SMPL layer # Forward from the SMPL layer
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params) verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
print(Jtr)
# Draw output vertices and joints # Draw output vertices and joints
display_model( display_model(

View File

@ -1,3 +1,4 @@
from xml.parsers.expat import model
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection from mpl_toolkits.mplot3d.art3d import Poly3DCollection
@ -21,7 +22,8 @@ def display_model(
if ax is None: if ax is None:
fig = plt.figure() fig = plt.figure()
ax = fig.add_subplot(111, projection='3d') ax = fig.add_subplot(111, projection='3d')
verts, joints = model_info['verts'][batch_idx], model_info['joints'][batch_idx] verts = model_info['verts'][batch_idx]
joints = model_info['joints'][batch_idx]
if model_faces is None: if model_faces is None:
ax.scatter(verts[:, 0], verts[:, 1], verts[:, 2], alpha=0.2) ax.scatter(verts[:, 0], verts[:, 1], verts[:, 2], alpha=0.2)
elif not only_joint: elif not only_joint:

24
fit/configs/HAA4D.json Normal file
View File

@ -0,0 +1,24 @@
{
"MODEL": {
"GENDER": "neutral"
},
"TRAIN": {
"LEARNING_RATE": 1e-2,
"MAX_EPOCH": 1000,
"WRITE": 10,
"OPTIMIZE_SCALE":1,
"OPTIMIZE_SHAPE":1
},
"USE_GPU": 1,
"DATASET": {
"NAME": "NTU",
"PATH": "../NTU RGB+D/result",
"TARGET_PATH": "",
"DATA_MAP": [
[0,0],[1,4],[2,1],[4,5],[5,2],[7,6],[8,3],
[12,9],[18,12],[19,15],[20,13],[21,16],
[15,10],[6,1]
]
},
"DEBUG": 0
}

View File

@ -20,19 +20,19 @@
0 0
], ],
[ [
1, 2,
12 12
], ],
[ [
2, 1,
16 16
], ],
[ [
4, 5,
13 13
], ],
[ [
5, 4,
17 17
], ],
[ [
@ -40,52 +40,51 @@
1 1
], ],
[ [
7, 8,
14 14
], ],
[ [
8, 7,
18 18
], ],
[ [
9, 9,
20 20
], ],
[ [
12, 12,
2 2
], ],
[ [
13, 14,
4 4
], ],
[ [
14, 13,
8 8
], ],
[ [
18, 19,
5 5
], ],
[ [
19, 18,
9 9
], ],
[ [
20, 21,
6 6
], ],
[ [
21, 20,
10 10
], ],
[ [
22, 23,
22 22
], ],
[ [
23, 22,
24 24
] ]
] ]

View File

@ -1156,127 +1156,127 @@ CMU_Mocap = {
"41_09": "Climb" "41_09": "Climb"
} }
NTU={ NTU = {
"1":" drink water", "1":"drink water",
"2":" eat meal/snack", "2":"eat meal/snack",
"3":" brushing teeth", "3":"brushing teeth",
"4":" brushing hair", "4":"brushing hair",
"5":" drop", "5":"drop",
"6":" pickup", "6":"pickup",
"7":" throw", "7":"throw",
"8":" sitting down", "8":"sitting down",
"9":" standing up (from sitting position)", "9":"standing up (from sitting position)",
"10":" clapping", "10":"clapping",
"11":" reading", "11":"reading",
"12":" writing", "12":"writing",
"13":" tear up paper", "13":"tear up paper",
"14":" wear jacket", "14":"wear jacket",
"15":" take off jacket", "15":"take off jacket",
"16":" wear a shoe", "16":"wear a shoe",
"17":" take off a shoe", "17":"take off a shoe",
"18":" wear on glasses", "18":"wear on glasses",
"19":" take off glasses", "19":"take off glasses",
"20":" put on a hat/cap", "20":"put on a hat/cap",
"21":" take off a hat/cap", "21":"take off a hat/cap",
"22":" cheer up", "22":"cheer up",
"23":" hand waving", "23":"hand waving",
"24":" kicking something", "24":"kicking something",
"25":" reach into pocket", "25":"reach into pocket",
"26":" hopping (one foot jumping)", "26":"hopping (one foot jumping)",
"27":" jump up", "27":"jump up",
"28":" make a phone call/answer phone", "28":"make a phone call/answer phone",
"29":" playing with phone/tablet", "29":"playing with phone/tablet",
"30":" typing on a keyboard", "30":"typing on a keyboard",
"31":" pointing to something with finger", "31":"pointing to something with finger",
"32":" taking a selfie", "32":"taking a selfie",
"33":" check time (from watch)", "33":"check time (from watch)",
"34":" rub two hands together", "34":"rub two hands together",
"35":" nod head/bow", "35":"nod head/bow",
"36":" shake head", "36":"shake head",
"37":" wipe face", "37":"wipe face",
"38":" salute", "38":"salute",
"39":" put the palms together", "39":"put the palms together",
"40":" cross hands in front (say stop)", "40":"cross hands in front (say stop)",
"41":" sneeze/cough", "41":"sneeze/cough",
"42":" staggering", "42":"staggering",
"43":" falling", "43":"falling",
"44":" touch head (headache)", "44":"touch head (headache)",
"45":" touch chest (stomachache/heart pain)", "45":"touch chest (stomachache/heart pain)",
"46":" touch back (backache)", "46":"touch back (backache)",
"47":" touch neck (neckache)", "47":"touch neck (neckache)",
"48":" nausea or vomiting condition", "48":"nausea or vomiting condition",
"49":" use a fan (with hand or paper)/feeling warm", "49":"use a fan (with hand or paper)/feeling warm",
"50":" punching/slapping other person", "50":"punching/slapping other person",
"51":" kicking other person", "51":"kicking other person",
"52":" pushing other person", "52":"pushing other person",
"53":" pat on back of other person", "53":"pat on back of other person",
"54":" point finger at the other person", "54":"point finger at the other person",
"55":" hugging other person", "55":"hugging other person",
"56":" giving something to other person", "56":"giving something to other person",
"57":" touch other person's pocket", "57":"touch other person's pocket",
"58":" handshaking", "58":"handshaking",
"59":" walking towards each other", "59":"walking towards each other",
"60":" walking apart from each other", "60":"walking apart from each other",
"61":" put on headphone", "61":"put on headphone",
"62":" take off headphone", "62":"take off headphone",
"63":" shoot at the basket", "63":"shoot at the basket",
"64":" bounce ball", "64":"bounce ball",
"65":" tennis bat swing", "65":"tennis bat swing",
"66":" juggling table tennis balls", "66":"juggling table tennis balls",
"67":" hush (quite)", "67":"hush (quite)",
"68":" flick hair", "68":"flick hair",
"69":" thumb up", "69":"thumb up",
"70":" thumb down", "70":"thumb down",
"71":" make ok sign", "71":"make ok sign",
"72":" make victory sign", "72":"make victory sign",
"73":" staple book", "73":"staple book",
"74":" counting money", "74":"counting money",
"75":" cutting nails", "75":"cutting nails",
"76":" cutting paper (using scissors)", "76":"cutting paper (using scissors)",
"77":" snapping fingers", "77":"snapping fingers",
"78":" open bottle", "78":"open bottle",
"79":" sniff (smell)", "79":"sniff (smell)",
"80":" squat down", "80":"squat down",
"81":" toss a coin", "81":"toss a coin",
"82":" fold paper", "82":"fold paper",
"83":" ball up paper", "83":"ball up paper",
"84":" play magic cube", "84":"play magic cube",
"85":" apply cream on face", "85":"apply cream on face",
"86":" apply cream on hand back", "86":"apply cream on hand back",
"87":" put on bag", "87":"put on bag",
"88":" take off bag", "88":"take off bag",
"89":" put something into a bag", "89":"put something into a bag",
"90":" take something out of a bag", "90":"take something out of a bag",
"91":" open a box", "91":"open a box",
"92":" move heavy objects", "92":"move heavy objects",
"93":" shake fist", "93":"shake fist",
"94":" throw up cap/hat", "94":"throw up cap/hat",
"95":" hands up (both hands)", "95":"hands up (both hands)",
"96":" cross arms", "96":"cross arms",
"97":" arm circles", "97":"arm circles",
"98":" arm swings", "98":"arm swings",
"99":" running on the spot", "99":"running on the spot",
"100":" butt kicks (kick backward)", "100":"butt kicks (kick backward)",
"101":" cross toe touch", "101":"cross toe touch",
"102":" side kick", "102":"side kick",
"103":" yawn", "103":"yawn",
"104":" stretch oneself", "104":"stretch oneself",
"105":" blow nose", "105":"blow nose",
"106":" hit other person with something", "106":"hit other person with something",
"107":" wield knife towards other person", "107":"wield knife towards other person",
"108":" knock over other person (hit with body)", "108":"knock over other person (hit with body)",
"109":" grab other persons stuff", "109":"grab other persons stuff",
"110":" shoot at other person with a gun", "110":"shoot at other person with a gun",
"111":" step on foot", "111":"step on foot",
"112":" high-five", "112":"high-five",
"113":" cheers and drink", "113":"cheers and drink",
"114":" carry something with other person", "114":"carry something with other person",
"115":" take a photo of other person", "115":"take a photo of other person",
"116":" follow other person", "116":"follow other person",
"117":" whisper in other persons ear", "117":"whisper in other persons ear",
"118":" exchange things with other person", "118":"exchange things with other person",
"119":" support somebody with hand", "119":"support somebody with hand",
"120":" finger-guessing game (playing rock-paper-scissors)", "120":"finger-guessing game (playing rock-paper-scissors)",
} }
def get_label(file_name, dataset_name): def get_label(file_name, dataset_name):

View File

@ -19,5 +19,7 @@ def load(name, path):
elif name == "Human3.6M": elif name == "Human3.6M":
return np.load(path, allow_pickle=True)[0::5] # down_sample return np.load(path, allow_pickle=True)[0::5] # down_sample
elif name == "NTU": elif name == "NTU":
return np.load(path, allow_pickle=True)[0::2]
elif name == "HAA4D":
return np.load(path, allow_pickle=True) return np.load(path, allow_pickle=True)

View File

@ -1,32 +1,33 @@
import os
import sys
sys.path.append(os.getcwd())
from meters import Meters
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
from train import train
from transform import transform
from save import save_pic, save_params
from load import load
import torch import torch
import numpy as np import numpy as np
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from easydict import EasyDict as edict from easydict import EasyDict as edict
import time import time
import sys
import os
import logging import logging
import argparse import argparse
import json import json
sys.path.append(os.getcwd())
from load import load
from save import save_pic, save_params
from transform import transform
from train import train
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
from meters import Meters
torch.backends.cudnn.enabled = True torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Fit SMPL') parser = argparse.ArgumentParser(description='Fit SMPL')
parser.add_argument('--exp', dest='exp', parser.add_argument('--exp', dest='exp',
help='Define exp name', help='Define exp name',
default=time.strftime('%Y-%m-%d %H-%M-%S', time.localtime(time.time())), type=str) default=time.strftime('%Y-%m-%d %H-%M-%S', time.localtime(time.time())), type=str)
parser.add_argument('--dataset_name', dest='dataset_name', parser.add_argument('--dataset_name', '-n', dest='dataset_name',
help='select dataset', help='select dataset',
default='', type=str) default='', type=str)
parser.add_argument('--dataset_path', dest='dataset_path', parser.add_argument('--dataset_path', dest='dataset_path',
@ -98,13 +99,16 @@ if __name__ == "__main__":
gender=cfg.MODEL.GENDER, gender=cfg.MODEL.GENDER,
model_root='smplpytorch/native/models') model_root='smplpytorch/native/models')
meters=Meters() meters = Meters()
file_num = 0 file_num = 0
for root, dirs, files in os.walk(cfg.DATASET.PATH): for root, dirs, files in os.walk(cfg.DATASET.PATH):
for file in files: for file in sorted(files):
if not 'baseball_swing' in file:
continue
file_num += 1 file_num += 1
logger.info('Processing file: {} [{} / {}]'.format(file, file_num, len(files))) logger.info(
target = torch.from_numpy(transform(args.dataset_name,load(args.dataset_name, 'Processing file: {} [{} / {}]'.format(file, file_num, len(files)))
target = torch.from_numpy(transform(args.dataset_name, load(args.dataset_name,
os.path.join(root, file)))).float() os.path.join(root, file)))).float()
logger.info("target shape:{}".format(target.shape)) logger.info("target shape:{}".format(target.shape))
res = train(smpl_layer, target, res = train(smpl_layer, target,
@ -115,7 +119,8 @@ if __name__ == "__main__":
logger.info("avg_loss:{:.4f}".format(meters.avg)) logger.info("avg_loss:{:.4f}".format(meters.avg))
save_params(res, file, logger, args.dataset_name) save_params(res, file, logger, args.dataset_name)
# save_pic(res,smpl_layer,file,logger,args.dataset_name,target) save_pic(res, smpl_layer, file, logger, args.dataset_name, target)
torch.cuda.empty_cache() torch.cuda.empty_cache()
logger.info("Fitting finished! Average loss: {:.9f}".format(meters.avg)) logger.info(
"Fitting finished! Average loss: {:.9f}".format(meters.avg))

View File

@ -19,7 +19,7 @@ def save_pic(res, smpl_layer, file, logger, dataset_name, target):
_, _, verts, Jtr = res _, _, verts, Jtr = res
file_name = re.split('[/.]', file)[-2] file_name = re.split('[/.]', file)[-2]
fit_path = "fit/output/{}/picture/{}".format(dataset_name, file_name) fit_path = "fit/output/{}/picture/{}".format(dataset_name, file_name)
create_dir_not_exist(fit_path) os.makedirs(fit_path,exist_ok=True)
logger.info('Saving pictures at {}'.format(fit_path)) logger.info('Saving pictures at {}'.format(fit_path))
for i in tqdm(range(Jtr.shape[0])): for i in tqdm(range(Jtr.shape[0])):
display_model( display_model(
@ -28,7 +28,7 @@ def save_pic(res, smpl_layer, file, logger, dataset_name, target):
model_faces=smpl_layer.th_faces, model_faces=smpl_layer.th_faces,
with_joints=True, with_joints=True,
kintree_table=smpl_layer.kintree_table, kintree_table=smpl_layer.kintree_table,
savepath=os.path.join(fit_path+"/frame_{}".format(i)), savepath=os.path.join(fit_path+"/frame_{:0>4d}".format(i)),
batch_idx=i, batch_idx=i,
show=False, show=False,
only_joint=True) only_joint=True)
@ -55,3 +55,21 @@ def save_params(res, file, logger, dataset_name):
with open(os.path.join((fit_path), with open(os.path.join((fit_path),
"{}_params.pkl".format(file_name)), 'wb') as f: "{}_params.pkl".format(file_name)), 'wb') as f:
pickle.dump(params, f) pickle.dump(params, f)
def save_single_pic(res, smpl_layer, epoch, logger, dataset_name, target):
_, _, verts, Jtr = res
fit_path = "fit/output/{}/picture".format(dataset_name)
create_dir_not_exist(fit_path)
logger.info('Saving pictures at {}'.format(fit_path))
display_model(
{'verts': verts.cpu().detach(),
'joints': Jtr.cpu().detach()},
model_faces=smpl_layer.th_faces,
with_joints=True,
kintree_table=smpl_layer.kintree_table,
savepath=fit_path+"/epoch_{:0>4d}".format(epoch),
batch_idx=60,
show=False,
only_joint=False)
logger.info('Picture saved')

View File

@ -6,6 +6,7 @@ import os
from tqdm import tqdm from tqdm import tqdm
sys.path.append(os.getcwd()) sys.path.append(os.getcwd())
from save import save_single_pic
@ -25,8 +26,10 @@ def init(smpl_layer, target, device, cfg):
params["shape_params"].requires_grad = bool(cfg.TRAIN.OPTIMIZE_SHAPE) params["shape_params"].requires_grad = bool(cfg.TRAIN.OPTIMIZE_SHAPE)
params["scale"].requires_grad = bool(cfg.TRAIN.OPTIMIZE_SCALE) params["scale"].requires_grad = bool(cfg.TRAIN.OPTIMIZE_SCALE)
optimizer = optim.Adam([params["pose_params"], params["shape_params"], params["scale"]], optim_params = [{'params': params["pose_params"], 'lr': cfg.TRAIN.LEARNING_RATE},
lr=cfg.TRAIN.LEARNING_RATE) {'params': params["shape_params"], 'lr': cfg.TRAIN.LEARNING_RATE},
{'params': params["scale"], 'lr': cfg.TRAIN.LEARNING_RATE*10},]
optimizer = optim.Adam(optim_params)
index = {} index = {}
smpl_index = [] smpl_index = []
@ -51,11 +54,14 @@ def train(smpl_layer, target,
shape_params = params["shape_params"] shape_params = params["shape_params"]
scale = params["scale"] scale = params["scale"]
for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)): with torch.no_grad():
# for epoch in range(cfg.TRAIN.MAX_EPOCH):
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params) verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
loss = F.smooth_l1_loss(Jtr.index_select(1, index["smpl_index"]) * 100, params["scale"]*=(torch.max(torch.abs(target))/torch.max(torch.abs(Jtr)))
target.index_select(1, index["dataset_index"]) * 100 * scale)
for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)):
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
loss = F.smooth_l1_loss(scale*Jtr.index_select(1, index["smpl_index"]),
target.index_select(1, index["dataset_index"]))
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
@ -67,12 +73,15 @@ def train(smpl_layer, target,
logger.info("Early stop at epoch {} !".format(epoch)) logger.info("Early stop at epoch {} !".format(epoch))
break break
if epoch % cfg.TRAIN.WRITE == 0: if epoch % cfg.TRAIN.WRITE == 0 or epoch<10:
# logger.info("Epoch {}, lossPerBatch={:.6f}, scale={:.4f}".format( # logger.info("Epoch {}, lossPerBatch={:.6f}, scale={:.4f}".format(
# epoch, float(loss),float(scale))) # epoch, float(loss),float(scale)))
print("Epoch {}, lossPerBatch={:.6f}, scale={:.4f}".format(
epoch, float(loss),float(scale)))
writer.add_scalar('loss', float(loss), epoch) writer.add_scalar('loss', float(loss), epoch)
writer.add_scalar('learning_rate', float( writer.add_scalar('learning_rate', float(
optimizer.state_dict()['param_groups'][0]['lr']), epoch) optimizer.state_dict()['param_groups'][0]['lr']), epoch)
# save_single_pic(res,smpl_layer,epoch,logger,args.dataset_name,target)
logger.info('Train ended, min_loss = {:.4f}'.format( logger.info('Train ended, min_loss = {:.4f}'.format(
float(meters.min_loss))) float(meters.min_loss)))

View File

@ -5,7 +5,8 @@ rotate = {
'CMU_Mocap': [0.05, 0.05, 0.05], 'CMU_Mocap': [0.05, 0.05, 0.05],
'UTD_MHAD': [-1., 1., -1.], 'UTD_MHAD': [-1., 1., -1.],
'Human3.6M': [-0.001, -0.001, 0.001], 'Human3.6M': [-0.001, -0.001, 0.001],
'NTU': [-1., 1., -1.] 'NTU': [1., 1., -1.],
'HAA4D': [1., -1., -1.],
} }

View File

@ -1,7 +1,7 @@
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import imageio, os import imageio, os
images = [] images = []
filenames = sorted(fn for fn in os.listdir('./fit/output/Human3.6M/picture/fit/s_01_act_09_subact_02_ca_02') ) filenames = sorted(fn for fn in os.listdir('D:/OneDrive - sjtu.edu.cn/MVIG/Action-Dataset/Pose_to_SMPL/fit/output/NTU/picture') )
for filename in filenames: for filename in filenames:
images.append(imageio.imread('./fit/output/Human3.6M/picture/fit/s_01_act_09_subact_02_ca_02/'+filename)) images.append(imageio.imread('D:/OneDrive - sjtu.edu.cn/MVIG/Action-Dataset/Pose_to_SMPL/fit/output/NTU/picture/'+filename))
imageio.mimsave('fit_mesh.gif', images, duration=0.2) imageio.mimsave('clapping_example.gif', images, duration=0.2)