initial commit
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@ -10,3 +10,6 @@ dist/
|
|||||||
|
|
||||||
image.png
|
image.png
|
||||||
smplpytorch/native/models/*.pkl
|
smplpytorch/native/models/*.pkl
|
||||||
|
|
||||||
|
exp/
|
||||||
|
output/
|
||||||
5
demo.py
5
demo.py
@ -5,13 +5,13 @@ from display_utils import display_model
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
cuda = False
|
cuda = True
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
|
|
||||||
# Create the SMPL layer
|
# Create the SMPL layer
|
||||||
smpl_layer = SMPL_Layer(
|
smpl_layer = SMPL_Layer(
|
||||||
center_idx=0,
|
center_idx=0,
|
||||||
gender='neutral',
|
gender='male',
|
||||||
model_root='smplpytorch/native/models')
|
model_root='smplpytorch/native/models')
|
||||||
|
|
||||||
# Generate random pose and shape parameters
|
# Generate random pose and shape parameters
|
||||||
@ -26,6 +26,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
# Forward from the SMPL layer
|
# Forward from the SMPL layer
|
||||||
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
|
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
|
||||||
|
print(Jtr)
|
||||||
|
|
||||||
# Draw output vertices and joints
|
# Draw output vertices and joints
|
||||||
display_model(
|
display_model(
|
||||||
|
|||||||
@ -12,7 +12,8 @@ def display_model(
|
|||||||
ax=None,
|
ax=None,
|
||||||
batch_idx=0,
|
batch_idx=0,
|
||||||
show=True,
|
show=True,
|
||||||
savepath=None):
|
savepath=None,
|
||||||
|
only_joint=False):
|
||||||
"""
|
"""
|
||||||
Displays mesh batch_idx in batch of model_info, model_info as returned by
|
Displays mesh batch_idx in batch of model_info, model_info as returned by
|
||||||
generate_random_model
|
generate_random_model
|
||||||
@ -20,8 +21,7 @@ def display_model(
|
|||||||
if ax is None:
|
if ax is None:
|
||||||
fig = plt.figure()
|
fig = plt.figure()
|
||||||
ax = fig.add_subplot(111, projection='3d')
|
ax = fig.add_subplot(111, projection='3d')
|
||||||
verts, joints = model_info['verts'][batch_idx], model_info['joints'][
|
verts, joints = model_info['verts'][batch_idx], model_info['joints'][batch_idx]
|
||||||
batch_idx]
|
|
||||||
if model_faces is None:
|
if model_faces is None:
|
||||||
ax.scatter(verts[:, 0], verts[:, 1], verts[:, 2], alpha=0.2)
|
ax.scatter(verts[:, 0], verts[:, 1], verts[:, 2], alpha=0.2)
|
||||||
else:
|
else:
|
||||||
@ -30,7 +30,8 @@ def display_model(
|
|||||||
edge_color = (50 / 255, 50 / 255, 50 / 255)
|
edge_color = (50 / 255, 50 / 255, 50 / 255)
|
||||||
mesh.set_edgecolor(edge_color)
|
mesh.set_edgecolor(edge_color)
|
||||||
mesh.set_facecolor(face_color)
|
mesh.set_facecolor(face_color)
|
||||||
ax.add_collection3d(mesh)
|
if not only_joint:
|
||||||
|
ax.add_collection3d(mesh)
|
||||||
if with_joints:
|
if with_joints:
|
||||||
draw_skeleton(joints, kintree_table=kintree_table, ax=ax)
|
draw_skeleton(joints, kintree_table=kintree_table, ax=ax)
|
||||||
ax.set_xlabel('X')
|
ax.set_xlabel('X')
|
||||||
|
|||||||
25
fit/configs/config.json
Normal file
25
fit/configs/config.json
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
{
|
||||||
|
"MODEL": {
|
||||||
|
"GENDER": "male"
|
||||||
|
},
|
||||||
|
"TRAIN": {
|
||||||
|
"LEARNING_RATE":2e-2,
|
||||||
|
"MAX_EPOCH": 5,
|
||||||
|
"WRITE": 1,
|
||||||
|
"SAVE": 10,
|
||||||
|
"BATCH_SIZE": 1,
|
||||||
|
"MOMENTUM": 0.9,
|
||||||
|
"lr_scheduler": {
|
||||||
|
"T_0": 10,
|
||||||
|
"T_mult": 2,
|
||||||
|
"eta_min": 1e-2
|
||||||
|
},
|
||||||
|
"loss_func": ""
|
||||||
|
},
|
||||||
|
"USE_GPU": 1,
|
||||||
|
"DATA_LOADER": {
|
||||||
|
"NUM_WORKERS": 1
|
||||||
|
},
|
||||||
|
"TARGET_PATH":"../Action2Motion/HumanAct12/HumanAct12/P01G01R01F0069T0143A0102.npy",
|
||||||
|
"DATASET_PATH":"../Action2Motion/HumanAct12/HumanAct12/"
|
||||||
|
}
|
||||||
116
fit/tools/main.py
Normal file
116
fit/tools/main.py
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
import matplotlib as plt
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn.modules import module
|
||||||
|
from torch.optim import lr_scheduler
|
||||||
|
import torch.optim as optim
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.data import sampler
|
||||||
|
import torchvision.datasets as dset
|
||||||
|
import torchvision.transforms as T
|
||||||
|
import numpy as np
|
||||||
|
from tensorboardX import SummaryWriter
|
||||||
|
from easydict import EasyDict as edict
|
||||||
|
import time
|
||||||
|
import inspect
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
from tqdm import tqdm
|
||||||
|
sys.path.append(os.getcwd())
|
||||||
|
from display_utils import display_model
|
||||||
|
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
|
||||||
|
from train import train
|
||||||
|
from transform import transform
|
||||||
|
from save import save_pic
|
||||||
|
torch.backends.cudnn.benchmark=True
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description='Fit SMPL')
|
||||||
|
parser.add_argument('--exp', dest='exp',
|
||||||
|
help='Define exp name',
|
||||||
|
default=time.strftime('%Y-%m-%d %H-%M-%S', time.localtime(time.time())), type=str)
|
||||||
|
parser.add_argument('--config_path', dest='config_path',
|
||||||
|
help='Select configuration file',
|
||||||
|
default='fit/configs/config.json', type=str)
|
||||||
|
parser.add_argument('--dataset_path', dest='dataset_path',
|
||||||
|
help='select dataset',
|
||||||
|
default='', type=str)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
def get_config(args):
|
||||||
|
with open(args.config_path, 'r') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
cfg = edict(data.copy())
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
def set_device(USE_GPU):
|
||||||
|
if USE_GPU and torch.cuda.is_available():
|
||||||
|
device = torch.device('cuda')
|
||||||
|
else:
|
||||||
|
device = torch.device('cpu')
|
||||||
|
return device
|
||||||
|
|
||||||
|
def get_logger(cur_path):
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.setLevel(level=logging.INFO)
|
||||||
|
|
||||||
|
handler = logging.FileHandler(os.path.join(cur_path, "log.txt"))
|
||||||
|
handler.setLevel(logging.INFO)
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(handler)
|
||||||
|
|
||||||
|
handler = logging.StreamHandler()
|
||||||
|
handler.setLevel(logging.INFO)
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(handler)
|
||||||
|
|
||||||
|
writer = SummaryWriter(os.path.join(cur_path, 'tb'))
|
||||||
|
|
||||||
|
return logger, writer
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
cur_path = os.path.join(os.getcwd(), 'exp', args.exp)
|
||||||
|
assert not os.path.exists(cur_path), 'Duplicate exp name'
|
||||||
|
os.mkdir(cur_path)
|
||||||
|
|
||||||
|
cfg = get_config(args)
|
||||||
|
json.dump(dict(cfg), open(os.path.join(cur_path, 'config.json'), 'w'))
|
||||||
|
|
||||||
|
logger, writer = get_logger(cur_path)
|
||||||
|
logger.info("Start print log")
|
||||||
|
|
||||||
|
device = set_device(USE_GPU=cfg.USE_GPU)
|
||||||
|
logger.info('using device: {}'.format(device))
|
||||||
|
|
||||||
|
smpl_layer = SMPL_Layer(
|
||||||
|
center_idx = 0,
|
||||||
|
gender='neutral',
|
||||||
|
model_root='smplpytorch/native/models')
|
||||||
|
|
||||||
|
for root,dirs,files in os.walk(cfg.DATASET_PATH):
|
||||||
|
for file in files:
|
||||||
|
logger.info('Processing file: {}'.format(file))
|
||||||
|
target_path=os.path.join(root,file)
|
||||||
|
|
||||||
|
target = np.array(transform(np.load(cfg.TARGET_PATH)))
|
||||||
|
logger.info('File shape: {}'.format(target.shape))
|
||||||
|
target = torch.from_numpy(target).float()
|
||||||
|
|
||||||
|
res = train(smpl_layer,target,
|
||||||
|
logger,writer,device,
|
||||||
|
args,cfg)
|
||||||
|
|
||||||
|
# save_pic(target,res,smpl_layer,file)
|
||||||
|
|
||||||
38
fit/tools/save.py
Normal file
38
fit/tools/save.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
|
sys.path.append(os.getcwd())
|
||||||
|
from display_utils import display_model
|
||||||
|
|
||||||
|
def create_dir_not_exist(path):
|
||||||
|
if not os.path.exists(path):
|
||||||
|
os.mkdir(path)
|
||||||
|
|
||||||
|
def save_pic(target, res, smpl_layer, file):
|
||||||
|
pose_params, shape_params, verts, Jtr = res
|
||||||
|
name=re.split('[/.]',file)[-2]
|
||||||
|
gt_path="fit/output/HumanAct12/picture/gt/{}".format(name)
|
||||||
|
fit_path="fit/output/HumanAct12/picture/fit/{}".format(name)
|
||||||
|
create_dir_not_exist(gt_path)
|
||||||
|
create_dir_not_exist(fit_path)
|
||||||
|
for i in range(target.shape[0]):
|
||||||
|
display_model(
|
||||||
|
{'verts': verts.cpu().detach(),
|
||||||
|
'joints': target.cpu().detach()},
|
||||||
|
model_faces=smpl_layer.th_faces,
|
||||||
|
with_joints=True,
|
||||||
|
kintree_table=smpl_layer.kintree_table,
|
||||||
|
savepath=os.path.join(gt_path+"/frame_{}".format(i)),
|
||||||
|
batch_idx=i,
|
||||||
|
show=False,
|
||||||
|
only_joint=True)
|
||||||
|
display_model(
|
||||||
|
{'verts': verts.cpu().detach(),
|
||||||
|
'joints': Jtr.cpu().detach()},
|
||||||
|
model_faces=smpl_layer.th_faces,
|
||||||
|
with_joints=True,
|
||||||
|
kintree_table=smpl_layer.kintree_table,
|
||||||
|
savepath=os.path.join(fit_path+"/frame_{}".format(i)),
|
||||||
|
batch_idx=i,
|
||||||
|
show=False)
|
||||||
66
fit/tools/train.py
Normal file
66
fit/tools/train.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
import matplotlib as plt
|
||||||
|
from matplotlib.pyplot import show
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn.modules import module
|
||||||
|
from torch.optim import lr_scheduler
|
||||||
|
import torch.optim as optim
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
import torchvision.datasets as dset
|
||||||
|
import torchvision.transforms as T
|
||||||
|
import numpy as np
|
||||||
|
from tensorboardX import SummaryWriter
|
||||||
|
from easydict import EasyDict as edict
|
||||||
|
import time
|
||||||
|
import inspect
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
from tqdm import tqdm
|
||||||
|
sys.path.append(os.getcwd())
|
||||||
|
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
|
||||||
|
from display_utils import display_model
|
||||||
|
|
||||||
|
def train(smpl_layer, target,
|
||||||
|
logger, writer, device,
|
||||||
|
args, cfg):
|
||||||
|
res = []
|
||||||
|
pose_params = torch.rand(target.shape[0], 72) * 0.0
|
||||||
|
shape_params = torch.rand(target.shape[0], 10) * 0.1
|
||||||
|
scale = torch.ones([1])
|
||||||
|
|
||||||
|
smpl_layer = smpl_layer.to(device)
|
||||||
|
pose_params = pose_params.to(device)
|
||||||
|
shape_params = shape_params.to(device)
|
||||||
|
target = target.to(device)
|
||||||
|
scale = scale.to(device)
|
||||||
|
|
||||||
|
pose_params.requires_grad = True
|
||||||
|
shape_params.requires_grad = True
|
||||||
|
scale.requires_grad = True
|
||||||
|
|
||||||
|
optimizer = optim.Adam([pose_params],
|
||||||
|
lr=cfg.TRAIN.LEARNING_RATE)
|
||||||
|
|
||||||
|
min_loss = float('inf')
|
||||||
|
for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)):
|
||||||
|
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
|
||||||
|
loss = F.smooth_l1_loss(Jtr * 100, target * 100)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
if float(loss) < min_loss:
|
||||||
|
min_loss = float(loss)
|
||||||
|
res = [pose_params, shape_params, verts, Jtr]
|
||||||
|
if epoch % cfg.TRAIN.WRITE == 0:
|
||||||
|
# logger.info("Epoch {}, lossPerBatch={:.9f}, scale={:.6f}".format(
|
||||||
|
# epoch, float(loss), float(scale)))
|
||||||
|
writer.add_scalar('loss', float(loss), epoch)
|
||||||
|
writer.add_scalar('learning_rate', float(
|
||||||
|
optimizer.state_dict()['param_groups'][0]['lr']), epoch)
|
||||||
|
logger.info('Train ended, loss = {:.9f}'.format(float(loss)))
|
||||||
|
return res
|
||||||
12
fit/tools/transform.py
Normal file
12
fit/tools/transform.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def transform(arr: np.ndarray):
|
||||||
|
for i in range(arr.shape[0]):
|
||||||
|
origin = arr[i][0].copy()
|
||||||
|
for j in range(arr.shape[1]):
|
||||||
|
arr[i][j] -= origin
|
||||||
|
arr[i][j][1] *= -1
|
||||||
|
arr[i][j][2] *= -1
|
||||||
|
arr[i][0] = [0.0, 0.0, 0.0]
|
||||||
|
return arr
|
||||||
7
make_gif.py
Normal file
7
make_gif.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import imageio, os
|
||||||
|
images = []
|
||||||
|
filenames = sorted(fn for fn in os.listdir('./output/') )
|
||||||
|
for filename in filenames:
|
||||||
|
images.append(imageio.imread('./output/'+filename))
|
||||||
|
imageio.mimsave('./output/gif.gif', images, duration=0.5)
|
||||||
Reference in New Issue
Block a user