update code
This commit is contained in:
@ -24,14 +24,13 @@ def display_model(
|
|||||||
verts, joints = model_info['verts'][batch_idx], model_info['joints'][batch_idx]
|
verts, joints = model_info['verts'][batch_idx], model_info['joints'][batch_idx]
|
||||||
if model_faces is None:
|
if model_faces is None:
|
||||||
ax.scatter(verts[:, 0], verts[:, 1], verts[:, 2], alpha=0.2)
|
ax.scatter(verts[:, 0], verts[:, 1], verts[:, 2], alpha=0.2)
|
||||||
else:
|
elif not only_joint:
|
||||||
mesh = Poly3DCollection(verts[model_faces], alpha=0.2)
|
mesh = Poly3DCollection(verts[model_faces], alpha=0.2)
|
||||||
face_color = (141 / 255, 184 / 255, 226 / 255)
|
face_color = (141 / 255, 184 / 255, 226 / 255)
|
||||||
edge_color = (50 / 255, 50 / 255, 50 / 255)
|
edge_color = (50 / 255, 50 / 255, 50 / 255)
|
||||||
mesh.set_edgecolor(edge_color)
|
mesh.set_edgecolor(edge_color)
|
||||||
mesh.set_facecolor(face_color)
|
mesh.set_facecolor(face_color)
|
||||||
if not only_joint:
|
ax.add_collection3d(mesh)
|
||||||
ax.add_collection3d(mesh)
|
|
||||||
if with_joints:
|
if with_joints:
|
||||||
draw_skeleton(joints, kintree_table=kintree_table, ax=ax)
|
draw_skeleton(joints, kintree_table=kintree_table, ax=ax)
|
||||||
ax.set_xlabel('X')
|
ax.set_xlabel('X')
|
||||||
@ -43,10 +42,11 @@ def display_model(
|
|||||||
ax.view_init(azim=-90, elev=100)
|
ax.view_init(azim=-90, elev=100)
|
||||||
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
|
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
|
||||||
if savepath:
|
if savepath:
|
||||||
print('Saving figure at {}.'.format(savepath))
|
# print('Saving figure at {}.'.format(savepath))
|
||||||
plt.savefig(savepath, bbox_inches='tight', pad_inches=0)
|
plt.savefig(savepath, bbox_inches='tight', pad_inches=0)
|
||||||
if show:
|
if show:
|
||||||
plt.show()
|
plt.show()
|
||||||
|
plt.close()
|
||||||
return ax
|
return ax
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,7 @@
|
|||||||
},
|
},
|
||||||
"TRAIN": {
|
"TRAIN": {
|
||||||
"LEARNING_RATE":2e-2,
|
"LEARNING_RATE":2e-2,
|
||||||
"MAX_EPOCH": 5,
|
"MAX_EPOCH": 500,
|
||||||
"WRITE": 1,
|
"WRITE": 1,
|
||||||
"SAVE": 10,
|
"SAVE": 10,
|
||||||
"BATCH_SIZE": 1,
|
"BATCH_SIZE": 1,
|
||||||
|
|||||||
@ -26,7 +26,7 @@ from display_utils import display_model
|
|||||||
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
|
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
|
||||||
from train import train
|
from train import train
|
||||||
from transform import transform
|
from transform import transform
|
||||||
from save import save_pic
|
from save import save_pic,save_params
|
||||||
torch.backends.cudnn.benchmark=True
|
torch.backends.cudnn.benchmark=True
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
@ -104,7 +104,7 @@ if __name__ == "__main__":
|
|||||||
logger.info('Processing file: {}'.format(file))
|
logger.info('Processing file: {}'.format(file))
|
||||||
target_path=os.path.join(root,file)
|
target_path=os.path.join(root,file)
|
||||||
|
|
||||||
target = np.array(transform(np.load(cfg.TARGET_PATH)))
|
target = np.array(transform(np.load(target_path)))
|
||||||
logger.info('File shape: {}'.format(target.shape))
|
logger.info('File shape: {}'.format(target.shape))
|
||||||
target = torch.from_numpy(target).float()
|
target = torch.from_numpy(target).float()
|
||||||
|
|
||||||
@ -112,5 +112,6 @@ if __name__ == "__main__":
|
|||||||
logger,writer,device,
|
logger,writer,device,
|
||||||
args,cfg)
|
args,cfg)
|
||||||
|
|
||||||
# save_pic(target,res,smpl_layer,file)
|
# save_pic(target,res,smpl_layer,file,logger)
|
||||||
|
save_params(res,file,logger)
|
||||||
|
|
||||||
@ -1,6 +1,10 @@
|
|||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
from tqdm import tqdm
|
||||||
|
import numpy as np
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
sys.path.append(os.getcwd())
|
sys.path.append(os.getcwd())
|
||||||
from display_utils import display_model
|
from display_utils import display_model
|
||||||
@ -9,14 +13,15 @@ def create_dir_not_exist(path):
|
|||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
os.mkdir(path)
|
os.mkdir(path)
|
||||||
|
|
||||||
def save_pic(target, res, smpl_layer, file):
|
def save_pic(target, res, smpl_layer, file,logger):
|
||||||
pose_params, shape_params, verts, Jtr = res
|
pose_params, shape_params, verts, Jtr = res
|
||||||
name=re.split('[/.]',file)[-2]
|
name=re.split('[/.]',file)[-2]
|
||||||
gt_path="fit/output/HumanAct12/picture/gt/{}".format(name)
|
gt_path="fit/output/HumanAct12/picture/gt/{}".format(name)
|
||||||
fit_path="fit/output/HumanAct12/picture/fit/{}".format(name)
|
fit_path="fit/output/HumanAct12/picture/fit/{}".format(name)
|
||||||
create_dir_not_exist(gt_path)
|
create_dir_not_exist(gt_path)
|
||||||
create_dir_not_exist(fit_path)
|
create_dir_not_exist(fit_path)
|
||||||
for i in range(target.shape[0]):
|
logger.info('Saving pictures at {} and {}'.format(gt_path,fit_path))
|
||||||
|
for i in tqdm(range(target.shape[0])):
|
||||||
display_model(
|
display_model(
|
||||||
{'verts': verts.cpu().detach(),
|
{'verts': verts.cpu().detach(),
|
||||||
'joints': target.cpu().detach()},
|
'joints': target.cpu().detach()},
|
||||||
@ -36,3 +41,26 @@ def save_pic(target, res, smpl_layer, file):
|
|||||||
savepath=os.path.join(fit_path+"/frame_{}".format(i)),
|
savepath=os.path.join(fit_path+"/frame_{}".format(i)),
|
||||||
batch_idx=i,
|
batch_idx=i,
|
||||||
show=False)
|
show=False)
|
||||||
|
logger.info('Pictures saved')
|
||||||
|
|
||||||
|
def save_params(res,file,logger):
|
||||||
|
pose_params, shape_params, verts, Jtr = res
|
||||||
|
name=re.split('[/.]',file)[-2]
|
||||||
|
fit_path="fit/output/HumanAct12/params/"
|
||||||
|
create_dir_not_exist(fit_path)
|
||||||
|
logger.info('Saving params at {}'.format(fit_path))
|
||||||
|
pose_params=pose_params.cpu().detach()
|
||||||
|
pose_params=pose_params.numpy().tolist()
|
||||||
|
shape_params=shape_params.cpu().detach()
|
||||||
|
shape_params=shape_params.numpy().tolist()
|
||||||
|
Jtr=Jtr.cpu().detach()
|
||||||
|
Jtr=Jtr.numpy().tolist()
|
||||||
|
params={}
|
||||||
|
params["pose_params"]=pose_params
|
||||||
|
params["shape_params"]=shape_params
|
||||||
|
params["Jtr"]=Jtr
|
||||||
|
f=open(os.path.join((fit_path),
|
||||||
|
"{}_params.json".format(name)),'w')
|
||||||
|
json.dump(params,f)
|
||||||
|
logger.info('Params saved')
|
||||||
|
|
||||||
|
|||||||
@ -30,7 +30,7 @@ def train(smpl_layer, target,
|
|||||||
args, cfg):
|
args, cfg):
|
||||||
res = []
|
res = []
|
||||||
pose_params = torch.rand(target.shape[0], 72) * 0.0
|
pose_params = torch.rand(target.shape[0], 72) * 0.0
|
||||||
shape_params = torch.rand(target.shape[0], 10) * 0.1
|
shape_params = torch.rand(target.shape[0], 10) * 0.03
|
||||||
scale = torch.ones([1])
|
scale = torch.ones([1])
|
||||||
|
|
||||||
smpl_layer = smpl_layer.to(device)
|
smpl_layer = smpl_layer.to(device)
|
||||||
@ -41,9 +41,10 @@ def train(smpl_layer, target,
|
|||||||
|
|
||||||
pose_params.requires_grad = True
|
pose_params.requires_grad = True
|
||||||
shape_params.requires_grad = True
|
shape_params.requires_grad = True
|
||||||
scale.requires_grad = True
|
scale.requires_grad = False
|
||||||
|
smpl_layer.requires_grad = False
|
||||||
|
|
||||||
optimizer = optim.Adam([pose_params],
|
optimizer = optim.Adam([pose_params, shape_params],
|
||||||
lr=cfg.TRAIN.LEARNING_RATE)
|
lr=cfg.TRAIN.LEARNING_RATE)
|
||||||
|
|
||||||
min_loss = float('inf')
|
min_loss = float('inf')
|
||||||
@ -62,5 +63,5 @@ def train(smpl_layer, target,
|
|||||||
writer.add_scalar('loss', float(loss), epoch)
|
writer.add_scalar('loss', float(loss), epoch)
|
||||||
writer.add_scalar('learning_rate', float(
|
writer.add_scalar('learning_rate', float(
|
||||||
optimizer.state_dict()['param_groups'][0]['lr']), epoch)
|
optimizer.state_dict()['param_groups'][0]['lr']), epoch)
|
||||||
logger.info('Train ended, loss = {:.9f}'.format(float(loss)))
|
logger.info('Train ended, min_loss = {:.9f}'.format(float(min_loss)))
|
||||||
return res
|
return res
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import imageio, os
|
import imageio, os
|
||||||
images = []
|
images = []
|
||||||
filenames = sorted(fn for fn in os.listdir('./output/') )
|
filenames = sorted(fn for fn in os.listdir('./fit/output/HumanAct12/picture/fit/P01G01R01F0449T0505A0201') )
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
images.append(imageio.imread('./output/'+filename))
|
images.append(imageio.imread('./fit/output/HumanAct12/picture/fit/P01G01R01F0449T0505A0201/'+filename))
|
||||||
imageio.mimsave('./output/gif.gif', images, duration=0.5)
|
imageio.mimsave('./fit.gif', images, duration=0.3)
|
||||||
Reference in New Issue
Block a user