initial commit

This commit is contained in:
Iridoudou
2021-08-05 11:59:22 +08:00
parent a18b0197d8
commit 791f02f280
10 changed files with 275 additions and 6 deletions

66
fit/tools/train.py Normal file
View File

@ -0,0 +1,66 @@
import matplotlib as plt
from matplotlib.pyplot import show
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules import module
from torch.optim import lr_scheduler
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.datasets as dset
import torchvision.transforms as T
import numpy as np
from tensorboardX import SummaryWriter
from easydict import EasyDict as edict
import time
import inspect
import sys
import os
import logging
import argparse
import json
from tqdm import tqdm
sys.path.append(os.getcwd())
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
from display_utils import display_model
def train(smpl_layer, target,
logger, writer, device,
args, cfg):
res = []
pose_params = torch.rand(target.shape[0], 72) * 0.0
shape_params = torch.rand(target.shape[0], 10) * 0.1
scale = torch.ones([1])
smpl_layer = smpl_layer.to(device)
pose_params = pose_params.to(device)
shape_params = shape_params.to(device)
target = target.to(device)
scale = scale.to(device)
pose_params.requires_grad = True
shape_params.requires_grad = True
scale.requires_grad = True
optimizer = optim.Adam([pose_params],
lr=cfg.TRAIN.LEARNING_RATE)
min_loss = float('inf')
for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)):
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
loss = F.smooth_l1_loss(Jtr * 100, target * 100)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if float(loss) < min_loss:
min_loss = float(loss)
res = [pose_params, shape_params, verts, Jtr]
if epoch % cfg.TRAIN.WRITE == 0:
# logger.info("Epoch {}, lossPerBatch={:.9f}, scale={:.6f}".format(
# epoch, float(loss), float(scale)))
writer.add_scalar('loss', float(loss), epoch)
writer.add_scalar('learning_rate', float(
optimizer.state_dict()['param_groups'][0]['lr']), epoch)
logger.info('Train ended, loss = {:.9f}'.format(float(loss)))
return res