121 lines
4.0 KiB
Python
121 lines
4.0 KiB
Python
import torch
|
|
import numpy as np
|
|
from tensorboardX import SummaryWriter
|
|
from easydict import EasyDict as edict
|
|
import time
|
|
import sys
|
|
import os
|
|
import logging
|
|
|
|
import argparse
|
|
import json
|
|
|
|
sys.path.append(os.getcwd())
|
|
from load import load
|
|
from save import save_pic, save_params
|
|
from transform import transform
|
|
from train import train
|
|
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
|
|
from meters import Meters
|
|
|
|
torch.backends.cudnn.enabled = True
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='Fit SMPL')
|
|
parser.add_argument('--exp', dest='exp',
|
|
help='Define exp name',
|
|
default=time.strftime('%Y-%m-%d %H-%M-%S', time.localtime(time.time())), type=str)
|
|
parser.add_argument('--dataset_name', dest='dataset_name',
|
|
help='select dataset',
|
|
default='', type=str)
|
|
parser.add_argument('--dataset_path', dest='dataset_path',
|
|
help='path of dataset',
|
|
default=None, type=str)
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def get_config(args):
|
|
config_path = 'fit/configs/{}.json'.format(args.dataset_name)
|
|
with open(config_path, 'r') as f:
|
|
data = json.load(f)
|
|
cfg = edict(data.copy())
|
|
if not args.dataset_path == None:
|
|
cfg.DATASET.PATH = args.dataset_path
|
|
return cfg
|
|
|
|
|
|
def set_device(USE_GPU):
|
|
if USE_GPU and torch.cuda.is_available():
|
|
device = torch.device('cuda')
|
|
else:
|
|
device = torch.device('cpu')
|
|
return device
|
|
|
|
|
|
def get_logger(cur_path):
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(level=logging.INFO)
|
|
|
|
handler = logging.FileHandler(os.path.join(cur_path, "log.txt"))
|
|
handler.setLevel(logging.INFO)
|
|
formatter = logging.Formatter(
|
|
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
handler.setFormatter(formatter)
|
|
logger.addHandler(handler)
|
|
|
|
handler = logging.StreamHandler()
|
|
handler.setLevel(logging.INFO)
|
|
formatter = logging.Formatter(
|
|
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
handler.setFormatter(formatter)
|
|
logger.addHandler(handler)
|
|
|
|
writer = SummaryWriter(os.path.join(cur_path, 'tb'))
|
|
|
|
return logger, writer
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parse_args()
|
|
|
|
cur_path = os.path.join(os.getcwd(), 'exp', args.exp)
|
|
assert not os.path.exists(cur_path), 'Duplicate exp name'
|
|
os.mkdir(cur_path)
|
|
|
|
cfg = get_config(args)
|
|
json.dump(dict(cfg), open(os.path.join(cur_path, 'config.json'), 'w'))
|
|
|
|
logger, writer = get_logger(cur_path)
|
|
logger.info("Start print log")
|
|
|
|
device = set_device(USE_GPU=cfg.USE_GPU)
|
|
logger.info('using device: {}'.format(device))
|
|
|
|
smpl_layer = SMPL_Layer(
|
|
center_idx=0,
|
|
gender=cfg.MODEL.GENDER,
|
|
model_root='smplpytorch/native/models')
|
|
|
|
meters=Meters()
|
|
file_num = 0
|
|
for root, dirs, files in os.walk(cfg.DATASET.PATH):
|
|
for file in files:
|
|
file_num += 1
|
|
logger.info('Processing file: {} [{} / {}]'.format(file, file_num, len(files)))
|
|
target = torch.from_numpy(transform(args.dataset_name,load(args.dataset_name,
|
|
os.path.join(root, file)))).float()
|
|
logger.info("target shape:{}".format(target.shape))
|
|
res = train(smpl_layer, target,
|
|
logger, writer, device,
|
|
args, cfg, meters)
|
|
meters.update_avg(meters.min_loss, k=target.shape[0])
|
|
meters.reset_early_stop()
|
|
logger.info("avg_loss:{:.4f}".format(meters.avg))
|
|
|
|
# save_pic(res,smpl_layer,file,logger,args.dataset_name,target)
|
|
save_params(res, file, logger, args.dataset_name)
|
|
torch.cuda.empty_cache()
|
|
logger.info("Fitting finished! Average loss: {:.9f}".format(meters.avg))
|