Support HAA4D

This commit is contained in:
Yiming Dou
2022-10-01 21:42:12 +08:00
parent fd2ff6616f
commit 6ac10e22e0
12 changed files with 232 additions and 170 deletions

View File

@ -1,32 +1,33 @@
import os
import sys
sys.path.append(os.getcwd())
from meters import Meters
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
from train import train
from transform import transform
from save import save_pic, save_params
from load import load
import torch
import numpy as np
from tensorboardX import SummaryWriter
from easydict import EasyDict as edict
import time
import sys
import os
import logging
import argparse
import json
sys.path.append(os.getcwd())
from load import load
from save import save_pic, save_params
from transform import transform
from train import train
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
from meters import Meters
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
def parse_args():
parser = argparse.ArgumentParser(description='Fit SMPL')
parser.add_argument('--exp', dest='exp',
help='Define exp name',
default=time.strftime('%Y-%m-%d %H-%M-%S', time.localtime(time.time())), type=str)
parser.add_argument('--dataset_name', dest='dataset_name',
parser.add_argument('--dataset_name', '-n', dest='dataset_name',
help='select dataset',
default='', type=str)
parser.add_argument('--dataset_path', dest='dataset_path',
@ -97,15 +98,18 @@ if __name__ == "__main__":
center_idx=0,
gender=cfg.MODEL.GENDER,
model_root='smplpytorch/native/models')
meters=Meters()
meters = Meters()
file_num = 0
for root, dirs, files in os.walk(cfg.DATASET.PATH):
for file in files:
for file in sorted(files):
if not 'baseball_swing' in file:
continue
file_num += 1
logger.info('Processing file: {} [{} / {}]'.format(file, file_num, len(files)))
target = torch.from_numpy(transform(args.dataset_name,load(args.dataset_name,
os.path.join(root, file)))).float()
logger.info(
'Processing file: {} [{} / {}]'.format(file, file_num, len(files)))
target = torch.from_numpy(transform(args.dataset_name, load(args.dataset_name,
os.path.join(root, file)))).float()
logger.info("target shape:{}".format(target.shape))
res = train(smpl_layer, target,
logger, writer, device,
@ -115,7 +119,8 @@ if __name__ == "__main__":
logger.info("avg_loss:{:.4f}".format(meters.avg))
save_params(res, file, logger, args.dataset_name)
# save_pic(res,smpl_layer,file,logger,args.dataset_name,target)
save_pic(res, smpl_layer, file, logger, args.dataset_name, target)
torch.cuda.empty_cache()
logger.info("Fitting finished! Average loss: {:.9f}".format(meters.avg))
logger.info(
"Fitting finished! Average loss: {:.9f}".format(meters.avg))