update code

This commit is contained in:
Iridoudou
2021-08-07 08:03:40 +08:00
parent 0152f1909f
commit 9f6274fc19
8 changed files with 160 additions and 65 deletions

View File

@ -3,23 +3,56 @@
"GENDER": "male" "GENDER": "male"
}, },
"TRAIN": { "TRAIN": {
"LEARNING_RATE":2e-2, "LEARNING_RATE": 5e-2,
"MAX_EPOCH": 500, "MAX_EPOCH": 500,
"WRITE": 1, "WRITE": 1,
"SAVE": 10, "SAVE": 100
"BATCH_SIZE": 1,
"MOMENTUM": 0.9,
"lr_scheduler": {
"T_0": 10,
"T_mult": 2,
"eta_min": 1e-2
},
"loss_func": ""
}, },
"USE_GPU": 1, "USE_GPU": 1,
"DATA_LOADER": { "DATASET": {
"NUM_WORKERS": 1 "NAME": "UTD-MHAD",
"PATH": "../UTD-MHAD/Skeleton/Skeleton/",
"TARGET_PATH": "../UTD-MHAD/Skeleton/Skeleton/a5_s7_t3_skeleton.mat",
"DATA_MAP": {
"UTD_MHAD": [
[
12,
0,
16,
18,
20,
22,
17,
19,
21,
23,
1,
4,
7,
2,
5,
8
],
[
1,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
16,
17,
18
]
]
}
}, },
"TARGET_PATH":"../Action2Motion/HumanAct12/HumanAct12/P01G01R01F0069T0143A0102.npy", "DEBUG": 1
"DATASET_PATH":"../Action2Motion/HumanAct12/HumanAct12/"
} }

16
fit/tools/load.py Normal file
View File

@ -0,0 +1,16 @@
import scipy.io
import numpy as np
def load(name, path):
if name == 'UTD_MHAD':
data = scipy.io.loadmat(path)
arr = data['d_skel']
new_arr = np.zeros([arr.shape[2], arr.shape[0], arr.shape[1]])
for i in range(arr.shape[2]):
for j in range(arr.shape[0]):
for k in range(arr.shape[1]):
new_arr[i][j][k] = arr[j][k][i]
return new_arr
elif name == 'HumanAct12':
return np.load(path)

View File

@ -27,6 +27,7 @@ from smplpytorch.pytorch.smpl_layer import SMPL_Layer
from train import train from train import train
from transform import transform from transform import transform
from save import save_pic,save_params from save import save_pic,save_params
from load import load
torch.backends.cudnn.benchmark=True torch.backends.cudnn.benchmark=True
def parse_args(): def parse_args():
@ -96,22 +97,34 @@ if __name__ == "__main__":
smpl_layer = SMPL_Layer( smpl_layer = SMPL_Layer(
center_idx = 0, center_idx = 0,
gender='neutral', gender=cfg.MODEL.GENDER,
model_root='smplpytorch/native/models') model_root='smplpytorch/native/models')
for root,dirs,files in os.walk(cfg.DATASET_PATH): if not cfg.DEBUG:
for file in files: for root,dirs,files in os.walk(cfg.DATASET_PATH):
logger.info('Processing file: {}'.format(file)) for file in files:
target_path=os.path.join(root,file) logger.info('Processing file: {}'.format(file))
target_path=os.path.join(root,file)
target = np.array(transform(np.load(target_path))) target = np.array(transform(np.load(target_path)))
logger.info('File shape: {}'.format(target.shape)) logger.info('File shape: {}'.format(target.shape))
target = torch.from_numpy(target).float() target = torch.from_numpy(target).float()
res = train(smpl_layer,target,
logger,writer,device,
args,cfg)
# save_pic(target,res,smpl_layer,file,logger) res = train(smpl_layer,target,
save_params(res,file,logger) logger,writer,device,
args,cfg)
# save_pic(target,res,smpl_layer,file,logger)
save_params(res,file,logger)
else:
target = np.array(transform(load('UTD_MHAD',cfg.DATASET.TARGET_PATH),
rotate=[-1,1,-1]))
target = torch.from_numpy(target).float()
data_map_dataset=torch.tensor(cfg.DATASET.DATA_MAP.UTD_MHAD[1])
target = target.index_select(1, data_map_dataset)
print(target.shape)
res = train(smpl_layer,target,
logger,writer,device,
args,cfg)

