initial commit
This commit is contained in:
25
fit/configs/config.json
Normal file
25
fit/configs/config.json
Normal file
@ -0,0 +1,25 @@
|
||||
{
|
||||
"MODEL": {
|
||||
"GENDER": "male"
|
||||
},
|
||||
"TRAIN": {
|
||||
"LEARNING_RATE":2e-2,
|
||||
"MAX_EPOCH": 5,
|
||||
"WRITE": 1,
|
||||
"SAVE": 10,
|
||||
"BATCH_SIZE": 1,
|
||||
"MOMENTUM": 0.9,
|
||||
"lr_scheduler": {
|
||||
"T_0": 10,
|
||||
"T_mult": 2,
|
||||
"eta_min": 1e-2
|
||||
},
|
||||
"loss_func": ""
|
||||
},
|
||||
"USE_GPU": 1,
|
||||
"DATA_LOADER": {
|
||||
"NUM_WORKERS": 1
|
||||
},
|
||||
"TARGET_PATH":"../Action2Motion/HumanAct12/HumanAct12/P01G01R01F0069T0143A0102.npy",
|
||||
"DATASET_PATH":"../Action2Motion/HumanAct12/HumanAct12/"
|
||||
}
|
||||
116
fit/tools/main.py
Normal file
116
fit/tools/main.py
Normal file
@ -0,0 +1,116 @@
|
||||
import matplotlib as plt
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.modules import module
|
||||
from torch.optim import lr_scheduler
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import sampler
|
||||
import torchvision.datasets as dset
|
||||
import torchvision.transforms as T
|
||||
import numpy as np
|
||||
from tensorboardX import SummaryWriter
|
||||
from easydict import EasyDict as edict
|
||||
import time
|
||||
import inspect
|
||||
import sys
|
||||
import os
|
||||
import logging
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
sys.path.append(os.getcwd())
|
||||
from display_utils import display_model
|
||||
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
|
||||
from train import train
|
||||
from transform import transform
|
||||
from save import save_pic
|
||||
torch.backends.cudnn.benchmark=True
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Fit SMPL')
|
||||
parser.add_argument('--exp', dest='exp',
|
||||
help='Define exp name',
|
||||
default=time.strftime('%Y-%m-%d %H-%M-%S', time.localtime(time.time())), type=str)
|
||||
parser.add_argument('--config_path', dest='config_path',
|
||||
help='Select configuration file',
|
||||
default='fit/configs/config.json', type=str)
|
||||
parser.add_argument('--dataset_path', dest='dataset_path',
|
||||
help='select dataset',
|
||||
default='', type=str)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
def get_config(args):
|
||||
with open(args.config_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
cfg = edict(data.copy())
|
||||
return cfg
|
||||
|
||||
def set_device(USE_GPU):
|
||||
if USE_GPU and torch.cuda.is_available():
|
||||
device = torch.device('cuda')
|
||||
else:
|
||||
device = torch.device('cpu')
|
||||
return device
|
||||
|
||||
def get_logger(cur_path):
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(level=logging.INFO)
|
||||
|
||||
handler = logging.FileHandler(os.path.join(cur_path, "log.txt"))
|
||||
handler.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
handler.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
writer = SummaryWriter(os.path.join(cur_path, 'tb'))
|
||||
|
||||
return logger, writer
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
cur_path = os.path.join(os.getcwd(), 'exp', args.exp)
|
||||
assert not os.path.exists(cur_path), 'Duplicate exp name'
|
||||
os.mkdir(cur_path)
|
||||
|
||||
cfg = get_config(args)
|
||||
json.dump(dict(cfg), open(os.path.join(cur_path, 'config.json'), 'w'))
|
||||
|
||||
logger, writer = get_logger(cur_path)
|
||||
logger.info("Start print log")
|
||||
|
||||
device = set_device(USE_GPU=cfg.USE_GPU)
|
||||
logger.info('using device: {}'.format(device))
|
||||
|
||||
smpl_layer = SMPL_Layer(
|
||||
center_idx = 0,
|
||||
gender='neutral',
|
||||
model_root='smplpytorch/native/models')
|
||||
|
||||
for root,dirs,files in os.walk(cfg.DATASET_PATH):
|
||||
for file in files:
|
||||
logger.info('Processing file: {}'.format(file))
|
||||
target_path=os.path.join(root,file)
|
||||
|
||||
target = np.array(transform(np.load(cfg.TARGET_PATH)))
|
||||
logger.info('File shape: {}'.format(target.shape))
|
||||
target = torch.from_numpy(target).float()
|
||||
|
||||
res = train(smpl_layer,target,
|
||||
logger,writer,device,
|
||||
args,cfg)
|
||||
|
||||
# save_pic(target,res,smpl_layer,file)
|
||||
|
||||
38
fit/tools/save.py
Normal file
38
fit/tools/save.py
Normal file
@ -0,0 +1,38 @@
|
||||
import sys
|
||||
import os
|
||||
import re
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
from display_utils import display_model
|
||||
|
||||
def create_dir_not_exist(path):
|
||||
if not os.path.exists(path):
|
||||
os.mkdir(path)
|
||||
|
||||
def save_pic(target, res, smpl_layer, file):
|
||||
pose_params, shape_params, verts, Jtr = res
|
||||
name=re.split('[/.]',file)[-2]
|
||||
gt_path="fit/output/HumanAct12/picture/gt/{}".format(name)
|
||||
fit_path="fit/output/HumanAct12/picture/fit/{}".format(name)
|
||||
create_dir_not_exist(gt_path)
|
||||
create_dir_not_exist(fit_path)
|
||||
for i in range(target.shape[0]):
|
||||
display_model(
|
||||
{'verts': verts.cpu().detach(),
|
||||
'joints': target.cpu().detach()},
|
||||
model_faces=smpl_layer.th_faces,
|
||||
with_joints=True,
|
||||
kintree_table=smpl_layer.kintree_table,
|
||||
savepath=os.path.join(gt_path+"/frame_{}".format(i)),
|
||||
batch_idx=i,
|
||||
show=False,
|
||||
only_joint=True)
|
||||
display_model(
|
||||
{'verts': verts.cpu().detach(),
|
||||
'joints': Jtr.cpu().detach()},
|
||||
model_faces=smpl_layer.th_faces,
|
||||
with_joints=True,
|
||||
kintree_table=smpl_layer.kintree_table,
|
||||
savepath=os.path.join(fit_path+"/frame_{}".format(i)),
|
||||
batch_idx=i,
|
||||
show=False)
|
||||
66
fit/tools/train.py
Normal file
66
fit/tools/train.py
Normal file
@ -0,0 +1,66 @@
|
||||
import matplotlib as plt
|
||||
from matplotlib.pyplot import show
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.modules import module
|
||||
from torch.optim import lr_scheduler
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
import torchvision.datasets as dset
|
||||
import torchvision.transforms as T
|
||||
import numpy as np
|
||||
from tensorboardX import SummaryWriter
|
||||
from easydict import EasyDict as edict
|
||||
import time
|
||||
import inspect
|
||||
import sys
|
||||
import os
|
||||
import logging
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
sys.path.append(os.getcwd())
|
||||
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
|
||||
from display_utils import display_model
|
||||
|
||||
def train(smpl_layer, target,
|
||||
logger, writer, device,
|
||||
args, cfg):
|
||||
res = []
|
||||
pose_params = torch.rand(target.shape[0], 72) * 0.0
|
||||
shape_params = torch.rand(target.shape[0], 10) * 0.1
|
||||
scale = torch.ones([1])
|
||||
|
||||
smpl_layer = smpl_layer.to(device)
|
||||
pose_params = pose_params.to(device)
|
||||
shape_params = shape_params.to(device)
|
||||
target = target.to(device)
|
||||
scale = scale.to(device)
|
||||
|
||||
pose_params.requires_grad = True
|
||||
shape_params.requires_grad = True
|
||||
scale.requires_grad = True
|
||||
|
||||
optimizer = optim.Adam([pose_params],
|
||||
lr=cfg.TRAIN.LEARNING_RATE)
|
||||
|
||||
min_loss = float('inf')
|
||||
for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)):
|
||||
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
|
||||
loss = F.smooth_l1_loss(Jtr * 100, target * 100)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
if float(loss) < min_loss:
|
||||
min_loss = float(loss)
|
||||
res = [pose_params, shape_params, verts, Jtr]
|
||||
if epoch % cfg.TRAIN.WRITE == 0:
|
||||
# logger.info("Epoch {}, lossPerBatch={:.9f}, scale={:.6f}".format(
|
||||
# epoch, float(loss), float(scale)))
|
||||
writer.add_scalar('loss', float(loss), epoch)
|
||||
writer.add_scalar('learning_rate', float(
|
||||
optimizer.state_dict()['param_groups'][0]['lr']), epoch)
|
||||
logger.info('Train ended, loss = {:.9f}'.format(float(loss)))
|
||||
return res
|
||||
12
fit/tools/transform.py
Normal file
12
fit/tools/transform.py
Normal file
@ -0,0 +1,12 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def transform(arr: np.ndarray):
|
||||
for i in range(arr.shape[0]):
|
||||
origin = arr[i][0].copy()
|
||||
for j in range(arr.shape[1]):
|
||||
arr[i][j] -= origin
|
||||
arr[i][j][1] *= -1
|
||||
arr[i][j][2] *= -1
|
||||
arr[i][0] = [0.0, 0.0, 0.0]
|
||||
return arr
|
||||
Reference in New Issue
Block a user