update meters
This commit is contained in:
@ -8,31 +8,11 @@ from tqdm import tqdm
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
|
||||
class Early_Stop:
|
||||
def __init__(self, eps=-1e-3, stop_threshold=10) -> None:
|
||||
self.min_loss = float('inf')
|
||||
self.eps = eps
|
||||
self.stop_threshold = stop_threshold
|
||||
self.satis_num = 0
|
||||
|
||||
def update(self, loss):
|
||||
delta = (loss - self.min_loss) / self.min_loss
|
||||
if float(loss) < self.min_loss:
|
||||
self.min_loss = float(loss)
|
||||
update_res = True
|
||||
else:
|
||||
update_res = False
|
||||
if delta >= self.eps:
|
||||
self.satis_num += 1
|
||||
else:
|
||||
self.satis_num = 0
|
||||
return update_res, self.satis_num >= self.stop_threshold
|
||||
|
||||
|
||||
def init(smpl_layer, target, device, cfg):
|
||||
params = {}
|
||||
params["pose_params"] = torch.rand(target.shape[0], 72) * 0.0
|
||||
params["shape_params"] = torch.rand(target.shape[0], 10) * 0.03
|
||||
params["pose_params"] = torch.zeros(target.shape[0], 72)
|
||||
params["shape_params"] = torch.zeros(target.shape[0], 10)
|
||||
params["scale"] = torch.ones([1])
|
||||
|
||||
smpl_layer = smpl_layer.to(device)
|
||||
@ -63,7 +43,7 @@ def init(smpl_layer, target, device, cfg):
|
||||
|
||||
def train(smpl_layer, target,
|
||||
logger, writer, device,
|
||||
args, cfg):
|
||||
args, cfg, meters):
|
||||
res = []
|
||||
smpl_layer, params, target, optimizer, index = \
|
||||
init(smpl_layer, target, device, cfg)
|
||||
@ -71,20 +51,19 @@ def train(smpl_layer, target,
|
||||
shape_params = params["shape_params"]
|
||||
scale = params["scale"]
|
||||
|
||||
early_stop = Early_Stop()
|
||||
for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)):
|
||||
# for epoch in range(cfg.TRAIN.MAX_EPOCH):
|
||||
# for epoch in range(cfg.TRAIN.MAX_EPOCH):
|
||||
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
|
||||
loss = F.smooth_l1_loss(Jtr.index_select(1, index["smpl_index"]) * 100 * scale,
|
||||
target.index_select(1, index["dataset_index"]) * 100)
|
||||
loss = F.smooth_l1_loss(Jtr.index_select(1, index["smpl_index"]) * 100,
|
||||
target.index_select(1, index["dataset_index"]) * 100 * scale)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
update_res, stop = early_stop.update(float(loss))
|
||||
if update_res:
|
||||
meters.update_early_stop(float(loss))
|
||||
if meters.update_res:
|
||||
res = [pose_params, shape_params, verts, Jtr]
|
||||
if stop:
|
||||
if meters.early_stop:
|
||||
logger.info("Early stop at epoch {} !".format(epoch))
|
||||
break
|
||||
|
||||
@ -95,6 +74,6 @@ def train(smpl_layer, target,
|
||||
writer.add_scalar('learning_rate', float(
|
||||
optimizer.state_dict()['param_groups'][0]['lr']), epoch)
|
||||
|
||||
logger.info('Train ended, min_loss = {:.9f}'.format(
|
||||
float(early_stop.min_loss)))
|
||||
logger.info('Train ended, min_loss = {:.4f}'.format(
|
||||
float(meters.min_loss)))
|
||||
return res
|
||||
|
||||
Reference in New Issue
Block a user