update readme

This commit is contained in:
Iridoudou
2021-08-10 10:40:13 +08:00
parent 75607bddec
commit 3bb691ee4d
2 changed files with 26 additions and 24 deletions

View File

@ -51,7 +51,7 @@ The SMPL human body layer for Pytorch is from the [smplpytorch](https://github.c
### 1. Executing Code ### 1. Executing Code
You can start the fitting procedure by the following code and the configuration file in *fit/configs* corresponding to the dataset_name will be loaded: You can start the fitting procedure by the following code and the configuration file in *fit/configs* corresponding to the dataset_name will be loaded (the dataset_path can also be set in the configuration file):
``` ```
python fit/tools/main.py --dataset_name [DATASET NAME] --dataset_path [DATASET PATH] python fit/tools/main.py --dataset_name [DATASET NAME] --dataset_path [DATASET PATH]

View File

@ -1,3 +1,4 @@
import numpy as np
import torch import torch
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from easydict import EasyDict as edict from easydict import EasyDict as edict
@ -9,13 +10,13 @@ import logging
import argparse import argparse
import json import json
sys.path.append(os.getcwd()) sys.path.append(os.getcwd())
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
from train import train
from transform import transform
from save import save_pic, save_params
from load import load from load import load
import numpy as np from save import save_pic, save_params
torch.backends.cudnn.benchmark=True from transform import transform
from train import train
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
torch.backends.cudnn.benchmark = True
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Fit SMPL') parser = argparse.ArgumentParser(description='Fit SMPL')
@ -31,8 +32,9 @@ def parse_args():
args = parser.parse_args() args = parser.parse_args()
return args return args
def get_config(args): def get_config(args):
config_path='fit/configs/{}.json'.format(args.dataset_name) config_path = 'fit/configs/{}.json'.format(args.dataset_name)
with open(config_path, 'r') as f: with open(config_path, 'r') as f:
data = json.load(f) data = json.load(f)
cfg = edict(data.copy()) cfg = edict(data.copy())
@ -40,6 +42,7 @@ def get_config(args):
cfg.DATASET.PATH = args.dataset_path cfg.DATASET.PATH = args.dataset_path
return cfg return cfg
def set_device(USE_GPU): def set_device(USE_GPU):
if USE_GPU and torch.cuda.is_available(): if USE_GPU and torch.cuda.is_available():
device = torch.device('cuda') device = torch.device('cuda')
@ -47,6 +50,7 @@ def set_device(USE_GPU):
device = torch.device('cpu') device = torch.device('cpu')
return device return device
def get_logger(cur_path): def get_logger(cur_path):
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO) logger.setLevel(level=logging.INFO)
@ -69,6 +73,7 @@ def get_logger(cur_path):
return logger, writer return logger, writer
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
@ -84,26 +89,23 @@ if __name__ == "__main__":
device = set_device(USE_GPU=cfg.USE_GPU) device = set_device(USE_GPU=cfg.USE_GPU)
logger.info('using device: {}'.format(device)) logger.info('using device: {}'.format(device))
smpl_layer = SMPL_Layer( smpl_layer = SMPL_Layer(
center_idx = 0, center_idx=0,
gender=cfg.MODEL.GENDER, gender=cfg.MODEL.GENDER,
model_root='smplpytorch/native/models') model_root='smplpytorch/native/models')
file_num = 0 file_num = 0
for root,dirs,files in os.walk(cfg.DATASET.PATH): for root, dirs, files in os.walk(cfg.DATASET.PATH):
for file in files: for file in files:
file_num += 1 file_num += 1
logger.info('Processing file: {} [{} / {}]'.format(file,file_num,len(files))) logger.info('Processing file: {} [{} / {}]'.format(file, file_num, len(files)))
target = torch.from_numpy(transform(args.dataset_name, target = torch.from_numpy(transform(args.dataset_name,load(args.dataset_name,
load(args.dataset_name, os.path.join(root, file)))).float()
os.path.join(root,file)))).float()
res = train(smpl_layer, target,
logger, writer, device,
res = train(smpl_layer,target, args, cfg)
logger,writer,device,
args,cfg)
# save_pic(res,smpl_layer,file,logger,args.dataset_name,target) # save_pic(res,smpl_layer,file,logger,args.dataset_name,target)
save_params(res,file,logger, args.dataset_name) save_params(res, file, logger, args.dataset_name)