update code

This commit is contained in:
Iridoudou
2021-08-07 08:03:40 +08:00
parent 0152f1909f
commit 9f6274fc19
8 changed files with 160 additions and 65 deletions

View File

@ -10,6 +10,7 @@ from torch.utils.data import DataLoader
import torchvision.datasets as dset
import torchvision.transforms as T
import numpy as np
import scipy.io
from tensorboardX import SummaryWriter
from easydict import EasyDict as edict
import time
@ -24,6 +25,7 @@ from tqdm import tqdm
sys.path.append(os.getcwd())
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
from display_utils import display_model
from map import mapping
def train(smpl_layer, target,
logger, writer, device,
@ -48,9 +50,11 @@ def train(smpl_layer, target,
lr=cfg.TRAIN.LEARNING_RATE)
min_loss = float('inf')
for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)):
data_map=torch.tensor(cfg.DATASET.DATA_MAP.UTD_MHAD)[0].to(device)
# for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)):
for epoch in range(cfg.TRAIN.MAX_EPOCH):
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
loss = F.smooth_l1_loss(Jtr * 100, target * 100)
loss = F.smooth_l1_loss(Jtr.index_select(1, data_map) * 100, target * 100)
optimizer.zero_grad()
loss.backward()
optimizer.step()
@ -58,10 +62,22 @@ def train(smpl_layer, target,
min_loss = float(loss)
res = [pose_params, shape_params, verts, Jtr]
if epoch % cfg.TRAIN.WRITE == 0:
# logger.info("Epoch {}, lossPerBatch={:.9f}, scale={:.6f}".format(
# epoch, float(loss), float(scale)))
logger.info("Epoch {}, lossPerBatch={:.9f}, scale={:.6f}".format(
epoch, float(loss), float(scale)))
writer.add_scalar('loss', float(loss), epoch)
writer.add_scalar('learning_rate', float(
optimizer.state_dict()['param_groups'][0]['lr']), epoch)
if epoch % cfg.TRAIN.SAVE == 0 and epoch > 0:
for i in tqdm(range(Jtr.shape[0])):
display_model(
{'verts': verts.cpu().detach(),
'joints': Jtr.cpu().detach()},
model_faces=smpl_layer.th_faces,
with_joints=True,
kintree_table=smpl_layer.kintree_table,
savepath="fit/output/UTD_MHAD/picture/frame_{}".format(str(i).zfill(4)),
batch_idx=i,
show=True,
only_joint=True)
logger.info('Train ended, min_loss = {:.9f}'.format(float(min_loss)))
return res