update readme
This commit is contained in:
@ -51,7 +51,7 @@ The SMPL human body layer for Pytorch is from the [smplpytorch](https://github.c
|
||||
|
||||
### 1. Executing Code
|
||||
|
||||
You can start the fitting procedure by the following code and the configuration file in *fit/configs* corresponding to the dataset_name will be loaded:
|
||||
You can start the fitting procedure by the following code and the configuration file in *fit/configs* corresponding to the dataset_name will be loaded (the dataset_path can also be set in the configuration file):
|
||||
|
||||
```
|
||||
python fit/tools/main.py --dataset_name [DATASET NAME] --dataset_path [DATASET PATH]
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from tensorboardX import SummaryWriter
|
||||
from easydict import EasyDict as edict
|
||||
@ -9,13 +10,13 @@ import logging
|
||||
import argparse
|
||||
import json
|
||||
sys.path.append(os.getcwd())
|
||||
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
|
||||
from train import train
|
||||
from transform import transform
|
||||
from save import save_pic, save_params
|
||||
from load import load
|
||||
import numpy as np
|
||||
torch.backends.cudnn.benchmark=True
|
||||
from save import save_pic, save_params
|
||||
from transform import transform
|
||||
from train import train
|
||||
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Fit SMPL')
|
||||
@ -31,8 +32,9 @@ def parse_args():
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def get_config(args):
|
||||
config_path='fit/configs/{}.json'.format(args.dataset_name)
|
||||
config_path = 'fit/configs/{}.json'.format(args.dataset_name)
|
||||
with open(config_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
cfg = edict(data.copy())
|
||||
@ -40,6 +42,7 @@ def get_config(args):
|
||||
cfg.DATASET.PATH = args.dataset_path
|
||||
return cfg
|
||||
|
||||
|
||||
def set_device(USE_GPU):
|
||||
if USE_GPU and torch.cuda.is_available():
|
||||
device = torch.device('cuda')
|
||||
@ -47,6 +50,7 @@ def set_device(USE_GPU):
|
||||
device = torch.device('cpu')
|
||||
return device
|
||||
|
||||
|
||||
def get_logger(cur_path):
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(level=logging.INFO)
|
||||
@ -69,6 +73,7 @@ def get_logger(cur_path):
|
||||
|
||||
return logger, writer
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
@ -86,24 +91,21 @@ if __name__ == "__main__":
|
||||
logger.info('using device: {}'.format(device))
|
||||
|
||||
smpl_layer = SMPL_Layer(
|
||||
center_idx = 0,
|
||||
center_idx=0,
|
||||
gender=cfg.MODEL.GENDER,
|
||||
model_root='smplpytorch/native/models')
|
||||
|
||||
file_num = 0
|
||||
for root,dirs,files in os.walk(cfg.DATASET.PATH):
|
||||
for root, dirs, files in os.walk(cfg.DATASET.PATH):
|
||||
for file in files:
|
||||
file_num += 1
|
||||
logger.info('Processing file: {} [{} / {}]'.format(file,file_num,len(files)))
|
||||
target = torch.from_numpy(transform(args.dataset_name,
|
||||
load(args.dataset_name,
|
||||
os.path.join(root,file)))).float()
|
||||
logger.info('Processing file: {} [{} / {}]'.format(file, file_num, len(files)))
|
||||
target = torch.from_numpy(transform(args.dataset_name,load(args.dataset_name,
|
||||
os.path.join(root, file)))).float()
|
||||
|
||||
|
||||
res = train(smpl_layer,target,
|
||||
logger,writer,device,
|
||||
args,cfg)
|
||||
res = train(smpl_layer, target,
|
||||
logger, writer, device,
|
||||
args, cfg)
|
||||
|
||||
# save_pic(res,smpl_layer,file,logger,args.dataset_name,target)
|
||||
save_params(res,file,logger, args.dataset_name)
|
||||
|
||||
save_params(res, file, logger, args.dataset_name)
|
||||
|
||||
Reference in New Issue
Block a user