Files
2025-07-25 15:05:31 +08:00

243 lines
8.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
SMPL模型拟合主程序
该脚本用于将人体姿态数据拟合到SMPLSkinned Multi-Person Linear模型中
主要功能:
1. 加载人体姿态数据
2. 使用SMPL模型进行拟合优化
3. 保存拟合结果和可视化图像
"""
import os
import sys
sys.path.append(os.getcwd())
# 导入自定义模块
from meters import Meters # 用于跟踪训练指标的工具类
from smplpytorch.pytorch.smpl_layer import SMPL_Layer # SMPL模型层
from train import train # 训练函数
from transform import transform # 数据变换函数
from save import save_pic, save_params # 保存结果的函数
from load import load # 数据加载函数
# 导入标准库
import torch # PyTorch深度学习框架
import numpy as np # 数值计算库
from easydict import EasyDict as edict # 用于创建字典对象的便捷工具
import time # 时间处理
import logging # 日志记录
import argparse # 命令行参数解析
import json # JSON文件处理
# 启用CUDNN加速提高训练效率
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
def parse_args():
"""
解析命令行参数
Returns:
args: 包含解析后参数的命名空间对象
"""
parser = argparse.ArgumentParser(description='Fit SMPL')
# 实验名称,默认使用当前时间戳
parser.add_argument('--exp', dest='exp',
help='Define exp name',
default=time.strftime('%Y-%m-%d %H-%M-%S', time.localtime(time.time())), type=str)
# 数据集名称,用于选择对应的配置文件
parser.add_argument('--dataset_name', '-n', dest='dataset_name',
help='select dataset',
default='', type=str)
# 数据集路径,可以覆盖配置文件中的默认路径
parser.add_argument('--dataset_path', dest='dataset_path',
help='path of dataset',
default=None, type=str)
args = parser.parse_args()
return args
def get_config(args):
"""
根据数据集名称加载对应的配置文件
Args:
args: 命令行参数对象
Returns:
cfg: 配置对象,包含所有训练和模型参数
"""
# 根据数据集名称构建配置文件路径
config_path = 'fit/configs/{}.json'.format(args.dataset_name)
# 读取JSON配置文件
with open(config_path, 'r') as f:
data = json.load(f)
# 将字典转换为edict对象支持点号访问属性
cfg = edict(data.copy())
# 如果命令行指定了数据集路径,则覆盖配置文件中的设置
if not args.dataset_path == None:
cfg.DATASET.PATH = args.dataset_path
return cfg
def set_device(USE_GPU):
"""
根据配置和硬件可用性设置计算设备
Args:
USE_GPU: 是否使用GPU的布尔值
Returns:
device: PyTorch设备对象'cuda''cpu'
"""
if USE_GPU and torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
return device
def get_logger(cur_path):
"""
设置日志记录器,同时输出到文件和控制台
Args:
cur_path: 当前实验路径,用于保存日志文件
Returns:
logger: 日志记录器对象
writer: TensorBoard写入器当前设置为None
"""
# 创建日志记录器
logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO)
# 设置文件输出处理器将日志保存到log.txt文件
handler = logging.FileHandler(os.path.join(cur_path, "log.txt"))
handler.setLevel(logging.INFO)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
# 设置控制台输出处理器,将日志同时输出到终端
handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
# TensorBoard写入器目前被注释掉设置为None
# from tensorboardX import SummaryWriter
# writer = SummaryWriter(os.path.join(cur_path, 'tb'))
writer = None
return logger, writer
if __name__ == "__main__":
"""
主函数执行SMPL模型拟合流程
主要步骤:
1. 解析命令行参数
2. 创建实验目录
3. 加载配置文件
4. 设置日志记录
5. 初始化SMPL模型
6. 遍历数据集进行拟合
7. 保存结果
"""
# 解析命令行参数
args = parse_args()
# 创建实验目录,使用时间戳或用户指定的实验名称
cur_path = os.path.join(os.getcwd(), 'exp', args.exp)
assert not os.path.exists(cur_path), 'Duplicate exp name' # 确保实验名称不重复
os.mkdir(cur_path)
# 加载配置文件
cfg = get_config(args)
# 将配置保存到实验目录中,便于后续追踪实验设置
json.dump(dict(cfg), open(os.path.join(cur_path, 'config.json'), 'w'))
# 设置日志记录器
logger, writer = get_logger(cur_path)
logger.info("Start print log")
# 设置计算设备GPU或CPU
device = set_device(USE_GPU=cfg.USE_GPU)
logger.info('using device: {}'.format(device))
# 初始化SMPL模型层
# center_idx=0: 设置中心关节点索引
# gender='male': 设置性别为男性注释掉的cfg.MODEL.GENDER可能用于从配置文件读取
# model_root: SMPL模型文件的路径
smpl_layer = SMPL_Layer(
center_idx=0,
gender='male', #cfg.MODEL.GENDER,
model_root='smplpytorch/native/models')
# 初始化指标记录器,用于跟踪训练损失等指标
meters = Meters()
file_num = 0
# 遍历数据集目录中的所有文件
for root, dirs, files in os.walk(cfg.DATASET.PATH):
for file in sorted(files): # 按文件名排序处理
# 可选的文件过滤器(当前被注释掉)
# 可以用于只处理特定的文件,如包含'baseball_swing'的文件
###if not 'baseball_swing' in file:
###continue
file_num += 1 # 文件计数器
logger.info(
'Processing file: {} [{} / {}]'.format(file, file_num, len(files)))
# 加载并变换目标数据
# 1. load(): 根据数据集类型加载原始数据
# 2. transform(): 将数据转换为模型所需的格式
# 3. torch.from_numpy(): 将numpy数组转换为PyTorch张量
# 4. .float(): 确保数据类型为float32
target = torch.from_numpy(transform(args.dataset_name, load(args.dataset_name,
os.path.join(root, file)))).float()
logger.info("target shape:{}".format(target.shape))
# 执行SMPL模型拟合训练
# 传入SMPL层、目标数据、日志记录器、设备信息等
res = train(smpl_layer, target,
logger, writer, device,
args, cfg, meters)
# 更新平均损失指标
# k=target.shape[0] 表示批次大小,用于加权平均
meters.update_avg(meters.min_loss, k=target.shape[0])
# 重置早停计数器,为下一个文件的训练做准备
meters.reset_early_stop()
# 记录当前的平均损失
logger.info("avg_loss:{:.4f}".format(meters.avg))
# 保存拟合结果
# 1. save_params(): 保存拟合得到的SMPL参数
# 2. save_pic(): 保存可视化图像,包括拟合结果和原始目标的对比
save_params(res, file, logger, args.dataset_name)
save_pic(res, smpl_layer, file, logger, args.dataset_name, target)
# 清空GPU缓存防止内存溢出
torch.cuda.empty_cache()
# 所有文件处理完成,记录最终的平均损失
logger.info(
"Fitting finished! Average loss: {:.9f}".format(meters.avg))