Support HAA4D
This commit is contained in:
@ -1,32 +1,33 @@
|
||||
import os
|
||||
import sys
|
||||
sys.path.append(os.getcwd())
|
||||
from meters import Meters
|
||||
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
|
||||
from train import train
|
||||
from transform import transform
|
||||
from save import save_pic, save_params
|
||||
from load import load
|
||||
import torch
|
||||
import numpy as np
|
||||
from tensorboardX import SummaryWriter
|
||||
from easydict import EasyDict as edict
|
||||
import time
|
||||
import sys
|
||||
import os
|
||||
import logging
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
from load import load
|
||||
from save import save_pic, save_params
|
||||
from transform import transform
|
||||
from train import train
|
||||
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
|
||||
from meters import Meters
|
||||
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Fit SMPL')
|
||||
parser.add_argument('--exp', dest='exp',
|
||||
help='Define exp name',
|
||||
default=time.strftime('%Y-%m-%d %H-%M-%S', time.localtime(time.time())), type=str)
|
||||
parser.add_argument('--dataset_name', dest='dataset_name',
|
||||
parser.add_argument('--dataset_name', '-n', dest='dataset_name',
|
||||
help='select dataset',
|
||||
default='', type=str)
|
||||
parser.add_argument('--dataset_path', dest='dataset_path',
|
||||
@ -97,15 +98,18 @@ if __name__ == "__main__":
|
||||
center_idx=0,
|
||||
gender=cfg.MODEL.GENDER,
|
||||
model_root='smplpytorch/native/models')
|
||||
|
||||
meters=Meters()
|
||||
|
||||
meters = Meters()
|
||||
file_num = 0
|
||||
for root, dirs, files in os.walk(cfg.DATASET.PATH):
|
||||
for file in files:
|
||||
for file in sorted(files):
|
||||
if not 'baseball_swing' in file:
|
||||
continue
|
||||
file_num += 1
|
||||
logger.info('Processing file: {} [{} / {}]'.format(file, file_num, len(files)))
|
||||
target = torch.from_numpy(transform(args.dataset_name,load(args.dataset_name,
|
||||
os.path.join(root, file)))).float()
|
||||
logger.info(
|
||||
'Processing file: {} [{} / {}]'.format(file, file_num, len(files)))
|
||||
target = torch.from_numpy(transform(args.dataset_name, load(args.dataset_name,
|
||||
os.path.join(root, file)))).float()
|
||||
logger.info("target shape:{}".format(target.shape))
|
||||
res = train(smpl_layer, target,
|
||||
logger, writer, device,
|
||||
@ -115,7 +119,8 @@ if __name__ == "__main__":
|
||||
logger.info("avg_loss:{:.4f}".format(meters.avg))
|
||||
|
||||
save_params(res, file, logger, args.dataset_name)
|
||||
# save_pic(res,smpl_layer,file,logger,args.dataset_name,target)
|
||||
|
||||
save_pic(res, smpl_layer, file, logger, args.dataset_name, target)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
logger.info("Fitting finished! Average loss: {:.9f}".format(meters.avg))
|
||||
logger.info(
|
||||
"Fitting finished! Average loss: {:.9f}".format(meters.avg))
|
||||
|
||||
Reference in New Issue
Block a user