""" SMPL模型拟合主程序 该脚本用于将人体姿态数据拟合到SMPL(Skinned Multi-Person Linear)模型中 主要功能: 1. 加载人体姿态数据 2. 使用SMPL模型进行拟合优化 3. 保存拟合结果和可视化图像 """ import os import sys sys.path.append(os.getcwd()) # 导入自定义模块 from meters import Meters # 用于跟踪训练指标的工具类 from smplpytorch.pytorch.smpl_layer import SMPL_Layer # SMPL模型层 from train import train # 训练函数 from transform import transform # 数据变换函数 from save import save_pic, save_params # 保存结果的函数 from load import load # 数据加载函数 # 导入标准库 import torch # PyTorch深度学习框架 import numpy as np # 数值计算库 from easydict import EasyDict as edict # 用于创建字典对象的便捷工具 import time # 时间处理 import logging # 日志记录 import argparse # 命令行参数解析 import json # JSON文件处理 # 启用CUDNN加速,提高训练效率 torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True def parse_args(): """ 解析命令行参数 Returns: args: 包含解析后参数的命名空间对象 """ parser = argparse.ArgumentParser(description='Fit SMPL') # 实验名称,默认使用当前时间戳 parser.add_argument('--exp', dest='exp', help='Define exp name', default=time.strftime('%Y-%m-%d %H-%M-%S', time.localtime(time.time())), type=str) # 数据集名称,用于选择对应的配置文件 parser.add_argument('--dataset_name', '-n', dest='dataset_name', help='select dataset', default='', type=str) # 数据集路径,可以覆盖配置文件中的默认路径 parser.add_argument('--dataset_path', dest='dataset_path', help='path of dataset', default=None, type=str) args = parser.parse_args() return args def get_config(args): """ 根据数据集名称加载对应的配置文件 Args: args: 命令行参数对象 Returns: cfg: 配置对象,包含所有训练和模型参数 """ # 根据数据集名称构建配置文件路径 config_path = 'fit/configs/{}.json'.format(args.dataset_name) # 读取JSON配置文件 with open(config_path, 'r') as f: data = json.load(f) # 将字典转换为edict对象,支持点号访问属性 cfg = edict(data.copy()) # 如果命令行指定了数据集路径,则覆盖配置文件中的设置 if not args.dataset_path == None: cfg.DATASET.PATH = args.dataset_path return cfg def set_device(USE_GPU): """ 根据配置和硬件可用性设置计算设备 Args: USE_GPU: 是否使用GPU的布尔值 Returns: device: PyTorch设备对象('cuda' 或 'cpu') """ if USE_GPU and torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') return device def get_logger(cur_path): """ 设置日志记录器,同时输出到文件和控制台 Args: cur_path: 当前实验路径,用于保存日志文件 Returns: logger: 日志记录器对象 writer: TensorBoard写入器(当前设置为None) """ # 创建日志记录器 logger = logging.getLogger(__name__) logger.setLevel(level=logging.INFO) # 设置文件输出处理器,将日志保存到log.txt文件 handler = logging.FileHandler(os.path.join(cur_path, "log.txt")) handler.setLevel(logging.INFO) formatter = logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s') handler.setFormatter(formatter) logger.addHandler(handler) # 设置控制台输出处理器,将日志同时输出到终端 handler = logging.StreamHandler() handler.setLevel(logging.INFO) formatter = logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s') handler.setFormatter(formatter) logger.addHandler(handler) # TensorBoard写入器(目前被注释掉,设置为None) # from tensorboardX import SummaryWriter # writer = SummaryWriter(os.path.join(cur_path, 'tb')) writer = None return logger, writer if __name__ == "__main__": """ 主函数:执行SMPL模型拟合流程 主要步骤: 1. 解析命令行参数 2. 创建实验目录 3. 加载配置文件 4. 设置日志记录 5. 初始化SMPL模型 6. 遍历数据集进行拟合 7. 保存结果 """ # 解析命令行参数 args = parse_args() # 创建实验目录,使用时间戳或用户指定的实验名称 cur_path = os.path.join(os.getcwd(), 'exp', args.exp) assert not os.path.exists(cur_path), 'Duplicate exp name' # 确保实验名称不重复 os.mkdir(cur_path) # 加载配置文件 cfg = get_config(args) # 将配置保存到实验目录中,便于后续追踪实验设置 json.dump(dict(cfg), open(os.path.join(cur_path, 'config.json'), 'w')) # 设置日志记录器 logger, writer = get_logger(cur_path) logger.info("Start print log") # 设置计算设备(GPU或CPU) device = set_device(USE_GPU=cfg.USE_GPU) logger.info('using device: {}'.format(device)) # 初始化SMPL模型层 # center_idx=0: 设置中心关节点索引 # gender='male': 设置性别为男性(注释掉的cfg.MODEL.GENDER可能用于从配置文件读取) # model_root: SMPL模型文件的路径 smpl_layer = SMPL_Layer( center_idx=0, gender='male', #cfg.MODEL.GENDER, model_root='smplpytorch/native/models') # 初始化指标记录器,用于跟踪训练损失等指标 meters = Meters() file_num = 0 # 遍历数据集目录中的所有文件 for root, dirs, files in os.walk(cfg.DATASET.PATH): for file in sorted(files): # 按文件名排序处理 # 可选的文件过滤器(当前被注释掉) # 可以用于只处理特定的文件,如包含'baseball_swing'的文件 ###if not 'baseball_swing' in file: ###continue file_num += 1 # 文件计数器 logger.info( 'Processing file: {} [{} / {}]'.format(file, file_num, len(files))) # 加载并变换目标数据 # 1. load(): 根据数据集类型加载原始数据 # 2. transform(): 将数据转换为模型所需的格式 # 3. torch.from_numpy(): 将numpy数组转换为PyTorch张量 # 4. .float(): 确保数据类型为float32 target = torch.from_numpy(transform(args.dataset_name, load(args.dataset_name, os.path.join(root, file)))).float() logger.info("target shape:{}".format(target.shape)) # 执行SMPL模型拟合训练 # 传入SMPL层、目标数据、日志记录器、设备信息等 res = train(smpl_layer, target, logger, writer, device, args, cfg, meters) # 更新平均损失指标 # k=target.shape[0] 表示批次大小,用于加权平均 meters.update_avg(meters.min_loss, k=target.shape[0]) # 重置早停计数器,为下一个文件的训练做准备 meters.reset_early_stop() # 记录当前的平均损失 logger.info("avg_loss:{:.4f}".format(meters.avg)) # 保存拟合结果 # 1. save_params(): 保存拟合得到的SMPL参数 # 2. save_pic(): 保存可视化图像,包括拟合结果和原始目标的对比 save_params(res, file, logger, args.dataset_name) save_pic(res, smpl_layer, file, logger, args.dataset_name, target) # 清空GPU缓存,防止内存溢出 torch.cuda.empty_cache() # 所有文件处理完成,记录最终的平均损失 logger.info( "Fitting finished! Average loss: {:.9f}".format(meters.avg))