Simplify code.

This commit is contained in:
gulvarol
2019-07-10 14:52:40 +02:00
parent 095e65e498
commit 89f5c60799
2 changed files with 29 additions and 24 deletions

View File

@ -6,7 +6,7 @@ from torch.nn import Module
from smpl.native.webuser.serialization import ready_arguments from smpl.native.webuser.serialization import ready_arguments
from smpl.pytorch import rodrigues_layer from smpl.pytorch import rodrigues_layer
from smpl.pytorch.tensutils import (th_posemap_axisang, th_with_zeros, th_pack, make_list) from smpl.pytorch.tensutils import (th_posemap_axisang, th_with_zeros, th_pack, make_list, subtract_flat_id)
class SMPL_Layer(Module): class SMPL_Layer(Module):
@ -61,33 +61,38 @@ class SMPL_Layer(Module):
self.num_joints = len(parents) # 24 self.num_joints = len(parents) # 24
def forward(self, def forward(self,
th_pose_coeffs, th_pose_axisang,
th_betas=torch.zeros(1), th_betas=torch.zeros(1),
th_trans=torch.zeros(1)): th_trans=torch.zeros(1)):
""" """
Args: Args:
th_pose_axisang (Tensor (batch_size x 72)): pose parameters in axis-angle representation
th_betas (Tensor (batch_size x 10)): if provided, uses given shape parameters th_betas (Tensor (batch_size x 10)): if provided, uses given shape parameters
th_trans (Tensor (batch_size x 3)): if provided, applies trans to joints and vertices th_trans (Tensor (batch_size x 3)): if provided, applies trans to joints and vertices
""" """
batch_size = th_pose_coeffs.shape[0] batch_size = th_pose_axisang.shape[0]
th_pose_map, th_rot_map = th_posemap_axisang(th_pose_coeffs) # Convert axis-angle representation to rotation matrix rep.
th_pose_coeffs = th_pose_coeffs.view(batch_size, -1, 3) th_pose_rotmat = th_posemap_axisang(th_pose_axisang)
root_rot = rodrigues_layer.batch_rodrigues( # Take out the first rotmat (global rotation)
th_pose_coeffs[:, 0]).view(batch_size, 3, 3) root_rot = th_pose_rotmat[:, :9].view(batch_size, 3, 3)
# Take out the remaining rotmats (23 joints)
th_pose_rotmat = th_pose_rotmat[:, 9:]
th_pose_map = subtract_flat_id(th_pose_rotmat)
# Below does: v_shaped = v_template + shapedirs * betas
# If shape parameters are not provided
if th_betas is None or bool(torch.norm(th_betas) == 0): if th_betas is None or bool(torch.norm(th_betas) == 0):
th_v_shaped = torch.matmul(self.th_shapedirs, th_v_shaped = self.th_v_template + torch.matmul(
self.th_betas.transpose(1, 0)).permute( self.th_shapedirs, self.th_betas.transpose(1, 0)).permute(2, 0, 1)
2, 0, 1) + self.th_v_template
th_j = torch.matmul(self.th_J_regressor, th_v_shaped).repeat( th_j = torch.matmul(self.th_J_regressor, th_v_shaped).repeat(
batch_size, 1, 1) batch_size, 1, 1)
else: else:
th_v_shaped = torch.matmul(self.th_shapedirs, th_v_shaped = self.th_v_template + torch.matmul(
th_betas.transpose(1, 0)).permute( self.th_shapedirs, th_betas.transpose(1, 0)).permute(2, 0, 1)
2, 0, 1) + self.th_v_template
th_j = torch.matmul(self.th_J_regressor, th_v_shaped) th_j = torch.matmul(self.th_J_regressor, th_v_shaped)
# Below does: v_posed = v_shaped + posedirs * pose_map
th_v_posed = th_v_shaped + torch.matmul( th_v_posed = th_v_shaped + torch.matmul(
self.th_posedirs, th_pose_map.transpose(0, 1)).permute(2, 0, 1) self.th_posedirs, th_pose_map.transpose(0, 1)).permute(2, 0, 1)
# Final T pose with transformation done! # Final T pose with transformation done!
@ -101,7 +106,7 @@ class SMPL_Layer(Module):
# Rotate each part # Rotate each part
for i in range(self.num_joints - 1): for i in range(self.num_joints - 1):
i_val = int(i + 1) i_val = int(i + 1)
joint_rot = th_rot_map[:, (i_val - 1) * 9:i_val * joint_rot = th_pose_rotmat[:, (i_val - 1) * 9:i_val *
9].contiguous().view(batch_size, 3, 3) 9].contiguous().view(batch_size, 3, 3)
joint_j = th_j[:, i_val, :].contiguous().view(batch_size, 3, 1) joint_j = th_j[:, i_val, :].contiguous().view(batch_size, 3, 1)
parent = make_list(self.kintree_parents)[i_val] parent = make_list(self.kintree_parents)[i_val]
@ -137,6 +142,7 @@ class SMPL_Layer(Module):
th_verts = th_verts[:, :, :3] th_verts = th_verts[:, :, :3]
th_jtr = torch.stack(th_results_global, dim=1)[:, :, :3, 3] th_jtr = torch.stack(th_results_global, dim=1)[:, :, :3, 3]
# If translation is not provided
if th_trans is None or bool(torch.norm(th_trans) == 0): if th_trans is None or bool(torch.norm(th_trans) == 0):
if self.center_idx is not None: if self.center_idx is not None:
center_joint = th_jtr[:, self.center_idx].unsqueeze(1) center_joint = th_jtr[:, self.center_idx].unsqueeze(1)
@ -146,7 +152,5 @@ class SMPL_Layer(Module):
th_jtr = th_jtr + th_trans.unsqueeze(1) th_jtr = th_jtr + th_trans.unsqueeze(1)
th_verts = th_verts + th_trans.unsqueeze(1) th_verts = th_verts + th_trans.unsqueeze(1)
# Scale to milimeters # Vertices and joints in meters
# th_verts = th_verts * 1000
# th_jtr = th_jtr * 1000
return th_verts, th_jtr return th_verts, th_jtr

View File

@ -4,18 +4,19 @@ from smpl.pytorch import rodrigues_layer
def th_posemap_axisang(pose_vectors): def th_posemap_axisang(pose_vectors):
'''
Converts axis-angle to rotmat
pose_vectors (Tensor (batch_size x 72)): pose parameters in axis-angle representation
'''
rot_nb = int(pose_vectors.shape[1] / 3) rot_nb = int(pose_vectors.shape[1] / 3)
rot_mats = [] rot_mats = []
for joint_idx in range(rot_nb - 1): for joint_idx in range(rot_nb):
joint_idx_val = joint_idx + 1 axis_ang = pose_vectors[:, joint_idx * 3:(joint_idx + 1) * 3]
axis_ang = pose_vectors[:, joint_idx_val * 3:(joint_idx_val + 1) * 3]
rot_mat = rodrigues_layer.batch_rodrigues(axis_ang) rot_mat = rodrigues_layer.batch_rodrigues(axis_ang)
rot_mats.append(rot_mat) rot_mats.append(rot_mat)
# rot_mats = torch.stack(rot_mats, 1).view(-1, 15 *9)
rot_mats = torch.cat(rot_mats, 1) rot_mats = torch.cat(rot_mats, 1)
pose_maps = subtract_flat_id(rot_mats) return rot_mats
return pose_maps, rot_mats
def th_with_zeros(tensor): def th_with_zeros(tensor):