initial commit

This commit is contained in:
Iridoudou
2021-08-05 11:59:22 +08:00
parent a18b0197d8
commit 791f02f280
10 changed files with 275 additions and 6 deletions

3
.gitignore vendored
View File

@ -10,3 +10,6 @@ dist/
image.png image.png
smplpytorch/native/models/*.pkl smplpytorch/native/models/*.pkl
exp/
output/

View File

@ -5,13 +5,13 @@ from display_utils import display_model
if __name__ == '__main__': if __name__ == '__main__':
cuda = False cuda = True
batch_size = 1 batch_size = 1
# Create the SMPL layer # Create the SMPL layer
smpl_layer = SMPL_Layer( smpl_layer = SMPL_Layer(
center_idx=0, center_idx=0,
gender='neutral', gender='male',
model_root='smplpytorch/native/models') model_root='smplpytorch/native/models')
# Generate random pose and shape parameters # Generate random pose and shape parameters
@ -26,6 +26,7 @@ if __name__ == '__main__':
# Forward from the SMPL layer # Forward from the SMPL layer
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params) verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
print(Jtr)
# Draw output vertices and joints # Draw output vertices and joints
display_model( display_model(

View File

@ -12,7 +12,8 @@ def display_model(
ax=None, ax=None,
batch_idx=0, batch_idx=0,
show=True, show=True,
savepath=None): savepath=None,
only_joint=False):
""" """
Displays mesh batch_idx in batch of model_info, model_info as returned by Displays mesh batch_idx in batch of model_info, model_info as returned by
generate_random_model generate_random_model
@ -20,8 +21,7 @@ def display_model(
if ax is None: if ax is None:
fig = plt.figure() fig = plt.figure()
ax = fig.add_subplot(111, projection='3d') ax = fig.add_subplot(111, projection='3d')
verts, joints = model_info['verts'][batch_idx], model_info['joints'][ verts, joints = model_info['verts'][batch_idx], model_info['joints'][batch_idx]
batch_idx]
if model_faces is None: if model_faces is None:
ax.scatter(verts[:, 0], verts[:, 1], verts[:, 2], alpha=0.2) ax.scatter(verts[:, 0], verts[:, 1], verts[:, 2], alpha=0.2)
else: else:
@ -30,7 +30,8 @@ def display_model(
edge_color = (50 / 255, 50 / 255, 50 / 255) edge_color = (50 / 255, 50 / 255, 50 / 255)
mesh.set_edgecolor(edge_color) mesh.set_edgecolor(edge_color)
mesh.set_facecolor(face_color) mesh.set_facecolor(face_color)
ax.add_collection3d(mesh) if not only_joint:
ax.add_collection3d(mesh)
if with_joints: if with_joints:
draw_skeleton(joints, kintree_table=kintree_table, ax=ax) draw_skeleton(joints, kintree_table=kintree_table, ax=ax)
ax.set_xlabel('X') ax.set_xlabel('X')

25
fit/configs/config.json Normal file
View File

@ -0,0 +1,25 @@
{
"MODEL": {
"GENDER": "male"
},
"TRAIN": {
"LEARNING_RATE":2e-2,
"MAX_EPOCH": 5,
"WRITE": 1,
"SAVE": 10,
"BATCH_SIZE": 1,
"MOMENTUM": 0.9,
"lr_scheduler": {
"T_0": 10,
"T_mult": 2,
"eta_min": 1e-2
},
"loss_func": ""
},
"USE_GPU": 1,
"DATA_LOADER": {
"NUM_WORKERS": 1
},
"TARGET_PATH":"../Action2Motion/HumanAct12/HumanAct12/P01G01R01F0069T0143A0102.npy",
"DATASET_PATH":"../Action2Motion/HumanAct12/HumanAct12/"
}

116
fit/tools/main.py Normal file
View File

@ -0,0 +1,116 @@
import matplotlib as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules import module
from torch.optim import lr_scheduler
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import torchvision.datasets as dset
import torchvision.transforms as T
import numpy as np
from tensorboardX import SummaryWriter
from easydict import EasyDict as edict
import time
import inspect
import sys
import os
import logging
import argparse
import json
from tqdm import tqdm
sys.path.append(os.getcwd())
from display_utils import display_model
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
from train import train
from transform import transform
from save import save_pic
torch.backends.cudnn.benchmark=True
def parse_args():
parser = argparse.ArgumentParser(description='Fit SMPL')
parser.add_argument('--exp', dest='exp',
help='Define exp name',
default=time.strftime('%Y-%m-%d %H-%M-%S', time.localtime(time.time())), type=str)
parser.add_argument('--config_path', dest='config_path',
help='Select configuration file',
default='fit/configs/config.json', type=str)
parser.add_argument('--dataset_path', dest='dataset_path',
help='select dataset',
default='', type=str)
args = parser.parse_args()
return args
def get_config(args):
with open(args.config_path, 'r') as f:
data = json.load(f)
cfg = edict(data.copy())
return cfg
def set_device(USE_GPU):
if USE_GPU and torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
return device
def get_logger(cur_path):
logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO)
handler = logging.FileHandler(os.path.join(cur_path, "log.txt"))
handler.setLevel(logging.INFO)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
writer = SummaryWriter(os.path.join(cur_path, 'tb'))
return logger, writer
if __name__ == "__main__":
args = parse_args()
cur_path = os.path.join(os.getcwd(), 'exp', args.exp)
assert not os.path.exists(cur_path), 'Duplicate exp name'
os.mkdir(cur_path)
cfg = get_config(args)
json.dump(dict(cfg), open(os.path.join(cur_path, 'config.json'), 'w'))
logger, writer = get_logger(cur_path)
logger.info("Start print log")
device = set_device(USE_GPU=cfg.USE_GPU)
logger.info('using device: {}'.format(device))
smpl_layer = SMPL_Layer(
center_idx = 0,
gender='neutral',
model_root='smplpytorch/native/models')
for root,dirs,files in os.walk(cfg.DATASET_PATH):
for file in files:
logger.info('Processing file: {}'.format(file))
target_path=os.path.join(root,file)
target = np.array(transform(np.load(cfg.TARGET_PATH)))
logger.info('File shape: {}'.format(target.shape))
target = torch.from_numpy(target).float()
res = train(smpl_layer,target,
logger,writer,device,
args,cfg)
# save_pic(target,res,smpl_layer,file)

38
fit/tools/save.py Normal file
View File

@ -0,0 +1,38 @@
import sys
import os
import re
sys.path.append(os.getcwd())
from display_utils import display_model
def create_dir_not_exist(path):
if not os.path.exists(path):
os.mkdir(path)
def save_pic(target, res, smpl_layer, file):
pose_params, shape_params, verts, Jtr = res
name=re.split('[/.]',file)[-2]
gt_path="fit/output/HumanAct12/picture/gt/{}".format(name)
fit_path="fit/output/HumanAct12/picture/fit/{}".format(name)
create_dir_not_exist(gt_path)
create_dir_not_exist(fit_path)
for i in range(target.shape[0]):
display_model(
{'verts': verts.cpu().detach(),
'joints': target.cpu().detach()},
model_faces=smpl_layer.th_faces,
with_joints=True,
kintree_table=smpl_layer.kintree_table,
savepath=os.path.join(gt_path+"/frame_{}".format(i)),
batch_idx=i,
show=False,
only_joint=True)
display_model(
{'verts': verts.cpu().detach(),
'joints': Jtr.cpu().detach()},
model_faces=smpl_layer.th_faces,
with_joints=True,
kintree_table=smpl_layer.kintree_table,
savepath=os.path.join(fit_path+"/frame_{}".format(i)),
batch_idx=i,
show=False)

66
fit/tools/train.py Normal file
View File

@ -0,0 +1,66 @@
import matplotlib as plt
from matplotlib.pyplot import show
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules import module
from torch.optim import lr_scheduler
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.datasets as dset
import torchvision.transforms as T
import numpy as np
from tensorboardX import SummaryWriter
from easydict import EasyDict as edict
import time
import inspect
import sys
import os
import logging
import argparse
import json
from tqdm import tqdm
sys.path.append(os.getcwd())
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
from display_utils import display_model
def train(smpl_layer, target,
logger, writer, device,
args, cfg):
res = []
pose_params = torch.rand(target.shape[0], 72) * 0.0
shape_params = torch.rand(target.shape[0], 10) * 0.1
scale = torch.ones([1])
smpl_layer = smpl_layer.to(device)
pose_params = pose_params.to(device)
shape_params = shape_params.to(device)
target = target.to(device)
scale = scale.to(device)
pose_params.requires_grad = True
shape_params.requires_grad = True
scale.requires_grad = True
optimizer = optim.Adam([pose_params],
lr=cfg.TRAIN.LEARNING_RATE)
min_loss = float('inf')
for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)):
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
loss = F.smooth_l1_loss(Jtr * 100, target * 100)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if float(loss) < min_loss:
min_loss = float(loss)
res = [pose_params, shape_params, verts, Jtr]
if epoch % cfg.TRAIN.WRITE == 0:
# logger.info("Epoch {}, lossPerBatch={:.9f}, scale={:.6f}".format(
# epoch, float(loss), float(scale)))
writer.add_scalar('loss', float(loss), epoch)
writer.add_scalar('learning_rate', float(
optimizer.state_dict()['param_groups'][0]['lr']), epoch)
logger.info('Train ended, loss = {:.9f}'.format(float(loss)))
return res

12
fit/tools/transform.py Normal file
View File

@ -0,0 +1,12 @@
import numpy as np
def transform(arr: np.ndarray):
for i in range(arr.shape[0]):
origin = arr[i][0].copy()
for j in range(arr.shape[1]):
arr[i][j] -= origin
arr[i][j][1] *= -1
arr[i][j][2] *= -1
arr[i][0] = [0.0, 0.0, 0.0]
return arr

BIN
image.png

Binary file not shown.

Before

Width:  |  Height:  |  Size: 109 KiB

7
make_gif.py Normal file
View File

@ -0,0 +1,7 @@
import matplotlib.pyplot as plt
import imageio, os
images = []
filenames = sorted(fn for fn in os.listdir('./output/') )
for filename in filenames:
images.append(imageio.imread('./output/'+filename))
imageio.mimsave('./output/gif.gif', images, duration=0.5)