Support HAA4D

This commit is contained in:
Yiming Dou
2022-10-01 21:42:12 +08:00
parent fd2ff6616f
commit 6ac10e22e0
12 changed files with 232 additions and 170 deletions

View File

@ -47,6 +47,7 @@ The SMPL human body layer for Pytorch is from the [smplpytorch](https://github.c
- [UTD-MHAD](https://personal.utdallas.edu/~kehtar/UTD-MHAD.html) - [UTD-MHAD](https://personal.utdallas.edu/~kehtar/UTD-MHAD.html)
- [Human3.6M](http://vision.imar.ro/human3.6m/description.php) - [Human3.6M](http://vision.imar.ro/human3.6m/description.php)
- [NTU](https://rose1.ntu.edu.sg/dataset/actionRecognition/) - [NTU](https://rose1.ntu.edu.sg/dataset/actionRecognition/)
- [HAA4D](https://cse.hkust.edu.hk/haa4d/dataset.html)
- Set the **DATASET.PATH** in the corresponding configuration file to the location of dataset. - Set the **DATASET.PATH** in the corresponding configuration file to the location of dataset.

View File

@ -1,4 +1,6 @@
import torch import torch
import random
import numpy as np
from smplpytorch.pytorch.smpl_layer import SMPL_Layer from smplpytorch.pytorch.smpl_layer import SMPL_Layer
from display_utils import display_model from display_utils import display_model
@ -15,7 +17,7 @@ if __name__ == '__main__':
model_root='smplpytorch/native/models') model_root='smplpytorch/native/models')
# Generate random pose and shape parameters # Generate random pose and shape parameters
pose_params = torch.rand(batch_size, 72) * 0.2 pose_params = torch.rand(batch_size, 72) * 0.01
shape_params = torch.rand(batch_size, 10) * 0.03 shape_params = torch.rand(batch_size, 10) * 0.03
# GPU mode # GPU mode
@ -26,7 +28,6 @@ if __name__ == '__main__':
# Forward from the SMPL layer # Forward from the SMPL layer
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params) verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
print(Jtr)
# Draw output vertices and joints # Draw output vertices and joints
display_model( display_model(

View File

@ -1,3 +1,4 @@
from xml.parsers.expat import model
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection from mpl_toolkits.mplot3d.art3d import Poly3DCollection
@ -21,7 +22,8 @@ def display_model(
if ax is None: if ax is None:
fig = plt.figure() fig = plt.figure()
ax = fig.add_subplot(111, projection='3d') ax = fig.add_subplot(111, projection='3d')
verts, joints = model_info['verts'][batch_idx], model_info['joints'][batch_idx] verts = model_info['verts'][batch_idx]
joints = model_info['joints'][batch_idx]
if model_faces is None: if model_faces is None:
ax.scatter(verts[:, 0], verts[:, 1], verts[:, 2], alpha=0.2) ax.scatter(verts[:, 0], verts[:, 1], verts[:, 2], alpha=0.2)
elif not only_joint: elif not only_joint:

24
fit/configs/HAA4D.json Normal file
View File

@ -0,0 +1,24 @@
{
"MODEL": {
"GENDER": "neutral"
},
"TRAIN": {
"LEARNING_RATE": 1e-2,
"MAX_EPOCH": 1000,
"WRITE": 10,
"OPTIMIZE_SCALE":1,
"OPTIMIZE_SHAPE":1
},
"USE_GPU": 1,
"DATASET": {
"NAME": "NTU",
"PATH": "../NTU RGB+D/result",
"TARGET_PATH": "",
"DATA_MAP": [
[0,0],[1,4],[2,1],[4,5],[5,2],[7,6],[8,3],
[12,9],[18,12],[19,15],[20,13],[21,16],
[15,10],[6,1]
]
},
"DEBUG": 0
}

View File

@ -20,19 +20,19 @@
0 0
], ],
[ [
1, 2,
12 12
], ],
[ [
2, 1,
16 16
], ],
[ [
4, 5,
13 13
], ],
[ [
5, 4,
17 17
], ],
[ [
@ -40,52 +40,51 @@
1 1
], ],
[ [
7, 8,
14 14
], ],
[ [
8, 7,
18 18
], ],
[ [
9, 9,
20 20
], ],
[ [
12, 12,
2 2
], ],
[ [
13, 14,
4 4
], ],
[ [
14, 13,
8 8
], ],
[ [
18, 19,
5 5
], ],
[ [
19, 18,
9 9
], ],
[ [
20, 21,
6 6
], ],
[ [
21, 20,
10 10
], ],
[ [
22, 23,
22 22
], ],
[ [
23, 22,
24 24
] ]
] ]

View File

@ -19,5 +19,7 @@ def load(name, path):
elif name == "Human3.6M": elif name == "Human3.6M":
return np.load(path, allow_pickle=True)[0::5] # down_sample return np.load(path, allow_pickle=True)[0::5] # down_sample
elif name == "NTU": elif name == "NTU":
return np.load(path, allow_pickle=True)[0::2]
elif name == "HAA4D":
return np.load(path, allow_pickle=True) return np.load(path, allow_pickle=True)

View File

@ -1,32 +1,33 @@
import os
import sys
sys.path.append(os.getcwd())
from meters import Meters
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
from train import train
from transform import transform
from save import save_pic, save_params
from load import load
import torch import torch
import numpy as np import numpy as np
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from easydict import EasyDict as edict from easydict import EasyDict as edict
import time import time
import sys
import os
import logging import logging
import argparse import argparse
import json import json
sys.path.append(os.getcwd())
from load import load
from save import save_pic, save_params
from transform import transform
from train import train
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
from meters import Meters
torch.backends.cudnn.enabled = True torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Fit SMPL') parser = argparse.ArgumentParser(description='Fit SMPL')
parser.add_argument('--exp', dest='exp', parser.add_argument('--exp', dest='exp',
help='Define exp name', help='Define exp name',
default=time.strftime('%Y-%m-%d %H-%M-%S', time.localtime(time.time())), type=str) default=time.strftime('%Y-%m-%d %H-%M-%S', time.localtime(time.time())), type=str)
parser.add_argument('--dataset_name', dest='dataset_name', parser.add_argument('--dataset_name', '-n', dest='dataset_name',
help='select dataset', help='select dataset',
default='', type=str) default='', type=str)
parser.add_argument('--dataset_path', dest='dataset_path', parser.add_argument('--dataset_path', dest='dataset_path',
@ -101,9 +102,12 @@ if __name__ == "__main__":
meters = Meters() meters = Meters()
file_num = 0 file_num = 0
for root, dirs, files in os.walk(cfg.DATASET.PATH): for root, dirs, files in os.walk(cfg.DATASET.PATH):
for file in files: for file in sorted(files):
if not 'baseball_swing' in file:
continue
file_num += 1 file_num += 1
logger.info('Processing file: {} [{} / {}]'.format(file, file_num, len(files))) logger.info(
'Processing file: {} [{} / {}]'.format(file, file_num, len(files)))
target = torch.from_numpy(transform(args.dataset_name, load(args.dataset_name, target = torch.from_numpy(transform(args.dataset_name, load(args.dataset_name,
os.path.join(root, file)))).float() os.path.join(root, file)))).float()
logger.info("target shape:{}".format(target.shape)) logger.info("target shape:{}".format(target.shape))
@ -115,7 +119,8 @@ if __name__ == "__main__":
logger.info("avg_loss:{:.4f}".format(meters.avg)) logger.info("avg_loss:{:.4f}".format(meters.avg))
save_params(res, file, logger, args.dataset_name) save_params(res, file, logger, args.dataset_name)
# save_pic(res,smpl_layer,file,logger,args.dataset_name,target) save_pic(res, smpl_layer, file, logger, args.dataset_name, target)
torch.cuda.empty_cache() torch.cuda.empty_cache()
logger.info("Fitting finished! Average loss: {:.9f}".format(meters.avg)) logger.info(
"Fitting finished! Average loss: {:.9f}".format(meters.avg))

View File

@ -19,7 +19,7 @@ def save_pic(res, smpl_layer, file, logger, dataset_name, target):
_, _, verts, Jtr = res _, _, verts, Jtr = res
file_name = re.split('[/.]', file)[-2] file_name = re.split('[/.]', file)[-2]
fit_path = "fit/output/{}/picture/{}".format(dataset_name, file_name) fit_path = "fit/output/{}/picture/{}".format(dataset_name, file_name)
create_dir_not_exist(fit_path) os.makedirs(fit_path,exist_ok=True)
logger.info('Saving pictures at {}'.format(fit_path)) logger.info('Saving pictures at {}'.format(fit_path))
for i in tqdm(range(Jtr.shape[0])): for i in tqdm(range(Jtr.shape[0])):
display_model( display_model(
@ -28,7 +28,7 @@ def save_pic(res, smpl_layer, file, logger, dataset_name, target):
model_faces=smpl_layer.th_faces, model_faces=smpl_layer.th_faces,
with_joints=True, with_joints=True,
kintree_table=smpl_layer.kintree_table, kintree_table=smpl_layer.kintree_table,
savepath=os.path.join(fit_path+"/frame_{}".format(i)), savepath=os.path.join(fit_path+"/frame_{:0>4d}".format(i)),
batch_idx=i, batch_idx=i,
show=False, show=False,
only_joint=True) only_joint=True)
@ -55,3 +55,21 @@ def save_params(res, file, logger, dataset_name):
with open(os.path.join((fit_path), with open(os.path.join((fit_path),
"{}_params.pkl".format(file_name)), 'wb') as f: "{}_params.pkl".format(file_name)), 'wb') as f:
pickle.dump(params, f) pickle.dump(params, f)
def save_single_pic(res, smpl_layer, epoch, logger, dataset_name, target):
_, _, verts, Jtr = res
fit_path = "fit/output/{}/picture".format(dataset_name)
create_dir_not_exist(fit_path)
logger.info('Saving pictures at {}'.format(fit_path))
display_model(
{'verts': verts.cpu().detach(),
'joints': Jtr.cpu().detach()},
model_faces=smpl_layer.th_faces,
with_joints=True,
kintree_table=smpl_layer.kintree_table,
savepath=fit_path+"/epoch_{:0>4d}".format(epoch),
batch_idx=60,
show=False,
only_joint=False)
logger.info('Picture saved')

View File

@ -6,6 +6,7 @@ import os
from tqdm import tqdm from tqdm import tqdm
sys.path.append(os.getcwd()) sys.path.append(os.getcwd())
from save import save_single_pic
@ -25,8 +26,10 @@ def init(smpl_layer, target, device, cfg):
params["shape_params"].requires_grad = bool(cfg.TRAIN.OPTIMIZE_SHAPE) params["shape_params"].requires_grad = bool(cfg.TRAIN.OPTIMIZE_SHAPE)
params["scale"].requires_grad = bool(cfg.TRAIN.OPTIMIZE_SCALE) params["scale"].requires_grad = bool(cfg.TRAIN.OPTIMIZE_SCALE)
optimizer = optim.Adam([params["pose_params"], params["shape_params"], params["scale"]], optim_params = [{'params': params["pose_params"], 'lr': cfg.TRAIN.LEARNING_RATE},
lr=cfg.TRAIN.LEARNING_RATE) {'params': params["shape_params"], 'lr': cfg.TRAIN.LEARNING_RATE},
{'params': params["scale"], 'lr': cfg.TRAIN.LEARNING_RATE*10},]
optimizer = optim.Adam(optim_params)
index = {} index = {}
smpl_index = [] smpl_index = []
@ -51,11 +54,14 @@ def train(smpl_layer, target,
shape_params = params["shape_params"] shape_params = params["shape_params"]
scale = params["scale"] scale = params["scale"]
for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)): with torch.no_grad():
# for epoch in range(cfg.TRAIN.MAX_EPOCH):
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params) verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
loss = F.smooth_l1_loss(Jtr.index_select(1, index["smpl_index"]) * 100, params["scale"]*=(torch.max(torch.abs(target))/torch.max(torch.abs(Jtr)))
target.index_select(1, index["dataset_index"]) * 100 * scale)
for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)):
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
loss = F.smooth_l1_loss(scale*Jtr.index_select(1, index["smpl_index"]),
target.index_select(1, index["dataset_index"]))
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
@ -67,12 +73,15 @@ def train(smpl_layer, target,
logger.info("Early stop at epoch {} !".format(epoch)) logger.info("Early stop at epoch {} !".format(epoch))
break break
if epoch % cfg.TRAIN.WRITE == 0: if epoch % cfg.TRAIN.WRITE == 0 or epoch<10:
# logger.info("Epoch {}, lossPerBatch={:.6f}, scale={:.4f}".format( # logger.info("Epoch {}, lossPerBatch={:.6f}, scale={:.4f}".format(
# epoch, float(loss),float(scale))) # epoch, float(loss),float(scale)))
print("Epoch {}, lossPerBatch={:.6f}, scale={:.4f}".format(
epoch, float(loss),float(scale)))
writer.add_scalar('loss', float(loss), epoch) writer.add_scalar('loss', float(loss), epoch)
writer.add_scalar('learning_rate', float( writer.add_scalar('learning_rate', float(
optimizer.state_dict()['param_groups'][0]['lr']), epoch) optimizer.state_dict()['param_groups'][0]['lr']), epoch)
# save_single_pic(res,smpl_layer,epoch,logger,args.dataset_name,target)
logger.info('Train ended, min_loss = {:.4f}'.format( logger.info('Train ended, min_loss = {:.4f}'.format(
float(meters.min_loss))) float(meters.min_loss)))

View File

@ -5,7 +5,8 @@ rotate = {
'CMU_Mocap': [0.05, 0.05, 0.05], 'CMU_Mocap': [0.05, 0.05, 0.05],
'UTD_MHAD': [-1., 1., -1.], 'UTD_MHAD': [-1., 1., -1.],
'Human3.6M': [-0.001, -0.001, 0.001], 'Human3.6M': [-0.001, -0.001, 0.001],
'NTU': [-1., 1., -1.] 'NTU': [1., 1., -1.],
'HAA4D': [1., -1., -1.],
} }

View File

@ -1,7 +1,7 @@
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import imageio, os import imageio, os
images = [] images = []
filenames = sorted(fn for fn in os.listdir('./fit/output/Human3.6M/picture/fit/s_01_act_09_subact_02_ca_02') ) filenames = sorted(fn for fn in os.listdir('D:/OneDrive - sjtu.edu.cn/MVIG/Action-Dataset/Pose_to_SMPL/fit/output/NTU/picture') )
for filename in filenames: for filename in filenames:
images.append(imageio.imread('./fit/output/Human3.6M/picture/fit/s_01_act_09_subact_02_ca_02/'+filename)) images.append(imageio.imread('D:/OneDrive - sjtu.edu.cn/MVIG/Action-Dataset/Pose_to_SMPL/fit/output/NTU/picture/'+filename))
imageio.mimsave('fit_mesh.gif', images, duration=0.2) imageio.mimsave('clapping_example.gif', images, duration=0.2)