update code

This commit is contained in:
Iridoudou
2021-08-14 16:24:48 +08:00
parent f9c327d2a2
commit a5f0cf4653
4 changed files with 38 additions and 38 deletions

View File

@ -1,5 +1,5 @@
import numpy as np
import torch import torch
import numpy as np
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from easydict import EasyDict as edict from easydict import EasyDict as edict
import time import time

View File

@ -6,8 +6,8 @@ import numpy as np
import pickle import pickle
sys.path.append(os.getcwd()) sys.path.append(os.getcwd())
from display_utils import display_model
from label import get_label from label import get_label
from display_utils import display_model
def create_dir_not_exist(path): def create_dir_not_exist(path):
@ -15,11 +15,11 @@ def create_dir_not_exist(path):
os.mkdir(path) os.mkdir(path)
def save_pic(res, smpl_layer, file, logger, dataset_name,target): def save_pic(res, smpl_layer, file, logger, dataset_name, target):
_, _, verts, Jtr = res _, _, verts, Jtr = res
file_name = re.split('[/.]', file)[-2] file_name = re.split('[/.]', file)[-2]
fit_path = "fit/output/{}/picture/fit/{}".format(dataset_name,file_name) fit_path = "fit/output/{}/picture/fit/{}".format(dataset_name, file_name)
gt_path = "fit/output/{}/picture/gt/{}".format(dataset_name,file_name) gt_path = "fit/output/{}/picture/gt/{}".format(dataset_name, file_name)
create_dir_not_exist(fit_path) create_dir_not_exist(fit_path)
create_dir_not_exist(gt_path) create_dir_not_exist(gt_path)
logger.info('Saving pictures at {}'.format(fit_path)) logger.info('Saving pictures at {}'.format(fit_path))
@ -53,7 +53,7 @@ def save_params(res, file, logger, dataset_name):
fit_path = "fit/output/{}/".format(dataset_name) fit_path = "fit/output/{}/".format(dataset_name)
create_dir_not_exist(fit_path) create_dir_not_exist(fit_path)
logger.info('Saving params at {}'.format(fit_path)) logger.info('Saving params at {}'.format(fit_path))
label=get_label(file_name, dataset_name) label = get_label(file_name, dataset_name)
pose_params = (pose_params.cpu().detach()).numpy().tolist() pose_params = (pose_params.cpu().detach()).numpy().tolist()
shape_params = (shape_params.cpu().detach()).numpy().tolist() shape_params = (shape_params.cpu().detach()).numpy().tolist()
Jtr = (Jtr.cpu().detach()).numpy().tolist() Jtr = (Jtr.cpu().detach()).numpy().tolist()
@ -64,5 +64,5 @@ def save_params(res, file, logger, dataset_name):
params["shape_params"] = shape_params params["shape_params"] = shape_params
params["Jtr"] = Jtr params["Jtr"] = Jtr
with open(os.path.join((fit_path), with open(os.path.join((fit_path),
"{}_params.pkl".format(file_name)), 'wb') as f: "{}_params.pkl".format(file_name)), 'wb') as f:
pickle.dump(params, f) pickle.dump(params, f)

View File

