update code

This commit is contained in:
Iridoudou
2021-08-14 16:24:48 +08:00
parent f9c327d2a2
commit a5f0cf4653
4 changed files with 38 additions and 38 deletions

View File

@ -1,5 +1,5 @@
import numpy as np
import torch
import numpy as np
from tensorboardX import SummaryWriter
from easydict import EasyDict as edict
import time

View File

@ -6,8 +6,8 @@ import numpy as np
import pickle
sys.path.append(os.getcwd())
from display_utils import display_model
from label import get_label
from display_utils import display_model
def create_dir_not_exist(path):
@ -15,11 +15,11 @@ def create_dir_not_exist(path):
os.mkdir(path)
def save_pic(res, smpl_layer, file, logger, dataset_name,target):
def save_pic(res, smpl_layer, file, logger, dataset_name, target):
_, _, verts, Jtr = res
file_name = re.split('[/.]', file)[-2]
fit_path = "fit/output/{}/picture/fit/{}".format(dataset_name,file_name)
gt_path = "fit/output/{}/picture/gt/{}".format(dataset_name,file_name)
fit_path = "fit/output/{}/picture/fit/{}".format(dataset_name, file_name)
gt_path = "fit/output/{}/picture/gt/{}".format(dataset_name, file_name)
create_dir_not_exist(fit_path)
create_dir_not_exist(gt_path)
logger.info('Saving pictures at {}'.format(fit_path))
@ -53,7 +53,7 @@ def save_params(res, file, logger, dataset_name):
fit_path = "fit/output/{}/".format(dataset_name)
create_dir_not_exist(fit_path)
logger.info('Saving params at {}'.format(fit_path))
label=get_label(file_name, dataset_name)
label = get_label(file_name, dataset_name)
pose_params = (pose_params.cpu().detach()).numpy().tolist()
shape_params = (shape_params.cpu().detach()).numpy().tolist()
Jtr = (Jtr.cpu().detach()).numpy().tolist()
@ -64,5 +64,5 @@ def save_params(res, file, logger, dataset_name):
params["shape_params"] = shape_params
params["Jtr"] = Jtr
with open(os.path.join((fit_path),
"{}_params.pkl".format(file_name)), 'wb') as f:
"{}_params.pkl".format(file_name)), 'wb') as f:
pickle.dump(params, f)

View File

@ -9,19 +9,19 @@ sys.path.append(os.getcwd())
class Early_Stop:
def __init__(self, eps = -1e-3, stop_threshold = 10) -> None:
self.min_loss=float('inf')
self.eps=eps
self.stop_threshold=stop_threshold
self.satis_num=0
def __init__(self, eps=-1e-3, stop_threshold=10) -> None:
self.min_loss = float('inf')
self.eps = eps
self.stop_threshold = stop_threshold
self.satis_num = 0
def update(self, loss):
delta = (loss - self.min_loss) / self.min_loss
if float(loss) < self.min_loss:
self.min_loss = float(loss)
update_res=True
update_res = True
else:
update_res=False
update_res = False
if delta >= self.eps:
self.satis_num += 1
else:
@ -30,7 +30,7 @@ class Early_Stop:
def init(smpl_layer, target, device, cfg):
params={}
params = {}
params["pose_params"] = torch.rand(target.shape[0], 72) * 0.0
params["shape_params"] = torch.rand(target.shape[0], 10) * 0.03
params["scale"] = torch.ones([1])
@ -48,24 +48,24 @@ def init(smpl_layer, target, device, cfg):
optimizer = optim.Adam([params["pose_params"], params["shape_params"], params["scale"]],
lr=cfg.TRAIN.LEARNING_RATE)
index={}
smpl_index=[]
dataset_index=[]
index = {}
smpl_index = []
dataset_index = []
for tp in cfg.DATASET.DATA_MAP:
smpl_index.append(tp[0])
dataset_index.append(tp[1])
index["smpl_index"]=torch.tensor(smpl_index).to(device)
index["dataset_index"]=torch.tensor(dataset_index).to(device)
index["smpl_index"] = torch.tensor(smpl_index).to(device)
index["dataset_index"] = torch.tensor(dataset_index).to(device)
return smpl_layer, params,target, optimizer, index
return smpl_layer, params, target, optimizer, index
def train(smpl_layer, target,
logger, writer, device,
args, cfg):
res = []
smpl_layer, params,target, optimizer, index = \
smpl_layer, params, target, optimizer, index = \
init(smpl_layer, target, device, cfg)
pose_params = params["pose_params"]
shape_params = params["shape_params"]
@ -73,7 +73,7 @@ def train(smpl_layer, target,
early_stop = Early_Stop()
for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)):
# for epoch in range(cfg.TRAIN.MAX_EPOCH):
# for epoch in range(cfg.TRAIN.MAX_EPOCH):
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
loss = F.smooth_l1_loss(Jtr.index_select(1, index["smpl_index"]) * 100 * scale,
target.index_select(1, index["dataset_index"]) * 100)
@ -95,6 +95,6 @@ def train(smpl_layer, target,
writer.add_scalar('learning_rate', float(
optimizer.state_dict()['param_groups'][0]['lr']), epoch)
logger.info('Train ended, min_loss = {:.9f}'.format(float(early_stop.min_loss)))
logger.info('Train ended, min_loss = {:.9f}'.format(
float(early_stop.min_loss)))
return res