update code

This commit is contained in:
Iridoudou
2021-08-07 08:03:40 +08:00
parent 0152f1909f
commit 9f6274fc19
8 changed files with 160 additions and 65 deletions

View File

@ -3,23 +3,56 @@
"GENDER": "male"
},
"TRAIN": {
"LEARNING_RATE":2e-2,
"LEARNING_RATE": 5e-2,
"MAX_EPOCH": 500,
"WRITE": 1,
"SAVE": 10,
"BATCH_SIZE": 1,
"MOMENTUM": 0.9,
"lr_scheduler": {
"T_0": 10,
"T_mult": 2,
"eta_min": 1e-2
},
"loss_func": ""
"SAVE": 100
},
"USE_GPU": 1,
"DATA_LOADER": {
"NUM_WORKERS": 1
"DATASET": {
"NAME": "UTD-MHAD",
"PATH": "../UTD-MHAD/Skeleton/Skeleton/",
"TARGET_PATH": "../UTD-MHAD/Skeleton/Skeleton/a5_s7_t3_skeleton.mat",
"DATA_MAP": {
"UTD_MHAD": [
[
12,
0,
16,
18,
20,
22,
17,
19,
21,
23,
1,
4,
7,
2,
5,
8
],
[
1,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
16,
17,
18
]
]
}
},
"TARGET_PATH":"../Action2Motion/HumanAct12/HumanAct12/P01G01R01F0069T0143A0102.npy",
"DATASET_PATH":"../Action2Motion/HumanAct12/HumanAct12/"
"DEBUG": 1
}

16
fit/tools/load.py Normal file
View File

@ -0,0 +1,16 @@
import scipy.io
import numpy as np
def load(name, path):
if name == 'UTD_MHAD':
data = scipy.io.loadmat(path)
arr = data['d_skel']
new_arr = np.zeros([arr.shape[2], arr.shape[0], arr.shape[1]])
for i in range(arr.shape[2]):
for j in range(arr.shape[0]):
for k in range(arr.shape[1]):
new_arr[i][j][k] = arr[j][k][i]
return new_arr
elif name == 'HumanAct12':
return np.load(path)

View File

@ -27,6 +27,7 @@ from smplpytorch.pytorch.smpl_layer import SMPL_Layer
from train import train
from transform import transform
from save import save_pic,save_params
from load import load
torch.backends.cudnn.benchmark=True
def parse_args():
@ -96,22 +97,34 @@ if __name__ == "__main__":
smpl_layer = SMPL_Layer(
center_idx = 0,
gender='neutral',
gender=cfg.MODEL.GENDER,
model_root='smplpytorch/native/models')
for root,dirs,files in os.walk(cfg.DATASET_PATH):
for file in files:
logger.info('Processing file: {}'.format(file))
target_path=os.path.join(root,file)
if not cfg.DEBUG:
for root,dirs,files in os.walk(cfg.DATASET_PATH):
for file in files:
logger.info('Processing file: {}'.format(file))
target_path=os.path.join(root,file)
target = np.array(transform(np.load(target_path)))
logger.info('File shape: {}'.format(target.shape))
target = torch.from_numpy(target).float()
target = np.array(transform(np.load(target_path)))
logger.info('File shape: {}'.format(target.shape))
target = torch.from_numpy(target).float()
res = train(smpl_layer,target,
logger,writer,device,
args,cfg)
# save_pic(target,res,smpl_layer,file,logger)
save_params(res,file,logger)
res = train(smpl_layer,target,
logger,writer,device,
args,cfg)
# save_pic(target,res,smpl_layer,file,logger)
save_params(res,file,logger)
else:
target = np.array(transform(load('UTD_MHAD',cfg.DATASET.TARGET_PATH),
rotate=[-1,1,-1]))
target = torch.from_numpy(target).float()
data_map_dataset=torch.tensor(cfg.DATASET.DATA_MAP.UTD_MHAD[1])
target = target.index_select(1, data_map_dataset)
print(target.shape)
res = train(smpl_layer,target,
logger,writer,device,
args,cfg)

14
fit/tools/map.py Normal file
View File

@ -0,0 +1,14 @@
import numpy as np
def mapping(Jtr,cfg):
name=cfg.DATASET.NAME
if not name=='HumanAct12':
mapped_joint=cfg.DATASET.DATA_MAP.UTD_MHAD
Jtr_mapped=np.zeros([Jtr.shape[0],len(mapped_joint),Jtr.shape[2]])
for i in range(Jtr.shape[0]):
for j in range(len(mapped_joint)):
for k in range(Jtr.shape[2]):
Jtr_mapped[i][j][k]=Jtr[i][mapped_joint[j]][k]
return Jtr_mapped
return Jtr

View File

