Compare commits

...

10 Commits

Author SHA1 Message Date
lmd
65769e5eb6 name 2025-07-25 15:05:31 +08:00
24e1e31234 Merge branch 'main' of github.com:Iridoudou/Pose_to_SMPL 2022-10-01 21:42:18 +08:00
6ac10e22e0 Support HAA4D 2022-10-01 21:42:12 +08:00
03debbe123 Create README.md 2022-09-28 20:57:19 +08:00
fd2ff6616f Support NTU 2021-09-03 15:31:39 +08:00
52603950d6 support NTU 2021-09-03 11:07:42 +08:00
16387a7afe support NTU 2021-09-03 11:06:26 +08:00
0702081dbe update code 2021-08-20 22:48:02 +08:00
bd93f8dcaf down sample 2021-08-19 15:08:28 +08:00
8573a09b84 update meters 2021-08-19 12:40:54 +08:00
25 changed files with 4022 additions and 150 deletions

5
.gitignore vendored
View File

@ -13,4 +13,7 @@ smplpytorch/native/models/*.pkl
exp/
output/
make_gif.py
make_gif.py
*.pkl
*.whl
*.png

1
.python-version Normal file
View File

@ -0,0 +1 @@
3.12

15
.vscode/launch.json vendored Normal file
View File

@ -0,0 +1,15 @@
{
// 使用 IntelliSense 了解相关属性。
// 悬停以查看现有属性的描述。
// 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python 调试程序: 当前文件",
"type": "debugpy",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal"
}
]
}

View File

@ -1,4 +1,4 @@
pose2smpl
Pose_to_SMPL
=======
### Fitting SMPL Parameters by 3D-pose Key-points
@ -40,11 +40,14 @@ The SMPL human body layer for Pytorch is from the [smplpytorch](https://github.c
- Download the datasets you want to fit
currently supported datasets:
currently support:
- [HumanAct12](https://ericguo5513.github.io/action-to-motion/)
- [CMU Mocap](https://ericguo5513.github.io/action-to-motion/)
- [UTD-MHAD](https://personal.utdallas.edu/~kehtar/UTD-MHAD.html)
- [Human3.6M](http://vision.imar.ro/human3.6m/description.php)
- [NTU](https://rose1.ntu.edu.sg/dataset/actionRecognition/)
- [HAA4D](https://cse.hkust.edu.hk/haa4d/dataset.html)
- Set the **DATASET.PATH** in the corresponding configuration file to the location of dataset.

View File

@ -0,0 +1,122 @@
[
[
0.3718571662902832,
0.6803165078163147,
-0.7502949833869934
],
[
0.45830726623535156,
0.6820494532585144,
-0.6965712904930115
],
[
0.28540709614753723,
0.6785836219787598,
-0.8040187358856201
],
[
0.3710346221923828,
0.8273340463638306,
-0.6851714849472046
],
[
0.4792303442955017,
0.2899134159088135,
-0.6879411935806274
],
[
0.2960263788700104,
0.29219183325767517,
-0.8322873711585999
],
[
0.3702120780944824,
0.9743515253067017,
-0.6200480461120605
],
[
0.5172274112701416,
-0.09073961526155472,
-0.7129285335540771
],
[
0.3149631917476654,
-0.08920176327228546,
-0.9007936120033264
],
[
0.3702120780944824,
1.0243514776229858,
-0.6200480461120605
],
[
0.5201388001441956,
-0.13739190995693207,
-0.6414260864257812
],
[
0.2730123996734619,
-0.13770559430122375,
-0.8578707575798035
],
[
0.3782634735107422,
1.2141607999801636,
-0.7061752080917358
],
[
0.4474300742149353,
1.1862998008728027,
-0.6677815914154053
],
[
0.3090968728065491,
1.1820218563079834,
-0.7445688247680664
],
[
0.36856698989868164,
1.2683864831924438,
-0.4898010492324829
],
[
0.5165966749191284,
1.1984386444091797,
-0.6293879151344299
],
[
0.23993025720119476,
1.1898829936981201,
-0.7829625010490417
],
[
0.478150337934494,
0.971466600894928,
-0.5572866201400757
],
[
0.20990325510501862,
0.936416506767273,
-0.8448361158370972
],
[
0.3065527677536011,
1.1182676553726196,
-0.5591280460357666
],
[
0.17920757830142975,
0.6927539706230164,
-0.7849017381668091
],
[
0.25156882405281067,
1.1860939264297485,
-0.5887487530708313
],
[
0.1682051420211792,
0.6014264822006226,
-0.7367973327636719
]
]

View File

@ -0,0 +1,122 @@
[
[
0.3723624646663666,
0.6807540655136108,
-0.7536930441856384
],
[
0.4606369733810425,
0.6819767355918884,
-0.7026320695877075
],
[
0.2840879261493683,
0.6795313954353333,
-0.8047540187835693
],
[
0.37141507863998413,
0.8276264667510986,
-0.6875424385070801
],
[
0.479561448097229,
0.2904614806175232,
-0.6911981701850891
],
[
0.29553577303886414,
0.2926357686519623,
-0.832754373550415
],
[
0.3704676926136017,
0.9744989275932312,
-0.621391773223877
],
[
0.5182272791862488,
-0.09084545075893402,
-0.7154797911643982
],
[
0.31344518065452576,
-0.08931314945220947,
-0.9003074169158936
],
[
0.3704676926136017,
1.0244989395141602,
-0.621391773223877
],
[
0.5207709074020386,
-0.1377965807914734,
-0.6441547870635986
],
[
0.2722875475883484,
-0.13784976303577423,
-0.8576371669769287
],
[
0.3782576322555542,
1.214632511138916,
-0.7064842581748962
],
[
0.4484415650367737,
1.1865851879119873,
-0.6695226430892944
],
[
0.3080736994743347,
1.1826797723770142,
-0.743445873260498
],
[
0.3685729205608368,
1.2682437896728516,
-0.48909056186676025
],
[
0.5186254978179932,
1.1985379457473755,
-0.6325609683990479
],
[
0.23788979649543762,
1.1907269954681396,
-0.7804075479507446
],
[
0.4849953353404999,
0.969235897064209,
-0.5583885312080383
],
[
0.20581978559494019,
0.9366675019264221,
-0.8437999486923218
],
[
0.3075758218765259,
1.1038239002227783,
-0.5611312389373779
],
[
0.17861124873161316,
0.6925873756408691,
-0.7852163314819336
],
[
0.24666009843349457,
1.1842905282974243,
-0.5826389193534851
],
[
0.16754546761512756,
0.6011085510253906,
-0.7378559708595276
]
]

View File

@ -1,4 +1,6 @@
import torch
import random
import numpy as np
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
from display_utils import display_model
@ -12,10 +14,10 @@ if __name__ == '__main__':
smpl_layer = SMPL_Layer(
center_idx=0,
gender='male',
model_root='smplpytorch/native/models')
model_root='/home/lmd/Code/Pose_to_SMPL_an_230402/smplpytorch/native/models')
# Generate random pose and shape parameters
pose_params = torch.rand(batch_size, 72) * 0.2
pose_params = (torch.rand(batch_size, 72)*2-1) * np.pi
shape_params = torch.rand(batch_size, 10) * 0.03
# GPU mode
@ -26,7 +28,6 @@ if __name__ == '__main__':
# Forward from the SMPL layer
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
print(Jtr)
# Draw output vertices and joints
display_model(
@ -35,5 +36,5 @@ if __name__ == '__main__':
model_faces=smpl_layer.th_faces,
with_joints=True,
kintree_table=smpl_layer.kintree_table,
savepath='image.png',
savepath='image1.png',
show=True)

View File

@ -1,12 +1,15 @@
from xml.parsers.expat import model
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
# plt.switch_backend('agg')
from torch import Tensor
from typing import Optional
def display_model(
model_info,
model_faces=None,
model_faces:Optional[Tensor]=None,
with_joints=False,
kintree_table=None,
ax=None,
@ -21,11 +24,12 @@ def display_model(
if ax is None:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
verts, joints = model_info['verts'][batch_idx], model_info['joints'][batch_idx]
verts = model_info['verts'][batch_idx]
joints = model_info['joints'][batch_idx]
if model_faces is None:
ax.scatter(verts[:, 0], verts[:, 1], verts[:, 2], alpha=0.2)
elif not only_joint:
mesh = Poly3DCollection(verts[model_faces], alpha=0.2)
mesh = Poly3DCollection(verts[model_faces.cpu()], alpha=0.2)
face_color = (141 / 255, 184 / 255, 226 / 255)
edge_color = (50 / 255, 50 / 255, 50 / 255)
mesh.set_edgecolor(edge_color)
@ -44,8 +48,8 @@ def display_model(
if savepath:
# print('Saving figure at {}.'.format(savepath))
plt.savefig(savepath, bbox_inches='tight', pad_inches=0)
if show:
plt.show()
# if show:
# plt.show()
plt.close()
return ax

View File

@ -12,7 +12,7 @@
"USE_GPU": 1,
"DATASET": {
"NAME": "UTD-MHAD",
"PATH": "../Action2Motion/CMU Mocap/mocap/mocap_3djoints/",
"PATH": "/home/lmd/Code/Pose_to_SMPL_an_230402/database",
"TARGET_PATH": "",
"DATA_MAP": [
[

24
fit/configs/HAA4D.json Normal file
View File

@ -0,0 +1,24 @@
{
"MODEL": {
"GENDER": "neutral"
},
"TRAIN": {
"LEARNING_RATE": 1e-2,
"MAX_EPOCH": 1000,
"WRITE": 10,
"OPTIMIZE_SCALE":1,
"OPTIMIZE_SHAPE":1
},
"USE_GPU": 1,
"DATASET": {
"NAME": "NTU",
"PATH": "../NTU RGB+D/result",
"TARGET_PATH": "",
"DATA_MAP": [
[0,0],[1,4],[2,1],[4,5],[5,2],[7,6],[8,3],
[12,9],[18,12],[19,15],[20,13],[21,16],
[15,10],[6,1]
]
},
"DEBUG": 0
}

93
fit/configs/NTU.json Normal file
View File

@ -0,0 +1,93 @@
{
"MODEL": {
"GENDER": "neutral"
},
"TRAIN": {
"LEARNING_RATE": 5e-2,
"MAX_EPOCH": 1000,
"WRITE": 10,
"OPTIMIZE_SCALE":0,
"OPTIMIZE_SHAPE":0
},
"USE_GPU": 1,
"DATASET": {
"NAME": "NTU",
"PATH": "../NTU RGB+D/result",
"TARGET_PATH": "",
"DATA_MAP": [
[
0,
0
],
[
2,
12
],
[
1,
16
],
[
5,
13
],
[
4,
17
],
[
6,
1
],
[
8,
14
],
[
7,
18
],
[
9,
20
],
[
12,
2
],
[
14,
4
],
[
13,
8
],
[
19,
5
],
[
18,
9
],
[
21,
6
],
[
20,
10
],
[
23,
22
],
[
22,
24
]
]
},
"DEBUG": 0
}

View File

@ -12,72 +12,100 @@
"USE_GPU": 1,
"DATASET": {
"NAME": "UTD-MHAD",
"PATH": "../UTD-MHAD/Skeleton/Skeleton/",
"PATH": "/home/lmd/Code/Pose_to_SMPL_an_230402/database",
"TARGET_PATH": "",
"DATA_MAP": [
[
12,
1,
1
],
[
2,
2
],
[
0,
3,
3
],
[
16,
4,
4
],
[
18,
5,
5
],
[
20,
6,
6
],
[
22,
7,
7
],
[
17,
8,
8
],
[
19,
9,
9
],
[
21,
10,
10
],
[
23,
11,
11
],
[
1,
12,
12
],
[
4,
13,
13
],
[
7,
14,
14
],
[
15,
15
],
[
2,
16,
16
],
[
5,
17,
17
],
[
8,
18,
18
],
[
19,
19
],
[
20,
20
],
[
21,
21
],
[
22,
22
],
[
23,
23
]
]
},

View File

@ -4,6 +4,7 @@ import os
import json
import argparse
def parse_args():
parser = argparse.ArgumentParser(description='Detect cross joints')
parser.add_argument('--dataset_name', dest='dataset_name',
@ -15,10 +16,12 @@ def parse_args():
args = parser.parse_args()
return args
def create_dir_not_exist(path):
if not os.path.exists(path):
os.mkdir(path)
def load_Jtr(file_path):
with open(file_path, 'rb') as f:
data = pickle.load(f)
@ -40,15 +43,18 @@ def cross_frames(Jtr: np.ndarray):
def cross_detector(dir_path):
ans={}
ans = {}
for root, dirs, files in os.walk(dir_path):
for file in files:
file_path = os.path.join(dir_path, file)
Jtr = load_Jtr(file_path)
ans[file]=cross_frames(Jtr)
ans[file] = cross_frames(Jtr)
return ans
if __name__ == "__main__":
args=parse_args()
d=cross_detector(args.output_path)
json.dump(d,open("./fit/output/cross_detection/{}.json".format(args.dataset_name),'w'))
args = parse_args()
d = cross_detector(args.output_path)
json.dump(
d, open("./fit/output/cross_detection/{}.json"
.format(args.dataset_name), 'w'))

View File

@ -1156,17 +1156,140 @@ CMU_Mocap = {
"41_09": "Climb"
}
NTU = {
"1":"drink water",
"2":"eat meal/snack",
"3":"brushing teeth",
"4":"brushing hair",
"5":"drop",
"6":"pickup",
"7":"throw",
"8":"sitting down",
"9":"standing up (from sitting position)",
"10":"clapping",
"11":"reading",
"12":"writing",
"13":"tear up paper",
"14":"wear jacket",
"15":"take off jacket",
"16":"wear a shoe",
"17":"take off a shoe",
"18":"wear on glasses",
"19":"take off glasses",
"20":"put on a hat/cap",
"21":"take off a hat/cap",
"22":"cheer up",
"23":"hand waving",
"24":"kicking something",
"25":"reach into pocket",
"26":"hopping (one foot jumping)",
"27":"jump up",
"28":"make a phone call/answer phone",
"29":"playing with phone/tablet",
"30":"typing on a keyboard",
"31":"pointing to something with finger",
"32":"taking a selfie",
"33":"check time (from watch)",
"34":"rub two hands together",
"35":"nod head/bow",
"36":"shake head",
"37":"wipe face",
"38":"salute",
"39":"put the palms together",
"40":"cross hands in front (say stop)",
"41":"sneeze/cough",
"42":"staggering",
"43":"falling",
"44":"touch head (headache)",
"45":"touch chest (stomachache/heart pain)",
"46":"touch back (backache)",
"47":"touch neck (neckache)",
"48":"nausea or vomiting condition",
"49":"use a fan (with hand or paper)/feeling warm",
"50":"punching/slapping other person",
"51":"kicking other person",
"52":"pushing other person",
"53":"pat on back of other person",
"54":"point finger at the other person",
"55":"hugging other person",
"56":"giving something to other person",
"57":"touch other person's pocket",
"58":"handshaking",
"59":"walking towards each other",
"60":"walking apart from each other",
"61":"put on headphone",
"62":"take off headphone",
"63":"shoot at the basket",
"64":"bounce ball",
"65":"tennis bat swing",
"66":"juggling table tennis balls",
"67":"hush (quite)",
"68":"flick hair",
"69":"thumb up",
"70":"thumb down",
"71":"make ok sign",
"72":"make victory sign",
"73":"staple book",
"74":"counting money",
"75":"cutting nails",
"76":"cutting paper (using scissors)",
"77":"snapping fingers",
"78":"open bottle",
"79":"sniff (smell)",
"80":"squat down",
"81":"toss a coin",
"82":"fold paper",
"83":"ball up paper",
"84":"play magic cube",
"85":"apply cream on face",
"86":"apply cream on hand back",
"87":"put on bag",
"88":"take off bag",
"89":"put something into a bag",
"90":"take something out of a bag",
"91":"open a box",
"92":"move heavy objects",
"93":"shake fist",
"94":"throw up cap/hat",
"95":"hands up (both hands)",
"96":"cross arms",
"97":"arm circles",
"98":"arm swings",
"99":"running on the spot",
"100":"butt kicks (kick backward)",
"101":"cross toe touch",
"102":"side kick",
"103":"yawn",
"104":"stretch oneself",
"105":"blow nose",
"106":"hit other person with something",
"107":"wield knife towards other person",
"108":"knock over other person (hit with body)",
"109":"grab other persons stuff",
"110":"shoot at other person with a gun",
"111":"step on foot",
"112":"high-five",
"113":"cheers and drink",
"114":"carry something with other person",
"115":"take a photo of other person",
"116":"follow other person",
"117":"whisper in other persons ear",
"118":"exchange things with other person",
"119":"support somebody with hand",
"120":"finger-guessing game (playing rock-paper-scissors)",
}
def get_label(file_name, dataset_name):
if dataset_name == 'HumanAct12':
key = file_name[-5:]
return HumanAct12[key]
elif dataset_name == 'UTD_MHAD':
key = file_name.split('_')[0][1:]
return UTD_MHAD[key]
###key = file_name.split('_')[0][1:]
###return UTD_MHAD[key]
return "single_person"
elif dataset_name == 'CMU_Mocap':
key = file_name.split('.')[0]
if key in CMU_Mocap.keys():
return CMU_Mocap[key]
else:
return ""
key = file_name.split(':')[0]
return CMU_Mocap[key] if key in CMU_Mocap.keys() else ""
elif dataset_name == 'NTU':
key = str(int(file_name[-3:]))
return NTU[key]

View File

@ -1,19 +1,46 @@
import scipy.io
import numpy as np
import json # 引入json模块
def load(name, path):
# 处理UTD-MHAD的JSON文件你的单帧数据
if name == 'UTD_MHAD':
data = scipy.io.loadmat(path)
arr = data['d_skel']
new_arr = np.zeros([arr.shape[2], arr.shape[0], arr.shape[1]])
for i in range(arr.shape[2]):
for j in range(arr.shape[0]):
for k in range(arr.shape[1]):
new_arr[i][j][k] = arr[j][k][i]
return new_arr
elif name == 'HumanAct12':
return np.load(path,allow_pickle=True)
elif name == "CMU_Mocap":
return np.load(path,allow_pickle=True)
# 判断文件是否为JSON格式通过后缀
if path.endswith('.json'):
with open(path, 'r') as f:
data = json.load(f) # 加载JSON列表
# 转换为NumPy数组确保形状为[关节数, 3]
data_np = np.array(data)
# 校验数据格式(防止错误)
assert data_np.ndim == 2 and data_np.shape[1] == 3, \
f"UTD-MHAD JSON格式错误应为[关节数, 3],实际为{data_np.shape}"
# 若需要单帧维度([1, 关节数, 3]),可扩展维度
return data_np[np.newaxis, ...] # 输出形状:[1, N, 3]1表示单帧
# 保留原UTD_MHAD的.mat文件支持如果还需要处理.mat数据
elif path.endswith('.mat'):
arr = scipy.io.loadmat(path)['d_skel']
new_arr = np.zeros([arr.shape[2], arr.shape[0], arr.shape[1]])
for i in range(arr.shape[2]):
for j in range(arr.shape[0]):
for k in range(arr.shape[1]):
new_arr[i][j][k] = arr[j][k][i]
return new_arr
else:
raise ValueError(f"UTD-MHAD不支持的文件格式{path}")
# 其他数据集的原有逻辑保持不变
elif name == 'HumanAct12':
return np.load(path, allow_pickle=True)
elif name == "CMU_Mocap":
return np.load(path, allow_pickle=True)
elif name == "Human3.6M":
return np.load(path, allow_pickle=True)[0::5] # 下采样
elif name == "NTU":
return np.load(path, allow_pickle=True)[0::2]
elif name == "HAA4D":
return np.load(path, allow_pickle=True)
else:
raise ValueError(f"不支持的数据集名称:{name}")

View File

@ -1,49 +1,104 @@
import torch
import numpy as np
from tensorboardX import SummaryWriter
from easydict import EasyDict as edict
import time
import sys
import os
import logging
"""
SMPL模型拟合主程序
该脚本用于将人体姿态数据拟合到SMPLSkinned Multi-Person Linear模型中
主要功能:
1. 加载人体姿态数据
2. 使用SMPL模型进行拟合优化
3. 保存拟合结果和可视化图像
"""
import argparse
import json
import os
import sys
sys.path.append(os.getcwd())
from load import load
from save import save_pic, save_params
from transform import transform
from train import train
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
# 导入自定义模块
from meters import Meters # 用于跟踪训练指标的工具类
from smplpytorch.pytorch.smpl_layer import SMPL_Layer # SMPL模型层
from train import train # 训练函数
from transform import transform # 数据变换函数
from save import save_pic, save_params # 保存结果的函数
from load import load # 数据加载函数
# 导入标准库
import torch # PyTorch深度学习框架
import numpy as np # 数值计算库
from easydict import EasyDict as edict # 用于创建字典对象的便捷工具
import time # 时间处理
import logging # 日志记录
import argparse # 命令行参数解析
import json # JSON文件处理
# 启用CUDNN加速提高训练效率
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
def parse_args():
"""
解析命令行参数
Returns:
args: 包含解析后参数的命名空间对象
"""
parser = argparse.ArgumentParser(description='Fit SMPL')
# 实验名称,默认使用当前时间戳
parser.add_argument('--exp', dest='exp',
help='Define exp name',
default=time.strftime('%Y-%m-%d %H-%M-%S', time.localtime(time.time())), type=str)
parser.add_argument('--dataset_name', dest='dataset_name',
# 数据集名称,用于选择对应的配置文件
parser.add_argument('--dataset_name', '-n', dest='dataset_name',
help='select dataset',
default='', type=str)
# 数据集路径,可以覆盖配置文件中的默认路径
parser.add_argument('--dataset_path', dest='dataset_path',
help='path of dataset',
default=None, type=str)
args = parser.parse_args()
return args
def get_config(args):
"""
根据数据集名称加载对应的配置文件
Args:
args: 命令行参数对象
Returns:
cfg: 配置对象,包含所有训练和模型参数
"""
# 根据数据集名称构建配置文件路径
config_path = 'fit/configs/{}.json'.format(args.dataset_name)
# 读取JSON配置文件
with open(config_path, 'r') as f:
data = json.load(f)
# 将字典转换为edict对象支持点号访问属性
cfg = edict(data.copy())
# 如果命令行指定了数据集路径,则覆盖配置文件中的设置
if not args.dataset_path == None:
cfg.DATASET.PATH = args.dataset_path
return cfg
def set_device(USE_GPU):
"""
根据配置和硬件可用性设置计算设备
Args:
USE_GPU: 是否使用GPU的布尔值
Returns:
device: PyTorch设备对象'cuda''cpu'
"""
if USE_GPU and torch.cuda.is_available():
device = torch.device('cuda')
else:
@ -52,9 +107,21 @@ def set_device(USE_GPU):
def get_logger(cur_path):
"""
设置日志记录器,同时输出到文件和控制台
Args:
cur_path: 当前实验路径,用于保存日志文件
Returns:
logger: 日志记录器对象
writer: TensorBoard写入器当前设置为None
"""
# 创建日志记录器
logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO)
# 设置文件输出处理器将日志保存到log.txt文件
handler = logging.FileHandler(os.path.join(cur_path, "log.txt"))
handler.setLevel(logging.INFO)
formatter = logging.Formatter(
@ -62,6 +129,7 @@ def get_logger(cur_path):
handler.setFormatter(formatter)
logger.addHandler(handler)
# 设置控制台输出处理器,将日志同时输出到终端
handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
formatter = logging.Formatter(
@ -69,43 +137,106 @@ def get_logger(cur_path):
handler.setFormatter(formatter)
logger.addHandler(handler)
writer = SummaryWriter(os.path.join(cur_path, 'tb'))
# TensorBoard写入器目前被注释掉设置为None
# from tensorboardX import SummaryWriter
# writer = SummaryWriter(os.path.join(cur_path, 'tb'))
writer = None
return logger, writer
if __name__ == "__main__":
"""
主函数执行SMPL模型拟合流程
主要步骤:
1. 解析命令行参数
2. 创建实验目录
3. 加载配置文件
4. 设置日志记录
5. 初始化SMPL模型
6. 遍历数据集进行拟合
7. 保存结果
"""
# 解析命令行参数
args = parse_args()
# 创建实验目录,使用时间戳或用户指定的实验名称
cur_path = os.path.join(os.getcwd(), 'exp', args.exp)
assert not os.path.exists(cur_path), 'Duplicate exp name'
assert not os.path.exists(cur_path), 'Duplicate exp name' # 确保实验名称不重复
os.mkdir(cur_path)
# 加载配置文件
cfg = get_config(args)
# 将配置保存到实验目录中,便于后续追踪实验设置
json.dump(dict(cfg), open(os.path.join(cur_path, 'config.json'), 'w'))
# 设置日志记录器
logger, writer = get_logger(cur_path)
logger.info("Start print log")
# 设置计算设备GPU或CPU
device = set_device(USE_GPU=cfg.USE_GPU)
logger.info('using device: {}'.format(device))
# 初始化SMPL模型层
# center_idx=0: 设置中心关节点索引
# gender='male': 设置性别为男性注释掉的cfg.MODEL.GENDER可能用于从配置文件读取
# model_root: SMPL模型文件的路径
smpl_layer = SMPL_Layer(
center_idx=0,
gender=cfg.MODEL.GENDER,
gender='male', #cfg.MODEL.GENDER,
model_root='smplpytorch/native/models')
# 初始化指标记录器,用于跟踪训练损失等指标
meters = Meters()
file_num = 0
# 遍历数据集目录中的所有文件
for root, dirs, files in os.walk(cfg.DATASET.PATH):
for file in files:
file_num += 1
logger.info('Processing file: {} [{} / {}]'.format(file, file_num, len(files)))
target = torch.from_numpy(transform(args.dataset_name,load(args.dataset_name,
os.path.join(root, file)))).float()
for file in sorted(files): # 按文件名排序处理
# 可选的文件过滤器(当前被注释掉)
# 可以用于只处理特定的文件,如包含'baseball_swing'的文件
###if not 'baseball_swing' in file:
###continue
file_num += 1 # 文件计数器
logger.info(
'Processing file: {} [{} / {}]'.format(file, file_num, len(files)))
# 加载并变换目标数据
# 1. load(): 根据数据集类型加载原始数据
# 2. transform(): 将数据转换为模型所需的格式
# 3. torch.from_numpy(): 将numpy数组转换为PyTorch张量
# 4. .float(): 确保数据类型为float32
target = torch.from_numpy(transform(args.dataset_name, load(args.dataset_name,
os.path.join(root, file)))).float()
logger.info("target shape:{}".format(target.shape))
# 执行SMPL模型拟合训练
# 传入SMPL层、目标数据、日志记录器、设备信息等
res = train(smpl_layer, target,
logger, writer, device,
args, cfg)
args, cfg, meters)
# 更新平均损失指标
# k=target.shape[0] 表示批次大小,用于加权平均
meters.update_avg(meters.min_loss, k=target.shape[0])
# 重置早停计数器,为下一个文件的训练做准备
meters.reset_early_stop()
# 记录当前的平均损失
logger.info("avg_loss:{:.4f}".format(meters.avg))
# save_pic(res,smpl_layer,file,logger,args.dataset_name,target)
# 保存拟合结果
# 1. save_params(): 保存拟合得到的SMPL参数
# 2. save_pic(): 保存可视化图像,包括拟合结果和原始目标的对比
save_params(res, file, logger, args.dataset_name)
save_pic(res, smpl_layer, file, logger, args.dataset_name, target)
# 清空GPU缓存防止内存溢出
torch.cuda.empty_cache()
# 所有文件处理完成,记录最终的平均损失
logger.info(
"Fitting finished! Average loss: {:.9f}".format(meters.avg))

27
fit/tools/meters.py Normal file
View File

@ -0,0 +1,27 @@
class Meters:
def __init__(self, eps=-1e-3, stop_threshold=10) -> None:
self.eps = eps
self.stop_threshold = stop_threshold
self.avg = 0
self.cnt = 0
self.reset_early_stop()
def reset_early_stop(self):
self.min_loss = float('inf')
self.satis_num = 0
self.update_res = True
self.early_stop = False
def update_avg(self, val, k=1):
self.avg = self.avg + (val - self.avg) * k / (self.cnt + k)
self.cnt += k
def update_early_stop(self, val):
delta = (val - self.min_loss) / self.min_loss
if float(val) < self.min_loss:
self.min_loss = float(val)
self.update_res = True
else:
self.update_res = False
self.satis_num = self.satis_num + 1 if delta >= self.eps else 0
self.early_stop = self.satis_num >= self.stop_threshold

View File

@ -1,3 +1,5 @@
from display_utils import display_model
from label import get_label
import sys
import os
import re
@ -6,8 +8,6 @@ import numpy as np
import pickle
sys.path.append(os.getcwd())
from label import get_label
from display_utils import display_model
def create_dir_not_exist(path):
@ -18,10 +18,8 @@ def create_dir_not_exist(path):
def save_pic(res, smpl_layer, file, logger, dataset_name, target):
_, _, verts, Jtr = res
file_name = re.split('[/.]', file)[-2]
fit_path = "fit/output/{}/picture/fit/{}".format(dataset_name, file_name)
gt_path = "fit/output/{}/picture/gt/{}".format(dataset_name, file_name)
create_dir_not_exist(fit_path)
create_dir_not_exist(gt_path)
fit_path = "fit/output/{}/picture/{}".format(dataset_name, file_name)
os.makedirs(fit_path,exist_ok=True)
logger.info('Saving pictures at {}'.format(fit_path))
for i in tqdm(range(Jtr.shape[0])):
display_model(
@ -30,20 +28,10 @@ def save_pic(res, smpl_layer, file, logger, dataset_name, target):
model_faces=smpl_layer.th_faces,
with_joints=True,
kintree_table=smpl_layer.kintree_table,
savepath=os.path.join(fit_path+"/frame_{}".format(i)),
savepath=os.path.join(fit_path+"/frame_{:0>4d}".format(i)),
batch_idx=i,
show=False,
only_joint=True)
# display_model(
# {'verts': verts.cpu().detach(),
# 'joints': target.cpu().detach()},
# model_faces=smpl_layer.th_faces,
# with_joints=True,
# kintree_table=smpl_layer.kintree_table,
# savepath=os.path.join(gt_path+"/frame_{}".format(i)),
# batch_idx=i,
# show=False,
# only_joint=True)
logger.info('Pictures saved')
@ -63,6 +51,25 @@ def save_params(res, file, logger, dataset_name):
params["pose_params"] = pose_params
params["shape_params"] = shape_params
params["Jtr"] = Jtr
print("label:{}".format(label))
with open(os.path.join((fit_path),
"{}_params.pkl".format(file_name)), 'wb') as f:
pickle.dump(params, f)
def save_single_pic(res, smpl_layer, epoch, logger, dataset_name, target):
_, _, verts, Jtr = res
fit_path = "fit/output/{}/picture".format(dataset_name)
create_dir_not_exist(fit_path)
logger.info('Saving pictures at {}'.format(fit_path))
display_model(
{'verts': verts.cpu().detach(),
'joints': Jtr.cpu().detach()},
model_faces=smpl_layer.th_faces,
with_joints=True,
kintree_table=smpl_layer.kintree_table,
savepath=fit_path+"/epoch_{:0>4d}".format(epoch),
batch_idx=60,
show=False,
only_joint=False)
logger.info('Picture saved')

View File

@ -6,33 +6,14 @@ import os
from tqdm import tqdm
sys.path.append(os.getcwd())
from save import save_single_pic
class Early_Stop:
def __init__(self, eps=-1e-3, stop_threshold=10) -> None:
self.min_loss = float('inf')
self.eps = eps
self.stop_threshold = stop_threshold
self.satis_num = 0
def update(self, loss):
delta = (loss - self.min_loss) / self.min_loss
if float(loss) < self.min_loss:
self.min_loss = float(loss)
update_res = True
else:
update_res = False
if delta >= self.eps:
self.satis_num += 1
else:
self.satis_num = 0
return update_res, self.satis_num >= self.stop_threshold
def init(smpl_layer, target, device, cfg):
params = {}
params["pose_params"] = torch.rand(target.shape[0], 72) * 0.0
params["shape_params"] = torch.rand(target.shape[0], 10) * 0.03
params["pose_params"] = torch.zeros(target.shape[0], 72)
params["shape_params"] = torch.zeros(target.shape[0], 10)
params["scale"] = torch.ones([1])
smpl_layer = smpl_layer.to(device)
@ -45,8 +26,10 @@ def init(smpl_layer, target, device, cfg):
params["shape_params"].requires_grad = bool(cfg.TRAIN.OPTIMIZE_SHAPE)
params["scale"].requires_grad = bool(cfg.TRAIN.OPTIMIZE_SCALE)
optimizer = optim.Adam([params["pose_params"], params["shape_params"], params["scale"]],
lr=cfg.TRAIN.LEARNING_RATE)
optim_params = [{'params': params["pose_params"], 'lr': cfg.TRAIN.LEARNING_RATE},
{'params': params["shape_params"], 'lr': cfg.TRAIN.LEARNING_RATE},
{'params': params["scale"], 'lr': cfg.TRAIN.LEARNING_RATE*10},]
optimizer = optim.Adam(optim_params)
index = {}
smpl_index = []
@ -63,38 +46,43 @@ def init(smpl_layer, target, device, cfg):
def train(smpl_layer, target,
logger, writer, device,
args, cfg):
args, cfg, meters):
res = []
smpl_layer, params, target, optimizer, index = \
init(smpl_layer, target, device, cfg)
pose_params = params["pose_params"]
shape_params = params["shape_params"]
scale = params["scale"]
early_stop = Early_Stop()
for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)):
# for epoch in range(cfg.TRAIN.MAX_EPOCH):
with torch.no_grad():
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
loss = F.smooth_l1_loss(Jtr.index_select(1, index["smpl_index"]) * 100 * scale,
target.index_select(1, index["dataset_index"]) * 100)
params["scale"]*=(torch.max(torch.abs(target))/torch.max(torch.abs(Jtr)))
for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)):
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
loss = F.smooth_l1_loss(scale*Jtr.index_select(1, index["smpl_index"]),
target.index_select(1, index["dataset_index"]))
optimizer.zero_grad()
loss.backward()
optimizer.step()
update_res, stop = early_stop.update(float(loss))
if update_res:
meters.update_early_stop(float(loss))
if meters.update_res:
res = [pose_params, shape_params, verts, Jtr]
if stop:
if meters.early_stop:
logger.info("Early stop at epoch {} !".format(epoch))
break
if epoch % cfg.TRAIN.WRITE == 0:
# logger.info("Epoch {}, lossPerBatch={:.6f}, scale={:.4f} EarlyStopSatis: {}".format(
# epoch, float(loss),float(scale), early_stop.satis_num))
writer.add_scalar('loss', float(loss), epoch)
writer.add_scalar('learning_rate', float(
optimizer.state_dict()['param_groups'][0]['lr']), epoch)
if epoch % cfg.TRAIN.WRITE == 0 or epoch<10:
# logger.info("Epoch {}, lossPerBatch={:.6f}, scale={:.4f}".format(
# epoch, float(loss),float(scale)))
print("Epoch {}, lossPerBatch={:.6f}, scale={:.4f}".format(
epoch, float(loss),float(scale)))
###writer.add_scalar('loss', float(loss), epoch)
###writer.add_scalar('learning_rate', float(
###optimizer.state_dict()['param_groups'][0]['lr']), epoch)
# save_single_pic(res,smpl_layer,epoch,logger,args.dataset_name,target)
logger.info('Train ended, min_loss = {:.9f}'.format(
float(early_stop.min_loss)))
logger.info('Train ended, min_loss = {:.4f}'.format(
float(meters.min_loss)))
return res

View File

@ -3,7 +3,10 @@ import numpy as np
rotate = {
'HumanAct12': [1., -1., -1.],
'CMU_Mocap': [0.05, 0.05, 0.05],
'UTD_MHAD': [-1., 1., -1.]
'UTD_MHAD': [1., 1., 1.],
'Human3.6M': [-0.001, -0.001, 0.001],
'NTU': [1., 1., -1.],
'HAA4D': [1., -1., -1.],
}
@ -14,4 +17,4 @@ def transform(name, arr: np.ndarray):
arr[i][j] -= origin
for k in range(3):
arr[i][j][k] *= rotate[name][k]
return arr
return arr

View File

@ -1,7 +1,7 @@
import matplotlib.pyplot as plt
import imageio, os
images = []
filenames = sorted(fn for fn in os.listdir('./fit/output/CMU_Mocap/picture/fit/01_01') )
filenames = sorted(fn for fn in os.listdir('D:/OneDrive - sjtu.edu.cn/MVIG/Action-Dataset/Pose_to_SMPL/fit/output/NTU/picture') )
for filename in filenames:
images.append(imageio.imread('./fit/output/CMU_Mocap/picture/fit/01_01/'+filename))
imageio.mimsave('fit.gif', images, duration=0.2)
images.append(imageio.imread('D:/OneDrive - sjtu.edu.cn/MVIG/Action-Dataset/Pose_to_SMPL/fit/output/NTU/picture/'+filename))
imageio.mimsave('clapping_example.gif', images, duration=0.2)

64
pyproject.toml Normal file
View File

@ -0,0 +1,64 @@
[project]
name = "template_project_alpha"
version = "0.1.0"
description = "a basic template project that has the ability to complete most of ML task with PyTorch and MMCV"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"anyio>=4.9.0",
"awkward>=2.8.5",
"beartype>=0.21.0",
"click>=8.2.1",
"easydict>=1.13",
"jaxtyping>=0.3.2",
"loguru>=0.7.3",
"mmcv",
"mmdet>=3.3.0",
"mmpose>=1.3.2",
"nats-py>=2.10.0",
"numba>=0.61.2",
"nvidia-nvimgcodec-cu12>=0.5.0.13",
"opencv-python>=4.12.0.88",
"orjson>=3.10.18",
"pyarrow>=20.0.0",
"pydantic>=2.11.7",
"pytorch3d",
"redis>=6.2.0",
"result>=0.17.0",
"tomli>=2.2.1",
"tomli-w>=1.2.0",
"torch>=2.7.0",
"torchvision>=0.22.0",
"ultralytics>=8.3.166",
"xtcocotools",
]
[dependency-groups]
dev = ["jupyterlab>=4.4.4", "pip", "psutil>=7.0.0", "setuptools"]
[tool.uv]
no-build-isolation-package = ["chumpy", "xtcocotools"]
[tool.uv.sources]
torch = [
{ index = "pytorch-cpu", marker = "sys_platform != 'linux'" },
{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" },
]
torchvision = [
{ index = "pytorch-cpu", marker = "sys_platform != 'linux'" },
{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" },
]
# built with cu128
mmcv = { path = "misc/mmcv-2.2.0-cp312-cp312-linux_x86_64.whl" }
xtcocotools = { path = "misc/xtcocotools-1.14.3-cp312-cp312-linux_x86_64.whl" }
pytorch3d = { path = "misc/pytorch3d-0.7.8-cp312-cp312-linux_x86_64.whl" }
[[tool.uv.index]]
name = "pytorch-cpu"
url = "https://download.pytorch.org/whl/cpu"
explicit = true
[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true

57
smpl.py Normal file
View File

@ -0,0 +1,57 @@
import torch
import random
import numpy as np
import pickle # 新增用于加载pkl文件
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
from display_utils import display_model
if __name__ == '__main__':
cuda = True
batch_size = 1
# Create the SMPL layer
smpl_layer = SMPL_Layer(
center_idx=0,
gender='male', # 确保与pkl文件中的性别一致
model_root='/home/lmd/Code/Pose_to_SMPL_an_230402/smplpytorch/native/models')
# 从pkl文件加载参数
pkl_path = '/home/lmd/Code/Pose_to_SMPL_an_230402/fit/output/UTD_MHAD/sigle_people_smpl_params.pkl' # 替换为实际的pkl文件路径
with open(pkl_path, 'rb') as f:
data = pickle.load(f)
# 提取pose和shape参数
pose_params = torch.tensor(data['pose_params']).float() # 确保数据类型为float
shape_params = torch.tensor(data['shape_params']).float()
# 调整维度(如果需要)
if pose_params.dim() == 1:
pose_params = pose_params.unsqueeze(0) # 添加batch维度
if shape_params.dim() == 1:
shape_params = shape_params.unsqueeze(0)
# 验证batch size
if pose_params.shape[0] != batch_size:
batch_size = pose_params.shape[0]
print(f"Warning: Batch size adjusted to {batch_size} based on loaded data.")
# GPU mode
if cuda:
pose_params = pose_params.cuda()
shape_params = shape_params.cuda()
smpl_layer.cuda()
# Forward from the SMPL layer
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
# Draw output vertices and joints
display_model(
{'verts': verts.cpu().detach(),
'joints': Jtr.cpu().detach()},
model_faces=smpl_layer.th_faces,
with_joints=True,
kintree_table=smpl_layer.kintree_table,
savepath='image2.png',
show=True)

View File

@ -27,12 +27,19 @@ class SMPL_Layer(Module):
self.center_idx = center_idx
self.gender = gender
if gender == 'neutral':
self.model_path = os.path.join(model_root, 'basicModel_neutral_lbs_10_207_0_v1.0.0.pkl')
elif gender == 'female':
# if gender == 'neutral':
# self.model_path = os.path.join(model_root, 'basicModel_neutral_lbs_10_207_0_v1.0.0.pkl')
# elif gender == 'female':
# self.model_path = os.path.join(model_root, 'basicModel_f_lbs_10_207_0_v1.0.0.pkl')
# elif gender == 'male':
# self.model_path = os.path.join(model_root, 'basicModel_m_lbs_10_207_0_v1.0.0.pkl')
if gender == 'female':
self.model_path = os.path.join(model_root, 'basicModel_f_lbs_10_207_0_v1.0.0.pkl')
elif gender == 'male':
self.model_path = os.path.join(model_root, 'basicModel_m_lbs_10_207_0_v1.0.0.pkl')
self.model_path = os.path.join(model_root, 'basicmodel_m_lbs_10_207_0_v1.0.0.pkl')
else:
raise ValueError("no valid gender")
smpl_data = ready_arguments(self.model_path)
self.smpl_data = smpl_data

3016
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff