update code

This commit is contained in:
Iridoudou
2021-08-11 19:36:41 +08:00
parent d1c497bbbc
commit f9c327d2a2

View File

@ -72,8 +72,8 @@ def train(smpl_layer, target,
scale = params["scale"] scale = params["scale"]
early_stop = Early_Stop() early_stop = Early_Stop()
# for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)): for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)):
for epoch in range(cfg.TRAIN.MAX_EPOCH): # for epoch in range(cfg.TRAIN.MAX_EPOCH):
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params) verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
loss = F.smooth_l1_loss(Jtr.index_select(1, index["smpl_index"]) * 100 * scale, loss = F.smooth_l1_loss(Jtr.index_select(1, index["smpl_index"]) * 100 * scale,
target.index_select(1, index["dataset_index"]) * 100) target.index_select(1, index["dataset_index"]) * 100)
@ -89,8 +89,8 @@ def train(smpl_layer, target,
break break
if epoch % cfg.TRAIN.WRITE == 0: if epoch % cfg.TRAIN.WRITE == 0:
logger.info("Epoch {}, lossPerBatch={:.6f}, scale={:.4f} EarlyStopSatis: {}".format( # logger.info("Epoch {}, lossPerBatch={:.6f}, scale={:.4f} EarlyStopSatis: {}".format(
epoch, float(loss),float(scale), early_stop.satis_num)) # epoch, float(loss),float(scale), early_stop.satis_num))
writer.add_scalar('loss', float(loss), epoch) writer.add_scalar('loss', float(loss), epoch)
writer.add_scalar('learning_rate', float( writer.add_scalar('learning_rate', float(
optimizer.state_dict()['param_groups'][0]['lr']), epoch) optimizer.state_dict()['param_groups'][0]['lr']), epoch)