@ -9,18 +9,20 @@ import json
sys.path.append(os.getcwd())
from display_utils import display_model
def create_dir_not_exist(path):
if not os.path.exists(path):
os.mkdir(path)
def save_pic(target, res, smpl_layer, file,logger):
def save_pic(target, res, smpl_layer, file, logger):
pose_params, shape_params, verts, Jtr = res
name=re.split('[/.]',file)[-2]
gt_path="fit/output/HumanAct12/picture/gt/{}".format(name)
fit_path="fit/output/HumanAct12/picture/fit/{}".format(name)
name = re.split('[/.]', file)[-2]
gt_path = "fit/output/HumanAct12/picture/gt/{}".format(name)
fit_path = "fit/output/HumanAct12/picture/fit/{}".format(name)
create_dir_not_exist(gt_path)
create_dir_not_exist(fit_path)
logger.info('Saving pictures at {} and {}'.format(gt_path,fit_path))
logger.info('Saving pictures at {} and {}'.format(gt_path, fit_path))
for i in tqdm(range(target.shape[0])):
display_model(
{'verts': verts.cpu().detach(),
@ -43,24 +45,24 @@ def save_pic(target, res, smpl_layer, file,logger):
show=False)
logger.info('Pictures saved')
def save_params(res,file,logger):
def save_params(res, file, logger):
pose_params, shape_params, verts, Jtr = res
name=re.split('[/.]',file)[-2]
fit_path="fit/output/HumanAct12/params/"
name = re.split('[/.]', file)[-2]
fit_path = "fit/output/HumanAct12/params/"
create_dir_not_exist(fit_path)
logger.info('Saving params at {}'.format(fit_path))
pose_params=pose_params.cpu().detach()
pose_params=pose_params.numpy().tolist()
shape_params=shape_params.cpu().detach()
shape_params=shape_params.numpy().tolist()
Jtr=Jtr.cpu().detach()
Jtr=Jtr.numpy().tolist()
params={}
params["pose_params"]=pose_params
params["shape_params"]=shape_params
params["Jtr"]=Jtr
f=open(os.path.join((fit_path),
"{}_params.json".format(name)),'w')
json.dump(params,f)
pose_params = pose_params.cpu().detach()
pose_params = pose_params.numpy().tolist()
shape_params = shape_params.cpu().detach()
shape_params = shape_params.numpy().tolist()
Jtr = Jtr.cpu().detach()
Jtr = Jtr.numpy().tolist()
params = {}
params["pose_params"] = pose_params
params["shape_params"] = shape_params
params["Jtr"] = Jtr
f = open(os.path.join((fit_path),
"{}_params.json".format(name)), 'w')
json.dump(params, f)
logger.info('Params saved')

View File

@ -10,6 +10,7 @@ from torch.utils.data import DataLoader
import torchvision.datasets as dset
import torchvision.transforms as T
import numpy as np
import scipy.io
from tensorboardX import SummaryWriter
from easydict import EasyDict as edict
import time
@ -24,6 +25,7 @@ from tqdm import tqdm
sys.path.append(os.getcwd())
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
from display_utils import display_model
from map import mapping
def train(smpl_layer, target,
logger, writer, device,
@ -48,9 +50,11 @@ def train(smpl_layer, target,
lr=cfg.TRAIN.LEARNING_RATE)
min_loss = float('inf')
for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)):
data_map=torch.tensor(cfg.DATASET.DATA_MAP.UTD_MHAD)[0].to(device)
# for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)):
for epoch in range(cfg.TRAIN.MAX_EPOCH):
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
loss = F.smooth_l1_loss(Jtr * 100, target * 100)
loss = F.smooth_l1_loss(Jtr.index_select(1, data_map) * 100, target * 100)
optimizer.zero_grad()
loss.backward()
optimizer.step()
@ -58,10 +62,22 @@ def train(smpl_layer, target,
min_loss = float(loss)
res = [pose_params, shape_params, verts, Jtr]
if epoch % cfg.TRAIN.WRITE == 0:
# logger.info("Epoch {}, lossPerBatch={:.9f}, scale={:.6f}".format(
# epoch, float(loss), float(scale)))
logger.info("Epoch {}, lossPerBatch={:.9f}, scale={:.6f}".format(
epoch, float(loss), float(scale)))
writer.add_scalar('loss', float(loss), epoch)
writer.add_scalar('learning_rate', float(
optimizer.state_dict()['param_groups'][0]['lr']), epoch)
if epoch % cfg.TRAIN.SAVE == 0 and epoch > 0:
for i in tqdm(range(Jtr.shape[0])):
display_model(
{'verts': verts.cpu().detach(),
'joints': Jtr.cpu().detach()},
model_faces=smpl_layer.th_faces,
with_joints=True,
kintree_table=smpl_layer.kintree_table,
savepath="fit/output/UTD_MHAD/picture/frame_{}".format(str(i).zfill(4)),
batch_idx=i,
show=True,
only_joint=True)
logger.info('Train ended, min_loss = {:.9f}'.format(float(min_loss)))
return res

View File

@ -1,12 +1,13 @@
import numpy as np
def transform(arr: np.ndarray):
def transform(arr: np.ndarray, rotate=[1.,-1.,-1.]):
for i in range(arr.shape[0]):
origin = arr[i][0].copy()
origin = arr[i][3].copy()
for j in range(arr.shape[1]):
arr[i][j] -= origin
arr[i][j][1] *= -1
arr[i][j][2] *= -1
arr[i][0] = [0.0, 0.0, 0.0]
for k in range(3):
arr[i][j][k] *= rotate[k]
arr[i][3] = [0.0, 0.0, 0.0]
print(arr[0])
return arr

View File

@ -1,7 +1,7 @@
import matplotlib.pyplot as plt
import imageio, os
images = []
filenames = sorted(fn for fn in os.listdir('./fit/output/HumanAct12/picture/fit/P01G01R01F0449T0505A0201') )
filenames = sorted(fn for fn in os.listdir('./fit/output/UTD_MHAD/picture/') )
for filename in filenames:
images.append(imageio.imread('./fit/output/HumanAct12/picture/fit/P01G01R01F0449T0505A0201/'+filename))
images.append(imageio.imread('./fit/output/UTD_MHAD/picture/'+filename))
imageio.mimsave('./fit.gif', images, duration=0.3)