update code
This commit is contained in:
@ -1,5 +1,5 @@
|
|||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
|
import numpy as np
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
from easydict import EasyDict as edict
|
from easydict import EasyDict as edict
|
||||||
import time
|
import time
|
||||||
|
|||||||
@ -6,8 +6,8 @@ import numpy as np
|
|||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
sys.path.append(os.getcwd())
|
sys.path.append(os.getcwd())
|
||||||
from display_utils import display_model
|
|
||||||
from label import get_label
|
from label import get_label
|
||||||
|
from display_utils import display_model
|
||||||
|
|
||||||
|
|
||||||
def create_dir_not_exist(path):
|
def create_dir_not_exist(path):
|
||||||
@ -15,11 +15,11 @@ def create_dir_not_exist(path):
|
|||||||
os.mkdir(path)
|
os.mkdir(path)
|
||||||
|
|
||||||
|
|
||||||
def save_pic(res, smpl_layer, file, logger, dataset_name,target):
|
def save_pic(res, smpl_layer, file, logger, dataset_name, target):
|
||||||
_, _, verts, Jtr = res
|
_, _, verts, Jtr = res
|
||||||
file_name = re.split('[/.]', file)[-2]
|
file_name = re.split('[/.]', file)[-2]
|
||||||
fit_path = "fit/output/{}/picture/fit/{}".format(dataset_name,file_name)
|
fit_path = "fit/output/{}/picture/fit/{}".format(dataset_name, file_name)
|
||||||
gt_path = "fit/output/{}/picture/gt/{}".format(dataset_name,file_name)
|
gt_path = "fit/output/{}/picture/gt/{}".format(dataset_name, file_name)
|
||||||
create_dir_not_exist(fit_path)
|
create_dir_not_exist(fit_path)
|
||||||
create_dir_not_exist(gt_path)
|
create_dir_not_exist(gt_path)
|
||||||
logger.info('Saving pictures at {}'.format(fit_path))
|
logger.info('Saving pictures at {}'.format(fit_path))
|
||||||
@ -53,7 +53,7 @@ def save_params(res, file, logger, dataset_name):
|
|||||||
fit_path = "fit/output/{}/".format(dataset_name)
|
fit_path = "fit/output/{}/".format(dataset_name)
|
||||||
create_dir_not_exist(fit_path)
|
create_dir_not_exist(fit_path)
|
||||||
logger.info('Saving params at {}'.format(fit_path))
|
logger.info('Saving params at {}'.format(fit_path))
|
||||||
label=get_label(file_name, dataset_name)
|
label = get_label(file_name, dataset_name)
|
||||||
pose_params = (pose_params.cpu().detach()).numpy().tolist()
|
pose_params = (pose_params.cpu().detach()).numpy().tolist()
|
||||||
shape_params = (shape_params.cpu().detach()).numpy().tolist()
|
shape_params = (shape_params.cpu().detach()).numpy().tolist()
|
||||||
Jtr = (Jtr.cpu().detach()).numpy().tolist()
|
Jtr = (Jtr.cpu().detach()).numpy().tolist()
|
||||||
@ -64,5 +64,5 @@ def save_params(res, file, logger, dataset_name):
|
|||||||
params["shape_params"] = shape_params
|
params["shape_params"] = shape_params
|
||||||
params["Jtr"] = Jtr
|
params["Jtr"] = Jtr
|
||||||
with open(os.path.join((fit_path),
|
with open(os.path.join((fit_path),
|
||||||
"{}_params.pkl".format(file_name)), 'wb') as f:
|
"{}_params.pkl".format(file_name)), 'wb') as f:
|
||||||
pickle.dump(params, f)
|
pickle.dump(params, f)
|
||||||
|
|||||||
@ -7,21 +7,21 @@ import os
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
sys.path.append(os.getcwd())
|
sys.path.append(os.getcwd())
|
||||||
|
|
||||||
|
|
||||||
class Early_Stop:
|
class Early_Stop:
|
||||||
def __init__(self, eps = -1e-3, stop_threshold = 10) -> None:
|
def __init__(self, eps=-1e-3, stop_threshold=10) -> None:
|
||||||
self.min_loss=float('inf')
|
self.min_loss = float('inf')
|
||||||
self.eps=eps
|
self.eps = eps
|
||||||
self.stop_threshold=stop_threshold
|
self.stop_threshold = stop_threshold
|
||||||
self.satis_num=0
|
self.satis_num = 0
|
||||||
|
|
||||||
def update(self, loss):
|
def update(self, loss):
|
||||||
delta = (loss - self.min_loss) / self.min_loss
|
delta = (loss - self.min_loss) / self.min_loss
|
||||||
if float(loss) < self.min_loss:
|
if float(loss) < self.min_loss:
|
||||||
self.min_loss = float(loss)
|
self.min_loss = float(loss)
|
||||||
update_res=True
|
update_res = True
|
||||||
else:
|
else:
|
||||||
update_res=False
|
update_res = False
|
||||||
if delta >= self.eps:
|
if delta >= self.eps:
|
||||||
self.satis_num += 1
|
self.satis_num += 1
|
||||||
else:
|
else:
|
||||||
@ -30,71 +30,71 @@ class Early_Stop:
|
|||||||
|
|
||||||
|
|
||||||
def init(smpl_layer, target, device, cfg):
|
def init(smpl_layer, target, device, cfg):
|
||||||
params={}
|
params = {}
|
||||||
params["pose_params"] = torch.rand(target.shape[0], 72) * 0.0
|
params["pose_params"] = torch.rand(target.shape[0], 72) * 0.0
|
||||||
params["shape_params"] = torch.rand(target.shape[0], 10) * 0.03
|
params["shape_params"] = torch.rand(target.shape[0], 10) * 0.03
|
||||||
params["scale"] = torch.ones([1])
|
params["scale"] = torch.ones([1])
|
||||||
|
|
||||||
smpl_layer = smpl_layer.to(device)
|
smpl_layer = smpl_layer.to(device)
|
||||||
params["pose_params"] = params["pose_params"].to(device)
|
params["pose_params"] = params["pose_params"].to(device)
|
||||||
params["shape_params"] = params["shape_params"].to(device)
|
params["shape_params"] = params["shape_params"].to(device)
|
||||||
target = target.to(device)
|
target = target.to(device)
|
||||||
params["scale"] = params["scale"].to(device)
|
params["scale"] = params["scale"].to(device)
|
||||||
|
|
||||||
params["pose_params"].requires_grad = True
|
params["pose_params"].requires_grad = True
|
||||||
params["shape_params"].requires_grad = bool(cfg.TRAIN.OPTIMIZE_SHAPE)
|
params["shape_params"].requires_grad = bool(cfg.TRAIN.OPTIMIZE_SHAPE)
|
||||||
params["scale"].requires_grad = bool(cfg.TRAIN.OPTIMIZE_SCALE)
|
params["scale"].requires_grad = bool(cfg.TRAIN.OPTIMIZE_SCALE)
|
||||||
|
|
||||||
optimizer = optim.Adam([params["pose_params"], params["shape_params"], params["scale"]],
|
optimizer = optim.Adam([params["pose_params"], params["shape_params"], params["scale"]],
|
||||||
lr=cfg.TRAIN.LEARNING_RATE)
|
lr=cfg.TRAIN.LEARNING_RATE)
|
||||||
|
|
||||||
index={}
|
index = {}
|
||||||
smpl_index=[]
|
smpl_index = []
|
||||||
dataset_index=[]
|
dataset_index = []
|
||||||
for tp in cfg.DATASET.DATA_MAP:
|
for tp in cfg.DATASET.DATA_MAP:
|
||||||
smpl_index.append(tp[0])
|
smpl_index.append(tp[0])
|
||||||
dataset_index.append(tp[1])
|
dataset_index.append(tp[1])
|
||||||
|
|
||||||
index["smpl_index"]=torch.tensor(smpl_index).to(device)
|
index["smpl_index"] = torch.tensor(smpl_index).to(device)
|
||||||
index["dataset_index"]=torch.tensor(dataset_index).to(device)
|
index["dataset_index"] = torch.tensor(dataset_index).to(device)
|
||||||
|
|
||||||
return smpl_layer, params,target, optimizer, index
|
return smpl_layer, params, target, optimizer, index
|
||||||
|
|
||||||
|
|
||||||
def train(smpl_layer, target,
|
def train(smpl_layer, target,
|
||||||
logger, writer, device,
|
logger, writer, device,
|
||||||
args, cfg):
|
args, cfg):
|
||||||
res = []
|
res = []
|
||||||
smpl_layer, params,target, optimizer, index = \
|
smpl_layer, params, target, optimizer, index = \
|
||||||
init(smpl_layer, target, device, cfg)
|
init(smpl_layer, target, device, cfg)
|
||||||
pose_params = params["pose_params"]
|
pose_params = params["pose_params"]
|
||||||
shape_params = params["shape_params"]
|
shape_params = params["shape_params"]
|
||||||
scale = params["scale"]
|
scale = params["scale"]
|
||||||
|
|
||||||
early_stop = Early_Stop()
|
early_stop = Early_Stop()
|
||||||
for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)):
|
for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)):
|
||||||
# for epoch in range(cfg.TRAIN.MAX_EPOCH):
|
# for epoch in range(cfg.TRAIN.MAX_EPOCH):
|
||||||
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
|
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
|
||||||
loss = F.smooth_l1_loss(Jtr.index_select(1, index["smpl_index"]) * 100 * scale,
|
loss = F.smooth_l1_loss(Jtr.index_select(1, index["smpl_index"]) * 100 * scale,
|
||||||
target.index_select(1, index["dataset_index"]) * 100)
|
target.index_select(1, index["dataset_index"]) * 100)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
update_res, stop = early_stop.update(float(loss))
|
update_res, stop = early_stop.update(float(loss))
|
||||||
if update_res:
|
if update_res:
|
||||||
res = [pose_params, shape_params, verts, Jtr]
|
res = [pose_params, shape_params, verts, Jtr]
|
||||||
if stop:
|
if stop:
|
||||||
logger.info("Early stop at epoch {} !".format(epoch))
|
logger.info("Early stop at epoch {} !".format(epoch))
|
||||||
break
|
break
|
||||||
|
|
||||||
if epoch % cfg.TRAIN.WRITE == 0:
|
if epoch % cfg.TRAIN.WRITE == 0:
|
||||||
# logger.info("Epoch {}, lossPerBatch={:.6f}, scale={:.4f} EarlyStopSatis: {}".format(
|
# logger.info("Epoch {}, lossPerBatch={:.6f}, scale={:.4f} EarlyStopSatis: {}".format(
|
||||||
# epoch, float(loss),float(scale), early_stop.satis_num))
|
# epoch, float(loss),float(scale), early_stop.satis_num))
|
||||||
writer.add_scalar('loss', float(loss), epoch)
|
writer.add_scalar('loss', float(loss), epoch)
|
||||||
writer.add_scalar('learning_rate', float(
|
writer.add_scalar('learning_rate', float(
|
||||||
optimizer.state_dict()['param_groups'][0]['lr']), epoch)
|
optimizer.state_dict()['param_groups'][0]['lr']), epoch)
|
||||||
|
|
||||||
|
logger.info('Train ended, min_loss = {:.9f}'.format(
|
||||||
logger.info('Train ended, min_loss = {:.9f}'.format(float(early_stop.min_loss)))
|
float(early_stop.min_loss)))
|
||||||
return res
|
return res
|
||||||
|
|||||||
@ -14,4 +14,4 @@ def transform(name, arr: np.ndarray):
|
|||||||
arr[i][j] -= origin
|
arr[i][j] -= origin
|
||||||
for k in range(3):
|
for k in range(3):
|
||||||
arr[i][j][k] *= rotate[name][k]
|
arr[i][j][k] *= rotate[name][k]
|
||||||
return arr
|
return arr
|
||||||
Reference in New Issue
Block a user