Simplify code.
This commit is contained in:
@ -6,7 +6,7 @@ from torch.nn import Module
|
|||||||
|
|
||||||
from smpl.native.webuser.serialization import ready_arguments
|
from smpl.native.webuser.serialization import ready_arguments
|
||||||
from smpl.pytorch import rodrigues_layer
|
from smpl.pytorch import rodrigues_layer
|
||||||
from smpl.pytorch.tensutils import (th_posemap_axisang, th_with_zeros, th_pack, make_list)
|
from smpl.pytorch.tensutils import (th_posemap_axisang, th_with_zeros, th_pack, make_list, subtract_flat_id)
|
||||||
|
|
||||||
|
|
||||||
class SMPL_Layer(Module):
|
class SMPL_Layer(Module):
|
||||||
@ -61,36 +61,41 @@ class SMPL_Layer(Module):
|
|||||||
self.num_joints = len(parents) # 24
|
self.num_joints = len(parents) # 24
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
th_pose_coeffs,
|
th_pose_axisang,
|
||||||
th_betas=torch.zeros(1),
|
th_betas=torch.zeros(1),
|
||||||
th_trans=torch.zeros(1)):
|
th_trans=torch.zeros(1)):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
th_pose_axisang (Tensor (batch_size x 72)): pose parameters in axis-angle representation
|
||||||
th_betas (Tensor (batch_size x 10)): if provided, uses given shape parameters
|
th_betas (Tensor (batch_size x 10)): if provided, uses given shape parameters
|
||||||
th_trans (Tensor (batch_size x 3)): if provided, applies trans to joints and vertices
|
th_trans (Tensor (batch_size x 3)): if provided, applies trans to joints and vertices
|
||||||
"""
|
"""
|
||||||
|
|
||||||
batch_size = th_pose_coeffs.shape[0]
|
batch_size = th_pose_axisang.shape[0]
|
||||||
th_pose_map, th_rot_map = th_posemap_axisang(th_pose_coeffs)
|
# Convert axis-angle representation to rotation matrix rep.
|
||||||
th_pose_coeffs = th_pose_coeffs.view(batch_size, -1, 3)
|
th_pose_rotmat = th_posemap_axisang(th_pose_axisang)
|
||||||
root_rot = rodrigues_layer.batch_rodrigues(
|
# Take out the first rotmat (global rotation)
|
||||||
th_pose_coeffs[:, 0]).view(batch_size, 3, 3)
|
root_rot = th_pose_rotmat[:, :9].view(batch_size, 3, 3)
|
||||||
|
# Take out the remaining rotmats (23 joints)
|
||||||
|
th_pose_rotmat = th_pose_rotmat[:, 9:]
|
||||||
|
th_pose_map = subtract_flat_id(th_pose_rotmat)
|
||||||
|
|
||||||
|
# Below does: v_shaped = v_template + shapedirs * betas
|
||||||
|
# If shape parameters are not provided
|
||||||
if th_betas is None or bool(torch.norm(th_betas) == 0):
|
if th_betas is None or bool(torch.norm(th_betas) == 0):
|
||||||
th_v_shaped = torch.matmul(self.th_shapedirs,
|
th_v_shaped = self.th_v_template + torch.matmul(
|
||||||
self.th_betas.transpose(1, 0)).permute(
|
self.th_shapedirs, self.th_betas.transpose(1, 0)).permute(2, 0, 1)
|
||||||
2, 0, 1) + self.th_v_template
|
|
||||||
th_j = torch.matmul(self.th_J_regressor, th_v_shaped).repeat(
|
th_j = torch.matmul(self.th_J_regressor, th_v_shaped).repeat(
|
||||||
batch_size, 1, 1)
|
batch_size, 1, 1)
|
||||||
else:
|
else:
|
||||||
th_v_shaped = torch.matmul(self.th_shapedirs,
|
th_v_shaped = self.th_v_template + torch.matmul(
|
||||||
th_betas.transpose(1, 0)).permute(
|
self.th_shapedirs, th_betas.transpose(1, 0)).permute(2, 0, 1)
|
||||||
2, 0, 1) + self.th_v_template
|
|
||||||
th_j = torch.matmul(self.th_J_regressor, th_v_shaped)
|
th_j = torch.matmul(self.th_J_regressor, th_v_shaped)
|
||||||
|
|
||||||
|
# Below does: v_posed = v_shaped + posedirs * pose_map
|
||||||
th_v_posed = th_v_shaped + torch.matmul(
|
th_v_posed = th_v_shaped + torch.matmul(
|
||||||
self.th_posedirs, th_pose_map.transpose(0, 1)).permute(2, 0, 1)
|
self.th_posedirs, th_pose_map.transpose(0, 1)).permute(2, 0, 1)
|
||||||
# Final T pose with transformation done !
|
# Final T pose with transformation done!
|
||||||
|
|
||||||
# Global rigid transformation
|
# Global rigid transformation
|
||||||
th_results = []
|
th_results = []
|
||||||
@ -101,7 +106,7 @@ class SMPL_Layer(Module):
|
|||||||
# Rotate each part
|
# Rotate each part
|
||||||
for i in range(self.num_joints - 1):
|
for i in range(self.num_joints - 1):
|
||||||
i_val = int(i + 1)
|
i_val = int(i + 1)
|
||||||
joint_rot = th_rot_map[:, (i_val - 1) * 9:i_val *
|
joint_rot = th_pose_rotmat[:, (i_val - 1) * 9:i_val *
|
||||||
9].contiguous().view(batch_size, 3, 3)
|
9].contiguous().view(batch_size, 3, 3)
|
||||||
joint_j = th_j[:, i_val, :].contiguous().view(batch_size, 3, 1)
|
joint_j = th_j[:, i_val, :].contiguous().view(batch_size, 3, 1)
|
||||||
parent = make_list(self.kintree_parents)[i_val]
|
parent = make_list(self.kintree_parents)[i_val]
|
||||||
@ -137,6 +142,7 @@ class SMPL_Layer(Module):
|
|||||||
th_verts = th_verts[:, :, :3]
|
th_verts = th_verts[:, :, :3]
|
||||||
th_jtr = torch.stack(th_results_global, dim=1)[:, :, :3, 3]
|
th_jtr = torch.stack(th_results_global, dim=1)[:, :, :3, 3]
|
||||||
|
|
||||||
|
# If translation is not provided
|
||||||
if th_trans is None or bool(torch.norm(th_trans) == 0):
|
if th_trans is None or bool(torch.norm(th_trans) == 0):
|
||||||
if self.center_idx is not None:
|
if self.center_idx is not None:
|
||||||
center_joint = th_jtr[:, self.center_idx].unsqueeze(1)
|
center_joint = th_jtr[:, self.center_idx].unsqueeze(1)
|
||||||
@ -146,7 +152,5 @@ class SMPL_Layer(Module):
|
|||||||
th_jtr = th_jtr + th_trans.unsqueeze(1)
|
th_jtr = th_jtr + th_trans.unsqueeze(1)
|
||||||
th_verts = th_verts + th_trans.unsqueeze(1)
|
th_verts = th_verts + th_trans.unsqueeze(1)
|
||||||
|
|
||||||
# Scale to milimeters
|
# Vertices and joints in meters
|
||||||
# th_verts = th_verts * 1000
|
|
||||||
# th_jtr = th_jtr * 1000
|
|
||||||
return th_verts, th_jtr
|
return th_verts, th_jtr
|
||||||
|
|||||||
@ -4,18 +4,19 @@ from smpl.pytorch import rodrigues_layer
|
|||||||
|
|
||||||
|
|
||||||
def th_posemap_axisang(pose_vectors):
|
def th_posemap_axisang(pose_vectors):
|
||||||
|
'''
|
||||||
|
Converts axis-angle to rotmat
|
||||||
|
pose_vectors (Tensor (batch_size x 72)): pose parameters in axis-angle representation
|
||||||
|
'''
|
||||||
rot_nb = int(pose_vectors.shape[1] / 3)
|
rot_nb = int(pose_vectors.shape[1] / 3)
|
||||||
rot_mats = []
|
rot_mats = []
|
||||||
for joint_idx in range(rot_nb - 1):
|
for joint_idx in range(rot_nb):
|
||||||
joint_idx_val = joint_idx + 1
|
axis_ang = pose_vectors[:, joint_idx * 3:(joint_idx + 1) * 3]
|
||||||
axis_ang = pose_vectors[:, joint_idx_val * 3:(joint_idx_val + 1) * 3]
|
|
||||||
rot_mat = rodrigues_layer.batch_rodrigues(axis_ang)
|
rot_mat = rodrigues_layer.batch_rodrigues(axis_ang)
|
||||||
rot_mats.append(rot_mat)
|
rot_mats.append(rot_mat)
|
||||||
|
|
||||||
# rot_mats = torch.stack(rot_mats, 1).view(-1, 15 *9)
|
|
||||||
rot_mats = torch.cat(rot_mats, 1)
|
rot_mats = torch.cat(rot_mats, 1)
|
||||||
pose_maps = subtract_flat_id(rot_mats)
|
return rot_mats
|
||||||
return pose_maps, rot_mats
|
|
||||||
|
|
||||||
|
|
||||||
def th_with_zeros(tensor):
|
def th_with_zeros(tensor):
|
||||||
|
|||||||
Reference in New Issue
Block a user