57 lines
1.8 KiB
Python
57 lines
1.8 KiB
Python
import torch
|
||
import random
|
||
import numpy as np
|
||
import pickle # 新增:用于加载pkl文件
|
||
|
||
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
|
||
from display_utils import display_model
|
||
|
||
|
||
if __name__ == '__main__':
|
||
cuda = True
|
||
batch_size = 1
|
||
|
||
# Create the SMPL layer
|
||
smpl_layer = SMPL_Layer(
|
||
center_idx=0,
|
||
gender='male', # 确保与pkl文件中的性别一致
|
||
model_root='/home/lmd/Code/Pose_to_SMPL_an_230402/smplpytorch/native/models')
|
||
|
||
# 从pkl文件加载参数
|
||
pkl_path = '/home/lmd/Code/Pose_to_SMPL_an_230402/fit/output/UTD_MHAD/sigle_people_smpl_params.pkl' # 替换为实际的pkl文件路径
|
||
with open(pkl_path, 'rb') as f:
|
||
data = pickle.load(f)
|
||
|
||
# 提取pose和shape参数
|
||
pose_params = torch.tensor(data['pose_params']).float() # 确保数据类型为float
|
||
shape_params = torch.tensor(data['shape_params']).float()
|
||
|
||
# 调整维度(如果需要)
|
||
if pose_params.dim() == 1:
|
||
pose_params = pose_params.unsqueeze(0) # 添加batch维度
|
||
if shape_params.dim() == 1:
|
||
shape_params = shape_params.unsqueeze(0)
|
||
|
||
# 验证batch size
|
||
if pose_params.shape[0] != batch_size:
|
||
batch_size = pose_params.shape[0]
|
||
print(f"Warning: Batch size adjusted to {batch_size} based on loaded data.")
|
||
|
||
# GPU mode
|
||
if cuda:
|
||
pose_params = pose_params.cuda()
|
||
shape_params = shape_params.cuda()
|
||
smpl_layer.cuda()
|
||
|
||
# Forward from the SMPL layer
|
||
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
|
||
|
||
# Draw output vertices and joints
|
||
display_model(
|
||
{'verts': verts.cpu().detach(),
|
||
'joints': Jtr.cpu().detach()},
|
||
model_faces=smpl_layer.th_faces,
|
||
with_joints=True,
|
||
kintree_table=smpl_layer.kintree_table,
|
||
savepath='image2.png',
|
||
show=True) |