153 lines
6.2 KiB
Python
153 lines
6.2 KiB
Python
import os
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch.nn import Module
|
|
|
|
from smpl.native.webuser.serialization import ready_arguments
|
|
from smpl.pytorch import rodrigues_layer
|
|
from smpl.pytorch.tensutils import (th_posemap_axisang, th_with_zeros, th_pack, make_list)
|
|
|
|
|
|
class SMPL_Layer(Module):
|
|
__constants__ = ['kintree_parents', 'gender', 'center_idx', 'num_joints']
|
|
|
|
def __init__(self,
|
|
center_idx=None,
|
|
gender='neutral',
|
|
model_root='smpl/native/models'):
|
|
"""
|
|
Args:
|
|
center_idx: index of center joint in our computations,
|
|
model_root: path to pkl files for the model
|
|
gender: 'neutral' (default) or 'female' or 'male'
|
|
"""
|
|
super().__init__()
|
|
|
|
self.center_idx = center_idx
|
|
self.gender = gender
|
|
|
|
if gender == 'neutral':
|
|
self.model_path = os.path.join(model_root, 'basicModel_neutral_lbs_10_207_0_v1.0.0.pkl')
|
|
elif gender == 'female':
|
|
self.model_path = os.path.join(model_root, 'basicModel_f_lbs_10_207_0_v1.0.0.pkl')
|
|
elif gender == 'male':
|
|
self.model_path = os.path.join(model_root, 'basicModel_m_lbs_10_207_0_v1.0.0.pkl')
|
|
|
|
smpl_data = ready_arguments(self.model_path)
|
|
self.smpl_data = smpl_data
|
|
|
|
self.register_buffer('th_betas',
|
|
torch.Tensor(smpl_data['betas'].r).unsqueeze(0))
|
|
self.register_buffer('th_shapedirs',
|
|
torch.Tensor(smpl_data['shapedirs'].r))
|
|
self.register_buffer('th_posedirs',
|
|
torch.Tensor(smpl_data['posedirs'].r))
|
|
self.register_buffer(
|
|
'th_v_template',
|
|
torch.Tensor(smpl_data['v_template'].r).unsqueeze(0))
|
|
self.register_buffer(
|
|
'th_J_regressor',
|
|
torch.Tensor(np.array(smpl_data['J_regressor'].toarray())))
|
|
self.register_buffer('th_weights',
|
|
torch.Tensor(smpl_data['weights'].r))
|
|
self.register_buffer('th_faces',
|
|
torch.Tensor(smpl_data['f'].astype(np.int32)).long())
|
|
|
|
# Kinematic chain params
|
|
self.kintree_table = smpl_data['kintree_table']
|
|
parents = list(self.kintree_table[0].tolist())
|
|
self.kintree_parents = parents
|
|
self.num_joints = len(parents) # 24
|
|
|
|
def forward(self,
|
|
th_pose_coeffs,
|
|
th_betas=torch.zeros(1),
|
|
th_trans=torch.zeros(1)):
|
|
"""
|
|
Args:
|
|
th_betas (Tensor (batch_size x 10)): if provided, uses given shape parameters
|
|
th_trans (Tensor (batch_size x 3)): if provided, applies trans to joints and vertices
|
|
"""
|
|
|
|
batch_size = th_pose_coeffs.shape[0]
|
|
th_pose_map, th_rot_map = th_posemap_axisang(th_pose_coeffs)
|
|
th_pose_coeffs = th_pose_coeffs.view(batch_size, -1, 3)
|
|
root_rot = rodrigues_layer.batch_rodrigues(
|
|
th_pose_coeffs[:, 0]).view(batch_size, 3, 3)
|
|
|
|
if th_betas is None or bool(torch.norm(th_betas) == 0):
|
|
th_v_shaped = torch.matmul(self.th_shapedirs,
|
|
self.th_betas.transpose(1, 0)).permute(
|
|
2, 0, 1) + self.th_v_template
|
|
th_j = torch.matmul(self.th_J_regressor, th_v_shaped).repeat(
|
|
batch_size, 1, 1)
|
|
else:
|
|
th_v_shaped = torch.matmul(self.th_shapedirs,
|
|
th_betas.transpose(1, 0)).permute(
|
|
2, 0, 1) + self.th_v_template
|
|
th_j = torch.matmul(self.th_J_regressor, th_v_shaped)
|
|
|
|
th_v_posed = th_v_shaped + torch.matmul(
|
|
self.th_posedirs, th_pose_map.transpose(0, 1)).permute(2, 0, 1)
|
|
# Final T pose with transformation done !
|
|
|
|
# Global rigid transformation
|
|
th_results = []
|
|
|
|
root_j = th_j[:, 0, :].contiguous().view(batch_size, 3, 1)
|
|
th_results.append(th_with_zeros(torch.cat([root_rot, root_j], 2)))
|
|
|
|
# Rotate each part
|
|
for i in range(self.num_joints - 1):
|
|
i_val = int(i + 1)
|
|
joint_rot = th_rot_map[:, (i_val - 1) * 9:i_val *
|
|
9].contiguous().view(batch_size, 3, 3)
|
|
joint_j = th_j[:, i_val, :].contiguous().view(batch_size, 3, 1)
|
|
parent = make_list(self.kintree_parents)[i_val]
|
|
parent_j = th_j[:, parent, :].contiguous().view(batch_size, 3, 1)
|
|
joint_rel_transform = th_with_zeros(
|
|
torch.cat([joint_rot, joint_j - parent_j], 2))
|
|
th_results.append(
|
|
torch.matmul(th_results[parent], joint_rel_transform))
|
|
th_results_global = th_results
|
|
|
|
th_results2 = torch.zeros((batch_size, 4, 4, self.num_joints),
|
|
dtype=root_j.dtype,
|
|
device=root_j.device)
|
|
|
|
for i in range(self.num_joints):
|
|
padd_zero = torch.zeros(1, dtype=th_j.dtype, device=th_j.device)
|
|
joint_j = torch.cat(
|
|
[th_j[:, i],
|
|
padd_zero.view(1, 1).repeat(batch_size, 1)], 1)
|
|
tmp = torch.bmm(th_results[i], joint_j.unsqueeze(2))
|
|
th_results2[:, :, :, i] = th_results[i] - th_pack(tmp)
|
|
|
|
th_T = torch.matmul(th_results2, self.th_weights.transpose(0, 1))
|
|
|
|
th_rest_shape_h = torch.cat([
|
|
th_v_posed.transpose(2, 1),
|
|
torch.ones((batch_size, 1, th_v_posed.shape[1]),
|
|
dtype=th_T.dtype,
|
|
device=th_T.device),
|
|
], 1)
|
|
|
|
th_verts = (th_T * th_rest_shape_h.unsqueeze(1)).sum(2).transpose(2, 1)
|
|
th_verts = th_verts[:, :, :3]
|
|
th_jtr = torch.stack(th_results_global, dim=1)[:, :, :3, 3]
|
|
|
|
if th_trans is None or bool(torch.norm(th_trans) == 0):
|
|
if self.center_idx is not None:
|
|
center_joint = th_jtr[:, self.center_idx].unsqueeze(1)
|
|
th_jtr = th_jtr - center_joint
|
|
th_verts = th_verts - center_joint
|
|
else:
|
|
th_jtr = th_jtr + th_trans.unsqueeze(1)
|
|
th_verts = th_verts + th_trans.unsqueeze(1)
|
|
|
|
# Scale to milimeters
|
|
# th_verts = th_verts * 1000
|
|
# th_jtr = th_jtr * 1000
|
|
return th_verts, th_jtr
|