update code
This commit is contained in:
@ -27,6 +27,7 @@ from smplpytorch.pytorch.smpl_layer import SMPL_Layer
|
||||
from train import train
|
||||
from transform import transform
|
||||
from save import save_pic,save_params
|
||||
from load import load
|
||||
torch.backends.cudnn.benchmark=True
|
||||
|
||||
def parse_args():
|
||||
@ -96,22 +97,34 @@ if __name__ == "__main__":
|
||||
|
||||
smpl_layer = SMPL_Layer(
|
||||
center_idx = 0,
|
||||
gender='neutral',
|
||||
gender=cfg.MODEL.GENDER,
|
||||
model_root='smplpytorch/native/models')
|
||||
|
||||
for root,dirs,files in os.walk(cfg.DATASET_PATH):
|
||||
for file in files:
|
||||
logger.info('Processing file: {}'.format(file))
|
||||
target_path=os.path.join(root,file)
|
||||
|
||||
target = np.array(transform(np.load(target_path)))
|
||||
logger.info('File shape: {}'.format(target.shape))
|
||||
target = torch.from_numpy(target).float()
|
||||
|
||||
res = train(smpl_layer,target,
|
||||
logger,writer,device,
|
||||
args,cfg)
|
||||
|
||||
# save_pic(target,res,smpl_layer,file,logger)
|
||||
save_params(res,file,logger)
|
||||
if not cfg.DEBUG:
|
||||
for root,dirs,files in os.walk(cfg.DATASET_PATH):
|
||||
for file in files:
|
||||
logger.info('Processing file: {}'.format(file))
|
||||
target_path=os.path.join(root,file)
|
||||
|
||||
target = np.array(transform(np.load(target_path)))
|
||||
logger.info('File shape: {}'.format(target.shape))
|
||||
target = torch.from_numpy(target).float()
|
||||
|
||||
|
||||
res = train(smpl_layer,target,
|
||||
logger,writer,device,
|
||||
args,cfg)
|
||||
|
||||
# save_pic(target,res,smpl_layer,file,logger)
|
||||
save_params(res,file,logger)
|
||||
else:
|
||||
target = np.array(transform(load('UTD_MHAD',cfg.DATASET.TARGET_PATH),
|
||||
rotate=[-1,1,-1]))
|
||||
target = torch.from_numpy(target).float()
|
||||
data_map_dataset=torch.tensor(cfg.DATASET.DATA_MAP.UTD_MHAD[1])
|
||||
target = target.index_select(1, data_map_dataset)
|
||||
print(target.shape)
|
||||
res = train(smpl_layer,target,
|
||||
logger,writer,device,
|
||||
args,cfg)
|
||||
|
||||
Reference in New Issue
Block a user