14
fit/tools/map.py Normal file
View File

@ -0,0 +1,14 @@
import numpy as np
def mapping(Jtr,cfg):
name=cfg.DATASET.NAME
if not name=='HumanAct12':
mapped_joint=cfg.DATASET.DATA_MAP.UTD_MHAD
Jtr_mapped=np.zeros([Jtr.shape[0],len(mapped_joint),Jtr.shape[2]])
for i in range(Jtr.shape[0]):
for j in range(len(mapped_joint)):
for k in range(Jtr.shape[2]):
Jtr_mapped[i][j][k]=Jtr[i][mapped_joint[j]][k]
return Jtr_mapped
return Jtr

View File

@ -9,18 +9,20 @@ import json
sys.path.append(os.getcwd()) sys.path.append(os.getcwd())
from display_utils import display_model from display_utils import display_model
def create_dir_not_exist(path): def create_dir_not_exist(path):
if not os.path.exists(path): if not os.path.exists(path):
os.mkdir(path) os.mkdir(path)
def save_pic(target, res, smpl_layer, file,logger):
def save_pic(target, res, smpl_layer, file, logger):
pose_params, shape_params, verts, Jtr = res pose_params, shape_params, verts, Jtr = res
name=re.split('[/.]',file)[-2] name = re.split('[/.]', file)[-2]
gt_path="fit/output/HumanAct12/picture/gt/{}".format(name) gt_path = "fit/output/HumanAct12/picture/gt/{}".format(name)
fit_path="fit/output/HumanAct12/picture/fit/{}".format(name) fit_path = "fit/output/HumanAct12/picture/fit/{}".format(name)
create_dir_not_exist(gt_path) create_dir_not_exist(gt_path)
create_dir_not_exist(fit_path) create_dir_not_exist(fit_path)
logger.info('Saving pictures at {} and {}'.format(gt_path,fit_path)) logger.info('Saving pictures at {} and {}'.format(gt_path, fit_path))
for i in tqdm(range(target.shape[0])): for i in tqdm(range(target.shape[0])):
display_model( display_model(
{'verts': verts.cpu().detach(), {'verts': verts.cpu().detach(),
@ -43,24 +45,24 @@ def save_pic(target, res, smpl_layer, file,logger):
show=False) show=False)
logger.info('Pictures saved') logger.info('Pictures saved')
def save_params(res,file,logger):
def save_params(res, file, logger):
pose_params, shape_params, verts, Jtr = res pose_params, shape_params, verts, Jtr = res
name=re.split('[/.]',file)[-2] name = re.split('[/.]', file)[-2]
fit_path="fit/output/HumanAct12/params/" fit_path = "fit/output/HumanAct12/params/"
create_dir_not_exist(fit_path) create_dir_not_exist(fit_path)
logger.info('Saving params at {}'.format(fit_path)) logger.info('Saving params at {}'.format(fit_path))
pose_params=pose_params.cpu().detach() pose_params = pose_params.cpu().detach()
pose_params=pose_params.numpy().tolist() pose_params = pose_params.numpy().tolist()
shape_params=shape_params.cpu().detach() shape_params = shape_params.cpu().detach()
shape_params=shape_params.numpy().tolist() shape_params = shape_params.numpy().tolist()
Jtr=Jtr.cpu().detach() Jtr = Jtr.cpu().detach()
Jtr=Jtr.numpy().tolist() Jtr = Jtr.numpy().tolist()
params={} params = {}
params["pose_params"]=pose_params params["pose_params"] = pose_params
params["shape_params"]=shape_params params["shape_params"] = shape_params
params["Jtr"]=Jtr params["Jtr"] = Jtr
f=open(os.path.join((fit_path), f = open(os.path.join((fit_path),
"{}_params.json".format(name)),'w') "{}_params.json".format(name)), 'w')
json.dump(params,f) json.dump(params, f)
logger.info('Params saved') logger.info('Params saved')

View File

@ -10,6 +10,7 @@ from torch.utils.data import DataLoader
import torchvision.datasets as dset import torchvision.datasets as dset
import torchvision.transforms as T import torchvision.transforms as T
import numpy as np import numpy as np
import scipy.io
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from easydict import EasyDict as edict from easydict import EasyDict as edict
import time import time
@ -24,6 +25,7 @@ from tqdm import tqdm
sys.path.append(os.getcwd()) sys.path.append(os.getcwd())
from smplpytorch.pytorch.smpl_layer import SMPL_Layer from smplpytorch.pytorch.smpl_layer import SMPL_Layer
from display_utils import display_model from display_utils import display_model
from map import mapping
def train(smpl_layer, target, def train(smpl_layer, target,
logger, writer, device, logger, writer, device,
@ -48,9 +50,11 @@ def train(smpl_layer, target,
lr=cfg.TRAIN.LEARNING_RATE) lr=cfg.TRAIN.LEARNING_RATE)
min_loss = float('inf') min_loss = float('inf')
for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)): data_map=torch.tensor(cfg.DATASET.DATA_MAP.UTD_MHAD)[0].to(device)
# for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)):
for epoch in range(cfg.TRAIN.MAX_EPOCH):
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params) verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
loss = F.smooth_l1_loss(Jtr * 100, target * 100) loss = F.smooth_l1_loss(Jtr.index_select(1, data_map) * 100, target * 100)
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
@ -58,10 +62,22 @@ def train(smpl_layer, target,
min_loss = float(loss) min_loss = float(loss)
res = [pose_params, shape_params, verts, Jtr] res = [pose_params, shape_params, verts, Jtr]
if epoch % cfg.TRAIN.WRITE == 0: if epoch % cfg.TRAIN.WRITE == 0:
# logger.info("Epoch {}, lossPerBatch={:.9f}, scale={:.6f}".format( logger.info("Epoch {}, lossPerBatch={:.9f}, scale={:.6f}".format(
# epoch, float(loss), float(scale))) epoch, float(loss), float(scale)))
writer.add_scalar('loss', float(loss), epoch) writer.add_scalar('loss', float(loss), epoch)
writer.add_scalar('learning_rate', float( writer.add_scalar('learning_rate', float(
optimizer.state_dict()['param_groups'][0]['lr']), epoch) optimizer.state_dict()['param_groups'][0]['lr']), epoch)
if epoch % cfg.TRAIN.SAVE == 0 and epoch > 0:
for i in tqdm(range(Jtr.shape[0])):
display_model(
{'verts': verts.cpu().detach(),
'joints': Jtr.cpu().detach()},
model_faces=smpl_layer.th_faces,
with_joints=True,
kintree_table=smpl_layer.kintree_table,
savepath="fit/output/UTD_MHAD/picture/frame_{}".format(str(i).zfill(4)),
batch_idx=i,
show=True,
only_joint=True)
logger.info('Train ended, min_loss = {:.9f}'.format(float(min_loss))) logger.info('Train ended, min_loss = {:.9f}'.format(float(min_loss)))
return res return res

View File

@ -1,12 +1,13 @@
import numpy as np import numpy as np
def transform(arr: np.ndarray): def transform(arr: np.ndarray, rotate=[1.,-1.,-1.]):
for i in range(arr.shape[0]): for i in range(arr.shape[0]):
origin = arr[i][0].copy() origin = arr[i][3].copy()
for j in range(arr.shape[1]): for j in range(arr.shape[1]):
arr[i][j] -= origin arr[i][j] -= origin
arr[i][j][1] *= -1 for k in range(3):
arr[i][j][2] *= -1 arr[i][j][k] *= rotate[k]
arr[i][0] = [0.0, 0.0, 0.0] arr[i][3] = [0.0, 0.0, 0.0]
print(arr[0])
return arr return arr

View File

@ -1,7 +1,7 @@
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import imageio, os import imageio, os
images = [] images = []
filenames = sorted(fn for fn in os.listdir('./fit/output/HumanAct12/picture/fit/P01G01R01F0449T0505A0201') ) filenames = sorted(fn for fn in os.listdir('./fit/output/UTD_MHAD/picture/') )
for filename in filenames: for filename in filenames:
images.append(imageio.imread('./fit/output/HumanAct12/picture/fit/P01G01R01F0449T0505A0201/'+filename)) images.append(imageio.imread('./fit/output/UTD_MHAD/picture/'+filename))
imageio.mimsave('./fit.gif', images, duration=0.3) imageio.mimsave('./fit.gif', images, duration=0.3)