update readme
This commit is contained in:
@ -51,7 +51,7 @@ The SMPL human body layer for Pytorch is from the [smplpytorch](https://github.c
|
|||||||
|
|
||||||
### 1. Executing Code
|
### 1. Executing Code
|
||||||
|
|
||||||
You can start the fitting procedure by the following code and the configuration file in *fit/configs* corresponding to the dataset_name will be loaded:
|
You can start the fitting procedure by the following code and the configuration file in *fit/configs* corresponding to the dataset_name will be loaded (the dataset_path can also be set in the configuration file):
|
||||||
|
|
||||||
```
|
```
|
||||||
python fit/tools/main.py --dataset_name [DATASET NAME] --dataset_path [DATASET PATH]
|
python fit/tools/main.py --dataset_name [DATASET NAME] --dataset_path [DATASET PATH]
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
from easydict import EasyDict as edict
|
from easydict import EasyDict as edict
|
||||||
@ -9,13 +10,13 @@ import logging
|
|||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
sys.path.append(os.getcwd())
|
sys.path.append(os.getcwd())
|
||||||
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
|
|
||||||
from train import train
|
|
||||||
from transform import transform
|
|
||||||
from save import save_pic, save_params
|
|
||||||
from load import load
|
from load import load
|
||||||
import numpy as np
|
from save import save_pic, save_params
|
||||||
torch.backends.cudnn.benchmark=True
|
from transform import transform
|
||||||
|
from train import train
|
||||||
|
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description='Fit SMPL')
|
parser = argparse.ArgumentParser(description='Fit SMPL')
|
||||||
@ -31,8 +32,9 @@ def parse_args():
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
def get_config(args):
|
def get_config(args):
|
||||||
config_path='fit/configs/{}.json'.format(args.dataset_name)
|
config_path = 'fit/configs/{}.json'.format(args.dataset_name)
|
||||||
with open(config_path, 'r') as f:
|
with open(config_path, 'r') as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
cfg = edict(data.copy())
|
cfg = edict(data.copy())
|
||||||
@ -40,6 +42,7 @@ def get_config(args):
|
|||||||
cfg.DATASET.PATH = args.dataset_path
|
cfg.DATASET.PATH = args.dataset_path
|
||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
def set_device(USE_GPU):
|
def set_device(USE_GPU):
|
||||||
if USE_GPU and torch.cuda.is_available():
|
if USE_GPU and torch.cuda.is_available():
|
||||||
device = torch.device('cuda')
|
device = torch.device('cuda')
|
||||||
@ -47,6 +50,7 @@ def set_device(USE_GPU):
|
|||||||
device = torch.device('cpu')
|
device = torch.device('cpu')
|
||||||
return device
|
return device
|
||||||
|
|
||||||
|
|
||||||
def get_logger(cur_path):
|
def get_logger(cur_path):
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logger.setLevel(level=logging.INFO)
|
logger.setLevel(level=logging.INFO)
|
||||||
@ -69,6 +73,7 @@ def get_logger(cur_path):
|
|||||||
|
|
||||||
return logger, writer
|
return logger, writer
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
@ -84,26 +89,23 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
device = set_device(USE_GPU=cfg.USE_GPU)
|
device = set_device(USE_GPU=cfg.USE_GPU)
|
||||||
logger.info('using device: {}'.format(device))
|
logger.info('using device: {}'.format(device))
|
||||||
|
|
||||||
smpl_layer = SMPL_Layer(
|
smpl_layer = SMPL_Layer(
|
||||||
center_idx = 0,
|
center_idx=0,
|
||||||
gender=cfg.MODEL.GENDER,
|
gender=cfg.MODEL.GENDER,
|
||||||
model_root='smplpytorch/native/models')
|
model_root='smplpytorch/native/models')
|
||||||
|
|
||||||
file_num = 0
|
file_num = 0
|
||||||
for root,dirs,files in os.walk(cfg.DATASET.PATH):
|
for root, dirs, files in os.walk(cfg.DATASET.PATH):
|
||||||
for file in files:
|
for file in files:
|
||||||
file_num += 1
|
file_num += 1
|
||||||
logger.info('Processing file: {} [{} / {}]'.format(file,file_num,len(files)))
|
logger.info('Processing file: {} [{} / {}]'.format(file, file_num, len(files)))
|
||||||
target = torch.from_numpy(transform(args.dataset_name,
|
target = torch.from_numpy(transform(args.dataset_name,load(args.dataset_name,
|
||||||
load(args.dataset_name,
|
os.path.join(root, file)))).float()
|
||||||
os.path.join(root,file)))).float()
|
|
||||||
|
res = train(smpl_layer, target,
|
||||||
|
logger, writer, device,
|
||||||
res = train(smpl_layer,target,
|
args, cfg)
|
||||||
logger,writer,device,
|
|
||||||
args,cfg)
|
|
||||||
|
|
||||||
# save_pic(res,smpl_layer,file,logger,args.dataset_name,target)
|
# save_pic(res,smpl_layer,file,logger,args.dataset_name,target)
|
||||||
save_params(res,file,logger, args.dataset_name)
|
save_params(res, file, logger, args.dataset_name)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user