update code

This commit is contained in:
Iridoudou
2021-08-05 16:22:06 +08:00
parent 791f02f280
commit bc0766fc76
11 changed files with 47 additions and 17 deletions

View File

@ -24,14 +24,13 @@ def display_model(
verts, joints = model_info['verts'][batch_idx], model_info['joints'][batch_idx] verts, joints = model_info['verts'][batch_idx], model_info['joints'][batch_idx]
if model_faces is None: if model_faces is None:
ax.scatter(verts[:, 0], verts[:, 1], verts[:, 2], alpha=0.2) ax.scatter(verts[:, 0], verts[:, 1], verts[:, 2], alpha=0.2)
else: elif not only_joint:
mesh = Poly3DCollection(verts[model_faces], alpha=0.2) mesh = Poly3DCollection(verts[model_faces], alpha=0.2)
face_color = (141 / 255, 184 / 255, 226 / 255) face_color = (141 / 255, 184 / 255, 226 / 255)
edge_color = (50 / 255, 50 / 255, 50 / 255) edge_color = (50 / 255, 50 / 255, 50 / 255)
mesh.set_edgecolor(edge_color) mesh.set_edgecolor(edge_color)
mesh.set_facecolor(face_color) mesh.set_facecolor(face_color)
if not only_joint: ax.add_collection3d(mesh)
ax.add_collection3d(mesh)
if with_joints: if with_joints:
draw_skeleton(joints, kintree_table=kintree_table, ax=ax) draw_skeleton(joints, kintree_table=kintree_table, ax=ax)
ax.set_xlabel('X') ax.set_xlabel('X')
@ -43,10 +42,11 @@ def display_model(
ax.view_init(azim=-90, elev=100) ax.view_init(azim=-90, elev=100)
fig.subplots_adjust(left=0, right=1, bottom=0, top=1) fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
if savepath: if savepath:
print('Saving figure at {}.'.format(savepath)) # print('Saving figure at {}.'.format(savepath))
plt.savefig(savepath, bbox_inches='tight', pad_inches=0) plt.savefig(savepath, bbox_inches='tight', pad_inches=0)
if show: if show:
plt.show() plt.show()
plt.close()
return ax return ax

BIN
fit.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.1 MiB

BIN
fit.rar Normal file

Binary file not shown.

View File

@ -4,7 +4,7 @@
}, },
"TRAIN": { "TRAIN": {
"LEARNING_RATE":2e-2, "LEARNING_RATE":2e-2,
"MAX_EPOCH": 5, "MAX_EPOCH": 500,
"WRITE": 1, "WRITE": 1,
"SAVE": 10, "SAVE": 10,
"BATCH_SIZE": 1, "BATCH_SIZE": 1,

View File

@ -26,7 +26,7 @@ from display_utils import display_model
from smplpytorch.pytorch.smpl_layer import SMPL_Layer from smplpytorch.pytorch.smpl_layer import SMPL_Layer
from train import train from train import train
from transform import transform from transform import transform
from save import save_pic from save import save_pic,save_params
torch.backends.cudnn.benchmark=True torch.backends.cudnn.benchmark=True
def parse_args(): def parse_args():
@ -104,7 +104,7 @@ if __name__ == "__main__":
logger.info('Processing file: {}'.format(file)) logger.info('Processing file: {}'.format(file))
target_path=os.path.join(root,file) target_path=os.path.join(root,file)
target = np.array(transform(np.load(cfg.TARGET_PATH))) target = np.array(transform(np.load(target_path)))
logger.info('File shape: {}'.format(target.shape)) logger.info('File shape: {}'.format(target.shape))
target = torch.from_numpy(target).float() target = torch.from_numpy(target).float()
@ -112,5 +112,6 @@ if __name__ == "__main__":
logger,writer,device, logger,writer,device,
args,cfg) args,cfg)
# save_pic(target,res,smpl_layer,file) # save_pic(target,res,smpl_layer,file,logger)
save_params(res,file,logger)

View File

