add HumanAct12, UTD_MHAD
This commit is contained in:
@ -1,14 +1,4 @@
|
||||
import matplotlib as plt
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.modules import module
|
||||
from torch.optim import lr_scheduler
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import sampler
|
||||
import torchvision.datasets as dset
|
||||
import torchvision.transforms as T
|
||||
import numpy as np
|
||||
from tensorboardX import SummaryWriter
|
||||
from easydict import EasyDict as edict
|
||||
@ -35,17 +25,15 @@ def parse_args():
|
||||
parser.add_argument('--exp', dest='exp',
|
||||
help='Define exp name',
|
||||
default=time.strftime('%Y-%m-%d %H-%M-%S', time.localtime(time.time())), type=str)
|
||||
parser.add_argument('--config_path', dest='config_path',
|
||||
help='Select configuration file',
|
||||
default='fit/configs/config.json', type=str)
|
||||
parser.add_argument('--dataset_path', dest='dataset_path',
|
||||
parser.add_argument('--dataset_name', dest='dataset_name',
|
||||
help='select dataset',
|
||||
default='', type=str)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
def get_config(args):
|
||||
with open(args.config_path, 'r') as f:
|
||||
config_path='fit/configs/{}.json'.format(args.dataset_name)
|
||||
with open(config_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
cfg = edict(data.copy())
|
||||
return cfg
|
||||
@ -100,31 +88,18 @@ if __name__ == "__main__":
|
||||
gender=cfg.MODEL.GENDER,
|
||||
model_root='smplpytorch/native/models')
|
||||
|
||||
if not cfg.DEBUG:
|
||||
for root,dirs,files in os.walk(cfg.DATASET_PATH):
|
||||
for file in files:
|
||||
logger.info('Processing file: {}'.format(file))
|
||||
target_path=os.path.join(root,file)
|
||||
|
||||
target = np.array(transform(np.load(target_path)))
|
||||
logger.info('File shape: {}'.format(target.shape))
|
||||
target = torch.from_numpy(target).float()
|
||||
|
||||
|
||||
res = train(smpl_layer,target,
|
||||
logger,writer,device,
|
||||
args,cfg)
|
||||
|
||||
# save_pic(target,res,smpl_layer,file,logger)
|
||||
save_params(res,file,logger)
|
||||
else:
|
||||
target = np.array(transform(load('UTD_MHAD',cfg.DATASET.TARGET_PATH),
|
||||
rotate=[-1,1,-1]))
|
||||
target = torch.from_numpy(target).float()
|
||||
data_map_dataset=torch.tensor(cfg.DATASET.DATA_MAP.UTD_MHAD[1])
|
||||
target = target.index_select(1, data_map_dataset)
|
||||
print(target.shape)
|
||||
res = train(smpl_layer,target,
|
||||
logger,writer,device,
|
||||
args,cfg)
|
||||
for root,dirs,files in os.walk(cfg.DATASET.PATH):
|
||||
for file in files:
|
||||
logger.info('Processing file: {}'.format(file))
|
||||
target = torch.from_numpy(transform(args.dataset_name,
|
||||
load(args.dataset_name,
|
||||
os.path.join(root,file)))).float()
|
||||
|
||||
|
||||
res = train(smpl_layer,target,
|
||||
logger,writer,device,
|
||||
args,cfg)
|
||||
|
||||
# save_pic(res,smpl_layer,file,logger,args.dataset_name)
|
||||
save_params(res,file,logger, args.dataset_name)
|
||||
|
||||
Reference in New Issue
Block a user