update code
This commit is contained in:
@ -72,8 +72,8 @@ def train(smpl_layer, target,
|
|||||||
scale = params["scale"]
|
scale = params["scale"]
|
||||||
|
|
||||||
early_stop = Early_Stop()
|
early_stop = Early_Stop()
|
||||||
# for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)):
|
for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)):
|
||||||
for epoch in range(cfg.TRAIN.MAX_EPOCH):
|
# for epoch in range(cfg.TRAIN.MAX_EPOCH):
|
||||||
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
|
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
|
||||||
loss = F.smooth_l1_loss(Jtr.index_select(1, index["smpl_index"]) * 100 * scale,
|
loss = F.smooth_l1_loss(Jtr.index_select(1, index["smpl_index"]) * 100 * scale,
|
||||||
target.index_select(1, index["dataset_index"]) * 100)
|
target.index_select(1, index["dataset_index"]) * 100)
|
||||||
@ -89,8 +89,8 @@ def train(smpl_layer, target,
|
|||||||
break
|
break
|
||||||
|
|
||||||
if epoch % cfg.TRAIN.WRITE == 0:
|
if epoch % cfg.TRAIN.WRITE == 0:
|
||||||
logger.info("Epoch {}, lossPerBatch={:.6f}, scale={:.4f} EarlyStopSatis: {}".format(
|
# logger.info("Epoch {}, lossPerBatch={:.6f}, scale={:.4f} EarlyStopSatis: {}".format(
|
||||||
epoch, float(loss),float(scale), early_stop.satis_num))
|
# epoch, float(loss),float(scale), early_stop.satis_num))
|
||||||
writer.add_scalar('loss', float(loss), epoch)
|
writer.add_scalar('loss', float(loss), epoch)
|
||||||
writer.add_scalar('learning_rate', float(
|
writer.add_scalar('learning_rate', float(
|
||||||
optimizer.state_dict()['param_groups'][0]['lr']), epoch)
|
optimizer.state_dict()['param_groups'][0]['lr']), epoch)
|
||||||
|
|||||||
Reference in New Issue
Block a user