diff --git a/fit/tools/main.py b/fit/tools/main.py index 095f5d3..fe2634f 100644 --- a/fit/tools/main.py +++ b/fit/tools/main.py @@ -1,5 +1,5 @@ -import numpy as np import torch +import numpy as np from tensorboardX import SummaryWriter from easydict import EasyDict as edict import time diff --git a/fit/tools/save.py b/fit/tools/save.py index e8fd9fe..8f7a401 100644 --- a/fit/tools/save.py +++ b/fit/tools/save.py @@ -6,8 +6,8 @@ import numpy as np import pickle sys.path.append(os.getcwd()) -from display_utils import display_model from label import get_label +from display_utils import display_model def create_dir_not_exist(path): @@ -15,11 +15,11 @@ def create_dir_not_exist(path): os.mkdir(path) -def save_pic(res, smpl_layer, file, logger, dataset_name,target): +def save_pic(res, smpl_layer, file, logger, dataset_name, target): _, _, verts, Jtr = res file_name = re.split('[/.]', file)[-2] - fit_path = "fit/output/{}/picture/fit/{}".format(dataset_name,file_name) - gt_path = "fit/output/{}/picture/gt/{}".format(dataset_name,file_name) + fit_path = "fit/output/{}/picture/fit/{}".format(dataset_name, file_name) + gt_path = "fit/output/{}/picture/gt/{}".format(dataset_name, file_name) create_dir_not_exist(fit_path) create_dir_not_exist(gt_path) logger.info('Saving pictures at {}'.format(fit_path)) @@ -53,7 +53,7 @@ def save_params(res, file, logger, dataset_name): fit_path = "fit/output/{}/".format(dataset_name) create_dir_not_exist(fit_path) logger.info('Saving params at {}'.format(fit_path)) - label=get_label(file_name, dataset_name) + label = get_label(file_name, dataset_name) pose_params = (pose_params.cpu().detach()).numpy().tolist() shape_params = (shape_params.cpu().detach()).numpy().tolist() Jtr = (Jtr.cpu().detach()).numpy().tolist() @@ -64,5 +64,5 @@ def save_params(res, file, logger, dataset_name): params["shape_params"] = shape_params params["Jtr"] = Jtr with open(os.path.join((fit_path), - "{}_params.pkl".format(file_name)), 'wb') as f: + "{}_params.pkl".format(file_name)), 'wb') as f: pickle.dump(params, f) diff --git a/fit/tools/train.py b/fit/tools/train.py index 954b08f..28852af 100644 --- a/fit/tools/train.py +++ b/fit/tools/train.py @@ -7,21 +7,21 @@ import os from tqdm import tqdm sys.path.append(os.getcwd()) - + class Early_Stop: - def __init__(self, eps = -1e-3, stop_threshold = 10) -> None: - self.min_loss=float('inf') - self.eps=eps - self.stop_threshold=stop_threshold - self.satis_num=0 - + def __init__(self, eps=-1e-3, stop_threshold=10) -> None: + self.min_loss = float('inf') + self.eps = eps + self.stop_threshold = stop_threshold + self.satis_num = 0 + def update(self, loss): delta = (loss - self.min_loss) / self.min_loss if float(loss) < self.min_loss: self.min_loss = float(loss) - update_res=True + update_res = True else: - update_res=False + update_res = False if delta >= self.eps: self.satis_num += 1 else: @@ -30,71 +30,71 @@ class Early_Stop: def init(smpl_layer, target, device, cfg): - params={} + params = {} params["pose_params"] = torch.rand(target.shape[0], 72) * 0.0 params["shape_params"] = torch.rand(target.shape[0], 10) * 0.03 params["scale"] = torch.ones([1]) - + smpl_layer = smpl_layer.to(device) params["pose_params"] = params["pose_params"].to(device) params["shape_params"] = params["shape_params"].to(device) target = target.to(device) params["scale"] = params["scale"].to(device) - + params["pose_params"].requires_grad = True params["shape_params"].requires_grad = bool(cfg.TRAIN.OPTIMIZE_SHAPE) params["scale"].requires_grad = bool(cfg.TRAIN.OPTIMIZE_SCALE) - + optimizer = optim.Adam([params["pose_params"], params["shape_params"], params["scale"]], lr=cfg.TRAIN.LEARNING_RATE) - - index={} - smpl_index=[] - dataset_index=[] + + index = {} + smpl_index = [] + dataset_index = [] for tp in cfg.DATASET.DATA_MAP: smpl_index.append(tp[0]) dataset_index.append(tp[1]) - - index["smpl_index"]=torch.tensor(smpl_index).to(device) - index["dataset_index"]=torch.tensor(dataset_index).to(device) - - return smpl_layer, params,target, optimizer, index + + index["smpl_index"] = torch.tensor(smpl_index).to(device) + index["dataset_index"] = torch.tensor(dataset_index).to(device) + + return smpl_layer, params, target, optimizer, index def train(smpl_layer, target, logger, writer, device, args, cfg): res = [] - smpl_layer, params,target, optimizer, index = \ + smpl_layer, params, target, optimizer, index = \ init(smpl_layer, target, device, cfg) pose_params = params["pose_params"] shape_params = params["shape_params"] scale = params["scale"] - + early_stop = Early_Stop() for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)): - # for epoch in range(cfg.TRAIN.MAX_EPOCH): + # for epoch in range(cfg.TRAIN.MAX_EPOCH): verts, Jtr = smpl_layer(pose_params, th_betas=shape_params) loss = F.smooth_l1_loss(Jtr.index_select(1, index["smpl_index"]) * 100 * scale, target.index_select(1, index["dataset_index"]) * 100) optimizer.zero_grad() loss.backward() optimizer.step() - + update_res, stop = early_stop.update(float(loss)) if update_res: res = [pose_params, shape_params, verts, Jtr] if stop: logger.info("Early stop at epoch {} !".format(epoch)) break - + if epoch % cfg.TRAIN.WRITE == 0: # logger.info("Epoch {}, lossPerBatch={:.6f}, scale={:.4f} EarlyStopSatis: {}".format( # epoch, float(loss),float(scale), early_stop.satis_num)) writer.add_scalar('loss', float(loss), epoch) writer.add_scalar('learning_rate', float( optimizer.state_dict()['param_groups'][0]['lr']), epoch) - - - logger.info('Train ended, min_loss = {:.9f}'.format(float(early_stop.min_loss))) + + logger.info('Train ended, min_loss = {:.9f}'.format( + float(early_stop.min_loss))) return res diff --git a/fit/tools/transform.py b/fit/tools/transform.py index 7a94078..9530004 100644 --- a/fit/tools/transform.py +++ b/fit/tools/transform.py @@ -14,4 +14,4 @@ def transform(name, arr: np.ndarray): arr[i][j] -= origin for k in range(3): arr[i][j][k] *= rotate[name][k] - return arr + return arr \ No newline at end of file