diff --git a/.gitignore b/.gitignore index 372a5d8..6aac4b0 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,6 @@ dist/ image.png smplpytorch/native/models/*.pkl + +exp/ +output/ \ No newline at end of file diff --git a/demo.py b/demo.py index f506974..d36bce4 100644 --- a/demo.py +++ b/demo.py @@ -5,13 +5,13 @@ from display_utils import display_model if __name__ == '__main__': - cuda = False + cuda = True batch_size = 1 # Create the SMPL layer smpl_layer = SMPL_Layer( center_idx=0, - gender='neutral', + gender='male', model_root='smplpytorch/native/models') # Generate random pose and shape parameters @@ -26,6 +26,7 @@ if __name__ == '__main__': # Forward from the SMPL layer verts, Jtr = smpl_layer(pose_params, th_betas=shape_params) + print(Jtr) # Draw output vertices and joints display_model( diff --git a/display_utils.py b/display_utils.py index 2a4b13a..8b05b71 100644 --- a/display_utils.py +++ b/display_utils.py @@ -12,7 +12,8 @@ def display_model( ax=None, batch_idx=0, show=True, - savepath=None): + savepath=None, + only_joint=False): """ Displays mesh batch_idx in batch of model_info, model_info as returned by generate_random_model @@ -20,8 +21,7 @@ def display_model( if ax is None: fig = plt.figure() ax = fig.add_subplot(111, projection='3d') - verts, joints = model_info['verts'][batch_idx], model_info['joints'][ - batch_idx] + verts, joints = model_info['verts'][batch_idx], model_info['joints'][batch_idx] if model_faces is None: ax.scatter(verts[:, 0], verts[:, 1], verts[:, 2], alpha=0.2) else: @@ -30,7 +30,8 @@ def display_model( edge_color = (50 / 255, 50 / 255, 50 / 255) mesh.set_edgecolor(edge_color) mesh.set_facecolor(face_color) - ax.add_collection3d(mesh) + if not only_joint: + ax.add_collection3d(mesh) if with_joints: draw_skeleton(joints, kintree_table=kintree_table, ax=ax) ax.set_xlabel('X') diff --git a/fit/configs/config.json b/fit/configs/config.json new file mode 100644 index 0000000..add2030 --- /dev/null +++ b/fit/configs/config.json @@ -0,0 +1,25 @@ +{ + "MODEL": { + "GENDER": "male" + }, + "TRAIN": { + "LEARNING_RATE":2e-2, + "MAX_EPOCH": 5, + "WRITE": 1, + "SAVE": 10, + "BATCH_SIZE": 1, + "MOMENTUM": 0.9, + "lr_scheduler": { + "T_0": 10, + "T_mult": 2, + "eta_min": 1e-2 + }, + "loss_func": "" + }, + "USE_GPU": 1, + "DATA_LOADER": { + "NUM_WORKERS": 1 + }, + "TARGET_PATH":"../Action2Motion/HumanAct12/HumanAct12/P01G01R01F0069T0143A0102.npy", + "DATASET_PATH":"../Action2Motion/HumanAct12/HumanAct12/" +} \ No newline at end of file diff --git a/fit/tools/main.py b/fit/tools/main.py new file mode 100644 index 0000000..37bc664 --- /dev/null +++ b/fit/tools/main.py @@ -0,0 +1,116 @@ +import matplotlib as plt +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.modules import module +from torch.optim import lr_scheduler +import torch.optim as optim +from torch.utils.data import DataLoader +from torch.utils.data import sampler +import torchvision.datasets as dset +import torchvision.transforms as T +import numpy as np +from tensorboardX import SummaryWriter +from easydict import EasyDict as edict +import time +import inspect +import sys +import os +import logging + +import argparse +import json +from tqdm import tqdm +sys.path.append(os.getcwd()) +from display_utils import display_model +from smplpytorch.pytorch.smpl_layer import SMPL_Layer +from train import train +from transform import transform +from save import save_pic +torch.backends.cudnn.benchmark=True + +def parse_args(): + parser = argparse.ArgumentParser(description='Fit SMPL') + parser.add_argument('--exp', dest='exp', + help='Define exp name', + default=time.strftime('%Y-%m-%d %H-%M-%S', time.localtime(time.time())), type=str) + parser.add_argument('--config_path', dest='config_path', + help='Select configuration file', + default='fit/configs/config.json', type=str) + parser.add_argument('--dataset_path', dest='dataset_path', + help='select dataset', + default='', type=str) + args = parser.parse_args() + return args + +def get_config(args): + with open(args.config_path, 'r') as f: + data = json.load(f) + cfg = edict(data.copy()) + return cfg + +def set_device(USE_GPU): + if USE_GPU and torch.cuda.is_available(): + device = torch.device('cuda') + else: + device = torch.device('cpu') + return device + +def get_logger(cur_path): + logger = logging.getLogger(__name__) + logger.setLevel(level=logging.INFO) + + handler = logging.FileHandler(os.path.join(cur_path, "log.txt")) + handler.setLevel(logging.INFO) + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s') + handler.setFormatter(formatter) + logger.addHandler(handler) + + handler = logging.StreamHandler() + handler.setLevel(logging.INFO) + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s') + handler.setFormatter(formatter) + logger.addHandler(handler) + + writer = SummaryWriter(os.path.join(cur_path, 'tb')) + + return logger, writer + +if __name__ == "__main__": + args = parse_args() + + cur_path = os.path.join(os.getcwd(), 'exp', args.exp) + assert not os.path.exists(cur_path), 'Duplicate exp name' + os.mkdir(cur_path) + + cfg = get_config(args) + json.dump(dict(cfg), open(os.path.join(cur_path, 'config.json'), 'w')) + + logger, writer = get_logger(cur_path) + logger.info("Start print log") + + device = set_device(USE_GPU=cfg.USE_GPU) + logger.info('using device: {}'.format(device)) + + smpl_layer = SMPL_Layer( + center_idx = 0, + gender='neutral', + model_root='smplpytorch/native/models') + + for root,dirs,files in os.walk(cfg.DATASET_PATH): + for file in files: + logger.info('Processing file: {}'.format(file)) + target_path=os.path.join(root,file) + + target = np.array(transform(np.load(cfg.TARGET_PATH))) + logger.info('File shape: {}'.format(target.shape)) + target = torch.from_numpy(target).float() + + res = train(smpl_layer,target, + logger,writer,device, + args,cfg) + + # save_pic(target,res,smpl_layer,file) + \ No newline at end of file diff --git a/fit/tools/save.py b/fit/tools/save.py new file mode 100644 index 0000000..0257ec9 --- /dev/null +++ b/fit/tools/save.py @@ -0,0 +1,38 @@ +import sys +import os +import re + +sys.path.append(os.getcwd()) +from display_utils import display_model + +def create_dir_not_exist(path): + if not os.path.exists(path): + os.mkdir(path) + +def save_pic(target, res, smpl_layer, file): + pose_params, shape_params, verts, Jtr = res + name=re.split('[/.]',file)[-2] + gt_path="fit/output/HumanAct12/picture/gt/{}".format(name) + fit_path="fit/output/HumanAct12/picture/fit/{}".format(name) + create_dir_not_exist(gt_path) + create_dir_not_exist(fit_path) + for i in range(target.shape[0]): + display_model( + {'verts': verts.cpu().detach(), + 'joints': target.cpu().detach()}, + model_faces=smpl_layer.th_faces, + with_joints=True, + kintree_table=smpl_layer.kintree_table, + savepath=os.path.join(gt_path+"/frame_{}".format(i)), + batch_idx=i, + show=False, + only_joint=True) + display_model( + {'verts': verts.cpu().detach(), + 'joints': Jtr.cpu().detach()}, + model_faces=smpl_layer.th_faces, + with_joints=True, + kintree_table=smpl_layer.kintree_table, + savepath=os.path.join(fit_path+"/frame_{}".format(i)), + batch_idx=i, + show=False) diff --git a/fit/tools/train.py b/fit/tools/train.py new file mode 100644 index 0000000..8c7cbc9 --- /dev/null +++ b/fit/tools/train.py @@ -0,0 +1,66 @@ +import matplotlib as plt +from matplotlib.pyplot import show +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.modules import module +from torch.optim import lr_scheduler +import torch.optim as optim +from torch.utils.data import DataLoader +import torchvision.datasets as dset +import torchvision.transforms as T +import numpy as np +from tensorboardX import SummaryWriter +from easydict import EasyDict as edict +import time +import inspect +import sys +import os +import logging + +import argparse +import json +from tqdm import tqdm +sys.path.append(os.getcwd()) +from smplpytorch.pytorch.smpl_layer import SMPL_Layer +from display_utils import display_model + +def train(smpl_layer, target, + logger, writer, device, + args, cfg): + res = [] + pose_params = torch.rand(target.shape[0], 72) * 0.0 + shape_params = torch.rand(target.shape[0], 10) * 0.1 + scale = torch.ones([1]) + + smpl_layer = smpl_layer.to(device) + pose_params = pose_params.to(device) + shape_params = shape_params.to(device) + target = target.to(device) + scale = scale.to(device) + + pose_params.requires_grad = True + shape_params.requires_grad = True + scale.requires_grad = True + + optimizer = optim.Adam([pose_params], + lr=cfg.TRAIN.LEARNING_RATE) + + min_loss = float('inf') + for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)): + verts, Jtr = smpl_layer(pose_params, th_betas=shape_params) + loss = F.smooth_l1_loss(Jtr * 100, target * 100) + optimizer.zero_grad() + loss.backward() + optimizer.step() + if float(loss) < min_loss: + min_loss = float(loss) + res = [pose_params, shape_params, verts, Jtr] + if epoch % cfg.TRAIN.WRITE == 0: + # logger.info("Epoch {}, lossPerBatch={:.9f}, scale={:.6f}".format( + # epoch, float(loss), float(scale))) + writer.add_scalar('loss', float(loss), epoch) + writer.add_scalar('learning_rate', float( + optimizer.state_dict()['param_groups'][0]['lr']), epoch) + logger.info('Train ended, loss = {:.9f}'.format(float(loss))) + return res diff --git a/fit/tools/transform.py b/fit/tools/transform.py new file mode 100644 index 0000000..8d042fd --- /dev/null +++ b/fit/tools/transform.py @@ -0,0 +1,12 @@ +import numpy as np + + +def transform(arr: np.ndarray): + for i in range(arr.shape[0]): + origin = arr[i][0].copy() + for j in range(arr.shape[1]): + arr[i][j] -= origin + arr[i][j][1] *= -1 + arr[i][j][2] *= -1 + arr[i][0] = [0.0, 0.0, 0.0] + return arr diff --git a/image.png b/image.png deleted file mode 100644 index c73c891..0000000 Binary files a/image.png and /dev/null differ diff --git a/make_gif.py b/make_gif.py new file mode 100644 index 0000000..47950e5 --- /dev/null +++ b/make_gif.py @@ -0,0 +1,7 @@ +import matplotlib.pyplot as plt +import imageio, os +images = [] +filenames = sorted(fn for fn in os.listdir('./output/') ) +for filename in filenames: + images.append(imageio.imread('./output/'+filename)) +imageio.mimsave('./output/gif.gif', images, duration=0.5) \ No newline at end of file