Simplify code.

This commit is contained in:
gulvarol
2019-07-10 14:52:40 +02:00
parent 095e65e498
commit 89f5c60799
2 changed files with 29 additions and 24 deletions

View File

@ -6,7 +6,7 @@ from torch.nn import Module
from smpl.native.webuser.serialization import ready_arguments
from smpl.pytorch import rodrigues_layer
from smpl.pytorch.tensutils import (th_posemap_axisang, th_with_zeros, th_pack, make_list)
from smpl.pytorch.tensutils import (th_posemap_axisang, th_with_zeros, th_pack, make_list, subtract_flat_id)
class SMPL_Layer(Module):
@ -61,36 +61,41 @@ class SMPL_Layer(Module):
self.num_joints = len(parents) # 24
def forward(self,
th_pose_coeffs,
th_pose_axisang,
th_betas=torch.zeros(1),
th_trans=torch.zeros(1)):
"""
Args:
th_pose_axisang (Tensor (batch_size x 72)): pose parameters in axis-angle representation
th_betas (Tensor (batch_size x 10)): if provided, uses given shape parameters
th_trans (Tensor (batch_size x 3)): if provided, applies trans to joints and vertices
"""
batch_size = th_pose_coeffs.shape[0]
th_pose_map, th_rot_map = th_posemap_axisang(th_pose_coeffs)
th_pose_coeffs = th_pose_coeffs.view(batch_size, -1, 3)
root_rot = rodrigues_layer.batch_rodrigues(
th_pose_coeffs[:, 0]).view(batch_size, 3, 3)
batch_size = th_pose_axisang.shape[0]
# Convert axis-angle representation to rotation matrix rep.
th_pose_rotmat = th_posemap_axisang(th_pose_axisang)
# Take out the first rotmat (global rotation)
root_rot = th_pose_rotmat[:, :9].view(batch_size, 3, 3)
# Take out the remaining rotmats (23 joints)
th_pose_rotmat = th_pose_rotmat[:, 9:]
th_pose_map = subtract_flat_id(th_pose_rotmat)
# Below does: v_shaped = v_template + shapedirs * betas
# If shape parameters are not provided
if th_betas is None or bool(torch.norm(th_betas) == 0):
th_v_shaped = torch.matmul(self.th_shapedirs,
self.th_betas.transpose(1, 0)).permute(
2, 0, 1) + self.th_v_template
th_v_shaped = self.th_v_template + torch.matmul(
self.th_shapedirs, self.th_betas.transpose(1, 0)).permute(2, 0, 1)
th_j = torch.matmul(self.th_J_regressor, th_v_shaped).repeat(
batch_size, 1, 1)
else:
th_v_shaped = torch.matmul(self.th_shapedirs,
th_betas.transpose(1, 0)).permute(
2, 0, 1) + self.th_v_template
th_v_shaped = self.th_v_template + torch.matmul(
self.th_shapedirs, th_betas.transpose(1, 0)).permute(2, 0, 1)
th_j = torch.matmul(self.th_J_regressor, th_v_shaped)
# Below does: v_posed = v_shaped + posedirs * pose_map
th_v_posed = th_v_shaped + torch.matmul(
self.th_posedirs, th_pose_map.transpose(0, 1)).permute(2, 0, 1)
# Final T pose with transformation done !
# Final T pose with transformation done!
# Global rigid transformation
th_results = []
@ -101,7 +106,7 @@ class SMPL_Layer(Module):
# Rotate each part
for i in range(self.num_joints - 1):
i_val = int(i + 1)
joint_rot = th_rot_map[:, (i_val - 1) * 9:i_val *
joint_rot = th_pose_rotmat[:, (i_val - 1) * 9:i_val *
9].contiguous().view(batch_size, 3, 3)
joint_j = th_j[:, i_val, :].contiguous().view(batch_size, 3, 1)
parent = make_list(self.kintree_parents)[i_val]
@ -137,6 +142,7 @@ class SMPL_Layer(Module):
th_verts = th_verts[:, :, :3]
th_jtr = torch.stack(th_results_global, dim=1)[:, :, :3, 3]
# If translation is not provided
if th_trans is None or bool(torch.norm(th_trans) == 0):
if self.center_idx is not None:
center_joint = th_jtr[:, self.center_idx].unsqueeze(1)
@ -146,7 +152,5 @@ class SMPL_Layer(Module):
th_jtr = th_jtr + th_trans.unsqueeze(1)
th_verts = th_verts + th_trans.unsqueeze(1)
# Scale to milimeters
# th_verts = th_verts * 1000
# th_jtr = th_jtr * 1000
# Vertices and joints in meters
return th_verts, th_jtr

View File

@ -4,18 +4,19 @@ from smpl.pytorch import rodrigues_layer
def th_posemap_axisang(pose_vectors):
'''
Converts axis-angle to rotmat
pose_vectors (Tensor (batch_size x 72)): pose parameters in axis-angle representation
'''
rot_nb = int(pose_vectors.shape[1] / 3)
rot_mats = []
for joint_idx in range(rot_nb - 1):
joint_idx_val = joint_idx + 1
axis_ang = pose_vectors[:, joint_idx_val * 3:(joint_idx_val + 1) * 3]
for joint_idx in range(rot_nb):
axis_ang = pose_vectors[:, joint_idx * 3:(joint_idx + 1) * 3]
rot_mat = rodrigues_layer.batch_rodrigues(axis_ang)
rot_mats.append(rot_mat)
# rot_mats = torch.stack(rot_mats, 1).view(-1, 15 *9)
rot_mats = torch.cat(rot_mats, 1)
pose_maps = subtract_flat_id(rot_mats)
return pose_maps, rot_mats
return rot_mats
def th_with_zeros(tensor):