update meters
This commit is contained in:
@ -9,14 +9,17 @@ import logging
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
from load import load
|
||||
from save import save_pic, save_params
|
||||
from transform import transform
|
||||
from train import train
|
||||
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
|
||||
torch.backends.cudnn.benchmark = True
|
||||
from meters import Meters
|
||||
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Fit SMPL')
|
||||
@ -94,7 +97,8 @@ if __name__ == "__main__":
|
||||
center_idx=0,
|
||||
gender=cfg.MODEL.GENDER,
|
||||
model_root='smplpytorch/native/models')
|
||||
|
||||
|
||||
meters=Meters()
|
||||
file_num = 0
|
||||
for root, dirs, files in os.walk(cfg.DATASET.PATH):
|
||||
for file in files:
|
||||
@ -102,10 +106,15 @@ if __name__ == "__main__":
|
||||
logger.info('Processing file: {} [{} / {}]'.format(file, file_num, len(files)))
|
||||
target = torch.from_numpy(transform(args.dataset_name,load(args.dataset_name,
|
||||
os.path.join(root, file)))).float()
|
||||
|
||||
logger.info("target shape:{}".format(target.shape))
|
||||
res = train(smpl_layer, target,
|
||||
logger, writer, device,
|
||||
args, cfg)
|
||||
args, cfg, meters)
|
||||
meters.update_avg(meters.min_loss, k=target.shape[0])
|
||||
meters.reset_early_stop()
|
||||
logger.info("avg_loss:{:.4f}".format(meters.avg))
|
||||
|
||||
# save_pic(res,smpl_layer,file,logger,args.dataset_name,target)
|
||||
save_params(res, file, logger, args.dataset_name)
|
||||
torch.cuda.empty_cache()
|
||||
logger.info("Fitting finished! Average loss: {:.9f}".format(meters.avg))
|
||||
|
||||
Reference in New Issue
Block a user