@ -7,21 +7,21 @@ import os
from tqdm import tqdm from tqdm import tqdm
sys.path.append(os.getcwd()) sys.path.append(os.getcwd())
class Early_Stop: class Early_Stop:
def __init__(self, eps = -1e-3, stop_threshold = 10) -> None: def __init__(self, eps=-1e-3, stop_threshold=10) -> None:
self.min_loss=float('inf') self.min_loss = float('inf')
self.eps=eps self.eps = eps
self.stop_threshold=stop_threshold self.stop_threshold = stop_threshold
self.satis_num=0 self.satis_num = 0
def update(self, loss): def update(self, loss):
delta = (loss - self.min_loss) / self.min_loss delta = (loss - self.min_loss) / self.min_loss
if float(loss) < self.min_loss: if float(loss) < self.min_loss:
self.min_loss = float(loss) self.min_loss = float(loss)
update_res=True update_res = True
else: else:
update_res=False update_res = False
if delta >= self.eps: if delta >= self.eps:
self.satis_num += 1 self.satis_num += 1
else: else:
@ -30,71 +30,71 @@ class Early_Stop:
def init(smpl_layer, target, device, cfg): def init(smpl_layer, target, device, cfg):
params={} params = {}
params["pose_params"] = torch.rand(target.shape[0], 72) * 0.0 params["pose_params"] = torch.rand(target.shape[0], 72) * 0.0
params["shape_params"] = torch.rand(target.shape[0], 10) * 0.03 params["shape_params"] = torch.rand(target.shape[0], 10) * 0.03
params["scale"] = torch.ones([1]) params["scale"] = torch.ones([1])
smpl_layer = smpl_layer.to(device) smpl_layer = smpl_layer.to(device)
params["pose_params"] = params["pose_params"].to(device) params["pose_params"] = params["pose_params"].to(device)
params["shape_params"] = params["shape_params"].to(device) params["shape_params"] = params["shape_params"].to(device)
target = target.to(device) target = target.to(device)
params["scale"] = params["scale"].to(device) params["scale"] = params["scale"].to(device)
params["pose_params"].requires_grad = True params["pose_params"].requires_grad = True
params["shape_params"].requires_grad = bool(cfg.TRAIN.OPTIMIZE_SHAPE) params["shape_params"].requires_grad = bool(cfg.TRAIN.OPTIMIZE_SHAPE)
params["scale"].requires_grad = bool(cfg.TRAIN.OPTIMIZE_SCALE) params["scale"].requires_grad = bool(cfg.TRAIN.OPTIMIZE_SCALE)
optimizer = optim.Adam([params["pose_params"], params["shape_params"], params["scale"]], optimizer = optim.Adam([params["pose_params"], params["shape_params"], params["scale"]],
lr=cfg.TRAIN.LEARNING_RATE) lr=cfg.TRAIN.LEARNING_RATE)
index={} index = {}
smpl_index=[] smpl_index = []
dataset_index=[] dataset_index = []
for tp in cfg.DATASET.DATA_MAP: for tp in cfg.DATASET.DATA_MAP:
smpl_index.append(tp[0]) smpl_index.append(tp[0])
dataset_index.append(tp[1]) dataset_index.append(tp[1])
index["smpl_index"]=torch.tensor(smpl_index).to(device) index["smpl_index"] = torch.tensor(smpl_index).to(device)
index["dataset_index"]=torch.tensor(dataset_index).to(device) index["dataset_index"] = torch.tensor(dataset_index).to(device)
return smpl_layer, params,target, optimizer, index return smpl_layer, params, target, optimizer, index
def train(smpl_layer, target, def train(smpl_layer, target,
logger, writer, device, logger, writer, device,
args, cfg): args, cfg):
res = [] res = []
smpl_layer, params,target, optimizer, index = \ smpl_layer, params, target, optimizer, index = \
init(smpl_layer, target, device, cfg) init(smpl_layer, target, device, cfg)
pose_params = params["pose_params"] pose_params = params["pose_params"]
shape_params = params["shape_params"] shape_params = params["shape_params"]
scale = params["scale"] scale = params["scale"]
early_stop = Early_Stop() early_stop = Early_Stop()
for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)): for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)):
# for epoch in range(cfg.TRAIN.MAX_EPOCH): # for epoch in range(cfg.TRAIN.MAX_EPOCH):
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params) verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
loss = F.smooth_l1_loss(Jtr.index_select(1, index["smpl_index"]) * 100 * scale, loss = F.smooth_l1_loss(Jtr.index_select(1, index["smpl_index"]) * 100 * scale,
target.index_select(1, index["dataset_index"]) * 100) target.index_select(1, index["dataset_index"]) * 100)
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
update_res, stop = early_stop.update(float(loss)) update_res, stop = early_stop.update(float(loss))
if update_res: if update_res:
res = [pose_params, shape_params, verts, Jtr] res = [pose_params, shape_params, verts, Jtr]
if stop: if stop:
logger.info("Early stop at epoch {} !".format(epoch)) logger.info("Early stop at epoch {} !".format(epoch))
break break
if epoch % cfg.TRAIN.WRITE == 0: if epoch % cfg.TRAIN.WRITE == 0:
# logger.info("Epoch {}, lossPerBatch={:.6f}, scale={:.4f} EarlyStopSatis: {}".format( # logger.info("Epoch {}, lossPerBatch={:.6f}, scale={:.4f} EarlyStopSatis: {}".format(
# epoch, float(loss),float(scale), early_stop.satis_num)) # epoch, float(loss),float(scale), early_stop.satis_num))
writer.add_scalar('loss', float(loss), epoch) writer.add_scalar('loss', float(loss), epoch)
writer.add_scalar('learning_rate', float( writer.add_scalar('learning_rate', float(
optimizer.state_dict()['param_groups'][0]['lr']), epoch) optimizer.state_dict()['param_groups'][0]['lr']), epoch)
logger.info('Train ended, min_loss = {:.9f}'.format(
logger.info('Train ended, min_loss = {:.9f}'.format(float(early_stop.min_loss))) float(early_stop.min_loss)))
return res return res

View File

@ -14,4 +14,4 @@ def transform(name, arr: np.ndarray):
arr[i][j] -= origin arr[i][j] -= origin
for k in range(3): for k in range(3):
arr[i][j][k] *= rotate[name][k] arr[i][j][k] *= rotate[name][k]
return arr return arr