Support HAA4D

This commit is contained in:
Yiming Dou
2022-10-01 21:42:12 +08:00
parent fd2ff6616f
commit 6ac10e22e0
12 changed files with 232 additions and 170 deletions

View File

@ -47,6 +47,7 @@ The SMPL human body layer for Pytorch is from the [smplpytorch](https://github.c
- [UTD-MHAD](https://personal.utdallas.edu/~kehtar/UTD-MHAD.html)
- [Human3.6M](http://vision.imar.ro/human3.6m/description.php)
- [NTU](https://rose1.ntu.edu.sg/dataset/actionRecognition/)
- [HAA4D](https://cse.hkust.edu.hk/haa4d/dataset.html)
- Set the **DATASET.PATH** in the corresponding configuration file to the location of dataset.

View File

@ -1,4 +1,6 @@
import torch
import random
import numpy as np
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
from display_utils import display_model
@ -15,7 +17,7 @@ if __name__ == '__main__':
model_root='smplpytorch/native/models')
# Generate random pose and shape parameters
pose_params = torch.rand(batch_size, 72) * 0.2
pose_params = torch.rand(batch_size, 72) * 0.01
shape_params = torch.rand(batch_size, 10) * 0.03
# GPU mode
@ -26,7 +28,6 @@ if __name__ == '__main__':
# Forward from the SMPL layer
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
print(Jtr)
# Draw output vertices and joints
display_model(

View File

@ -1,3 +1,4 @@
from xml.parsers.expat import model
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
@ -21,7 +22,8 @@ def display_model(
if ax is None:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
verts, joints = model_info['verts'][batch_idx], model_info['joints'][batch_idx]
verts = model_info['verts'][batch_idx]
joints = model_info['joints'][batch_idx]
if model_faces is None:
ax.scatter(verts[:, 0], verts[:, 1], verts[:, 2], alpha=0.2)
elif not only_joint:

24
fit/configs/HAA4D.json Normal file
View File

@ -0,0 +1,24 @@
{
"MODEL": {
"GENDER": "neutral"
},
"TRAIN": {
"LEARNING_RATE": 1e-2,
"MAX_EPOCH": 1000,
"WRITE": 10,
"OPTIMIZE_SCALE":1,
"OPTIMIZE_SHAPE":1
},
"USE_GPU": 1,
"DATASET": {
"NAME": "NTU",
"PATH": "../NTU RGB+D/result",
"TARGET_PATH": "",
"DATA_MAP": [
[0,0],[1,4],[2,1],[4,5],[5,2],[7,6],[8,3],
[12,9],[18,12],[19,15],[20,13],[21,16],
[15,10],[6,1]
]
},
"DEBUG": 0
}

View File

@ -20,19 +20,19 @@
0
],
[
1,
2,
12
],
[
2,
1,
16
],
[
4,
5,
13
],
[
5,
4,
17
],
[
@ -40,52 +40,51 @@
1
],
[
7,
8,
14
],
[
8,
7,
18
],
[
9,
20
],
[
12,
2
],
[
13,
14,
4
],
[
14,
13,
8
],
[
18,
19,
5
],
[
19,
18,
9
],
[
20,
21,
6
],
[
21,
20,
10
],
[
22,
23,
22
],
[
23,
22,
24
]
]

View File

@ -1156,127 +1156,127 @@ CMU_Mocap = {
"41_09": "Climb"
}
NTU={
"1":" drink water",
"2":" eat meal/snack",
"3":" brushing teeth",
"4":" brushing hair",
"5":" drop",
"6":" pickup",
"7":" throw",
"8":" sitting down",
"9":" standing up (from sitting position)",
"10":" clapping",
"11":" reading",
"12":" writing",
"13":" tear up paper",
"14":" wear jacket",
"15":" take off jacket",
"16":" wear a shoe",
"17":" take off a shoe",
"18":" wear on glasses",
"19":" take off glasses",
"20":" put on a hat/cap",
"21":" take off a hat/cap",
"22":" cheer up",
"23":" hand waving",
"24":" kicking something",
"25":" reach into pocket",
"26":" hopping (one foot jumping)",
"27":" jump up",
"28":" make a phone call/answer phone",
"29":" playing with phone/tablet",
"30":" typing on a keyboard",
"31":" pointing to something with finger",
"32":" taking a selfie",
"33":" check time (from watch)",
"34":" rub two hands together",
"35":" nod head/bow",
"36":" shake head",
"37":" wipe face",
"38":" salute",
"39":" put the palms together",
"40":" cross hands in front (say stop)",
"41":" sneeze/cough",
"42":" staggering",
"43":" falling",
"44":" touch head (headache)",
"45":" touch chest (stomachache/heart pain)",
"46":" touch back (backache)",
"47":" touch neck (neckache)",
"48":" nausea or vomiting condition",
"49":" use a fan (with hand or paper)/feeling warm",
"50":" punching/slapping other person",
"51":" kicking other person",
"52":" pushing other person",
"53":" pat on back of other person",
"54":" point finger at the other person",
"55":" hugging other person",
"56":" giving something to other person",
"57":" touch other person's pocket",
"58":" handshaking",
"59":" walking towards each other",
"60":" walking apart from each other",
"61":" put on headphone",
"62":" take off headphone",
"63":" shoot at the basket",
"64":" bounce ball",
"65":" tennis bat swing",
"66":" juggling table tennis balls",
"67":" hush (quite)",
"68":" flick hair",
"69":" thumb up",
"70":" thumb down",
"71":" make ok sign",
"72":" make victory sign",
"73":" staple book",
"74":" counting money",
"75":" cutting nails",
"76":" cutting paper (using scissors)",
"77":" snapping fingers",
"78":" open bottle",
"79":" sniff (smell)",
"80":" squat down",
"81":" toss a coin",
"82":" fold paper",
"83":" ball up paper",
"84":" play magic cube",
"85":" apply cream on face",
"86":" apply cream on hand back",
"87":" put on bag",
"88":" take off bag",
"89":" put something into a bag",
"90":" take something out of a bag",
"91":" open a box",
"92":" move heavy objects",
"93":" shake fist",
"94":" throw up cap/hat",
"95":" hands up (both hands)",
"96":" cross arms",
"97":" arm circles",
"98":" arm swings",
"99":" running on the spot",
"100":" butt kicks (kick backward)",
"101":" cross toe touch",
"102":" side kick",
"103":" yawn",
"104":" stretch oneself",
"105":" blow nose",
"106":" hit other person with something",
"107":" wield knife towards other person",
"108":" knock over other person (hit with body)",
"109":" grab other persons stuff",
"110":" shoot at other person with a gun",
"111":" step on foot",
"112":" high-five",
"113":" cheers and drink",
"114":" carry something with other person",
"115":" take a photo of other person",
"116":" follow other person",
"117":" whisper in other persons ear",
"118":" exchange things with other person",
"119":" support somebody with hand",
"120":" finger-guessing game (playing rock-paper-scissors)",
NTU = {
"1":"drink water",
"2":"eat meal/snack",
"3":"brushing teeth",
"4":"brushing hair",
"5":"drop",
"6":"pickup",
"7":"throw",
"8":"sitting down",
"9":"standing up (from sitting position)",
"10":"clapping",
"11":"reading",
"12":"writing",
"13":"tear up paper",
"14":"wear jacket",
"15":"take off jacket",
"16":"wear a shoe",
"17":"take off a shoe",
"18":"wear on glasses",
"19":"take off glasses",
"20":"put on a hat/cap",
"21":"take off a hat/cap",
"22":"cheer up",
"23":"hand waving",
"24":"kicking something",
"25":"reach into pocket",
"26":"hopping (one foot jumping)",
"27":"jump up",
"28":"make a phone call/answer phone",
"29":"playing with phone/tablet",
"30":"typing on a keyboard",
"31":"pointing to something with finger",
"32":"taking a selfie",
"33":"check time (from watch)",
"34":"rub two hands together",
"35":"nod head/bow",
"36":"shake head",
"37":"wipe face",
"38":"salute",
"39":"put the palms together",
"40":"cross hands in front (say stop)",
"41":"sneeze/cough",
"42":"staggering",
"43":"falling",
"44":"touch head (headache)",
"45":"touch chest (stomachache/heart pain)",
"46":"touch back (backache)",
"47":"touch neck (neckache)",
"48":"nausea or vomiting condition",
"49":"use a fan (with hand or paper)/feeling warm",
"50":"punching/slapping other person",
"51":"kicking other person",
"52":"pushing other person",
"53":"pat on back of other person",
"54":"point finger at the other person",
"55":"hugging other person",
"56":"giving something to other person",
"57":"touch other person's pocket",
"58":"handshaking",
"59":"walking towards each other",
"60":"walking apart from each other",
"61":"put on headphone",
"62":"take off headphone",
"63":"shoot at the basket",
"64":"bounce ball",
"65":"tennis bat swing",
"66":"juggling table tennis balls",
"67":"hush (quite)",
"68":"flick hair",
"69":"thumb up",
"70":"thumb down",
"71":"make ok sign",
"72":"make victory sign",
"73":"staple book",
"74":"counting money",
"75":"cutting nails",
"76":"cutting paper (using scissors)",
"77":"snapping fingers",
"78":"open bottle",
"79":"sniff (smell)",
"80":"squat down",
"81":"toss a coin",
"82":"fold paper",
"83":"ball up paper",
"84":"play magic cube",
"85":"apply cream on face",
"86":"apply cream on hand back",
"87":"put on bag",
"88":"take off bag",
"89":"put something into a bag",
"90":"take something out of a bag",
"91":"open a box",
"92":"move heavy objects",
"93":"shake fist",
"94":"throw up cap/hat",
"95":"hands up (both hands)",
"96":"cross arms",
"97":"arm circles",
"98":"arm swings",
"99":"running on the spot",
"100":"butt kicks (kick backward)",
"101":"cross toe touch",
"102":"side kick",
"103":"yawn",
"104":"stretch oneself",
"105":"blow nose",
"106":"hit other person with something",
"107":"wield knife towards other person",
"108":"knock over other person (hit with body)",
"109":"grab other persons stuff",
"110":"shoot at other person with a gun",
"111":"step on foot",
"112":"high-five",
"113":"cheers and drink",
"114":"carry something with other person",
"115":"take a photo of other person",
"116":"follow other person",
"117":"whisper in other persons ear",
"118":"exchange things with other person",
"119":"support somebody with hand",
"120":"finger-guessing game (playing rock-paper-scissors)",
}
def get_label(file_name, dataset_name):

View File

@ -19,5 +19,7 @@ def load(name, path):
elif name == "Human3.6M":
return np.load(path, allow_pickle=True)[0::5] # down_sample
elif name == "NTU":
return np.load(path, allow_pickle=True)[0::2]
elif name == "HAA4D":
return np.load(path, allow_pickle=True)

View File

@ -1,32 +1,33 @@
import os
import sys
sys.path.append(os.getcwd())
from meters import Meters
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
from train import train
from transform import transform
from save import save_pic, save_params
from load import load
import torch
import numpy as np
from tensorboardX import SummaryWriter
from easydict import EasyDict as edict
import time
import sys
import os
import logging
import argparse
import json
sys.path.append(os.getcwd())
from load import load
from save import save_pic, save_params
from transform import transform
from train import train
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
from meters import Meters
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
def parse_args():
parser = argparse.ArgumentParser(description='Fit SMPL')
parser.add_argument('--exp', dest='exp',
help='Define exp name',
default=time.strftime('%Y-%m-%d %H-%M-%S', time.localtime(time.time())), type=str)
parser.add_argument('--dataset_name', dest='dataset_name',
parser.add_argument('--dataset_name', '-n', dest='dataset_name',
help='select dataset',
default='', type=str)
parser.add_argument('--dataset_path', dest='dataset_path',
@ -98,14 +99,17 @@ if __name__ == "__main__":
gender=cfg.MODEL.GENDER,
model_root='smplpytorch/native/models')
meters=Meters()
meters = Meters()
file_num = 0
for root, dirs, files in os.walk(cfg.DATASET.PATH):
for file in files:
for file in sorted(files):
if not 'baseball_swing' in file:
continue
file_num += 1
logger.info('Processing file: {} [{} / {}]'.format(file, file_num, len(files)))
target = torch.from_numpy(transform(args.dataset_name,load(args.dataset_name,
os.path.join(root, file)))).float()
logger.info(
'Processing file: {} [{} / {}]'.format(file, file_num, len(files)))
target = torch.from_numpy(transform(args.dataset_name, load(args.dataset_name,
os.path.join(root, file)))).float()
logger.info("target shape:{}".format(target.shape))
res = train(smpl_layer, target,
logger, writer, device,
@ -115,7 +119,8 @@ if __name__ == "__main__":
logger.info("avg_loss:{:.4f}".format(meters.avg))
save_params(res, file, logger, args.dataset_name)
# save_pic(res,smpl_layer,file,logger,args.dataset_name,target)
save_pic(res, smpl_layer, file, logger, args.dataset_name, target)
torch.cuda.empty_cache()
logger.info("Fitting finished! Average loss: {:.9f}".format(meters.avg))
logger.info(
"Fitting finished! Average loss: {:.9f}".format(meters.avg))

View File

@ -19,7 +19,7 @@ def save_pic(res, smpl_layer, file, logger, dataset_name, target):
_, _, verts, Jtr = res
file_name = re.split('[/.]', file)[-2]
fit_path = "fit/output/{}/picture/{}".format(dataset_name, file_name)
create_dir_not_exist(fit_path)
os.makedirs(fit_path,exist_ok=True)
logger.info('Saving pictures at {}'.format(fit_path))
for i in tqdm(range(Jtr.shape[0])):
display_model(
@ -28,7 +28,7 @@ def save_pic(res, smpl_layer, file, logger, dataset_name, target):
model_faces=smpl_layer.th_faces,
with_joints=True,
kintree_table=smpl_layer.kintree_table,
savepath=os.path.join(fit_path+"/frame_{}".format(i)),
savepath=os.path.join(fit_path+"/frame_{:0>4d}".format(i)),
batch_idx=i,
show=False,
only_joint=True)
@ -55,3 +55,21 @@ def save_params(res, file, logger, dataset_name):
with open(os.path.join((fit_path),
"{}_params.pkl".format(file_name)), 'wb') as f:
pickle.dump(params, f)
def save_single_pic(res, smpl_layer, epoch, logger, dataset_name, target):
_, _, verts, Jtr = res
fit_path = "fit/output/{}/picture".format(dataset_name)
create_dir_not_exist(fit_path)
logger.info('Saving pictures at {}'.format(fit_path))
display_model(
{'verts': verts.cpu().detach(),
'joints': Jtr.cpu().detach()},
model_faces=smpl_layer.th_faces,
with_joints=True,
kintree_table=smpl_layer.kintree_table,
savepath=fit_path+"/epoch_{:0>4d}".format(epoch),
batch_idx=60,
show=False,
only_joint=False)
logger.info('Picture saved')

View File

@ -6,6 +6,7 @@ import os
from tqdm import tqdm
sys.path.append(os.getcwd())
from save import save_single_pic
@ -25,8 +26,10 @@ def init(smpl_layer, target, device, cfg):
params["shape_params"].requires_grad = bool(cfg.TRAIN.OPTIMIZE_SHAPE)
params["scale"].requires_grad = bool(cfg.TRAIN.OPTIMIZE_SCALE)
optimizer = optim.Adam([params["pose_params"], params["shape_params"], params["scale"]],
lr=cfg.TRAIN.LEARNING_RATE)
optim_params = [{'params': params["pose_params"], 'lr': cfg.TRAIN.LEARNING_RATE},
{'params': params["shape_params"], 'lr': cfg.TRAIN.LEARNING_RATE},
{'params': params["scale"], 'lr': cfg.TRAIN.LEARNING_RATE*10},]
optimizer = optim.Adam(optim_params)
index = {}
smpl_index = []
@ -51,11 +54,14 @@ def train(smpl_layer, target,
shape_params = params["shape_params"]
scale = params["scale"]
for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)):
# for epoch in range(cfg.TRAIN.MAX_EPOCH):
with torch.no_grad():
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
loss = F.smooth_l1_loss(Jtr.index_select(1, index["smpl_index"]) * 100,
target.index_select(1, index["dataset_index"]) * 100 * scale)
params["scale"]*=(torch.max(torch.abs(target))/torch.max(torch.abs(Jtr)))
for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)):
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
loss = F.smooth_l1_loss(scale*Jtr.index_select(1, index["smpl_index"]),
target.index_select(1, index["dataset_index"]))
optimizer.zero_grad()
loss.backward()
optimizer.step()
@ -67,12 +73,15 @@ def train(smpl_layer, target,
logger.info("Early stop at epoch {} !".format(epoch))
break
if epoch % cfg.TRAIN.WRITE == 0:
if epoch % cfg.TRAIN.WRITE == 0 or epoch<10:
# logger.info("Epoch {}, lossPerBatch={:.6f}, scale={:.4f}".format(
# epoch, float(loss),float(scale)))
print("Epoch {}, lossPerBatch={:.6f}, scale={:.4f}".format(
epoch, float(loss),float(scale)))
writer.add_scalar('loss', float(loss), epoch)
writer.add_scalar('learning_rate', float(
optimizer.state_dict()['param_groups'][0]['lr']), epoch)
# save_single_pic(res,smpl_layer,epoch,logger,args.dataset_name,target)
logger.info('Train ended, min_loss = {:.4f}'.format(
float(meters.min_loss)))

View File

@ -5,7 +5,8 @@ rotate = {
'CMU_Mocap': [0.05, 0.05, 0.05],
'UTD_MHAD': [-1., 1., -1.],
'Human3.6M': [-0.001, -0.001, 0.001],
'NTU': [-1., 1., -1.]
'NTU': [1., 1., -1.],
'HAA4D': [1., -1., -1.],
}

View File

@ -1,7 +1,7 @@
import matplotlib.pyplot as plt
import imageio, os
images = []
filenames = sorted(fn for fn in os.listdir('./fit/output/Human3.6M/picture/fit/s_01_act_09_subact_02_ca_02') )
filenames = sorted(fn for fn in os.listdir('D:/OneDrive - sjtu.edu.cn/MVIG/Action-Dataset/Pose_to_SMPL/fit/output/NTU/picture') )
for filename in filenames:
images.append(imageio.imread('./fit/output/Human3.6M/picture/fit/s_01_act_09_subact_02_ca_02/'+filename))
imageio.mimsave('fit_mesh.gif', images, duration=0.2)
images.append(imageio.imread('D:/OneDrive - sjtu.edu.cn/MVIG/Action-Dataset/Pose_to_SMPL/fit/output/NTU/picture/'+filename))
imageio.mimsave('clapping_example.gif', images, duration=0.2)