diff --git a/README.md b/README.md index d7fbf1f..a0b437c 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ The SMPL human body layer for Pytorch is from the [smplpytorch](https://github.c ### 1. Executing Code -You can start the fitting procedure by the following code and the configuration file in *fit/configs* corresponding to the dataset_name will be loaded: +You can start the fitting procedure by the following code and the configuration file in *fit/configs* corresponding to the dataset_name will be loaded (the dataset_path can also be set in the configuration file): ``` python fit/tools/main.py --dataset_name [DATASET NAME] --dataset_path [DATASET PATH] diff --git a/fit/tools/main.py b/fit/tools/main.py index e05473f..095f5d3 100644 --- a/fit/tools/main.py +++ b/fit/tools/main.py @@ -1,3 +1,4 @@ +import numpy as np import torch from tensorboardX import SummaryWriter from easydict import EasyDict as edict @@ -9,13 +10,13 @@ import logging import argparse import json sys.path.append(os.getcwd()) -from smplpytorch.pytorch.smpl_layer import SMPL_Layer -from train import train -from transform import transform -from save import save_pic, save_params from load import load -import numpy as np -torch.backends.cudnn.benchmark=True +from save import save_pic, save_params +from transform import transform +from train import train +from smplpytorch.pytorch.smpl_layer import SMPL_Layer +torch.backends.cudnn.benchmark = True + def parse_args(): parser = argparse.ArgumentParser(description='Fit SMPL') @@ -31,8 +32,9 @@ def parse_args(): args = parser.parse_args() return args + def get_config(args): - config_path='fit/configs/{}.json'.format(args.dataset_name) + config_path = 'fit/configs/{}.json'.format(args.dataset_name) with open(config_path, 'r') as f: data = json.load(f) cfg = edict(data.copy()) @@ -40,6 +42,7 @@ def get_config(args): cfg.DATASET.PATH = args.dataset_path return cfg + def set_device(USE_GPU): if USE_GPU and torch.cuda.is_available(): device = torch.device('cuda') @@ -47,6 +50,7 @@ def set_device(USE_GPU): device = torch.device('cpu') return device + def get_logger(cur_path): logger = logging.getLogger(__name__) logger.setLevel(level=logging.INFO) @@ -69,6 +73,7 @@ def get_logger(cur_path): return logger, writer + if __name__ == "__main__": args = parse_args() @@ -84,26 +89,23 @@ if __name__ == "__main__": device = set_device(USE_GPU=cfg.USE_GPU) logger.info('using device: {}'.format(device)) - + smpl_layer = SMPL_Layer( - center_idx = 0, + center_idx=0, gender=cfg.MODEL.GENDER, model_root='smplpytorch/native/models') - + file_num = 0 - for root,dirs,files in os.walk(cfg.DATASET.PATH): + for root, dirs, files in os.walk(cfg.DATASET.PATH): for file in files: file_num += 1 - logger.info('Processing file: {} [{} / {}]'.format(file,file_num,len(files))) - target = torch.from_numpy(transform(args.dataset_name, - load(args.dataset_name, - os.path.join(root,file)))).float() - - - res = train(smpl_layer,target, - logger,writer,device, - args,cfg) - + logger.info('Processing file: {} [{} / {}]'.format(file, file_num, len(files))) + target = torch.from_numpy(transform(args.dataset_name,load(args.dataset_name, + os.path.join(root, file)))).float() + + res = train(smpl_layer, target, + logger, writer, device, + args, cfg) + # save_pic(res,smpl_layer,file,logger,args.dataset_name,target) - save_params(res,file,logger, args.dataset_name) - \ No newline at end of file + save_params(res, file, logger, args.dataset_name)