update code
This commit is contained in:
@ -30,7 +30,7 @@ def train(smpl_layer, target,
|
||||
args, cfg):
|
||||
res = []
|
||||
pose_params = torch.rand(target.shape[0], 72) * 0.0
|
||||
shape_params = torch.rand(target.shape[0], 10) * 0.1
|
||||
shape_params = torch.rand(target.shape[0], 10) * 0.03
|
||||
scale = torch.ones([1])
|
||||
|
||||
smpl_layer = smpl_layer.to(device)
|
||||
@ -41,9 +41,10 @@ def train(smpl_layer, target,
|
||||
|
||||
pose_params.requires_grad = True
|
||||
shape_params.requires_grad = True
|
||||
scale.requires_grad = True
|
||||
scale.requires_grad = False
|
||||
smpl_layer.requires_grad = False
|
||||
|
||||
optimizer = optim.Adam([pose_params],
|
||||
optimizer = optim.Adam([pose_params, shape_params],
|
||||
lr=cfg.TRAIN.LEARNING_RATE)
|
||||
|
||||
min_loss = float('inf')
|
||||
@ -62,5 +63,5 @@ def train(smpl_layer, target,
|
||||
writer.add_scalar('loss', float(loss), epoch)
|
||||
writer.add_scalar('learning_rate', float(
|
||||
optimizer.state_dict()['param_groups'][0]['lr']), epoch)
|
||||
logger.info('Train ended, loss = {:.9f}'.format(float(loss)))
|
||||
logger.info('Train ended, min_loss = {:.9f}'.format(float(min_loss)))
|
||||
return res
|
||||
|
||||
Reference in New Issue
Block a user