diff --git a/display_utils.py b/display_utils.py index 8b05b71..ce06d41 100644 --- a/display_utils.py +++ b/display_utils.py @@ -24,14 +24,13 @@ def display_model( verts, joints = model_info['verts'][batch_idx], model_info['joints'][batch_idx] if model_faces is None: ax.scatter(verts[:, 0], verts[:, 1], verts[:, 2], alpha=0.2) - else: + elif not only_joint: mesh = Poly3DCollection(verts[model_faces], alpha=0.2) face_color = (141 / 255, 184 / 255, 226 / 255) edge_color = (50 / 255, 50 / 255, 50 / 255) mesh.set_edgecolor(edge_color) mesh.set_facecolor(face_color) - if not only_joint: - ax.add_collection3d(mesh) + ax.add_collection3d(mesh) if with_joints: draw_skeleton(joints, kintree_table=kintree_table, ax=ax) ax.set_xlabel('X') @@ -43,10 +42,11 @@ def display_model( ax.view_init(azim=-90, elev=100) fig.subplots_adjust(left=0, right=1, bottom=0, top=1) if savepath: - print('Saving figure at {}.'.format(savepath)) + # print('Saving figure at {}.'.format(savepath)) plt.savefig(savepath, bbox_inches='tight', pad_inches=0) if show: plt.show() + plt.close() return ax diff --git a/fit.gif b/fit.gif new file mode 100644 index 0000000..02351ad Binary files /dev/null and b/fit.gif differ diff --git a/fit.rar b/fit.rar new file mode 100644 index 0000000..526d51f Binary files /dev/null and b/fit.rar differ diff --git a/fit/configs/config.json b/fit/configs/config.json index add2030..b7d31c4 100644 --- a/fit/configs/config.json +++ b/fit/configs/config.json @@ -4,7 +4,7 @@ }, "TRAIN": { "LEARNING_RATE":2e-2, - "MAX_EPOCH": 5, + "MAX_EPOCH": 500, "WRITE": 1, "SAVE": 10, "BATCH_SIZE": 1, diff --git a/fit/tools/main.py b/fit/tools/main.py index 37bc664..6a56086 100644 --- a/fit/tools/main.py +++ b/fit/tools/main.py @@ -26,7 +26,7 @@ from display_utils import display_model from smplpytorch.pytorch.smpl_layer import SMPL_Layer from train import train from transform import transform -from save import save_pic +from save import save_pic,save_params torch.backends.cudnn.benchmark=True def parse_args(): @@ -104,7 +104,7 @@ if __name__ == "__main__": logger.info('Processing file: {}'.format(file)) target_path=os.path.join(root,file) - target = np.array(transform(np.load(cfg.TARGET_PATH))) + target = np.array(transform(np.load(target_path))) logger.info('File shape: {}'.format(target.shape)) target = torch.from_numpy(target).float() @@ -112,5 +112,6 @@ if __name__ == "__main__": logger,writer,device, args,cfg) - # save_pic(target,res,smpl_layer,file) + # save_pic(target,res,smpl_layer,file,logger) + save_params(res,file,logger) \ No newline at end of file diff --git a/fit/tools/save.py b/fit/tools/save.py index 0257ec9..2dc5e2a 100644 --- a/fit/tools/save.py +++ b/fit/tools/save.py @@ -1,6 +1,10 @@ import sys import os import re +from tqdm import tqdm +import numpy as np +import json + sys.path.append(os.getcwd()) from display_utils import display_model @@ -9,14 +13,15 @@ def create_dir_not_exist(path): if not os.path.exists(path): os.mkdir(path) -def save_pic(target, res, smpl_layer, file): +def save_pic(target, res, smpl_layer, file,logger): pose_params, shape_params, verts, Jtr = res name=re.split('[/.]',file)[-2] gt_path="fit/output/HumanAct12/picture/gt/{}".format(name) fit_path="fit/output/HumanAct12/picture/fit/{}".format(name) create_dir_not_exist(gt_path) create_dir_not_exist(fit_path) - for i in range(target.shape[0]): + logger.info('Saving pictures at {} and {}'.format(gt_path,fit_path)) + for i in tqdm(range(target.shape[0])): display_model( {'verts': verts.cpu().detach(), 'joints': target.cpu().detach()}, @@ -36,3 +41,26 @@ def save_pic(target, res, smpl_layer, file): savepath=os.path.join(fit_path+"/frame_{}".format(i)), batch_idx=i, show=False) + logger.info('Pictures saved') + +def save_params(res,file,logger): + pose_params, shape_params, verts, Jtr = res + name=re.split('[/.]',file)[-2] + fit_path="fit/output/HumanAct12/params/" + create_dir_not_exist(fit_path) + logger.info('Saving params at {}'.format(fit_path)) + pose_params=pose_params.cpu().detach() + pose_params=pose_params.numpy().tolist() + shape_params=shape_params.cpu().detach() + shape_params=shape_params.numpy().tolist() + Jtr=Jtr.cpu().detach() + Jtr=Jtr.numpy().tolist() + params={} + params["pose_params"]=pose_params + params["shape_params"]=shape_params + params["Jtr"]=Jtr + f=open(os.path.join((fit_path), + "{}_params.json".format(name)),'w') + json.dump(params,f) + logger.info('Params saved') + diff --git a/fit/tools/train.py b/fit/tools/train.py index 8c7cbc9..a75bd20 100644 --- a/fit/tools/train.py +++ b/fit/tools/train.py @@ -30,7 +30,7 @@ def train(smpl_layer, target, args, cfg): res = [] pose_params = torch.rand(target.shape[0], 72) * 0.0 - shape_params = torch.rand(target.shape[0], 10) * 0.1 + shape_params = torch.rand(target.shape[0], 10) * 0.03 scale = torch.ones([1]) smpl_layer = smpl_layer.to(device) @@ -41,9 +41,10 @@ def train(smpl_layer, target, pose_params.requires_grad = True shape_params.requires_grad = True - scale.requires_grad = True + scale.requires_grad = False + smpl_layer.requires_grad = False - optimizer = optim.Adam([pose_params], + optimizer = optim.Adam([pose_params, shape_params], lr=cfg.TRAIN.LEARNING_RATE) min_loss = float('inf') @@ -62,5 +63,5 @@ def train(smpl_layer, target, writer.add_scalar('loss', float(loss), epoch) writer.add_scalar('learning_rate', float( optimizer.state_dict()['param_groups'][0]['lr']), epoch) - logger.info('Train ended, loss = {:.9f}'.format(float(loss))) + logger.info('Train ended, min_loss = {:.9f}'.format(float(min_loss))) return res diff --git a/gif.gif b/gif.gif new file mode 100644 index 0000000..2032450 Binary files /dev/null and b/gif.gif differ diff --git a/gif.rar b/gif.rar new file mode 100644 index 0000000..c5b3256 Binary files /dev/null and b/gif.rar differ diff --git a/gt.gif b/gt.gif new file mode 100644 index 0000000..15f61ad Binary files /dev/null and b/gt.gif differ diff --git a/make_gif.py b/make_gif.py index 47950e5..6f99373 100644 --- a/make_gif.py +++ b/make_gif.py @@ -1,7 +1,7 @@ import matplotlib.pyplot as plt import imageio, os images = [] -filenames = sorted(fn for fn in os.listdir('./output/') ) +filenames = sorted(fn for fn in os.listdir('./fit/output/HumanAct12/picture/fit/P01G01R01F0449T0505A0201') ) for filename in filenames: - images.append(imageio.imread('./output/'+filename)) -imageio.mimsave('./output/gif.gif', images, duration=0.5) \ No newline at end of file + images.append(imageio.imread('./fit/output/HumanAct12/picture/fit/P01G01R01F0449T0505A0201/'+filename)) +imageio.mimsave('./fit.gif', images, duration=0.3) \ No newline at end of file