update meters

This commit is contained in:
Iridoudou
2021-08-19 12:40:54 +08:00
parent e0ee13d6d6
commit 8573a09b84
10 changed files with 81 additions and 59 deletions

View File

@ -9,14 +9,17 @@ import logging
import argparse
import json
sys.path.append(os.getcwd())
from load import load
from save import save_pic, save_params
from transform import transform
from train import train
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
torch.backends.cudnn.benchmark = True
from meters import Meters
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
def parse_args():
parser = argparse.ArgumentParser(description='Fit SMPL')
@ -94,7 +97,8 @@ if __name__ == "__main__":
center_idx=0,
gender=cfg.MODEL.GENDER,
model_root='smplpytorch/native/models')
meters=Meters()
file_num = 0
for root, dirs, files in os.walk(cfg.DATASET.PATH):
for file in files:
@ -102,10 +106,15 @@ if __name__ == "__main__":
logger.info('Processing file: {} [{} / {}]'.format(file, file_num, len(files)))
target = torch.from_numpy(transform(args.dataset_name,load(args.dataset_name,
os.path.join(root, file)))).float()
logger.info("target shape:{}".format(target.shape))
res = train(smpl_layer, target,
logger, writer, device,
args, cfg)
args, cfg, meters)
meters.update_avg(meters.min_loss, k=target.shape[0])
meters.reset_early_stop()
logger.info("avg_loss:{:.4f}".format(meters.avg))
# save_pic(res,smpl_layer,file,logger,args.dataset_name,target)
save_params(res, file, logger, args.dataset_name)
torch.cuda.empty_cache()
logger.info("Fitting finished! Average loss: {:.9f}".format(meters.avg))