243 lines
8.4 KiB
Python
243 lines
8.4 KiB
Python
"""
|
||
SMPL模型拟合主程序
|
||
该脚本用于将人体姿态数据拟合到SMPL(Skinned Multi-Person Linear)模型中
|
||
主要功能:
|
||
1. 加载人体姿态数据
|
||
2. 使用SMPL模型进行拟合优化
|
||
3. 保存拟合结果和可视化图像
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
sys.path.append(os.getcwd())
|
||
|
||
# 导入自定义模块
|
||
from meters import Meters # 用于跟踪训练指标的工具类
|
||
from smplpytorch.pytorch.smpl_layer import SMPL_Layer # SMPL模型层
|
||
from train import train # 训练函数
|
||
from transform import transform # 数据变换函数
|
||
from save import save_pic, save_params # 保存结果的函数
|
||
from load import load # 数据加载函数
|
||
|
||
# 导入标准库
|
||
import torch # PyTorch深度学习框架
|
||
import numpy as np # 数值计算库
|
||
from easydict import EasyDict as edict # 用于创建字典对象的便捷工具
|
||
import time # 时间处理
|
||
import logging # 日志记录
|
||
|
||
import argparse # 命令行参数解析
|
||
import json # JSON文件处理
|
||
|
||
# 启用CUDNN加速,提高训练效率
|
||
torch.backends.cudnn.enabled = True
|
||
torch.backends.cudnn.benchmark = True
|
||
|
||
|
||
def parse_args():
|
||
"""
|
||
解析命令行参数
|
||
|
||
Returns:
|
||
args: 包含解析后参数的命名空间对象
|
||
"""
|
||
parser = argparse.ArgumentParser(description='Fit SMPL')
|
||
|
||
# 实验名称,默认使用当前时间戳
|
||
parser.add_argument('--exp', dest='exp',
|
||
help='Define exp name',
|
||
default=time.strftime('%Y-%m-%d %H-%M-%S', time.localtime(time.time())), type=str)
|
||
|
||
# 数据集名称,用于选择对应的配置文件
|
||
parser.add_argument('--dataset_name', '-n', dest='dataset_name',
|
||
help='select dataset',
|
||
default='', type=str)
|
||
|
||
# 数据集路径,可以覆盖配置文件中的默认路径
|
||
parser.add_argument('--dataset_path', dest='dataset_path',
|
||
help='path of dataset',
|
||
default=None, type=str)
|
||
|
||
args = parser.parse_args()
|
||
return args
|
||
|
||
|
||
def get_config(args):
|
||
"""
|
||
根据数据集名称加载对应的配置文件
|
||
|
||
Args:
|
||
args: 命令行参数对象
|
||
|
||
Returns:
|
||
cfg: 配置对象,包含所有训练和模型参数
|
||
"""
|
||
# 根据数据集名称构建配置文件路径
|
||
config_path = 'fit/configs/{}.json'.format(args.dataset_name)
|
||
|
||
# 读取JSON配置文件
|
||
with open(config_path, 'r') as f:
|
||
data = json.load(f)
|
||
|
||
# 将字典转换为edict对象,支持点号访问属性
|
||
cfg = edict(data.copy())
|
||
|
||
# 如果命令行指定了数据集路径,则覆盖配置文件中的设置
|
||
if not args.dataset_path == None:
|
||
cfg.DATASET.PATH = args.dataset_path
|
||
|
||
return cfg
|
||
|
||
|
||
def set_device(USE_GPU):
|
||
"""
|
||
根据配置和硬件可用性设置计算设备
|
||
|
||
Args:
|
||
USE_GPU: 是否使用GPU的布尔值
|
||
|
||
Returns:
|
||
device: PyTorch设备对象('cuda' 或 'cpu')
|
||
"""
|
||
if USE_GPU and torch.cuda.is_available():
|
||
device = torch.device('cuda')
|
||
else:
|
||
device = torch.device('cpu')
|
||
return device
|
||
|
||
|
||
def get_logger(cur_path):
|
||
"""
|
||
设置日志记录器,同时输出到文件和控制台
|
||
|
||
Args:
|
||
cur_path: 当前实验路径,用于保存日志文件
|
||
|
||
Returns:
|
||
logger: 日志记录器对象
|
||
writer: TensorBoard写入器(当前设置为None)
|
||
"""
|
||
# 创建日志记录器
|
||
logger = logging.getLogger(__name__)
|
||
logger.setLevel(level=logging.INFO)
|
||
|
||
# 设置文件输出处理器,将日志保存到log.txt文件
|
||
handler = logging.FileHandler(os.path.join(cur_path, "log.txt"))
|
||
handler.setLevel(logging.INFO)
|
||
formatter = logging.Formatter(
|
||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||
handler.setFormatter(formatter)
|
||
logger.addHandler(handler)
|
||
|
||
# 设置控制台输出处理器,将日志同时输出到终端
|
||
handler = logging.StreamHandler()
|
||
handler.setLevel(logging.INFO)
|
||
formatter = logging.Formatter(
|
||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||
handler.setFormatter(formatter)
|
||
logger.addHandler(handler)
|
||
|
||
# TensorBoard写入器(目前被注释掉,设置为None)
|
||
# from tensorboardX import SummaryWriter
|
||
# writer = SummaryWriter(os.path.join(cur_path, 'tb'))
|
||
writer = None
|
||
|
||
return logger, writer
|
||
|
||
|
||
if __name__ == "__main__":
|
||
"""
|
||
主函数:执行SMPL模型拟合流程
|
||
|
||
主要步骤:
|
||
1. 解析命令行参数
|
||
2. 创建实验目录
|
||
3. 加载配置文件
|
||
4. 设置日志记录
|
||
5. 初始化SMPL模型
|
||
6. 遍历数据集进行拟合
|
||
7. 保存结果
|
||
"""
|
||
# 解析命令行参数
|
||
args = parse_args()
|
||
|
||
# 创建实验目录,使用时间戳或用户指定的实验名称
|
||
cur_path = os.path.join(os.getcwd(), 'exp', args.exp)
|
||
assert not os.path.exists(cur_path), 'Duplicate exp name' # 确保实验名称不重复
|
||
os.mkdir(cur_path)
|
||
|
||
# 加载配置文件
|
||
cfg = get_config(args)
|
||
# 将配置保存到实验目录中,便于后续追踪实验设置
|
||
json.dump(dict(cfg), open(os.path.join(cur_path, 'config.json'), 'w'))
|
||
|
||
# 设置日志记录器
|
||
logger, writer = get_logger(cur_path)
|
||
logger.info("Start print log")
|
||
|
||
# 设置计算设备(GPU或CPU)
|
||
device = set_device(USE_GPU=cfg.USE_GPU)
|
||
logger.info('using device: {}'.format(device))
|
||
|
||
# 初始化SMPL模型层
|
||
# center_idx=0: 设置中心关节点索引
|
||
# gender='male': 设置性别为男性(注释掉的cfg.MODEL.GENDER可能用于从配置文件读取)
|
||
# model_root: SMPL模型文件的路径
|
||
smpl_layer = SMPL_Layer(
|
||
center_idx=0,
|
||
gender='male', #cfg.MODEL.GENDER,
|
||
model_root='smplpytorch/native/models')
|
||
|
||
# 初始化指标记录器,用于跟踪训练损失等指标
|
||
meters = Meters()
|
||
file_num = 0
|
||
# 遍历数据集目录中的所有文件
|
||
for root, dirs, files in os.walk(cfg.DATASET.PATH):
|
||
for file in sorted(files): # 按文件名排序处理
|
||
# 可选的文件过滤器(当前被注释掉)
|
||
# 可以用于只处理特定的文件,如包含'baseball_swing'的文件
|
||
###if not 'baseball_swing' in file:
|
||
###continue
|
||
|
||
file_num += 1 # 文件计数器
|
||
logger.info(
|
||
'Processing file: {} [{} / {}]'.format(file, file_num, len(files)))
|
||
|
||
# 加载并变换目标数据
|
||
# 1. load(): 根据数据集类型加载原始数据
|
||
# 2. transform(): 将数据转换为模型所需的格式
|
||
# 3. torch.from_numpy(): 将numpy数组转换为PyTorch张量
|
||
# 4. .float(): 确保数据类型为float32
|
||
target = torch.from_numpy(transform(args.dataset_name, load(args.dataset_name,
|
||
os.path.join(root, file)))).float()
|
||
logger.info("target shape:{}".format(target.shape))
|
||
|
||
# 执行SMPL模型拟合训练
|
||
# 传入SMPL层、目标数据、日志记录器、设备信息等
|
||
res = train(smpl_layer, target,
|
||
logger, writer, device,
|
||
args, cfg, meters)
|
||
|
||
# 更新平均损失指标
|
||
# k=target.shape[0] 表示批次大小,用于加权平均
|
||
meters.update_avg(meters.min_loss, k=target.shape[0])
|
||
|
||
# 重置早停计数器,为下一个文件的训练做准备
|
||
meters.reset_early_stop()
|
||
|
||
# 记录当前的平均损失
|
||
logger.info("avg_loss:{:.4f}".format(meters.avg))
|
||
|
||
# 保存拟合结果
|
||
# 1. save_params(): 保存拟合得到的SMPL参数
|
||
# 2. save_pic(): 保存可视化图像,包括拟合结果和原始目标的对比
|
||
save_params(res, file, logger, args.dataset_name)
|
||
save_pic(res, smpl_layer, file, logger, args.dataset_name, target)
|
||
|
||
# 清空GPU缓存,防止内存溢出
|
||
torch.cuda.empty_cache()
|
||
|
||
# 所有文件处理完成,记录最终的平均损失
|
||
logger.info(
|
||
"Fitting finished! Average loss: {:.9f}".format(meters.avg))
|