Support HAA4D

This commit is contained in:
Yiming Dou
2022-10-01 21:42:12 +08:00
parent fd2ff6616f
commit 6ac10e22e0
12 changed files with 232 additions and 170 deletions

View File

@ -47,6 +47,7 @@ The SMPL human body layer for Pytorch is from the [smplpytorch](https://github.c
- [UTD-MHAD](https://personal.utdallas.edu/~kehtar/UTD-MHAD.html)
- [Human3.6M](http://vision.imar.ro/human3.6m/description.php)
- [NTU](https://rose1.ntu.edu.sg/dataset/actionRecognition/)
- [HAA4D](https://cse.hkust.edu.hk/haa4d/dataset.html)
- Set the **DATASET.PATH** in the corresponding configuration file to the location of dataset.

View File

@ -1,4 +1,6 @@
import torch
import random
import numpy as np
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
from display_utils import display_model
@ -15,7 +17,7 @@ if __name__ == '__main__':
model_root='smplpytorch/native/models')
# Generate random pose and shape parameters
pose_params = torch.rand(batch_size, 72) * 0.2
pose_params = torch.rand(batch_size, 72) * 0.01
shape_params = torch.rand(batch_size, 10) * 0.03
# GPU mode
@ -26,7 +28,6 @@ if __name__ == '__main__':
# Forward from the SMPL layer
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
print(Jtr)
# Draw output vertices and joints
display_model(

View File

@ -1,3 +1,4 @@
from xml.parsers.expat import model
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
@ -21,7 +22,8 @@ def display_model(
if ax is None:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
verts, joints = model_info['verts'][batch_idx], model_info['joints'][batch_idx]
verts = model_info['verts'][batch_idx]
joints = model_info['joints'][batch_idx]
if model_faces is None:
ax.scatter(verts[:, 0], verts[:, 1], verts[:, 2], alpha=0.2)
elif not only_joint:

24
fit/configs/HAA4D.json Normal file
View File

@ -0,0 +1,24 @@
{
"MODEL": {
"GENDER": "neutral"
},
"TRAIN": {
"LEARNING_RATE": 1e-2,
"MAX_EPOCH": 1000,
"WRITE": 10,
"OPTIMIZE_SCALE":1,
"OPTIMIZE_SHAPE":1
},
"USE_GPU": 1,
"DATASET": {
"NAME": "NTU",
"PATH": "../NTU RGB+D/result",
"TARGET_PATH": "",
"DATA_MAP": [
[0,0],[1,4],[2,1],[4,5],[5,2],[7,6],[8,3],
[12,9],[18,12],[19,15],[20,13],[21,16],
[15,10],[6,1]
]
},
"DEBUG": 0
}

View File

@ -20,19 +20,19 @@
0
],
[
1,
2,
12
],
[
2,
1,
16
],
[
4,
5,
13
],
[
5,
4,
17
],
[
@ -40,52 +40,51 @@
1
],
[
7,
8,
14
],
[
8,
7,
18
],
[
9,
20
],
[
12,
2
],
[
13,
14,
4
],
[
14,
13,
8
],
[
18,
19,
5
],
[
19,
18,
9
],
[
20,
21,
6
],
[
21,
20,
10
],
[
22,
23,
22
],
[
23,
22,
24
]
]

View File

@ -19,5 +19,7 @@ def load(name, path):
elif name == "Human3.6M":
return np.load(path, allow_pickle=True)[0::5] # down_sample
elif name == "NTU":
return np.load(path, allow_pickle=True)[0::2]
elif name == "HAA4D":
return np.load(path, allow_pickle=True)

View File

@ -1,32 +1,33 @@
import os
import sys
sys.path.append(os.getcwd())
from meters import Meters
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
from train import train
from transform import transform
from save import save_pic, save_params
from load import load
import torch
import numpy as np
from tensorboardX import SummaryWriter
from easydict import EasyDict as edict
import time
import sys
import os
import logging
import argparse
import json
sys.path.append(os.getcwd())
from load import load
from save import save_pic, save_params
from transform import transform
from train import train
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
from meters import Meters
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
def parse_args():
parser = argparse.ArgumentParser(description='Fit SMPL')
parser.add_argument('--exp', dest='exp',
help='Define exp name',
default=time.strftime('%Y-%m-%d %H-%M-%S', time.localtime(time.time())), type=str)
parser.add_argument('--dataset_name', dest='dataset_name',
parser.add_argument('--dataset_name', '-n', dest='dataset_name',
help='select dataset',
default='', type=str)
parser.add_argument('--dataset_path', dest='dataset_path',
@ -101,9 +102,12 @@ if __name__ == "__main__":
meters = Meters()
file_num = 0
for root, dirs, files in os.walk(cfg.DATASET.PATH):
for file in files:
for file in sorted(files):
if not 'baseball_swing' in file:
continue
file_num += 1
logger.info('Processing file: {} [{} / {}]'.format(file, file_num, len(files)))
logger.info(
'Processing file: {} [{} / {}]'.format(file, file_num, len(files)))
target = torch.from_numpy(transform(args.dataset_name, load(args.dataset_name,
os.path.join(root, file)))).float()
logger.info("target shape:{}".format(target.shape))
@ -115,7 +119,8 @@ if __name__ == "__main__":
logger.info("avg_loss:{:.4f}".format(meters.avg))
save_params(res, file, logger, args.dataset_name)
# save_pic(res,smpl_layer,file,logger,args.dataset_name,target)
save_pic(res, smpl_layer, file, logger, args.dataset_name, target)
torch.cuda.empty_cache()
logger.info("Fitting finished! Average loss: {:.9f}".format(meters.avg))
logger.info(
"Fitting finished! Average loss: {:.9f}".format(meters.avg))

View File

@ -19,7 +19,7 @@ def save_pic(res, smpl_layer, file, logger, dataset_name, target):
_, _, verts, Jtr = res
file_name = re.split('[/.]', file)[-2]
fit_path = "fit/output/{}/picture/{}".format(dataset_name, file_name)
create_dir_not_exist(fit_path)
os.makedirs(fit_path,exist_ok=True)
logger.info('Saving pictures at {}'.format(fit_path))
for i in tqdm(range(Jtr.shape[0])):
display_model(
@ -28,7 +28,7 @@ def save_pic(res, smpl_layer, file, logger, dataset_name, target):
model_faces=smpl_layer.th_faces,
with_joints=True,
kintree_table=smpl_layer.kintree_table,
savepath=os.path.join(fit_path+"/frame_{}".format(i)),
savepath=os.path.join(fit_path+"/frame_{:0>4d}".format(i)),
batch_idx=i,
show=False,
only_joint=True)
@ -55,3 +55,21 @@ def save_params(res, file, logger, dataset_name):
with open(os.path.join((fit_path),
"{}_params.pkl".format(file_name)), 'wb') as f:
pickle.dump(params, f)
def save_single_pic(res, smpl_layer, epoch, logger, dataset_name, target):
_, _, verts, Jtr = res
fit_path = "fit/output/{}/picture".format(dataset_name)
create_dir_not_exist(fit_path)
logger.info('Saving pictures at {}'.format(fit_path))
display_model(
{'verts': verts.cpu().detach(),
'joints': Jtr.cpu().detach()},
model_faces=smpl_layer.th_faces,
with_joints=True,
kintree_table=smpl_layer.kintree_table,
savepath=fit_path+"/epoch_{:0>4d}".format(epoch),
batch_idx=60,
show=False,
only_joint=False)
logger.info('Picture saved')

View File

@ -6,6 +6,7 @@ import os
from tqdm import tqdm
sys.path.append(os.getcwd())
from save import save_single_pic
@ -25,8 +26,10 @@ def init(smpl_layer, target, device, cfg):
params["shape_params"].requires_grad = bool(cfg.TRAIN.OPTIMIZE_SHAPE)
params["scale"].requires_grad = bool(cfg.TRAIN.OPTIMIZE_SCALE)
optimizer = optim.Adam([params["pose_params"], params["shape_params"], params["scale"]],
lr=cfg.TRAIN.LEARNING_RATE)
optim_params = [{'params': params["pose_params"], 'lr': cfg.TRAIN.LEARNING_RATE},
{'params': params["shape_params"], 'lr': cfg.TRAIN.LEARNING_RATE},
{'params': params["scale"], 'lr': cfg.TRAIN.LEARNING_RATE*10},]
optimizer = optim.Adam(optim_params)
index = {}
smpl_index = []
@ -51,11 +54,14 @@ def train(smpl_layer, target,
shape_params = params["shape_params"]
scale = params["scale"]
for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)):
# for epoch in range(cfg.TRAIN.MAX_EPOCH):
with torch.no_grad():
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
loss = F.smooth_l1_loss(Jtr.index_select(1, index["smpl_index"]) * 100,
target.index_select(1, index["dataset_index"]) * 100 * scale)
params["scale"]*=(torch.max(torch.abs(target))/torch.max(torch.abs(Jtr)))
for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)):
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
loss = F.smooth_l1_loss(scale*Jtr.index_select(1, index["smpl_index"]),
target.index_select(1, index["dataset_index"]))
optimizer.zero_grad()
loss.backward()
optimizer.step()
@ -67,12 +73,15 @@ def train(smpl_layer, target,
logger.info("Early stop at epoch {} !".format(epoch))
break
if epoch % cfg.TRAIN.WRITE == 0:
if epoch % cfg.TRAIN.WRITE == 0 or epoch<10:
# logger.info("Epoch {}, lossPerBatch={:.6f}, scale={:.4f}".format(
# epoch, float(loss),float(scale)))
print("Epoch {}, lossPerBatch={:.6f}, scale={:.4f}".format(
epoch, float(loss),float(scale)))
writer.add_scalar('loss', float(loss), epoch)
writer.add_scalar('learning_rate', float(
optimizer.state_dict()['param_groups'][0]['lr']), epoch)
# save_single_pic(res,smpl_layer,epoch,logger,args.dataset_name,target)
logger.info('Train ended, min_loss = {:.4f}'.format(
float(meters.min_loss)))

View File

@ -5,7 +5,8 @@ rotate = {
'CMU_Mocap': [0.05, 0.05, 0.05],
'UTD_MHAD': [-1., 1., -1.],
'Human3.6M': [-0.001, -0.001, 0.001],
'NTU': [-1., 1., -1.]
'NTU': [1., 1., -1.],
'HAA4D': [1., -1., -1.],
}

View File

@ -1,7 +1,7 @@
import matplotlib.pyplot as plt
import imageio, os
images = []
filenames = sorted(fn for fn in os.listdir('./fit/output/Human3.6M/picture/fit/s_01_act_09_subact_02_ca_02') )
filenames = sorted(fn for fn in os.listdir('D:/OneDrive - sjtu.edu.cn/MVIG/Action-Dataset/Pose_to_SMPL/fit/output/NTU/picture') )
for filename in filenames:
images.append(imageio.imread('./fit/output/Human3.6M/picture/fit/s_01_act_09_subact_02_ca_02/'+filename))
imageio.mimsave('fit_mesh.gif', images, duration=0.2)
images.append(imageio.imread('D:/OneDrive - sjtu.edu.cn/MVIG/Action-Dataset/Pose_to_SMPL/fit/output/NTU/picture/'+filename))
imageio.mimsave('clapping_example.gif', images, duration=0.2)