Compare commits

...

10 Commits

Author SHA1 Message Date
lmd
65769e5eb6 name 2025-07-25 15:05:31 +08:00
24e1e31234 Merge branch 'main' of github.com:Iridoudou/Pose_to_SMPL 2022-10-01 21:42:18 +08:00
6ac10e22e0 Support HAA4D 2022-10-01 21:42:12 +08:00
03debbe123 Create README.md 2022-09-28 20:57:19 +08:00
fd2ff6616f Support NTU 2021-09-03 15:31:39 +08:00
52603950d6 support NTU 2021-09-03 11:07:42 +08:00
16387a7afe support NTU 2021-09-03 11:06:26 +08:00
0702081dbe update code 2021-08-20 22:48:02 +08:00
bd93f8dcaf down sample 2021-08-19 15:08:28 +08:00
8573a09b84 update meters 2021-08-19 12:40:54 +08:00
25 changed files with 4022 additions and 150 deletions

3
.gitignore vendored
View File

@ -14,3 +14,6 @@ smplpytorch/native/models/*.pkl
exp/ exp/
output/ output/
make_gif.py make_gif.py
*.pkl
*.whl
*.png

1
.python-version Normal file
View File

@ -0,0 +1 @@
3.12

15
.vscode/launch.json vendored Normal file
View File

@ -0,0 +1,15 @@
{
// 使用 IntelliSense 了解相关属性。
// 悬停以查看现有属性的描述。
// 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python 调试程序: 当前文件",
"type": "debugpy",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal"
}
]
}

View File

@ -1,4 +1,4 @@
pose2smpl Pose_to_SMPL
======= =======
### Fitting SMPL Parameters by 3D-pose Key-points ### Fitting SMPL Parameters by 3D-pose Key-points
@ -40,11 +40,14 @@ The SMPL human body layer for Pytorch is from the [smplpytorch](https://github.c
- Download the datasets you want to fit - Download the datasets you want to fit
currently supported datasets: currently support:
- [HumanAct12](https://ericguo5513.github.io/action-to-motion/) - [HumanAct12](https://ericguo5513.github.io/action-to-motion/)
- [CMU Mocap](https://ericguo5513.github.io/action-to-motion/) - [CMU Mocap](https://ericguo5513.github.io/action-to-motion/)
- [UTD-MHAD](https://personal.utdallas.edu/~kehtar/UTD-MHAD.html) - [UTD-MHAD](https://personal.utdallas.edu/~kehtar/UTD-MHAD.html)
- [Human3.6M](http://vision.imar.ro/human3.6m/description.php)
- [NTU](https://rose1.ntu.edu.sg/dataset/actionRecognition/)
- [HAA4D](https://cse.hkust.edu.hk/haa4d/dataset.html)
- Set the **DATASET.PATH** in the corresponding configuration file to the location of dataset. - Set the **DATASET.PATH** in the corresponding configuration file to the location of dataset.

View File

@ -0,0 +1,122 @@
[
[
0.3718571662902832,
0.6803165078163147,
-0.7502949833869934
],
[
0.45830726623535156,
0.6820494532585144,
-0.6965712904930115
],
[
0.28540709614753723,
0.6785836219787598,
-0.8040187358856201
],
[
0.3710346221923828,
0.8273340463638306,
-0.6851714849472046
],
[
0.4792303442955017,
0.2899134159088135,
-0.6879411935806274
],
[
0.2960263788700104,
0.29219183325767517,
-0.8322873711585999
],
[
0.3702120780944824,
0.9743515253067017,
-0.6200480461120605
],
[
0.5172274112701416,
-0.09073961526155472,
-0.7129285335540771
],
[
0.3149631917476654,
-0.08920176327228546,
-0.9007936120033264
],
[
0.3702120780944824,
1.0243514776229858,
-0.6200480461120605
],
[
0.5201388001441956,
-0.13739190995693207,
-0.6414260864257812
],
[
0.2730123996734619,
-0.13770559430122375,
-0.8578707575798035
],
[
0.3782634735107422,
1.2141607999801636,
-0.7061752080917358
],
[
0.4474300742149353,
1.1862998008728027,
-0.6677815914154053
],
[
0.3090968728065491,
1.1820218563079834,
-0.7445688247680664
],
[
0.36856698989868164,
1.2683864831924438,
-0.4898010492324829
],
[
0.5165966749191284,
1.1984386444091797,
-0.6293879151344299
],
[
0.23993025720119476,
1.1898829936981201,
-0.7829625010490417
],
[
0.478150337934494,
0.971466600894928,
-0.5572866201400757
],
[
0.20990325510501862,
0.936416506767273,
-0.8448361158370972
],
[
0.3065527677536011,
1.1182676553726196,
-0.5591280460357666
],
[
0.17920757830142975,
0.6927539706230164,
-0.7849017381668091
],
[
0.25156882405281067,
1.1860939264297485,
-0.5887487530708313
],
[
0.1682051420211792,
0.6014264822006226,
-0.7367973327636719
]
]

View File

@ -0,0 +1,122 @@
[
[
0.3723624646663666,
0.6807540655136108,
-0.7536930441856384
],
[
0.4606369733810425,
0.6819767355918884,
-0.7026320695877075
],
[
0.2840879261493683,
0.6795313954353333,
-0.8047540187835693
],
[
0.37141507863998413,
0.8276264667510986,
-0.6875424385070801
],
[
0.479561448097229,
0.2904614806175232,
-0.6911981701850891
],
[
0.29553577303886414,
0.2926357686519623,
-0.832754373550415
],
[
0.3704676926136017,
0.9744989275932312,
-0.621391773223877
],
[
0.5182272791862488,
-0.09084545075893402,
-0.7154797911643982
],
[
0.31344518065452576,
-0.08931314945220947,
-0.9003074169158936
],
[
0.3704676926136017,
1.0244989395141602,
-0.621391773223877
],
[
0.5207709074020386,
-0.1377965807914734,
-0.6441547870635986
],
[
0.2722875475883484,
-0.13784976303577423,
-0.8576371669769287
],
[
0.3782576322555542,
1.214632511138916,
-0.7064842581748962
],
[
0.4484415650367737,
1.1865851879119873,
-0.6695226430892944
],
[
0.3080736994743347,
1.1826797723770142,
-0.743445873260498
],
[
0.3685729205608368,
1.2682437896728516,
-0.48909056186676025
],
[
0.5186254978179932,
1.1985379457473755,
-0.6325609683990479
],
[
0.23788979649543762,
1.1907269954681396,
-0.7804075479507446
],
[
0.4849953353404999,
0.969235897064209,
-0.5583885312080383
],
[
0.20581978559494019,
0.9366675019264221,
-0.8437999486923218
],
[
0.3075758218765259,
1.1038239002227783,
-0.5611312389373779
],
[
0.17861124873161316,
0.6925873756408691,
-0.7852163314819336
],
[
0.24666009843349457,
1.1842905282974243,
-0.5826389193534851
],
[
0.16754546761512756,
0.6011085510253906,
-0.7378559708595276
]
]

View File

@ -1,4 +1,6 @@
import torch import torch
import random
import numpy as np
from smplpytorch.pytorch.smpl_layer import SMPL_Layer from smplpytorch.pytorch.smpl_layer import SMPL_Layer
from display_utils import display_model from display_utils import display_model
@ -12,10 +14,10 @@ if __name__ == '__main__':
smpl_layer = SMPL_Layer( smpl_layer = SMPL_Layer(
center_idx=0, center_idx=0,
gender='male', gender='male',
model_root='smplpytorch/native/models') model_root='/home/lmd/Code/Pose_to_SMPL_an_230402/smplpytorch/native/models')
# Generate random pose and shape parameters # Generate random pose and shape parameters
pose_params = torch.rand(batch_size, 72) * 0.2 pose_params = (torch.rand(batch_size, 72)*2-1) * np.pi
shape_params = torch.rand(batch_size, 10) * 0.03 shape_params = torch.rand(batch_size, 10) * 0.03
# GPU mode # GPU mode
@ -26,7 +28,6 @@ if __name__ == '__main__':
# Forward from the SMPL layer # Forward from the SMPL layer
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params) verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
print(Jtr)
# Draw output vertices and joints # Draw output vertices and joints
display_model( display_model(
@ -35,5 +36,5 @@ if __name__ == '__main__':
model_faces=smpl_layer.th_faces, model_faces=smpl_layer.th_faces,
with_joints=True, with_joints=True,
kintree_table=smpl_layer.kintree_table, kintree_table=smpl_layer.kintree_table,
savepath='image.png', savepath='image1.png',
show=True) show=True)

View File

@ -1,12 +1,15 @@
from xml.parsers.expat import model
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection from mpl_toolkits.mplot3d.art3d import Poly3DCollection
# plt.switch_backend('agg') # plt.switch_backend('agg')
from torch import Tensor
from typing import Optional
def display_model( def display_model(
model_info, model_info,
model_faces=None, model_faces:Optional[Tensor]=None,
with_joints=False, with_joints=False,
kintree_table=None, kintree_table=None,
ax=None, ax=None,
@ -21,11 +24,12 @@ def display_model(
if ax is None: if ax is None:
fig = plt.figure() fig = plt.figure()
ax = fig.add_subplot(111, projection='3d') ax = fig.add_subplot(111, projection='3d')
verts, joints = model_info['verts'][batch_idx], model_info['joints'][batch_idx] verts = model_info['verts'][batch_idx]
joints = model_info['joints'][batch_idx]
if model_faces is None: if model_faces is None:
ax.scatter(verts[:, 0], verts[:, 1], verts[:, 2], alpha=0.2) ax.scatter(verts[:, 0], verts[:, 1], verts[:, 2], alpha=0.2)
elif not only_joint: elif not only_joint:
mesh = Poly3DCollection(verts[model_faces], alpha=0.2) mesh = Poly3DCollection(verts[model_faces.cpu()], alpha=0.2)
face_color = (141 / 255, 184 / 255, 226 / 255) face_color = (141 / 255, 184 / 255, 226 / 255)
edge_color = (50 / 255, 50 / 255, 50 / 255) edge_color = (50 / 255, 50 / 255, 50 / 255)
mesh.set_edgecolor(edge_color) mesh.set_edgecolor(edge_color)
@ -44,8 +48,8 @@ def display_model(
if savepath: if savepath:
# print('Saving figure at {}.'.format(savepath)) # print('Saving figure at {}.'.format(savepath))
plt.savefig(savepath, bbox_inches='tight', pad_inches=0) plt.savefig(savepath, bbox_inches='tight', pad_inches=0)
if show: # if show:
plt.show() # plt.show()
plt.close() plt.close()
return ax return ax

View File

@ -12,7 +12,7 @@
"USE_GPU": 1, "USE_GPU": 1,
"DATASET": { "DATASET": {
"NAME": "UTD-MHAD", "NAME": "UTD-MHAD",
"PATH": "../Action2Motion/CMU Mocap/mocap/mocap_3djoints/", "PATH": "/home/lmd/Code/Pose_to_SMPL_an_230402/database",
"TARGET_PATH": "", "TARGET_PATH": "",
"DATA_MAP": [ "DATA_MAP": [
[ [

24
fit/configs/HAA4D.json Normal file
View File

@ -0,0 +1,24 @@
{
"MODEL": {
"GENDER": "neutral"
},
"TRAIN": {
"LEARNING_RATE": 1e-2,
"MAX_EPOCH": 1000,
"WRITE": 10,
"OPTIMIZE_SCALE":1,
"OPTIMIZE_SHAPE":1
},
"USE_GPU": 1,
"DATASET": {
"NAME": "NTU",
"PATH": "../NTU RGB+D/result",
"TARGET_PATH": "",
"DATA_MAP": [
[0,0],[1,4],[2,1],[4,5],[5,2],[7,6],[8,3],
[12,9],[18,12],[19,15],[20,13],[21,16],
[15,10],[6,1]
]
},
"DEBUG": 0
}

93
fit/configs/NTU.json Normal file
View File

@ -0,0 +1,93 @@
{
"MODEL": {
"GENDER": "neutral"
},
"TRAIN": {
"LEARNING_RATE": 5e-2,
"MAX_EPOCH": 1000,
"WRITE": 10,
"OPTIMIZE_SCALE":0,
"OPTIMIZE_SHAPE":0
},
"USE_GPU": 1,
"DATASET": {
"NAME": "NTU",
"PATH": "../NTU RGB+D/result",
"TARGET_PATH": "",
"DATA_MAP": [
[
0,
0
],
[
2,
12
],
[
1,
16
],
[
5,
13
],
[
4,
17
],
[
6,
1
],
[
8,
14
],
[
7,
18
],
[
9,
20
],
[
12,
2
],
[
14,
4
],
[
13,
8
],
[
19,
5
],
[
18,
9
],
[
21,
6
],
[
20,
10
],
[
23,
22
],
[
22,
24
]
]
},
"DEBUG": 0
}

View File

@ -12,72 +12,100 @@
"USE_GPU": 1, "USE_GPU": 1,
"DATASET": { "DATASET": {
"NAME": "UTD-MHAD", "NAME": "UTD-MHAD",
"PATH": "../UTD-MHAD/Skeleton/Skeleton/", "PATH": "/home/lmd/Code/Pose_to_SMPL_an_230402/database",
"TARGET_PATH": "", "TARGET_PATH": "",
"DATA_MAP": [ "DATA_MAP": [
[ [
12, 1,
1 1
],
[
2,
2
], ],
[ [
0, 3,
3 3
], ],
[ [
16, 4,
4 4
], ],
[ [
18, 5,
5 5
], ],
[ [
20, 6,
6 6
], ],
[ [
22, 7,
7 7
], ],
[ [
17, 8,
8 8
], ],
[ [
19, 9,
9 9
], ],
[ [
21, 10,
10 10
], ],
[ [
23, 11,
11 11
], ],
[ [
1, 12,
12 12
], ],
[ [
4, 13,
13 13
], ],
[ [
7, 14,
14 14
],
[
15,
15
], ],
[ [
2, 16,
16 16
], ],
[ [
5, 17,
17 17
], ],
[ [
8, 18,
18 18
],
[
19,
19
],
[
20,
20
],
[
21,
21
],
[
22,
22
],
[
23,
23
] ]
] ]
}, },

View File

@ -4,6 +4,7 @@ import os
import json import json
import argparse import argparse
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Detect cross joints') parser = argparse.ArgumentParser(description='Detect cross joints')
parser.add_argument('--dataset_name', dest='dataset_name', parser.add_argument('--dataset_name', dest='dataset_name',
@ -15,10 +16,12 @@ def parse_args():
args = parser.parse_args() args = parser.parse_args()
return args return args
def create_dir_not_exist(path): def create_dir_not_exist(path):
if not os.path.exists(path): if not os.path.exists(path):
os.mkdir(path) os.mkdir(path)
def load_Jtr(file_path): def load_Jtr(file_path):
with open(file_path, 'rb') as f: with open(file_path, 'rb') as f:
data = pickle.load(f) data = pickle.load(f)
@ -40,15 +43,18 @@ def cross_frames(Jtr: np.ndarray):
def cross_detector(dir_path): def cross_detector(dir_path):
ans={} ans = {}
for root, dirs, files in os.walk(dir_path): for root, dirs, files in os.walk(dir_path):
for file in files: for file in files:
file_path = os.path.join(dir_path, file) file_path = os.path.join(dir_path, file)
Jtr = load_Jtr(file_path) Jtr = load_Jtr(file_path)
ans[file]=cross_frames(Jtr) ans[file] = cross_frames(Jtr)
return ans return ans
if __name__ == "__main__": if __name__ == "__main__":
args=parse_args() args = parse_args()
d=cross_detector(args.output_path) d = cross_detector(args.output_path)
json.dump(d,open("./fit/output/cross_detection/{}.json".format(args.dataset_name),'w')) json.dump(
d, open("./fit/output/cross_detection/{}.json"
.format(args.dataset_name), 'w'))

View File

@ -1156,17 +1156,140 @@ CMU_Mocap = {
"41_09": "Climb" "41_09": "Climb"
} }
NTU = {
"1":"drink water",
"2":"eat meal/snack",
"3":"brushing teeth",
"4":"brushing hair",
"5":"drop",
"6":"pickup",
"7":"throw",
"8":"sitting down",
"9":"standing up (from sitting position)",
"10":"clapping",
"11":"reading",
"12":"writing",
"13":"tear up paper",
"14":"wear jacket",
"15":"take off jacket",
"16":"wear a shoe",
"17":"take off a shoe",
"18":"wear on glasses",
"19":"take off glasses",
"20":"put on a hat/cap",
"21":"take off a hat/cap",
"22":"cheer up",
"23":"hand waving",
"24":"kicking something",
"25":"reach into pocket",
"26":"hopping (one foot jumping)",
"27":"jump up",
"28":"make a phone call/answer phone",
"29":"playing with phone/tablet",
"30":"typing on a keyboard",
"31":"pointing to something with finger",
"32":"taking a selfie",
"33":"check time (from watch)",
"34":"rub two hands together",
"35":"nod head/bow",
"36":"shake head",
"37":"wipe face",
"38":"salute",
"39":"put the palms together",
"40":"cross hands in front (say stop)",
"41":"sneeze/cough",
"42":"staggering",
"43":"falling",
"44":"touch head (headache)",
"45":"touch chest (stomachache/heart pain)",
"46":"touch back (backache)",
"47":"touch neck (neckache)",
"48":"nausea or vomiting condition",
"49":"use a fan (with hand or paper)/feeling warm",
"50":"punching/slapping other person",
"51":"kicking other person",
"52":"pushing other person",
"53":"pat on back of other person",
"54":"point finger at the other person",
"55":"hugging other person",
"56":"giving something to other person",
"57":"touch other person's pocket",
"58":"handshaking",
"59":"walking towards each other",
"60":"walking apart from each other",
"61":"put on headphone",
"62":"take off headphone",
"63":"shoot at the basket",
"64":"bounce ball",
"65":"tennis bat swing",
"66":"juggling table tennis balls",
"67":"hush (quite)",
"68":"flick hair",
"69":"thumb up",
"70":"thumb down",
"71":"make ok sign",
"72":"make victory sign",
"73":"staple book",
"74":"counting money",
"75":"cutting nails",
"76":"cutting paper (using scissors)",
"77":"snapping fingers",
"78":"open bottle",
"79":"sniff (smell)",
"80":"squat down",
"81":"toss a coin",
"82":"fold paper",
"83":"ball up paper",
"84":"play magic cube",
"85":"apply cream on face",
"86":"apply cream on hand back",
"87":"put on bag",
"88":"take off bag",
"89":"put something into a bag",
"90":"take something out of a bag",
"91":"open a box",
"92":"move heavy objects",
"93":"shake fist",
"94":"throw up cap/hat",
"95":"hands up (both hands)",
"96":"cross arms",
"97":"arm circles",
"98":"arm swings",
"99":"running on the spot",
"100":"butt kicks (kick backward)",
"101":"cross toe touch",
"102":"side kick",
"103":"yawn",
"104":"stretch oneself",
"105":"blow nose",
"106":"hit other person with something",
"107":"wield knife towards other person",
"108":"knock over other person (hit with body)",
"109":"grab other persons stuff",
"110":"shoot at other person with a gun",
"111":"step on foot",
"112":"high-five",
"113":"cheers and drink",
"114":"carry something with other person",
"115":"take a photo of other person",
"116":"follow other person",
"117":"whisper in other persons ear",
"118":"exchange things with other person",
"119":"support somebody with hand",
"120":"finger-guessing game (playing rock-paper-scissors)",
}
def get_label(file_name, dataset_name): def get_label(file_name, dataset_name):
if dataset_name == 'HumanAct12': if dataset_name == 'HumanAct12':
key = file_name[-5:] key = file_name[-5:]
return HumanAct12[key] return HumanAct12[key]
elif dataset_name == 'UTD_MHAD': elif dataset_name == 'UTD_MHAD':
key = file_name.split('_')[0][1:] ###key = file_name.split('_')[0][1:]
return UTD_MHAD[key] ###return UTD_MHAD[key]
return "single_person"
elif dataset_name == 'CMU_Mocap': elif dataset_name == 'CMU_Mocap':
key = file_name.split('.')[0] key = file_name.split(':')[0]
if key in CMU_Mocap.keys(): return CMU_Mocap[key] if key in CMU_Mocap.keys() else ""
return CMU_Mocap[key] elif dataset_name == 'NTU':
else: key = str(int(file_name[-3:]))
return "" return NTU[key]

View File

@ -1,19 +1,46 @@
import scipy.io import scipy.io
import numpy as np import numpy as np
import json # 引入json模块
def load(name, path): def load(name, path):
# 处理UTD-MHAD的JSON文件你的单帧数据
if name == 'UTD_MHAD': if name == 'UTD_MHAD':
data = scipy.io.loadmat(path) # 判断文件是否为JSON格式通过后缀
arr = data['d_skel'] if path.endswith('.json'):
new_arr = np.zeros([arr.shape[2], arr.shape[0], arr.shape[1]]) with open(path, 'r') as f:
for i in range(arr.shape[2]): data = json.load(f) # 加载JSON列表
for j in range(arr.shape[0]): # 转换为NumPy数组确保形状为[关节数, 3]
for k in range(arr.shape[1]): data_np = np.array(data)
new_arr[i][j][k] = arr[j][k][i] # 校验数据格式(防止错误)
return new_arr assert data_np.ndim == 2 and data_np.shape[1] == 3, \
elif name == 'HumanAct12': f"UTD-MHAD JSON格式错误应为[关节数, 3],实际为{data_np.shape}"
return np.load(path,allow_pickle=True) # 若需要单帧维度([1, 关节数, 3]),可扩展维度
elif name == "CMU_Mocap": return data_np[np.newaxis, ...] # 输出形状:[1, N, 3]1表示单帧
return np.load(path,allow_pickle=True)
# 保留原UTD_MHAD的.mat文件支持如果还需要处理.mat数据
elif path.endswith('.mat'):
arr = scipy.io.loadmat(path)['d_skel']
new_arr = np.zeros([arr.shape[2], arr.shape[0], arr.shape[1]])
for i in range(arr.shape[2]):
for j in range(arr.shape[0]):
for k in range(arr.shape[1]):
new_arr[i][j][k] = arr[j][k][i]
return new_arr
else:
raise ValueError(f"UTD-MHAD不支持的文件格式{path}")
# 其他数据集的原有逻辑保持不变
elif name == 'HumanAct12':
return np.load(path, allow_pickle=True)
elif name == "CMU_Mocap":
return np.load(path, allow_pickle=True)
elif name == "Human3.6M":
return np.load(path, allow_pickle=True)[0::5] # 下采样
elif name == "NTU":
return np.load(path, allow_pickle=True)[0::2]
elif name == "HAA4D":
return np.load(path, allow_pickle=True)
else:
raise ValueError(f"不支持的数据集名称:{name}")

View File

@ -1,49 +1,104 @@
import torch """
import numpy as np SMPL模型拟合主程序
from tensorboardX import SummaryWriter 该脚本用于将人体姿态数据拟合到SMPLSkinned Multi-Person Linear模型中
from easydict import EasyDict as edict 主要功能:
import time 1. 加载人体姿态数据
import sys 2. 使用SMPL模型进行拟合优化
import os 3. 保存拟合结果和可视化图像
import logging """
import argparse import os
import json import sys
sys.path.append(os.getcwd()) sys.path.append(os.getcwd())
from load import load
from save import save_pic, save_params # 导入自定义模块
from transform import transform from meters import Meters # 用于跟踪训练指标的工具类
from train import train from smplpytorch.pytorch.smpl_layer import SMPL_Layer # SMPL模型层
from smplpytorch.pytorch.smpl_layer import SMPL_Layer from train import train # 训练函数
from transform import transform # 数据变换函数
from save import save_pic, save_params # 保存结果的函数
from load import load # 数据加载函数
# 导入标准库
import torch # PyTorch深度学习框架
import numpy as np # 数值计算库
from easydict import EasyDict as edict # 用于创建字典对象的便捷工具
import time # 时间处理
import logging # 日志记录
import argparse # 命令行参数解析
import json # JSON文件处理
# 启用CUDNN加速提高训练效率
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
def parse_args(): def parse_args():
"""
解析命令行参数
Returns:
args: 包含解析后参数的命名空间对象
"""
parser = argparse.ArgumentParser(description='Fit SMPL') parser = argparse.ArgumentParser(description='Fit SMPL')
# 实验名称,默认使用当前时间戳
parser.add_argument('--exp', dest='exp', parser.add_argument('--exp', dest='exp',
help='Define exp name', help='Define exp name',
default=time.strftime('%Y-%m-%d %H-%M-%S', time.localtime(time.time())), type=str) default=time.strftime('%Y-%m-%d %H-%M-%S', time.localtime(time.time())), type=str)
parser.add_argument('--dataset_name', dest='dataset_name',
# 数据集名称,用于选择对应的配置文件
parser.add_argument('--dataset_name', '-n', dest='dataset_name',
help='select dataset', help='select dataset',
default='', type=str) default='', type=str)
# 数据集路径,可以覆盖配置文件中的默认路径
parser.add_argument('--dataset_path', dest='dataset_path', parser.add_argument('--dataset_path', dest='dataset_path',
help='path of dataset', help='path of dataset',
default=None, type=str) default=None, type=str)
args = parser.parse_args() args = parser.parse_args()
return args return args
def get_config(args): def get_config(args):
"""
根据数据集名称加载对应的配置文件
Args:
args: 命令行参数对象
Returns:
cfg: 配置对象,包含所有训练和模型参数
"""
# 根据数据集名称构建配置文件路径
config_path = 'fit/configs/{}.json'.format(args.dataset_name) config_path = 'fit/configs/{}.json'.format(args.dataset_name)
# 读取JSON配置文件
with open(config_path, 'r') as f: with open(config_path, 'r') as f:
data = json.load(f) data = json.load(f)
# 将字典转换为edict对象支持点号访问属性
cfg = edict(data.copy()) cfg = edict(data.copy())
# 如果命令行指定了数据集路径,则覆盖配置文件中的设置
if not args.dataset_path == None: if not args.dataset_path == None:
cfg.DATASET.PATH = args.dataset_path cfg.DATASET.PATH = args.dataset_path
return cfg return cfg
def set_device(USE_GPU): def set_device(USE_GPU):
"""
根据配置和硬件可用性设置计算设备
Args:
USE_GPU: 是否使用GPU的布尔值
Returns:
device: PyTorch设备对象'cuda''cpu'
"""
if USE_GPU and torch.cuda.is_available(): if USE_GPU and torch.cuda.is_available():
device = torch.device('cuda') device = torch.device('cuda')
else: else:
@ -52,9 +107,21 @@ def set_device(USE_GPU):
def get_logger(cur_path): def get_logger(cur_path):
"""
设置日志记录器,同时输出到文件和控制台
Args:
cur_path: 当前实验路径,用于保存日志文件
Returns:
logger: 日志记录器对象
writer: TensorBoard写入器当前设置为None
"""
# 创建日志记录器
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO) logger.setLevel(level=logging.INFO)
# 设置文件输出处理器将日志保存到log.txt文件
handler = logging.FileHandler(os.path.join(cur_path, "log.txt")) handler = logging.FileHandler(os.path.join(cur_path, "log.txt"))
handler.setLevel(logging.INFO) handler.setLevel(logging.INFO)
formatter = logging.Formatter( formatter = logging.Formatter(
@ -62,6 +129,7 @@ def get_logger(cur_path):
handler.setFormatter(formatter) handler.setFormatter(formatter)
logger.addHandler(handler) logger.addHandler(handler)
# 设置控制台输出处理器,将日志同时输出到终端
handler = logging.StreamHandler() handler = logging.StreamHandler()
handler.setLevel(logging.INFO) handler.setLevel(logging.INFO)
formatter = logging.Formatter( formatter = logging.Formatter(
@ -69,43 +137,106 @@ def get_logger(cur_path):
handler.setFormatter(formatter) handler.setFormatter(formatter)
logger.addHandler(handler) logger.addHandler(handler)
writer = SummaryWriter(os.path.join(cur_path, 'tb')) # TensorBoard写入器目前被注释掉设置为None
# from tensorboardX import SummaryWriter
# writer = SummaryWriter(os.path.join(cur_path, 'tb'))
writer = None
return logger, writer return logger, writer
if __name__ == "__main__": if __name__ == "__main__":
"""
主函数执行SMPL模型拟合流程
主要步骤:
1. 解析命令行参数
2. 创建实验目录
3. 加载配置文件
4. 设置日志记录
5. 初始化SMPL模型
6. 遍历数据集进行拟合
7. 保存结果
"""
# 解析命令行参数
args = parse_args() args = parse_args()
# 创建实验目录,使用时间戳或用户指定的实验名称
cur_path = os.path.join(os.getcwd(), 'exp', args.exp) cur_path = os.path.join(os.getcwd(), 'exp', args.exp)
assert not os.path.exists(cur_path), 'Duplicate exp name' assert not os.path.exists(cur_path), 'Duplicate exp name' # 确保实验名称不重复
os.mkdir(cur_path) os.mkdir(cur_path)
# 加载配置文件
cfg = get_config(args) cfg = get_config(args)
# 将配置保存到实验目录中,便于后续追踪实验设置
json.dump(dict(cfg), open(os.path.join(cur_path, 'config.json'), 'w')) json.dump(dict(cfg), open(os.path.join(cur_path, 'config.json'), 'w'))
# 设置日志记录器
logger, writer = get_logger(cur_path) logger, writer = get_logger(cur_path)
logger.info("Start print log") logger.info("Start print log")
# 设置计算设备GPU或CPU
device = set_device(USE_GPU=cfg.USE_GPU) device = set_device(USE_GPU=cfg.USE_GPU)
logger.info('using device: {}'.format(device)) logger.info('using device: {}'.format(device))
# 初始化SMPL模型层
# center_idx=0: 设置中心关节点索引
# gender='male': 设置性别为男性注释掉的cfg.MODEL.GENDER可能用于从配置文件读取
# model_root: SMPL模型文件的路径
smpl_layer = SMPL_Layer( smpl_layer = SMPL_Layer(
center_idx=0, center_idx=0,
gender=cfg.MODEL.GENDER, gender='male', #cfg.MODEL.GENDER,
model_root='smplpytorch/native/models') model_root='smplpytorch/native/models')
# 初始化指标记录器,用于跟踪训练损失等指标
meters = Meters()
file_num = 0 file_num = 0
# 遍历数据集目录中的所有文件
for root, dirs, files in os.walk(cfg.DATASET.PATH): for root, dirs, files in os.walk(cfg.DATASET.PATH):
for file in files: for file in sorted(files): # 按文件名排序处理
file_num += 1 # 可选的文件过滤器(当前被注释掉)
logger.info('Processing file: {} [{} / {}]'.format(file, file_num, len(files))) # 可以用于只处理特定的文件,如包含'baseball_swing'的文件
target = torch.from_numpy(transform(args.dataset_name,load(args.dataset_name, ###if not 'baseball_swing' in file:
os.path.join(root, file)))).float() ###continue
file_num += 1 # 文件计数器
logger.info(
'Processing file: {} [{} / {}]'.format(file, file_num, len(files)))
# 加载并变换目标数据
# 1. load(): 根据数据集类型加载原始数据
# 2. transform(): 将数据转换为模型所需的格式
# 3. torch.from_numpy(): 将numpy数组转换为PyTorch张量
# 4. .float(): 确保数据类型为float32
target = torch.from_numpy(transform(args.dataset_name, load(args.dataset_name,
os.path.join(root, file)))).float()
logger.info("target shape:{}".format(target.shape))
# 执行SMPL模型拟合训练
# 传入SMPL层、目标数据、日志记录器、设备信息等
res = train(smpl_layer, target, res = train(smpl_layer, target,
logger, writer, device, logger, writer, device,
args, cfg) args, cfg, meters)
# save_pic(res,smpl_layer,file,logger,args.dataset_name,target) # 更新平均损失指标
# k=target.shape[0] 表示批次大小,用于加权平均
meters.update_avg(meters.min_loss, k=target.shape[0])
# 重置早停计数器,为下一个文件的训练做准备
meters.reset_early_stop()
# 记录当前的平均损失
logger.info("avg_loss:{:.4f}".format(meters.avg))
# 保存拟合结果
# 1. save_params(): 保存拟合得到的SMPL参数
# 2. save_pic(): 保存可视化图像,包括拟合结果和原始目标的对比
save_params(res, file, logger, args.dataset_name) save_params(res, file, logger, args.dataset_name)
save_pic(res, smpl_layer, file, logger, args.dataset_name, target)
# 清空GPU缓存防止内存溢出
torch.cuda.empty_cache()
# 所有文件处理完成,记录最终的平均损失
logger.info(
"Fitting finished! Average loss: {:.9f}".format(meters.avg))

27
fit/tools/meters.py Normal file
View File

@ -0,0 +1,27 @@
class Meters:
def __init__(self, eps=-1e-3, stop_threshold=10) -> None:
self.eps = eps
self.stop_threshold = stop_threshold
self.avg = 0
self.cnt = 0
self.reset_early_stop()
def reset_early_stop(self):
self.min_loss = float('inf')
self.satis_num = 0
self.update_res = True
self.early_stop = False
def update_avg(self, val, k=1):
self.avg = self.avg + (val - self.avg) * k / (self.cnt + k)
self.cnt += k
def update_early_stop(self, val):
delta = (val - self.min_loss) / self.min_loss
if float(val) < self.min_loss:
self.min_loss = float(val)
self.update_res = True
else:
self.update_res = False
self.satis_num = self.satis_num + 1 if delta >= self.eps else 0
self.early_stop = self.satis_num >= self.stop_threshold

View File

@ -1,3 +1,5 @@
from display_utils import display_model
from label import get_label
import sys import sys
import os import os
import re import re
@ -6,8 +8,6 @@ import numpy as np
import pickle import pickle
sys.path.append(os.getcwd()) sys.path.append(os.getcwd())
from label import get_label
from display_utils import display_model
def create_dir_not_exist(path): def create_dir_not_exist(path):
@ -18,10 +18,8 @@ def create_dir_not_exist(path):
def save_pic(res, smpl_layer, file, logger, dataset_name, target): def save_pic(res, smpl_layer, file, logger, dataset_name, target):
_, _, verts, Jtr = res _, _, verts, Jtr = res
file_name = re.split('[/.]', file)[-2] file_name = re.split('[/.]', file)[-2]
fit_path = "fit/output/{}/picture/fit/{}".format(dataset_name, file_name) fit_path = "fit/output/{}/picture/{}".format(dataset_name, file_name)
gt_path = "fit/output/{}/picture/gt/{}".format(dataset_name, file_name) os.makedirs(fit_path,exist_ok=True)
create_dir_not_exist(fit_path)
create_dir_not_exist(gt_path)
logger.info('Saving pictures at {}'.format(fit_path)) logger.info('Saving pictures at {}'.format(fit_path))
for i in tqdm(range(Jtr.shape[0])): for i in tqdm(range(Jtr.shape[0])):
display_model( display_model(
@ -30,20 +28,10 @@ def save_pic(res, smpl_layer, file, logger, dataset_name, target):
model_faces=smpl_layer.th_faces, model_faces=smpl_layer.th_faces,
with_joints=True, with_joints=True,
kintree_table=smpl_layer.kintree_table, kintree_table=smpl_layer.kintree_table,
savepath=os.path.join(fit_path+"/frame_{}".format(i)), savepath=os.path.join(fit_path+"/frame_{:0>4d}".format(i)),
batch_idx=i, batch_idx=i,
show=False, show=False,
only_joint=True) only_joint=True)
# display_model(
# {'verts': verts.cpu().detach(),
# 'joints': target.cpu().detach()},
# model_faces=smpl_layer.th_faces,
# with_joints=True,
# kintree_table=smpl_layer.kintree_table,
# savepath=os.path.join(gt_path+"/frame_{}".format(i)),
# batch_idx=i,
# show=False,
# only_joint=True)
logger.info('Pictures saved') logger.info('Pictures saved')
@ -63,6 +51,25 @@ def save_params(res, file, logger, dataset_name):
params["pose_params"] = pose_params params["pose_params"] = pose_params
params["shape_params"] = shape_params params["shape_params"] = shape_params
params["Jtr"] = Jtr params["Jtr"] = Jtr
print("label:{}".format(label))
with open(os.path.join((fit_path), with open(os.path.join((fit_path),
"{}_params.pkl".format(file_name)), 'wb') as f: "{}_params.pkl".format(file_name)), 'wb') as f:
pickle.dump(params, f) pickle.dump(params, f)
def save_single_pic(res, smpl_layer, epoch, logger, dataset_name, target):
_, _, verts, Jtr = res
fit_path = "fit/output/{}/picture".format(dataset_name)
create_dir_not_exist(fit_path)
logger.info('Saving pictures at {}'.format(fit_path))
display_model(
{'verts': verts.cpu().detach(),
'joints': Jtr.cpu().detach()},
model_faces=smpl_layer.th_faces,
with_joints=True,
kintree_table=smpl_layer.kintree_table,
savepath=fit_path+"/epoch_{:0>4d}".format(epoch),
batch_idx=60,
show=False,
only_joint=False)
logger.info('Picture saved')

View File

@ -6,33 +6,14 @@ import os
from tqdm import tqdm from tqdm import tqdm
sys.path.append(os.getcwd()) sys.path.append(os.getcwd())
from save import save_single_pic
class Early_Stop:
def __init__(self, eps=-1e-3, stop_threshold=10) -> None:
self.min_loss = float('inf')
self.eps = eps
self.stop_threshold = stop_threshold
self.satis_num = 0
def update(self, loss):
delta = (loss - self.min_loss) / self.min_loss
if float(loss) < self.min_loss:
self.min_loss = float(loss)
update_res = True
else:
update_res = False
if delta >= self.eps:
self.satis_num += 1
else:
self.satis_num = 0
return update_res, self.satis_num >= self.stop_threshold
def init(smpl_layer, target, device, cfg): def init(smpl_layer, target, device, cfg):
params = {} params = {}
params["pose_params"] = torch.rand(target.shape[0], 72) * 0.0 params["pose_params"] = torch.zeros(target.shape[0], 72)
params["shape_params"] = torch.rand(target.shape[0], 10) * 0.03 params["shape_params"] = torch.zeros(target.shape[0], 10)
params["scale"] = torch.ones([1]) params["scale"] = torch.ones([1])
smpl_layer = smpl_layer.to(device) smpl_layer = smpl_layer.to(device)
@ -45,8 +26,10 @@ def init(smpl_layer, target, device, cfg):
params["shape_params"].requires_grad = bool(cfg.TRAIN.OPTIMIZE_SHAPE) params["shape_params"].requires_grad = bool(cfg.TRAIN.OPTIMIZE_SHAPE)
params["scale"].requires_grad = bool(cfg.TRAIN.OPTIMIZE_SCALE) params["scale"].requires_grad = bool(cfg.TRAIN.OPTIMIZE_SCALE)
optimizer = optim.Adam([params["pose_params"], params["shape_params"], params["scale"]], optim_params = [{'params': params["pose_params"], 'lr': cfg.TRAIN.LEARNING_RATE},
lr=cfg.TRAIN.LEARNING_RATE) {'params': params["shape_params"], 'lr': cfg.TRAIN.LEARNING_RATE},
{'params': params["scale"], 'lr': cfg.TRAIN.LEARNING_RATE*10},]
optimizer = optim.Adam(optim_params)
index = {} index = {}
smpl_index = [] smpl_index = []
@ -63,7 +46,7 @@ def init(smpl_layer, target, device, cfg):
def train(smpl_layer, target, def train(smpl_layer, target,
logger, writer, device, logger, writer, device,
args, cfg): args, cfg, meters):
res = [] res = []
smpl_layer, params, target, optimizer, index = \ smpl_layer, params, target, optimizer, index = \
init(smpl_layer, target, device, cfg) init(smpl_layer, target, device, cfg)
@ -71,30 +54,35 @@ def train(smpl_layer, target,
shape_params = params["shape_params"] shape_params = params["shape_params"]
scale = params["scale"] scale = params["scale"]
early_stop = Early_Stop() with torch.no_grad():
for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)):
# for epoch in range(cfg.TRAIN.MAX_EPOCH):
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params) verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
loss = F.smooth_l1_loss(Jtr.index_select(1, index["smpl_index"]) * 100 * scale, params["scale"]*=(torch.max(torch.abs(target))/torch.max(torch.abs(Jtr)))
target.index_select(1, index["dataset_index"]) * 100)
for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)):
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
loss = F.smooth_l1_loss(scale*Jtr.index_select(1, index["smpl_index"]),
target.index_select(1, index["dataset_index"]))
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
update_res, stop = early_stop.update(float(loss)) meters.update_early_stop(float(loss))
if update_res: if meters.update_res:
res = [pose_params, shape_params, verts, Jtr] res = [pose_params, shape_params, verts, Jtr]
if stop: if meters.early_stop:
logger.info("Early stop at epoch {} !".format(epoch)) logger.info("Early stop at epoch {} !".format(epoch))
break break
if epoch % cfg.TRAIN.WRITE == 0: if epoch % cfg.TRAIN.WRITE == 0 or epoch<10:
# logger.info("Epoch {}, lossPerBatch={:.6f}, scale={:.4f} EarlyStopSatis: {}".format( # logger.info("Epoch {}, lossPerBatch={:.6f}, scale={:.4f}".format(
# epoch, float(loss),float(scale), early_stop.satis_num)) # epoch, float(loss),float(scale)))
writer.add_scalar('loss', float(loss), epoch) print("Epoch {}, lossPerBatch={:.6f}, scale={:.4f}".format(
writer.add_scalar('learning_rate', float( epoch, float(loss),float(scale)))
optimizer.state_dict()['param_groups'][0]['lr']), epoch) ###writer.add_scalar('loss', float(loss), epoch)
###writer.add_scalar('learning_rate', float(
###optimizer.state_dict()['param_groups'][0]['lr']), epoch)
# save_single_pic(res,smpl_layer,epoch,logger,args.dataset_name,target)
logger.info('Train ended, min_loss = {:.9f}'.format( logger.info('Train ended, min_loss = {:.4f}'.format(
float(early_stop.min_loss))) float(meters.min_loss)))
return res return res

View File

@ -3,7 +3,10 @@ import numpy as np
rotate = { rotate = {
'HumanAct12': [1., -1., -1.], 'HumanAct12': [1., -1., -1.],
'CMU_Mocap': [0.05, 0.05, 0.05], 'CMU_Mocap': [0.05, 0.05, 0.05],
'UTD_MHAD': [-1., 1., -1.] 'UTD_MHAD': [1., 1., 1.],
'Human3.6M': [-0.001, -0.001, 0.001],
'NTU': [1., 1., -1.],
'HAA4D': [1., -1., -1.],
} }

View File

@ -1,7 +1,7 @@
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import imageio, os import imageio, os
images = [] images = []
filenames = sorted(fn for fn in os.listdir('./fit/output/CMU_Mocap/picture/fit/01_01') ) filenames = sorted(fn for fn in os.listdir('D:/OneDrive - sjtu.edu.cn/MVIG/Action-Dataset/Pose_to_SMPL/fit/output/NTU/picture') )
for filename in filenames: for filename in filenames:
images.append(imageio.imread('./fit/output/CMU_Mocap/picture/fit/01_01/'+filename)) images.append(imageio.imread('D:/OneDrive - sjtu.edu.cn/MVIG/Action-Dataset/Pose_to_SMPL/fit/output/NTU/picture/'+filename))
imageio.mimsave('fit.gif', images, duration=0.2) imageio.mimsave('clapping_example.gif', images, duration=0.2)

64
pyproject.toml Normal file
View File

@ -0,0 +1,64 @@
[project]
name = "template_project_alpha"
version = "0.1.0"
description = "a basic template project that has the ability to complete most of ML task with PyTorch and MMCV"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"anyio>=4.9.0",
"awkward>=2.8.5",
"beartype>=0.21.0",
"click>=8.2.1",
"easydict>=1.13",
"jaxtyping>=0.3.2",
"loguru>=0.7.3",
"mmcv",
"mmdet>=3.3.0",
"mmpose>=1.3.2",
"nats-py>=2.10.0",
"numba>=0.61.2",
"nvidia-nvimgcodec-cu12>=0.5.0.13",
"opencv-python>=4.12.0.88",
"orjson>=3.10.18",
"pyarrow>=20.0.0",
"pydantic>=2.11.7",
"pytorch3d",
"redis>=6.2.0",
"result>=0.17.0",
"tomli>=2.2.1",
"tomli-w>=1.2.0",
"torch>=2.7.0",
"torchvision>=0.22.0",
"ultralytics>=8.3.166",
"xtcocotools",
]
[dependency-groups]
dev = ["jupyterlab>=4.4.4", "pip", "psutil>=7.0.0", "setuptools"]
[tool.uv]
no-build-isolation-package = ["chumpy", "xtcocotools"]
[tool.uv.sources]
torch = [
{ index = "pytorch-cpu", marker = "sys_platform != 'linux'" },
{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" },
]
torchvision = [
{ index = "pytorch-cpu", marker = "sys_platform != 'linux'" },
{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" },
]
# built with cu128
mmcv = { path = "misc/mmcv-2.2.0-cp312-cp312-linux_x86_64.whl" }
xtcocotools = { path = "misc/xtcocotools-1.14.3-cp312-cp312-linux_x86_64.whl" }
pytorch3d = { path = "misc/pytorch3d-0.7.8-cp312-cp312-linux_x86_64.whl" }
[[tool.uv.index]]
name = "pytorch-cpu"
url = "https://download.pytorch.org/whl/cpu"
explicit = true
[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true

57
smpl.py Normal file
View File

@ -0,0 +1,57 @@
import torch
import random
import numpy as np
import pickle # 新增用于加载pkl文件
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
from display_utils import display_model
if __name__ == '__main__':
cuda = True
batch_size = 1
# Create the SMPL layer
smpl_layer = SMPL_Layer(
center_idx=0,
gender='male', # 确保与pkl文件中的性别一致
model_root='/home/lmd/Code/Pose_to_SMPL_an_230402/smplpytorch/native/models')
# 从pkl文件加载参数
pkl_path = '/home/lmd/Code/Pose_to_SMPL_an_230402/fit/output/UTD_MHAD/sigle_people_smpl_params.pkl' # 替换为实际的pkl文件路径
with open(pkl_path, 'rb') as f:
data = pickle.load(f)
# 提取pose和shape参数
pose_params = torch.tensor(data['pose_params']).float() # 确保数据类型为float
shape_params = torch.tensor(data['shape_params']).float()
# 调整维度(如果需要)
if pose_params.dim() == 1:
pose_params = pose_params.unsqueeze(0) # 添加batch维度
if shape_params.dim() == 1:
shape_params = shape_params.unsqueeze(0)
# 验证batch size
if pose_params.shape[0] != batch_size:
batch_size = pose_params.shape[0]
print(f"Warning: Batch size adjusted to {batch_size} based on loaded data.")
# GPU mode
if cuda:
pose_params = pose_params.cuda()
shape_params = shape_params.cuda()
smpl_layer.cuda()
# Forward from the SMPL layer
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
# Draw output vertices and joints
display_model(
{'verts': verts.cpu().detach(),
'joints': Jtr.cpu().detach()},
model_faces=smpl_layer.th_faces,
with_joints=True,
kintree_table=smpl_layer.kintree_table,
savepath='image2.png',
show=True)

View File

@ -27,12 +27,19 @@ class SMPL_Layer(Module):
self.center_idx = center_idx self.center_idx = center_idx
self.gender = gender self.gender = gender
if gender == 'neutral': # if gender == 'neutral':
self.model_path = os.path.join(model_root, 'basicModel_neutral_lbs_10_207_0_v1.0.0.pkl') # self.model_path = os.path.join(model_root, 'basicModel_neutral_lbs_10_207_0_v1.0.0.pkl')
elif gender == 'female': # elif gender == 'female':
# self.model_path = os.path.join(model_root, 'basicModel_f_lbs_10_207_0_v1.0.0.pkl')
# elif gender == 'male':
# self.model_path = os.path.join(model_root, 'basicModel_m_lbs_10_207_0_v1.0.0.pkl')
if gender == 'female':
self.model_path = os.path.join(model_root, 'basicModel_f_lbs_10_207_0_v1.0.0.pkl') self.model_path = os.path.join(model_root, 'basicModel_f_lbs_10_207_0_v1.0.0.pkl')
elif gender == 'male': elif gender == 'male':
self.model_path = os.path.join(model_root, 'basicModel_m_lbs_10_207_0_v1.0.0.pkl') self.model_path = os.path.join(model_root, 'basicmodel_m_lbs_10_207_0_v1.0.0.pkl')
else:
raise ValueError("no valid gender")
smpl_data = ready_arguments(self.model_path) smpl_data = ready_arguments(self.model_path)
self.smpl_data = smpl_data self.smpl_data = smpl_data

3016
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff