update code

This commit is contained in:
Iridoudou
2021-08-07 08:03:40 +08:00
parent 0152f1909f
commit 9f6274fc19
8 changed files with 160 additions and 65 deletions

View File

@ -27,6 +27,7 @@ from smplpytorch.pytorch.smpl_layer import SMPL_Layer
from train import train
from transform import transform
from save import save_pic,save_params
from load import load
torch.backends.cudnn.benchmark=True
def parse_args():
@ -96,22 +97,34 @@ if __name__ == "__main__":
smpl_layer = SMPL_Layer(
center_idx = 0,
gender='neutral',
gender=cfg.MODEL.GENDER,
model_root='smplpytorch/native/models')
for root,dirs,files in os.walk(cfg.DATASET_PATH):
for file in files:
logger.info('Processing file: {}'.format(file))
target_path=os.path.join(root,file)
target = np.array(transform(np.load(target_path)))
logger.info('File shape: {}'.format(target.shape))
target = torch.from_numpy(target).float()
res = train(smpl_layer,target,
logger,writer,device,
args,cfg)
# save_pic(target,res,smpl_layer,file,logger)
save_params(res,file,logger)
if not cfg.DEBUG:
for root,dirs,files in os.walk(cfg.DATASET_PATH):
for file in files:
logger.info('Processing file: {}'.format(file))
target_path=os.path.join(root,file)
target = np.array(transform(np.load(target_path)))
logger.info('File shape: {}'.format(target.shape))
target = torch.from_numpy(target).float()
res = train(smpl_layer,target,
logger,writer,device,
args,cfg)
# save_pic(target,res,smpl_layer,file,logger)
save_params(res,file,logger)
else:
target = np.array(transform(load('UTD_MHAD',cfg.DATASET.TARGET_PATH),
rotate=[-1,1,-1]))
target = torch.from_numpy(target).float()
data_map_dataset=torch.tensor(cfg.DATASET.DATA_MAP.UTD_MHAD[1])
target = target.index_select(1, data_map_dataset)
print(target.shape)
res = train(smpl_layer,target,
logger,writer,device,
args,cfg)