update readme
This commit is contained in:
@ -12,8 +12,9 @@ sys.path.append(os.getcwd())
|
||||
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
|
||||
from train import train
|
||||
from transform import transform
|
||||
from save import save_params
|
||||
from save import save_pic, save_params
|
||||
from load import load
|
||||
import numpy as np
|
||||
torch.backends.cudnn.benchmark=True
|
||||
|
||||
def parse_args():
|
||||
@ -101,6 +102,6 @@ if __name__ == "__main__":
|
||||
logger,writer,device,
|
||||
args,cfg)
|
||||
|
||||
# save_pic(res,smpl_layer,file,logger,args.dataset_name)
|
||||
# save_pic(res,smpl_layer,file,logger,args.dataset_name,target)
|
||||
save_params(res,file,logger, args.dataset_name)
|
||||
|
||||
@ -15,11 +15,13 @@ def create_dir_not_exist(path):
|
||||
os.mkdir(path)
|
||||
|
||||
|
||||
def save_pic(res, smpl_layer, file, logger, dataset_name):
|
||||
def save_pic(res, smpl_layer, file, logger, dataset_name,target):
|
||||
_, _, verts, Jtr = res
|
||||
file_name = re.split('[/.]', file)[-2]
|
||||
fit_path = "fit/output/{}/picture/fit/{}".format(dataset_name,file_name)
|
||||
gt_path = "fit/output/{}/picture/gt/{}".format(dataset_name,file_name)
|
||||
create_dir_not_exist(fit_path)
|
||||
create_dir_not_exist(gt_path)
|
||||
logger.info('Saving pictures at {}'.format(fit_path))
|
||||
for i in tqdm(range(Jtr.shape[0])):
|
||||
display_model(
|
||||
@ -32,6 +34,16 @@ def save_pic(res, smpl_layer, file, logger, dataset_name):
|
||||
batch_idx=i,
|
||||
show=False,
|
||||
only_joint=False)
|
||||
display_model(
|
||||
{'verts': verts.cpu().detach(),
|
||||
'joints': target.cpu().detach()},
|
||||
model_faces=smpl_layer.th_faces,
|
||||
with_joints=True,
|
||||
kintree_table=smpl_layer.kintree_table,
|
||||
savepath=os.path.join(gt_path+"/frame_{}".format(i)),
|
||||
batch_idx=i,
|
||||
show=False,
|
||||
only_joint=True)
|
||||
logger.info('Pictures saved')
|
||||
|
||||
|
||||
|
||||
@ -73,7 +73,6 @@ def train(smpl_layer, target,
|
||||
|
||||
early_stop = Early_Stop()
|
||||
for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)):
|
||||
# for epoch in range(cfg.TRAIN.MAX_EPOCH):
|
||||
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
|
||||
loss = F.smooth_l1_loss(Jtr.index_select(1, index["smpl_index"]) * 100,
|
||||
target.index_select(1, index["dataset_index"]) * 100)
|
||||
|
||||
Reference in New Issue
Block a user