@ -1,6 +1,10 @@
import sys import sys
import os import os
import re import re
from tqdm import tqdm
import numpy as np
import json
sys.path.append(os.getcwd()) sys.path.append(os.getcwd())
from display_utils import display_model from display_utils import display_model
@ -9,14 +13,15 @@ def create_dir_not_exist(path):
if not os.path.exists(path): if not os.path.exists(path):
os.mkdir(path) os.mkdir(path)
def save_pic(target, res, smpl_layer, file): def save_pic(target, res, smpl_layer, file,logger):
pose_params, shape_params, verts, Jtr = res pose_params, shape_params, verts, Jtr = res
name=re.split('[/.]',file)[-2] name=re.split('[/.]',file)[-2]
gt_path="fit/output/HumanAct12/picture/gt/{}".format(name) gt_path="fit/output/HumanAct12/picture/gt/{}".format(name)
fit_path="fit/output/HumanAct12/picture/fit/{}".format(name) fit_path="fit/output/HumanAct12/picture/fit/{}".format(name)
create_dir_not_exist(gt_path) create_dir_not_exist(gt_path)
create_dir_not_exist(fit_path) create_dir_not_exist(fit_path)
for i in range(target.shape[0]): logger.info('Saving pictures at {} and {}'.format(gt_path,fit_path))
for i in tqdm(range(target.shape[0])):
display_model( display_model(
{'verts': verts.cpu().detach(), {'verts': verts.cpu().detach(),
'joints': target.cpu().detach()}, 'joints': target.cpu().detach()},
@ -36,3 +41,26 @@ def save_pic(target, res, smpl_layer, file):
savepath=os.path.join(fit_path+"/frame_{}".format(i)), savepath=os.path.join(fit_path+"/frame_{}".format(i)),
batch_idx=i, batch_idx=i,
show=False) show=False)
logger.info('Pictures saved')
def save_params(res,file,logger):
pose_params, shape_params, verts, Jtr = res
name=re.split('[/.]',file)[-2]
fit_path="fit/output/HumanAct12/params/"
create_dir_not_exist(fit_path)
logger.info('Saving params at {}'.format(fit_path))
pose_params=pose_params.cpu().detach()
pose_params=pose_params.numpy().tolist()
shape_params=shape_params.cpu().detach()
shape_params=shape_params.numpy().tolist()
Jtr=Jtr.cpu().detach()
Jtr=Jtr.numpy().tolist()
params={}
params["pose_params"]=pose_params
params["shape_params"]=shape_params
params["Jtr"]=Jtr
f=open(os.path.join((fit_path),
"{}_params.json".format(name)),'w')
json.dump(params,f)
logger.info('Params saved')

View File

@ -30,7 +30,7 @@ def train(smpl_layer, target,
args, cfg): args, cfg):
res = [] res = []
pose_params = torch.rand(target.shape[0], 72) * 0.0 pose_params = torch.rand(target.shape[0], 72) * 0.0
shape_params = torch.rand(target.shape[0], 10) * 0.1 shape_params = torch.rand(target.shape[0], 10) * 0.03
scale = torch.ones([1]) scale = torch.ones([1])
smpl_layer = smpl_layer.to(device) smpl_layer = smpl_layer.to(device)
@ -41,9 +41,10 @@ def train(smpl_layer, target,
pose_params.requires_grad = True pose_params.requires_grad = True
shape_params.requires_grad = True shape_params.requires_grad = True
scale.requires_grad = True scale.requires_grad = False
smpl_layer.requires_grad = False
optimizer = optim.Adam([pose_params], optimizer = optim.Adam([pose_params, shape_params],
lr=cfg.TRAIN.LEARNING_RATE) lr=cfg.TRAIN.LEARNING_RATE)
min_loss = float('inf') min_loss = float('inf')
@ -62,5 +63,5 @@ def train(smpl_layer, target,
writer.add_scalar('loss', float(loss), epoch) writer.add_scalar('loss', float(loss), epoch)
writer.add_scalar('learning_rate', float( writer.add_scalar('learning_rate', float(
optimizer.state_dict()['param_groups'][0]['lr']), epoch) optimizer.state_dict()['param_groups'][0]['lr']), epoch)
logger.info('Train ended, loss = {:.9f}'.format(float(loss))) logger.info('Train ended, min_loss = {:.9f}'.format(float(min_loss)))
return res return res

BIN
gif.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.2 MiB

BIN
gif.rar Normal file

Binary file not shown.

BIN
gt.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.4 MiB

View File

@ -1,7 +1,7 @@
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import imageio, os import imageio, os
images = [] images = []
filenames = sorted(fn for fn in os.listdir('./output/') ) filenames = sorted(fn for fn in os.listdir('./fit/output/HumanAct12/picture/fit/P01G01R01F0449T0505A0201') )
for filename in filenames: for filename in filenames:
images.append(imageio.imread('./output/'+filename)) images.append(imageio.imread('./fit/output/HumanAct12/picture/fit/P01G01R01F0449T0505A0201/'+filename))
imageio.mimsave('./output/gif.gif', images, duration=0.5) imageio.mimsave('./fit.gif', images, duration=0.3)