initial commit
This commit is contained in:
116
fit/tools/main.py
Normal file
116
fit/tools/main.py
Normal file
@ -0,0 +1,116 @@
|
||||
import matplotlib as plt
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.modules import module
|
||||
from torch.optim import lr_scheduler
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import sampler
|
||||
import torchvision.datasets as dset
|
||||
import torchvision.transforms as T
|
||||
import numpy as np
|
||||
from tensorboardX import SummaryWriter
|
||||
from easydict import EasyDict as edict
|
||||
import time
|
||||
import inspect
|
||||
import sys
|
||||
import os
|
||||
import logging
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
sys.path.append(os.getcwd())
|
||||
from display_utils import display_model
|
||||
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
|
||||
from train import train
|
||||
from transform import transform
|
||||
from save import save_pic
|
||||
torch.backends.cudnn.benchmark=True
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Fit SMPL')
|
||||
parser.add_argument('--exp', dest='exp',
|
||||
help='Define exp name',
|
||||
default=time.strftime('%Y-%m-%d %H-%M-%S', time.localtime(time.time())), type=str)
|
||||
parser.add_argument('--config_path', dest='config_path',
|
||||
help='Select configuration file',
|
||||
default='fit/configs/config.json', type=str)
|
||||
parser.add_argument('--dataset_path', dest='dataset_path',
|
||||
help='select dataset',
|
||||
default='', type=str)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
def get_config(args):
|
||||
with open(args.config_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
cfg = edict(data.copy())
|
||||
return cfg
|
||||
|
||||
def set_device(USE_GPU):
|
||||
if USE_GPU and torch.cuda.is_available():
|
||||
device = torch.device('cuda')
|
||||
else:
|
||||
device = torch.device('cpu')
|
||||
return device
|
||||
|
||||
def get_logger(cur_path):
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(level=logging.INFO)
|
||||
|
||||
handler = logging.FileHandler(os.path.join(cur_path, "log.txt"))
|
||||
handler.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
handler.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
writer = SummaryWriter(os.path.join(cur_path, 'tb'))
|
||||
|
||||
return logger, writer
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
cur_path = os.path.join(os.getcwd(), 'exp', args.exp)
|
||||
assert not os.path.exists(cur_path), 'Duplicate exp name'
|
||||
os.mkdir(cur_path)
|
||||
|
||||
cfg = get_config(args)
|
||||
json.dump(dict(cfg), open(os.path.join(cur_path, 'config.json'), 'w'))
|
||||
|
||||
logger, writer = get_logger(cur_path)
|
||||
logger.info("Start print log")
|
||||
|
||||
device = set_device(USE_GPU=cfg.USE_GPU)
|
||||
logger.info('using device: {}'.format(device))
|
||||
|
||||
smpl_layer = SMPL_Layer(
|
||||
center_idx = 0,
|
||||
gender='neutral',
|
||||
model_root='smplpytorch/native/models')
|
||||
|
||||
for root,dirs,files in os.walk(cfg.DATASET_PATH):
|
||||
for file in files:
|
||||
logger.info('Processing file: {}'.format(file))
|
||||
target_path=os.path.join(root,file)
|
||||
|
||||
target = np.array(transform(np.load(cfg.TARGET_PATH)))
|
||||
logger.info('File shape: {}'.format(target.shape))
|
||||
target = torch.from_numpy(target).float()
|
||||
|
||||
res = train(smpl_layer,target,
|
||||
logger,writer,device,
|
||||
args,cfg)
|
||||
|
||||
# save_pic(target,res,smpl_layer,file)
|
||||
|
||||
Reference in New Issue
Block a user