Files
2025-07-25 15:05:31 +08:00

57 lines
1.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
import random
import numpy as np
import pickle # 新增用于加载pkl文件
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
from display_utils import display_model
if __name__ == '__main__':
cuda = True
batch_size = 1
# Create the SMPL layer
smpl_layer = SMPL_Layer(
center_idx=0,
gender='male', # 确保与pkl文件中的性别一致
model_root='/home/lmd/Code/Pose_to_SMPL_an_230402/smplpytorch/native/models')
# 从pkl文件加载参数
pkl_path = '/home/lmd/Code/Pose_to_SMPL_an_230402/fit/output/UTD_MHAD/sigle_people_smpl_params.pkl' # 替换为实际的pkl文件路径
with open(pkl_path, 'rb') as f:
data = pickle.load(f)
# 提取pose和shape参数
pose_params = torch.tensor(data['pose_params']).float() # 确保数据类型为float
shape_params = torch.tensor(data['shape_params']).float()
# 调整维度(如果需要)
if pose_params.dim() == 1:
pose_params = pose_params.unsqueeze(0) # 添加batch维度
if shape_params.dim() == 1:
shape_params = shape_params.unsqueeze(0)
# 验证batch size
if pose_params.shape[0] != batch_size:
batch_size = pose_params.shape[0]
print(f"Warning: Batch size adjusted to {batch_size} based on loaded data.")
# GPU mode
if cuda:
pose_params = pose_params.cuda()
shape_params = shape_params.cuda()
smpl_layer.cuda()
# Forward from the SMPL layer
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
# Draw output vertices and joints
display_model(
{'verts': verts.cpu().detach(),
'joints': Jtr.cpu().detach()},
model_faces=smpl_layer.th_faces,
with_joints=True,
kintree_table=smpl_layer.kintree_table,
savepath='image2.png',
show=True)