name
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@ -14,3 +14,6 @@ smplpytorch/native/models/*.pkl
|
||||
exp/
|
||||
output/
|
||||
make_gif.py
|
||||
*.pkl
|
||||
*.whl
|
||||
*.png
|
||||
1
.python-version
Normal file
1
.python-version
Normal file
@ -0,0 +1 @@
|
||||
3.12
|
||||
15
.vscode/launch.json
vendored
Normal file
15
.vscode/launch.json
vendored
Normal file
@ -0,0 +1,15 @@
|
||||
{
|
||||
// 使用 IntelliSense 了解相关属性。
|
||||
// 悬停以查看现有属性的描述。
|
||||
// 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python 调试程序: 当前文件",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "${file}",
|
||||
"console": "integratedTerminal"
|
||||
}
|
||||
]
|
||||
}
|
||||
122
dataset/single_people_smpl.json
Normal file
122
dataset/single_people_smpl.json
Normal file
@ -0,0 +1,122 @@
|
||||
[
|
||||
[
|
||||
0.3718571662902832,
|
||||
0.6803165078163147,
|
||||
-0.7502949833869934
|
||||
],
|
||||
[
|
||||
0.45830726623535156,
|
||||
0.6820494532585144,
|
||||
-0.6965712904930115
|
||||
],
|
||||
[
|
||||
0.28540709614753723,
|
||||
0.6785836219787598,
|
||||
-0.8040187358856201
|
||||
],
|
||||
[
|
||||
0.3710346221923828,
|
||||
0.8273340463638306,
|
||||
-0.6851714849472046
|
||||
],
|
||||
[
|
||||
0.4792303442955017,
|
||||
0.2899134159088135,
|
||||
-0.6879411935806274
|
||||
],
|
||||
[
|
||||
0.2960263788700104,
|
||||
0.29219183325767517,
|
||||
-0.8322873711585999
|
||||
],
|
||||
[
|
||||
0.3702120780944824,
|
||||
0.9743515253067017,
|
||||
-0.6200480461120605
|
||||
],
|
||||
[
|
||||
0.5172274112701416,
|
||||
-0.09073961526155472,
|
||||
-0.7129285335540771
|
||||
],
|
||||
[
|
||||
0.3149631917476654,
|
||||
-0.08920176327228546,
|
||||
-0.9007936120033264
|
||||
],
|
||||
[
|
||||
0.3702120780944824,
|
||||
1.0243514776229858,
|
||||
-0.6200480461120605
|
||||
],
|
||||
[
|
||||
0.5201388001441956,
|
||||
-0.13739190995693207,
|
||||
-0.6414260864257812
|
||||
],
|
||||
[
|
||||
0.2730123996734619,
|
||||
-0.13770559430122375,
|
||||
-0.8578707575798035
|
||||
],
|
||||
[
|
||||
0.3782634735107422,
|
||||
1.2141607999801636,
|
||||
-0.7061752080917358
|
||||
],
|
||||
[
|
||||
0.4474300742149353,
|
||||
1.1862998008728027,
|
||||
-0.6677815914154053
|
||||
],
|
||||
[
|
||||
0.3090968728065491,
|
||||
1.1820218563079834,
|
||||
-0.7445688247680664
|
||||
],
|
||||
[
|
||||
0.36856698989868164,
|
||||
1.2683864831924438,
|
||||
-0.4898010492324829
|
||||
],
|
||||
[
|
||||
0.5165966749191284,
|
||||
1.1984386444091797,
|
||||
-0.6293879151344299
|
||||
],
|
||||
[
|
||||
0.23993025720119476,
|
||||
1.1898829936981201,
|
||||
-0.7829625010490417
|
||||
],
|
||||
[
|
||||
0.478150337934494,
|
||||
0.971466600894928,
|
||||
-0.5572866201400757
|
||||
],
|
||||
[
|
||||
0.20990325510501862,
|
||||
0.936416506767273,
|
||||
-0.8448361158370972
|
||||
],
|
||||
[
|
||||
0.3065527677536011,
|
||||
1.1182676553726196,
|
||||
-0.5591280460357666
|
||||
],
|
||||
[
|
||||
0.17920757830142975,
|
||||
0.6927539706230164,
|
||||
-0.7849017381668091
|
||||
],
|
||||
[
|
||||
0.25156882405281067,
|
||||
1.1860939264297485,
|
||||
-0.5887487530708313
|
||||
],
|
||||
[
|
||||
0.1682051420211792,
|
||||
0.6014264822006226,
|
||||
-0.7367973327636719
|
||||
]
|
||||
]
|
||||
122
dataset/single_people_smpl_2.json
Normal file
122
dataset/single_people_smpl_2.json
Normal file
@ -0,0 +1,122 @@
|
||||
[
|
||||
[
|
||||
0.3723624646663666,
|
||||
0.6807540655136108,
|
||||
-0.7536930441856384
|
||||
],
|
||||
[
|
||||
0.4606369733810425,
|
||||
0.6819767355918884,
|
||||
-0.7026320695877075
|
||||
],
|
||||
[
|
||||
0.2840879261493683,
|
||||
0.6795313954353333,
|
||||
-0.8047540187835693
|
||||
],
|
||||
[
|
||||
0.37141507863998413,
|
||||
0.8276264667510986,
|
||||
-0.6875424385070801
|
||||
],
|
||||
[
|
||||
0.479561448097229,
|
||||
0.2904614806175232,
|
||||
-0.6911981701850891
|
||||
],
|
||||
[
|
||||
0.29553577303886414,
|
||||
0.2926357686519623,
|
||||
-0.832754373550415
|
||||
],
|
||||
[
|
||||
0.3704676926136017,
|
||||
0.9744989275932312,
|
||||
-0.621391773223877
|
||||
],
|
||||
[
|
||||
0.5182272791862488,
|
||||
-0.09084545075893402,
|
||||
-0.7154797911643982
|
||||
],
|
||||
[
|
||||
0.31344518065452576,
|
||||
-0.08931314945220947,
|
||||
-0.9003074169158936
|
||||
],
|
||||
[
|
||||
0.3704676926136017,
|
||||
1.0244989395141602,
|
||||
-0.621391773223877
|
||||
],
|
||||
[
|
||||
0.5207709074020386,
|
||||
-0.1377965807914734,
|
||||
-0.6441547870635986
|
||||
],
|
||||
[
|
||||
0.2722875475883484,
|
||||
-0.13784976303577423,
|
||||
-0.8576371669769287
|
||||
],
|
||||
[
|
||||
0.3782576322555542,
|
||||
1.214632511138916,
|
||||
-0.7064842581748962
|
||||
],
|
||||
[
|
||||
0.4484415650367737,
|
||||
1.1865851879119873,
|
||||
-0.6695226430892944
|
||||
],
|
||||
[
|
||||
0.3080736994743347,
|
||||
1.1826797723770142,
|
||||
-0.743445873260498
|
||||
],
|
||||
[
|
||||
0.3685729205608368,
|
||||
1.2682437896728516,
|
||||
-0.48909056186676025
|
||||
],
|
||||
[
|
||||
0.5186254978179932,
|
||||
1.1985379457473755,
|
||||
-0.6325609683990479
|
||||
],
|
||||
[
|
||||
0.23788979649543762,
|
||||
1.1907269954681396,
|
||||
-0.7804075479507446
|
||||
],
|
||||
[
|
||||
0.4849953353404999,
|
||||
0.969235897064209,
|
||||
-0.5583885312080383
|
||||
],
|
||||
[
|
||||
0.20581978559494019,
|
||||
0.9366675019264221,
|
||||
-0.8437999486923218
|
||||
],
|
||||
[
|
||||
0.3075758218765259,
|
||||
1.1038239002227783,
|
||||
-0.5611312389373779
|
||||
],
|
||||
[
|
||||
0.17861124873161316,
|
||||
0.6925873756408691,
|
||||
-0.7852163314819336
|
||||
],
|
||||
[
|
||||
0.24666009843349457,
|
||||
1.1842905282974243,
|
||||
-0.5826389193534851
|
||||
],
|
||||
[
|
||||
0.16754546761512756,
|
||||
0.6011085510253906,
|
||||
-0.7378559708595276
|
||||
]
|
||||
]
|
||||
6
demo.py
6
demo.py
@ -14,10 +14,10 @@ if __name__ == '__main__':
|
||||
smpl_layer = SMPL_Layer(
|
||||
center_idx=0,
|
||||
gender='male',
|
||||
model_root='smplpytorch/native/models')
|
||||
model_root='/home/lmd/Code/Pose_to_SMPL_an_230402/smplpytorch/native/models')
|
||||
|
||||
# Generate random pose and shape parameters
|
||||
pose_params = torch.rand(batch_size, 72) * 0.01
|
||||
pose_params = (torch.rand(batch_size, 72)*2-1) * np.pi
|
||||
shape_params = torch.rand(batch_size, 10) * 0.03
|
||||
|
||||
# GPU mode
|
||||
@ -36,5 +36,5 @@ if __name__ == '__main__':
|
||||
model_faces=smpl_layer.th_faces,
|
||||
with_joints=True,
|
||||
kintree_table=smpl_layer.kintree_table,
|
||||
savepath='image.png',
|
||||
savepath='image1.png',
|
||||
show=True)
|
||||
|
||||
@ -3,11 +3,13 @@ from matplotlib import pyplot as plt
|
||||
from mpl_toolkits.mplot3d import Axes3D
|
||||
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
|
||||
# plt.switch_backend('agg')
|
||||
from torch import Tensor
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def display_model(
|
||||
model_info,
|
||||
model_faces=None,
|
||||
model_faces:Optional[Tensor]=None,
|
||||
with_joints=False,
|
||||
kintree_table=None,
|
||||
ax=None,
|
||||
@ -27,7 +29,7 @@ def display_model(
|
||||
if model_faces is None:
|
||||
ax.scatter(verts[:, 0], verts[:, 1], verts[:, 2], alpha=0.2)
|
||||
elif not only_joint:
|
||||
mesh = Poly3DCollection(verts[model_faces], alpha=0.2)
|
||||
mesh = Poly3DCollection(verts[model_faces.cpu()], alpha=0.2)
|
||||
face_color = (141 / 255, 184 / 255, 226 / 255)
|
||||
edge_color = (50 / 255, 50 / 255, 50 / 255)
|
||||
mesh.set_edgecolor(edge_color)
|
||||
@ -46,8 +48,8 @@ def display_model(
|
||||
if savepath:
|
||||
# print('Saving figure at {}.'.format(savepath))
|
||||
plt.savefig(savepath, bbox_inches='tight', pad_inches=0)
|
||||
if show:
|
||||
plt.show()
|
||||
# if show:
|
||||
# plt.show()
|
||||
plt.close()
|
||||
return ax
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@
|
||||
"USE_GPU": 1,
|
||||
"DATASET": {
|
||||
"NAME": "UTD-MHAD",
|
||||
"PATH": "../Action2Motion/CMU Mocap/mocap/mocap_3djoints/",
|
||||
"PATH": "/home/lmd/Code/Pose_to_SMPL_an_230402/database",
|
||||
"TARGET_PATH": "",
|
||||
"DATA_MAP": [
|
||||
[
|
||||
|
||||
@ -12,72 +12,100 @@
|
||||
"USE_GPU": 1,
|
||||
"DATASET": {
|
||||
"NAME": "UTD-MHAD",
|
||||
"PATH": "../UTD-MHAD/Skeleton/Skeleton/",
|
||||
"PATH": "/home/lmd/Code/Pose_to_SMPL_an_230402/database",
|
||||
"TARGET_PATH": "",
|
||||
"DATA_MAP": [
|
||||
[
|
||||
12,
|
||||
1,
|
||||
1
|
||||
],
|
||||
[
|
||||
2,
|
||||
2
|
||||
],
|
||||
[
|
||||
0,
|
||||
3,
|
||||
3
|
||||
],
|
||||
[
|
||||
16,
|
||||
4,
|
||||
4
|
||||
],
|
||||
[
|
||||
18,
|
||||
5,
|
||||
5
|
||||
],
|
||||
[
|
||||
20,
|
||||
6,
|
||||
6
|
||||
],
|
||||
[
|
||||
22,
|
||||
7,
|
||||
7
|
||||
],
|
||||
[
|
||||
17,
|
||||
8,
|
||||
8
|
||||
],
|
||||
[
|
||||
19,
|
||||
9,
|
||||
9
|
||||
],
|
||||
[
|
||||
21,
|
||||
10,
|
||||
10
|
||||
],
|
||||
[
|
||||
23,
|
||||
11,
|
||||
11
|
||||
],
|
||||
[
|
||||
1,
|
||||
12,
|
||||
12
|
||||
],
|
||||
[
|
||||
4,
|
||||
13,
|
||||
13
|
||||
],
|
||||
[
|
||||
7,
|
||||
14,
|
||||
14
|
||||
],
|
||||
[
|
||||
15,
|
||||
15
|
||||
],
|
||||
[
|
||||
2,
|
||||
16,
|
||||
16
|
||||
],
|
||||
[
|
||||
5,
|
||||
17,
|
||||
17
|
||||
],
|
||||
[
|
||||
8,
|
||||
18,
|
||||
18
|
||||
],
|
||||
[
|
||||
19,
|
||||
19
|
||||
],
|
||||
[
|
||||
20,
|
||||
20
|
||||
],
|
||||
[
|
||||
21,
|
||||
21
|
||||
],
|
||||
[
|
||||
22,
|
||||
22
|
||||
],
|
||||
[
|
||||
23,
|
||||
23
|
||||
]
|
||||
]
|
||||
},
|
||||
|
||||
@ -1284,8 +1284,9 @@ def get_label(file_name, dataset_name):
|
||||
key = file_name[-5:]
|
||||
return HumanAct12[key]
|
||||
elif dataset_name == 'UTD_MHAD':
|
||||
key = file_name.split('_')[0][1:]
|
||||
return UTD_MHAD[key]
|
||||
###key = file_name.split('_')[0][1:]
|
||||
###return UTD_MHAD[key]
|
||||
return "single_person"
|
||||
elif dataset_name == 'CMU_Mocap':
|
||||
key = file_name.split(':')[0]
|
||||
return CMU_Mocap[key] if key in CMU_Mocap.keys() else ""
|
||||
|
||||
@ -1,25 +1,46 @@
|
||||
import scipy.io
|
||||
import numpy as np
|
||||
import json
|
||||
import json # 引入json模块
|
||||
|
||||
|
||||
def load(name, path):
|
||||
# 处理UTD-MHAD的JSON文件(你的单帧数据)
|
||||
if name == 'UTD_MHAD':
|
||||
arr = scipy.io.loadmat(path)['d_skel']
|
||||
new_arr = np.zeros([arr.shape[2], arr.shape[0], arr.shape[1]])
|
||||
for i in range(arr.shape[2]):
|
||||
for j in range(arr.shape[0]):
|
||||
for k in range(arr.shape[1]):
|
||||
new_arr[i][j][k] = arr[j][k][i]
|
||||
return new_arr
|
||||
# 判断文件是否为JSON格式(通过后缀)
|
||||
if path.endswith('.json'):
|
||||
with open(path, 'r') as f:
|
||||
data = json.load(f) # 加载JSON列表
|
||||
# 转换为NumPy数组,确保形状为[关节数, 3]
|
||||
data_np = np.array(data)
|
||||
# 校验数据格式(防止错误)
|
||||
assert data_np.ndim == 2 and data_np.shape[1] == 3, \
|
||||
f"UTD-MHAD JSON格式错误,应为[关节数, 3],实际为{data_np.shape}"
|
||||
# 若需要单帧维度([1, 关节数, 3]),可扩展维度
|
||||
return data_np[np.newaxis, ...] # 输出形状:[1, N, 3](1表示单帧)
|
||||
|
||||
# 保留原UTD_MHAD的.mat文件支持(如果还需要处理.mat数据)
|
||||
elif path.endswith('.mat'):
|
||||
arr = scipy.io.loadmat(path)['d_skel']
|
||||
new_arr = np.zeros([arr.shape[2], arr.shape[0], arr.shape[1]])
|
||||
for i in range(arr.shape[2]):
|
||||
for j in range(arr.shape[0]):
|
||||
for k in range(arr.shape[1]):
|
||||
new_arr[i][j][k] = arr[j][k][i]
|
||||
return new_arr
|
||||
|
||||
else:
|
||||
raise ValueError(f"UTD-MHAD不支持的文件格式:{path}")
|
||||
|
||||
# 其他数据集的原有逻辑保持不变
|
||||
elif name == 'HumanAct12':
|
||||
return np.load(path, allow_pickle=True)
|
||||
elif name == "CMU_Mocap":
|
||||
return np.load(path, allow_pickle=True)
|
||||
elif name == "Human3.6M":
|
||||
return np.load(path, allow_pickle=True)[0::5] # down_sample
|
||||
return np.load(path, allow_pickle=True)[0::5] # 下采样
|
||||
elif name == "NTU":
|
||||
return np.load(path, allow_pickle=True)[0::2]
|
||||
elif name == "HAA4D":
|
||||
return np.load(path, allow_pickle=True)
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的数据集名称:{name}")
|
||||
@ -1,53 +1,104 @@
|
||||
"""
|
||||
SMPL模型拟合主程序
|
||||
该脚本用于将人体姿态数据拟合到SMPL(Skinned Multi-Person Linear)模型中
|
||||
主要功能:
|
||||
1. 加载人体姿态数据
|
||||
2. 使用SMPL模型进行拟合优化
|
||||
3. 保存拟合结果和可视化图像
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
sys.path.append(os.getcwd())
|
||||
from meters import Meters
|
||||
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
|
||||
from train import train
|
||||
from transform import transform
|
||||
from save import save_pic, save_params
|
||||
from load import load
|
||||
import torch
|
||||
import numpy as np
|
||||
from tensorboardX import SummaryWriter
|
||||
from easydict import EasyDict as edict
|
||||
import time
|
||||
import logging
|
||||
|
||||
import argparse
|
||||
import json
|
||||
# 导入自定义模块
|
||||
from meters import Meters # 用于跟踪训练指标的工具类
|
||||
from smplpytorch.pytorch.smpl_layer import SMPL_Layer # SMPL模型层
|
||||
from train import train # 训练函数
|
||||
from transform import transform # 数据变换函数
|
||||
from save import save_pic, save_params # 保存结果的函数
|
||||
from load import load # 数据加载函数
|
||||
|
||||
# 导入标准库
|
||||
import torch # PyTorch深度学习框架
|
||||
import numpy as np # 数值计算库
|
||||
from easydict import EasyDict as edict # 用于创建字典对象的便捷工具
|
||||
import time # 时间处理
|
||||
import logging # 日志记录
|
||||
|
||||
import argparse # 命令行参数解析
|
||||
import json # JSON文件处理
|
||||
|
||||
# 启用CUDNN加速,提高训练效率
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
解析命令行参数
|
||||
|
||||
Returns:
|
||||
args: 包含解析后参数的命名空间对象
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='Fit SMPL')
|
||||
|
||||
# 实验名称,默认使用当前时间戳
|
||||
parser.add_argument('--exp', dest='exp',
|
||||
help='Define exp name',
|
||||
default=time.strftime('%Y-%m-%d %H-%M-%S', time.localtime(time.time())), type=str)
|
||||
|
||||
# 数据集名称,用于选择对应的配置文件
|
||||
parser.add_argument('--dataset_name', '-n', dest='dataset_name',
|
||||
help='select dataset',
|
||||
default='', type=str)
|
||||
|
||||
# 数据集路径,可以覆盖配置文件中的默认路径
|
||||
parser.add_argument('--dataset_path', dest='dataset_path',
|
||||
help='path of dataset',
|
||||
default=None, type=str)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def get_config(args):
|
||||
"""
|
||||
根据数据集名称加载对应的配置文件
|
||||
|
||||
Args:
|
||||
args: 命令行参数对象
|
||||
|
||||
Returns:
|
||||
cfg: 配置对象,包含所有训练和模型参数
|
||||
"""
|
||||
# 根据数据集名称构建配置文件路径
|
||||
config_path = 'fit/configs/{}.json'.format(args.dataset_name)
|
||||
|
||||
# 读取JSON配置文件
|
||||
with open(config_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
# 将字典转换为edict对象,支持点号访问属性
|
||||
cfg = edict(data.copy())
|
||||
|
||||
# 如果命令行指定了数据集路径,则覆盖配置文件中的设置
|
||||
if not args.dataset_path == None:
|
||||
cfg.DATASET.PATH = args.dataset_path
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
def set_device(USE_GPU):
|
||||
"""
|
||||
根据配置和硬件可用性设置计算设备
|
||||
|
||||
Args:
|
||||
USE_GPU: 是否使用GPU的布尔值
|
||||
|
||||
Returns:
|
||||
device: PyTorch设备对象('cuda' 或 'cpu')
|
||||
"""
|
||||
if USE_GPU and torch.cuda.is_available():
|
||||
device = torch.device('cuda')
|
||||
else:
|
||||
@ -56,9 +107,21 @@ def set_device(USE_GPU):
|
||||
|
||||
|
||||
def get_logger(cur_path):
|
||||
"""
|
||||
设置日志记录器,同时输出到文件和控制台
|
||||
|
||||
Args:
|
||||
cur_path: 当前实验路径,用于保存日志文件
|
||||
|
||||
Returns:
|
||||
logger: 日志记录器对象
|
||||
writer: TensorBoard写入器(当前设置为None)
|
||||
"""
|
||||
# 创建日志记录器
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(level=logging.INFO)
|
||||
|
||||
# 设置文件输出处理器,将日志保存到log.txt文件
|
||||
handler = logging.FileHandler(os.path.join(cur_path, "log.txt"))
|
||||
handler.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter(
|
||||
@ -66,6 +129,7 @@ def get_logger(cur_path):
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
# 设置控制台输出处理器,将日志同时输出到终端
|
||||
handler = logging.StreamHandler()
|
||||
handler.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter(
|
||||
@ -73,54 +137,106 @@ def get_logger(cur_path):
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
writer = SummaryWriter(os.path.join(cur_path, 'tb'))
|
||||
# TensorBoard写入器(目前被注释掉,设置为None)
|
||||
# from tensorboardX import SummaryWriter
|
||||
# writer = SummaryWriter(os.path.join(cur_path, 'tb'))
|
||||
writer = None
|
||||
|
||||
return logger, writer
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
主函数:执行SMPL模型拟合流程
|
||||
|
||||
主要步骤:
|
||||
1. 解析命令行参数
|
||||
2. 创建实验目录
|
||||
3. 加载配置文件
|
||||
4. 设置日志记录
|
||||
5. 初始化SMPL模型
|
||||
6. 遍历数据集进行拟合
|
||||
7. 保存结果
|
||||
"""
|
||||
# 解析命令行参数
|
||||
args = parse_args()
|
||||
|
||||
# 创建实验目录,使用时间戳或用户指定的实验名称
|
||||
cur_path = os.path.join(os.getcwd(), 'exp', args.exp)
|
||||
assert not os.path.exists(cur_path), 'Duplicate exp name'
|
||||
assert not os.path.exists(cur_path), 'Duplicate exp name' # 确保实验名称不重复
|
||||
os.mkdir(cur_path)
|
||||
|
||||
# 加载配置文件
|
||||
cfg = get_config(args)
|
||||
# 将配置保存到实验目录中,便于后续追踪实验设置
|
||||
json.dump(dict(cfg), open(os.path.join(cur_path, 'config.json'), 'w'))
|
||||
|
||||
# 设置日志记录器
|
||||
logger, writer = get_logger(cur_path)
|
||||
logger.info("Start print log")
|
||||
|
||||
# 设置计算设备(GPU或CPU)
|
||||
device = set_device(USE_GPU=cfg.USE_GPU)
|
||||
logger.info('using device: {}'.format(device))
|
||||
|
||||
# 初始化SMPL模型层
|
||||
# center_idx=0: 设置中心关节点索引
|
||||
# gender='male': 设置性别为男性(注释掉的cfg.MODEL.GENDER可能用于从配置文件读取)
|
||||
# model_root: SMPL模型文件的路径
|
||||
smpl_layer = SMPL_Layer(
|
||||
center_idx=0,
|
||||
gender=cfg.MODEL.GENDER,
|
||||
gender='male', #cfg.MODEL.GENDER,
|
||||
model_root='smplpytorch/native/models')
|
||||
|
||||
# 初始化指标记录器,用于跟踪训练损失等指标
|
||||
meters = Meters()
|
||||
file_num = 0
|
||||
# 遍历数据集目录中的所有文件
|
||||
for root, dirs, files in os.walk(cfg.DATASET.PATH):
|
||||
for file in sorted(files):
|
||||
if not 'baseball_swing' in file:
|
||||
continue
|
||||
file_num += 1
|
||||
for file in sorted(files): # 按文件名排序处理
|
||||
# 可选的文件过滤器(当前被注释掉)
|
||||
# 可以用于只处理特定的文件,如包含'baseball_swing'的文件
|
||||
###if not 'baseball_swing' in file:
|
||||
###continue
|
||||
|
||||
file_num += 1 # 文件计数器
|
||||
logger.info(
|
||||
'Processing file: {} [{} / {}]'.format(file, file_num, len(files)))
|
||||
|
||||
# 加载并变换目标数据
|
||||
# 1. load(): 根据数据集类型加载原始数据
|
||||
# 2. transform(): 将数据转换为模型所需的格式
|
||||
# 3. torch.from_numpy(): 将numpy数组转换为PyTorch张量
|
||||
# 4. .float(): 确保数据类型为float32
|
||||
target = torch.from_numpy(transform(args.dataset_name, load(args.dataset_name,
|
||||
os.path.join(root, file)))).float()
|
||||
logger.info("target shape:{}".format(target.shape))
|
||||
|
||||
# 执行SMPL模型拟合训练
|
||||
# 传入SMPL层、目标数据、日志记录器、设备信息等
|
||||
res = train(smpl_layer, target,
|
||||
logger, writer, device,
|
||||
args, cfg, meters)
|
||||
|
||||
# 更新平均损失指标
|
||||
# k=target.shape[0] 表示批次大小,用于加权平均
|
||||
meters.update_avg(meters.min_loss, k=target.shape[0])
|
||||
|
||||
# 重置早停计数器,为下一个文件的训练做准备
|
||||
meters.reset_early_stop()
|
||||
|
||||
# 记录当前的平均损失
|
||||
logger.info("avg_loss:{:.4f}".format(meters.avg))
|
||||
|
||||
# 保存拟合结果
|
||||
# 1. save_params(): 保存拟合得到的SMPL参数
|
||||
# 2. save_pic(): 保存可视化图像,包括拟合结果和原始目标的对比
|
||||
save_params(res, file, logger, args.dataset_name)
|
||||
save_pic(res, smpl_layer, file, logger, args.dataset_name, target)
|
||||
|
||||
# 清空GPU缓存,防止内存溢出
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# 所有文件处理完成,记录最终的平均损失
|
||||
logger.info(
|
||||
"Fitting finished! Average loss: {:.9f}".format(meters.avg))
|
||||
|
||||
@ -78,9 +78,9 @@ def train(smpl_layer, target,
|
||||
# epoch, float(loss),float(scale)))
|
||||
print("Epoch {}, lossPerBatch={:.6f}, scale={:.4f}".format(
|
||||
epoch, float(loss),float(scale)))
|
||||
writer.add_scalar('loss', float(loss), epoch)
|
||||
writer.add_scalar('learning_rate', float(
|
||||
optimizer.state_dict()['param_groups'][0]['lr']), epoch)
|
||||
###writer.add_scalar('loss', float(loss), epoch)
|
||||
###writer.add_scalar('learning_rate', float(
|
||||
###optimizer.state_dict()['param_groups'][0]['lr']), epoch)
|
||||
# save_single_pic(res,smpl_layer,epoch,logger,args.dataset_name,target)
|
||||
|
||||
logger.info('Train ended, min_loss = {:.4f}'.format(
|
||||
|
||||
@ -3,7 +3,7 @@ import numpy as np
|
||||
rotate = {
|
||||
'HumanAct12': [1., -1., -1.],
|
||||
'CMU_Mocap': [0.05, 0.05, 0.05],
|
||||
'UTD_MHAD': [-1., 1., -1.],
|
||||
'UTD_MHAD': [1., 1., 1.],
|
||||
'Human3.6M': [-0.001, -0.001, 0.001],
|
||||
'NTU': [1., 1., -1.],
|
||||
'HAA4D': [1., -1., -1.],
|
||||
|
||||
64
pyproject.toml
Normal file
64
pyproject.toml
Normal file
@ -0,0 +1,64 @@
|
||||
[project]
|
||||
name = "template_project_alpha"
|
||||
version = "0.1.0"
|
||||
description = "a basic template project that has the ability to complete most of ML task with PyTorch and MMCV"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"anyio>=4.9.0",
|
||||
"awkward>=2.8.5",
|
||||
"beartype>=0.21.0",
|
||||
"click>=8.2.1",
|
||||
"easydict>=1.13",
|
||||
"jaxtyping>=0.3.2",
|
||||
"loguru>=0.7.3",
|
||||
"mmcv",
|
||||
"mmdet>=3.3.0",
|
||||
"mmpose>=1.3.2",
|
||||
"nats-py>=2.10.0",
|
||||
"numba>=0.61.2",
|
||||
"nvidia-nvimgcodec-cu12>=0.5.0.13",
|
||||
"opencv-python>=4.12.0.88",
|
||||
"orjson>=3.10.18",
|
||||
"pyarrow>=20.0.0",
|
||||
"pydantic>=2.11.7",
|
||||
"pytorch3d",
|
||||
"redis>=6.2.0",
|
||||
"result>=0.17.0",
|
||||
"tomli>=2.2.1",
|
||||
"tomli-w>=1.2.0",
|
||||
"torch>=2.7.0",
|
||||
"torchvision>=0.22.0",
|
||||
"ultralytics>=8.3.166",
|
||||
"xtcocotools",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
dev = ["jupyterlab>=4.4.4", "pip", "psutil>=7.0.0", "setuptools"]
|
||||
|
||||
[tool.uv]
|
||||
no-build-isolation-package = ["chumpy", "xtcocotools"]
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = [
|
||||
{ index = "pytorch-cpu", marker = "sys_platform != 'linux'" },
|
||||
{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" },
|
||||
]
|
||||
torchvision = [
|
||||
{ index = "pytorch-cpu", marker = "sys_platform != 'linux'" },
|
||||
{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" },
|
||||
]
|
||||
# built with cu128
|
||||
mmcv = { path = "misc/mmcv-2.2.0-cp312-cp312-linux_x86_64.whl" }
|
||||
xtcocotools = { path = "misc/xtcocotools-1.14.3-cp312-cp312-linux_x86_64.whl" }
|
||||
pytorch3d = { path = "misc/pytorch3d-0.7.8-cp312-cp312-linux_x86_64.whl" }
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cpu"
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
explicit = true
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cu128"
|
||||
url = "https://download.pytorch.org/whl/cu128"
|
||||
explicit = true
|
||||
57
smpl.py
Normal file
57
smpl.py
Normal file
@ -0,0 +1,57 @@
|
||||
import torch
|
||||
import random
|
||||
import numpy as np
|
||||
import pickle # 新增:用于加载pkl文件
|
||||
|
||||
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
|
||||
from display_utils import display_model
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
cuda = True
|
||||
batch_size = 1
|
||||
|
||||
# Create the SMPL layer
|
||||
smpl_layer = SMPL_Layer(
|
||||
center_idx=0,
|
||||
gender='male', # 确保与pkl文件中的性别一致
|
||||
model_root='/home/lmd/Code/Pose_to_SMPL_an_230402/smplpytorch/native/models')
|
||||
|
||||
# 从pkl文件加载参数
|
||||
pkl_path = '/home/lmd/Code/Pose_to_SMPL_an_230402/fit/output/UTD_MHAD/sigle_people_smpl_params.pkl' # 替换为实际的pkl文件路径
|
||||
with open(pkl_path, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
# 提取pose和shape参数
|
||||
pose_params = torch.tensor(data['pose_params']).float() # 确保数据类型为float
|
||||
shape_params = torch.tensor(data['shape_params']).float()
|
||||
|
||||
# 调整维度(如果需要)
|
||||
if pose_params.dim() == 1:
|
||||
pose_params = pose_params.unsqueeze(0) # 添加batch维度
|
||||
if shape_params.dim() == 1:
|
||||
shape_params = shape_params.unsqueeze(0)
|
||||
|
||||
# 验证batch size
|
||||
if pose_params.shape[0] != batch_size:
|
||||
batch_size = pose_params.shape[0]
|
||||
print(f"Warning: Batch size adjusted to {batch_size} based on loaded data.")
|
||||
|
||||
# GPU mode
|
||||
if cuda:
|
||||
pose_params = pose_params.cuda()
|
||||
shape_params = shape_params.cuda()
|
||||
smpl_layer.cuda()
|
||||
|
||||
# Forward from the SMPL layer
|
||||
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
|
||||
|
||||
# Draw output vertices and joints
|
||||
display_model(
|
||||
{'verts': verts.cpu().detach(),
|
||||
'joints': Jtr.cpu().detach()},
|
||||
model_faces=smpl_layer.th_faces,
|
||||
with_joints=True,
|
||||
kintree_table=smpl_layer.kintree_table,
|
||||
savepath='image2.png',
|
||||
show=True)
|
||||
@ -27,12 +27,19 @@ class SMPL_Layer(Module):
|
||||
self.center_idx = center_idx
|
||||
self.gender = gender
|
||||
|
||||
if gender == 'neutral':
|
||||
self.model_path = os.path.join(model_root, 'basicModel_neutral_lbs_10_207_0_v1.0.0.pkl')
|
||||
elif gender == 'female':
|
||||
# if gender == 'neutral':
|
||||
# self.model_path = os.path.join(model_root, 'basicModel_neutral_lbs_10_207_0_v1.0.0.pkl')
|
||||
# elif gender == 'female':
|
||||
# self.model_path = os.path.join(model_root, 'basicModel_f_lbs_10_207_0_v1.0.0.pkl')
|
||||
# elif gender == 'male':
|
||||
# self.model_path = os.path.join(model_root, 'basicModel_m_lbs_10_207_0_v1.0.0.pkl')
|
||||
|
||||
if gender == 'female':
|
||||
self.model_path = os.path.join(model_root, 'basicModel_f_lbs_10_207_0_v1.0.0.pkl')
|
||||
elif gender == 'male':
|
||||
self.model_path = os.path.join(model_root, 'basicModel_m_lbs_10_207_0_v1.0.0.pkl')
|
||||
self.model_path = os.path.join(model_root, 'basicmodel_m_lbs_10_207_0_v1.0.0.pkl')
|
||||
else:
|
||||
raise ValueError("no valid gender")
|
||||
|
||||
smpl_data = ready_arguments(self.model_path)
|
||||
self.smpl_data = smpl_data
|
||||
|
||||
Reference in New Issue
Block a user