Files
Pose_to_SMPL_an_230402/fit/tools/train.py
2021-08-07 21:19:21 +08:00

122 lines
4.0 KiB
Python

from fit.tools.save import save_pic
import matplotlib as plt
from matplotlib.pyplot import show
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules import module
from torch.optim import lr_scheduler
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.datasets as dset
import torchvision.transforms as T
import numpy as np
import scipy.io
from tensorboardX import SummaryWriter
from easydict import EasyDict as edict
import time
import inspect
import sys
import os
import logging
import argparse
import json
from tqdm import tqdm
sys.path.append(os.getcwd())
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
from display_utils import display_model
from map import mapping
class Early_Stop:
def __init__(self, eps = -1e-3, stop_threshold = 10) -> None:
self.min_loss=float('inf')
self.eps=eps
self.stop_threshold=stop_threshold
self.satis_num=0
def update(self, loss):
delta = (loss - self.min_loss) / self.min_loss
if float(loss) < self.min_loss:
self.min_loss = float(loss)
update_res=True
else:
update_res=False
if delta >= self.eps:
self.satis_num += 1
else:
self.satis_num = max(0,self.satis_num-1)
return update_res, self.satis_num >= self.stop_threshold
def init(smpl_layer, target, device, cfg):
params={}
params["pose_params"] = torch.rand(target.shape[0], 72) * 0.0
params["shape_params"] = torch.rand(target.shape[0], 10) * 0.03
params["scale"] = torch.ones([1])
smpl_layer = smpl_layer.to(device)
params["pose_params"] = params["pose_params"].to(device)
params["shape_params"] = params["shape_params"].to(device)
target = target.to(device)
params["scale"] = params["scale"].to(device)
params["pose_params"].requires_grad = True
params["shape_params"].requires_grad = True
params["scale"].requires_grad = False
optimizer = optim.Adam([params["pose_params"], params["shape_params"]],
lr=cfg.TRAIN.LEARNING_RATE)
index={}
smpl_index=[]
dataset_index=[]
for tp in cfg.DATASET.DATA_MAP:
smpl_index.append(tp[0])
dataset_index.append(tp[1])
index["smpl_index"]=torch.tensor(smpl_index).to(device)
index["dataset_index"]=torch.tensor(dataset_index).to(device)
return smpl_layer, params,target, optimizer, index
def train(smpl_layer, target,
logger, writer, device,
args, cfg):
res = []
smpl_layer, params,target, optimizer, index = \
init(smpl_layer, target, device, cfg)
pose_params = params["pose_params"]
shape_params = params["shape_params"]
scale = params["scale"]
early_stop = Early_Stop()
for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)):
# for epoch in range(cfg.TRAIN.MAX_EPOCH):
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
loss = F.smooth_l1_loss(Jtr.index_select(1, index["smpl_index"]) * 100,
target.index_select(1, index["dataset_index"]) * 100)
optimizer.zero_grad()
loss.backward()
optimizer.step()
update_res, stop = early_stop.update(float(loss))
if update_res:
res = [pose_params, shape_params, verts, Jtr]
if stop:
logger.info("Early stop at epoch {} !".format(epoch))
break
if epoch % cfg.TRAIN.WRITE == 0:
# logger.info("Epoch {}, lossPerBatch={:.6f}, EarlyStopSatis: {}".format(
# epoch, float(loss), early_stop.satis_num))
writer.add_scalar('loss', float(loss), epoch)
writer.add_scalar('learning_rate', float(
optimizer.state_dict()['param_groups'][0]['lr']), epoch)
logger.info('Train ended, min_loss = {:.9f}'.format(float(early_stop.min_loss)))
return res