Files
Pose_to_SMPL_an_230402/fit/tools/main.py
2021-08-10 10:40:13 +08:00

112 lines
3.5 KiB
Python

import numpy as np
import torch
from tensorboardX import SummaryWriter
from easydict import EasyDict as edict
import time
import sys
import os
import logging
import argparse
import json
sys.path.append(os.getcwd())
from load import load
from save import save_pic, save_params
from transform import transform
from train import train
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
torch.backends.cudnn.benchmark = True
def parse_args():
parser = argparse.ArgumentParser(description='Fit SMPL')
parser.add_argument('--exp', dest='exp',
help='Define exp name',
default=time.strftime('%Y-%m-%d %H-%M-%S', time.localtime(time.time())), type=str)
parser.add_argument('--dataset_name', dest='dataset_name',
help='select dataset',
default='', type=str)
parser.add_argument('--dataset_path', dest='dataset_path',
help='path of dataset',
default=None, type=str)
args = parser.parse_args()
return args
def get_config(args):
config_path = 'fit/configs/{}.json'.format(args.dataset_name)
with open(config_path, 'r') as f:
data = json.load(f)
cfg = edict(data.copy())
if not args.dataset_path == None:
cfg.DATASET.PATH = args.dataset_path
return cfg
def set_device(USE_GPU):
if USE_GPU and torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
return device
def get_logger(cur_path):
logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO)
handler = logging.FileHandler(os.path.join(cur_path, "log.txt"))
handler.setLevel(logging.INFO)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
writer = SummaryWriter(os.path.join(cur_path, 'tb'))
return logger, writer
if __name__ == "__main__":
args = parse_args()
cur_path = os.path.join(os.getcwd(), 'exp', args.exp)
assert not os.path.exists(cur_path), 'Duplicate exp name'
os.mkdir(cur_path)
cfg = get_config(args)
json.dump(dict(cfg), open(os.path.join(cur_path, 'config.json'), 'w'))
logger, writer = get_logger(cur_path)
logger.info("Start print log")
device = set_device(USE_GPU=cfg.USE_GPU)
logger.info('using device: {}'.format(device))
smpl_layer = SMPL_Layer(
center_idx=0,
gender=cfg.MODEL.GENDER,
model_root='smplpytorch/native/models')
file_num = 0
for root, dirs, files in os.walk(cfg.DATASET.PATH):
for file in files:
file_num += 1
logger.info('Processing file: {} [{} / {}]'.format(file, file_num, len(files)))
target = torch.from_numpy(transform(args.dataset_name,load(args.dataset_name,
os.path.join(root, file)))).float()
res = train(smpl_layer, target,
logger, writer, device,
args, cfg)
# save_pic(res,smpl_layer,file,logger,args.dataset_name,target)
save_params(res, file, logger, args.dataset_name)