update code

This commit is contained in:
Iridoudou
2021-08-05 16:22:06 +08:00
parent 791f02f280
commit bc0766fc76
11 changed files with 47 additions and 17 deletions

View File

@ -30,7 +30,7 @@ def train(smpl_layer, target,
args, cfg):
res = []
pose_params = torch.rand(target.shape[0], 72) * 0.0
shape_params = torch.rand(target.shape[0], 10) * 0.1
shape_params = torch.rand(target.shape[0], 10) * 0.03
scale = torch.ones([1])
smpl_layer = smpl_layer.to(device)
@ -41,9 +41,10 @@ def train(smpl_layer, target,
pose_params.requires_grad = True
shape_params.requires_grad = True
scale.requires_grad = True
scale.requires_grad = False
smpl_layer.requires_grad = False
optimizer = optim.Adam([pose_params],
optimizer = optim.Adam([pose_params, shape_params],
lr=cfg.TRAIN.LEARNING_RATE)
min_loss = float('inf')
@ -62,5 +63,5 @@ def train(smpl_layer, target,
writer.add_scalar('loss', float(loss), epoch)
writer.add_scalar('learning_rate', float(
optimizer.state_dict()['param_groups'][0]['lr']), epoch)
logger.info('Train ended, loss = {:.9f}'.format(float(loss)))
logger.info('Train ended, min_loss = {:.9f}'.format(float(min_loss)))
return res