From 89f5c6079986275d0476d835677f75e6bacb59b8 Mon Sep 17 00:00:00 2001 From: gulvarol Date: Wed, 10 Jul 2019 14:52:40 +0200 Subject: [PATCH] Simplify code. --- smpl/pytorch/smpl_layer.py | 40 +++++++++++++++++++++----------------- smpl/pytorch/tensutils.py | 13 +++++++------ 2 files changed, 29 insertions(+), 24 deletions(-) diff --git a/smpl/pytorch/smpl_layer.py b/smpl/pytorch/smpl_layer.py index 9ab4cb3..49b2b5b 100644 --- a/smpl/pytorch/smpl_layer.py +++ b/smpl/pytorch/smpl_layer.py @@ -6,7 +6,7 @@ from torch.nn import Module from smpl.native.webuser.serialization import ready_arguments from smpl.pytorch import rodrigues_layer -from smpl.pytorch.tensutils import (th_posemap_axisang, th_with_zeros, th_pack, make_list) +from smpl.pytorch.tensutils import (th_posemap_axisang, th_with_zeros, th_pack, make_list, subtract_flat_id) class SMPL_Layer(Module): @@ -61,36 +61,41 @@ class SMPL_Layer(Module): self.num_joints = len(parents) # 24 def forward(self, - th_pose_coeffs, + th_pose_axisang, th_betas=torch.zeros(1), th_trans=torch.zeros(1)): """ Args: + th_pose_axisang (Tensor (batch_size x 72)): pose parameters in axis-angle representation th_betas (Tensor (batch_size x 10)): if provided, uses given shape parameters th_trans (Tensor (batch_size x 3)): if provided, applies trans to joints and vertices """ - batch_size = th_pose_coeffs.shape[0] - th_pose_map, th_rot_map = th_posemap_axisang(th_pose_coeffs) - th_pose_coeffs = th_pose_coeffs.view(batch_size, -1, 3) - root_rot = rodrigues_layer.batch_rodrigues( - th_pose_coeffs[:, 0]).view(batch_size, 3, 3) + batch_size = th_pose_axisang.shape[0] + # Convert axis-angle representation to rotation matrix rep. + th_pose_rotmat = th_posemap_axisang(th_pose_axisang) + # Take out the first rotmat (global rotation) + root_rot = th_pose_rotmat[:, :9].view(batch_size, 3, 3) + # Take out the remaining rotmats (23 joints) + th_pose_rotmat = th_pose_rotmat[:, 9:] + th_pose_map = subtract_flat_id(th_pose_rotmat) + # Below does: v_shaped = v_template + shapedirs * betas + # If shape parameters are not provided if th_betas is None or bool(torch.norm(th_betas) == 0): - th_v_shaped = torch.matmul(self.th_shapedirs, - self.th_betas.transpose(1, 0)).permute( - 2, 0, 1) + self.th_v_template + th_v_shaped = self.th_v_template + torch.matmul( + self.th_shapedirs, self.th_betas.transpose(1, 0)).permute(2, 0, 1) th_j = torch.matmul(self.th_J_regressor, th_v_shaped).repeat( batch_size, 1, 1) else: - th_v_shaped = torch.matmul(self.th_shapedirs, - th_betas.transpose(1, 0)).permute( - 2, 0, 1) + self.th_v_template + th_v_shaped = self.th_v_template + torch.matmul( + self.th_shapedirs, th_betas.transpose(1, 0)).permute(2, 0, 1) th_j = torch.matmul(self.th_J_regressor, th_v_shaped) + # Below does: v_posed = v_shaped + posedirs * pose_map th_v_posed = th_v_shaped + torch.matmul( self.th_posedirs, th_pose_map.transpose(0, 1)).permute(2, 0, 1) - # Final T pose with transformation done ! + # Final T pose with transformation done! # Global rigid transformation th_results = [] @@ -101,7 +106,7 @@ class SMPL_Layer(Module): # Rotate each part for i in range(self.num_joints - 1): i_val = int(i + 1) - joint_rot = th_rot_map[:, (i_val - 1) * 9:i_val * + joint_rot = th_pose_rotmat[:, (i_val - 1) * 9:i_val * 9].contiguous().view(batch_size, 3, 3) joint_j = th_j[:, i_val, :].contiguous().view(batch_size, 3, 1) parent = make_list(self.kintree_parents)[i_val] @@ -137,6 +142,7 @@ class SMPL_Layer(Module): th_verts = th_verts[:, :, :3] th_jtr = torch.stack(th_results_global, dim=1)[:, :, :3, 3] + # If translation is not provided if th_trans is None or bool(torch.norm(th_trans) == 0): if self.center_idx is not None: center_joint = th_jtr[:, self.center_idx].unsqueeze(1) @@ -146,7 +152,5 @@ class SMPL_Layer(Module): th_jtr = th_jtr + th_trans.unsqueeze(1) th_verts = th_verts + th_trans.unsqueeze(1) - # Scale to milimeters - # th_verts = th_verts * 1000 - # th_jtr = th_jtr * 1000 + # Vertices and joints in meters return th_verts, th_jtr diff --git a/smpl/pytorch/tensutils.py b/smpl/pytorch/tensutils.py index 5d07bbc..0082fc2 100644 --- a/smpl/pytorch/tensutils.py +++ b/smpl/pytorch/tensutils.py @@ -4,18 +4,19 @@ from smpl.pytorch import rodrigues_layer def th_posemap_axisang(pose_vectors): + ''' + Converts axis-angle to rotmat + pose_vectors (Tensor (batch_size x 72)): pose parameters in axis-angle representation + ''' rot_nb = int(pose_vectors.shape[1] / 3) rot_mats = [] - for joint_idx in range(rot_nb - 1): - joint_idx_val = joint_idx + 1 - axis_ang = pose_vectors[:, joint_idx_val * 3:(joint_idx_val + 1) * 3] + for joint_idx in range(rot_nb): + axis_ang = pose_vectors[:, joint_idx * 3:(joint_idx + 1) * 3] rot_mat = rodrigues_layer.batch_rodrigues(axis_ang) rot_mats.append(rot_mat) - # rot_mats = torch.stack(rot_mats, 1).view(-1, 15 *9) rot_mats = torch.cat(rot_mats, 1) - pose_maps = subtract_flat_id(rot_mats) - return pose_maps, rot_mats + return rot_mats def th_with_zeros(tensor):