update meters

This commit is contained in:
Iridoudou
2021-08-19 12:40:54 +08:00
parent e0ee13d6d6
commit 8573a09b84
10 changed files with 81 additions and 59 deletions

View File

@ -8,31 +8,11 @@ from tqdm import tqdm
sys.path.append(os.getcwd())
class Early_Stop:
def __init__(self, eps=-1e-3, stop_threshold=10) -> None:
self.min_loss = float('inf')
self.eps = eps
self.stop_threshold = stop_threshold
self.satis_num = 0
def update(self, loss):
delta = (loss - self.min_loss) / self.min_loss
if float(loss) < self.min_loss:
self.min_loss = float(loss)
update_res = True
else:
update_res = False
if delta >= self.eps:
self.satis_num += 1
else:
self.satis_num = 0
return update_res, self.satis_num >= self.stop_threshold
def init(smpl_layer, target, device, cfg):
params = {}
params["pose_params"] = torch.rand(target.shape[0], 72) * 0.0
params["shape_params"] = torch.rand(target.shape[0], 10) * 0.03
params["pose_params"] = torch.zeros(target.shape[0], 72)
params["shape_params"] = torch.zeros(target.shape[0], 10)
params["scale"] = torch.ones([1])
smpl_layer = smpl_layer.to(device)
@ -63,7 +43,7 @@ def init(smpl_layer, target, device, cfg):
def train(smpl_layer, target,
logger, writer, device,
args, cfg):
args, cfg, meters):
res = []
smpl_layer, params, target, optimizer, index = \
init(smpl_layer, target, device, cfg)
@ -71,20 +51,19 @@ def train(smpl_layer, target,
shape_params = params["shape_params"]
scale = params["scale"]
early_stop = Early_Stop()
for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)):
# for epoch in range(cfg.TRAIN.MAX_EPOCH):
# for epoch in range(cfg.TRAIN.MAX_EPOCH):
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
loss = F.smooth_l1_loss(Jtr.index_select(1, index["smpl_index"]) * 100 * scale,
target.index_select(1, index["dataset_index"]) * 100)
loss = F.smooth_l1_loss(Jtr.index_select(1, index["smpl_index"]) * 100,
target.index_select(1, index["dataset_index"]) * 100 * scale)
optimizer.zero_grad()
loss.backward()
optimizer.step()
update_res, stop = early_stop.update(float(loss))
if update_res:
meters.update_early_stop(float(loss))
if meters.update_res:
res = [pose_params, shape_params, verts, Jtr]
if stop:
if meters.early_stop:
logger.info("Early stop at epoch {} !".format(epoch))
break
@ -95,6 +74,6 @@ def train(smpl_layer, target,
writer.add_scalar('learning_rate', float(
optimizer.state_dict()['param_groups'][0]['lr']), epoch)
logger.info('Train ended, min_loss = {:.9f}'.format(
float(early_stop.min_loss)))
logger.info('Train ended, min_loss = {:.4f}'.format(
float(meters.min_loss)))
return res