From 8573a09b8470ba0b2f9030ed0577b55f07677c0a Mon Sep 17 00:00:00 2001 From: Iridoudou <2534936416@qq.com> Date: Thu, 19 Aug 2021 12:40:54 +0800 Subject: [PATCH] update meters --- README.md | 3 ++- fit/tools/cross_detector.py | 18 ++++++++++------ fit/tools/label.py | 5 +---- fit/tools/load.py | 10 +++++---- fit/tools/main.py | 17 +++++++++++---- fit/tools/meters.py | 27 +++++++++++++++++++++++ fit/tools/save.py | 8 +++---- fit/tools/train.py | 43 ++++++++++--------------------------- fit/tools/transform.py | 3 ++- make_gif.py | 6 +++--- 10 files changed, 81 insertions(+), 59 deletions(-) create mode 100644 fit/tools/meters.py diff --git a/README.md b/README.md index 41ed84c..593f4eb 100644 --- a/README.md +++ b/README.md @@ -40,11 +40,12 @@ The SMPL human body layer for Pytorch is from the [smplpytorch](https://github.c - Download the datasets you want to fit - currently supported datasets: + currently support: - [HumanAct12](https://ericguo5513.github.io/action-to-motion/) - [CMU Mocap](https://ericguo5513.github.io/action-to-motion/) - [UTD-MHAD](https://personal.utdallas.edu/~kehtar/UTD-MHAD.html) + - [Human3.6M](http://vision.imar.ro/human3.6m/description.php) - Set the **DATASET.PATH** in the corresponding configuration file to the location of dataset. diff --git a/fit/tools/cross_detector.py b/fit/tools/cross_detector.py index b080144..3771a03 100644 --- a/fit/tools/cross_detector.py +++ b/fit/tools/cross_detector.py @@ -4,6 +4,7 @@ import os import json import argparse + def parse_args(): parser = argparse.ArgumentParser(description='Detect cross joints') parser.add_argument('--dataset_name', dest='dataset_name', @@ -15,10 +16,12 @@ def parse_args(): args = parser.parse_args() return args + def create_dir_not_exist(path): if not os.path.exists(path): os.mkdir(path) + def load_Jtr(file_path): with open(file_path, 'rb') as f: data = pickle.load(f) @@ -40,15 +43,18 @@ def cross_frames(Jtr: np.ndarray): def cross_detector(dir_path): - ans={} + ans = {} for root, dirs, files in os.walk(dir_path): for file in files: file_path = os.path.join(dir_path, file) Jtr = load_Jtr(file_path) - ans[file]=cross_frames(Jtr) + ans[file] = cross_frames(Jtr) return ans - + + if __name__ == "__main__": - args=parse_args() - d=cross_detector(args.output_path) - json.dump(d,open("./fit/output/cross_detection/{}.json".format(args.dataset_name),'w')) \ No newline at end of file + args = parse_args() + d = cross_detector(args.output_path) + json.dump( + d, open("./fit/output/cross_detection/{}.json" + .format(args.dataset_name), 'w')) \ No newline at end of file diff --git a/fit/tools/label.py b/fit/tools/label.py index de6eea7..6b842e3 100644 --- a/fit/tools/label.py +++ b/fit/tools/label.py @@ -1166,7 +1166,4 @@ def get_label(file_name, dataset_name): return UTD_MHAD[key] elif dataset_name == 'CMU_Mocap': key = file_name.split('.')[0] - if key in CMU_Mocap.keys(): - return CMU_Mocap[key] - else: - return "" + return CMU_Mocap[key] if key in CMU_Mocap.keys() else "" \ No newline at end of file diff --git a/fit/tools/load.py b/fit/tools/load.py index 422c161..b23b29c 100644 --- a/fit/tools/load.py +++ b/fit/tools/load.py @@ -1,11 +1,11 @@ import scipy.io import numpy as np +import json def load(name, path): if name == 'UTD_MHAD': - data = scipy.io.loadmat(path) - arr = data['d_skel'] + arr = scipy.io.loadmat(path)['d_skel'] new_arr = np.zeros([arr.shape[2], arr.shape[0], arr.shape[1]]) for i in range(arr.shape[2]): for j in range(arr.shape[0]): @@ -13,7 +13,9 @@ def load(name, path): new_arr[i][j][k] = arr[j][k][i] return new_arr elif name == 'HumanAct12': - return np.load(path,allow_pickle=True) + return np.load(path, allow_pickle=True) elif name == "CMU_Mocap": - return np.load(path,allow_pickle=True) + return np.load(path, allow_pickle=True) + elif name == "Human3.6M": + return np.load(path, allow_pickle=True) diff --git a/fit/tools/main.py b/fit/tools/main.py index fe2634f..644b2e2 100644 --- a/fit/tools/main.py +++ b/fit/tools/main.py @@ -9,14 +9,17 @@ import logging import argparse import json + sys.path.append(os.getcwd()) from load import load from save import save_pic, save_params from transform import transform from train import train from smplpytorch.pytorch.smpl_layer import SMPL_Layer -torch.backends.cudnn.benchmark = True +from meters import Meters +torch.backends.cudnn.enabled = True +torch.backends.cudnn.benchmark = True def parse_args(): parser = argparse.ArgumentParser(description='Fit SMPL') @@ -94,7 +97,8 @@ if __name__ == "__main__": center_idx=0, gender=cfg.MODEL.GENDER, model_root='smplpytorch/native/models') - + + meters=Meters() file_num = 0 for root, dirs, files in os.walk(cfg.DATASET.PATH): for file in files: @@ -102,10 +106,15 @@ if __name__ == "__main__": logger.info('Processing file: {} [{} / {}]'.format(file, file_num, len(files))) target = torch.from_numpy(transform(args.dataset_name,load(args.dataset_name, os.path.join(root, file)))).float() - + logger.info("target shape:{}".format(target.shape)) res = train(smpl_layer, target, logger, writer, device, - args, cfg) + args, cfg, meters) + meters.update_avg(meters.min_loss, k=target.shape[0]) + meters.reset_early_stop() + logger.info("avg_loss:{:.4f}".format(meters.avg)) # save_pic(res,smpl_layer,file,logger,args.dataset_name,target) save_params(res, file, logger, args.dataset_name) + torch.cuda.empty_cache() + logger.info("Fitting finished! Average loss: {:.9f}".format(meters.avg)) diff --git a/fit/tools/meters.py b/fit/tools/meters.py new file mode 100644 index 0000000..b6edb82 --- /dev/null +++ b/fit/tools/meters.py @@ -0,0 +1,27 @@ +class Meters: + def __init__(self, eps=-1e-3, stop_threshold=10) -> None: + self.eps = eps + self.stop_threshold = stop_threshold + self.avg = 0 + self.cnt = 0 + self.reset_early_stop() + + def reset_early_stop(self): + self.min_loss = float('inf') + self.satis_num = 0 + self.update_res = True + self.early_stop = False + + def update_avg(self, val, k=1): + self.avg = self.avg + (val - self.avg) * k / (self.cnt + k) + self.cnt += k + + def update_early_stop(self, val): + delta = (val - self.min_loss) / self.min_loss + if float(val) < self.min_loss: + self.min_loss = float(val) + self.update_res = True + else: + self.update_res = False + self.satis_num = self.satis_num + 1 if delta >= self.eps else 0 + self.early_stop = self.satis_num >= self.stop_threshold \ No newline at end of file diff --git a/fit/tools/save.py b/fit/tools/save.py index 8f7a401..5233a5f 100644 --- a/fit/tools/save.py +++ b/fit/tools/save.py @@ -19,9 +19,9 @@ def save_pic(res, smpl_layer, file, logger, dataset_name, target): _, _, verts, Jtr = res file_name = re.split('[/.]', file)[-2] fit_path = "fit/output/{}/picture/fit/{}".format(dataset_name, file_name) - gt_path = "fit/output/{}/picture/gt/{}".format(dataset_name, file_name) + # gt_path = "fit/output/{}/picture/gt/{}".format(dataset_name, file_name) create_dir_not_exist(fit_path) - create_dir_not_exist(gt_path) + # create_dir_not_exist(gt_path) logger.info('Saving pictures at {}'.format(fit_path)) for i in tqdm(range(Jtr.shape[0])): display_model( @@ -32,7 +32,7 @@ def save_pic(res, smpl_layer, file, logger, dataset_name, target): kintree_table=smpl_layer.kintree_table, savepath=os.path.join(fit_path+"/frame_{}".format(i)), batch_idx=i, - show=False, + show=True, only_joint=True) # display_model( # {'verts': verts.cpu().detach(), @@ -59,7 +59,7 @@ def save_params(res, file, logger, dataset_name): Jtr = (Jtr.cpu().detach()).numpy().tolist() verts = (verts.cpu().detach()).numpy().tolist() params = {} - params["label"] = label + # params["label"] = label params["pose_params"] = pose_params params["shape_params"] = shape_params params["Jtr"] = Jtr diff --git a/fit/tools/train.py b/fit/tools/train.py index 28852af..1e8f855 100644 --- a/fit/tools/train.py +++ b/fit/tools/train.py @@ -8,31 +8,11 @@ from tqdm import tqdm sys.path.append(os.getcwd()) -class Early_Stop: - def __init__(self, eps=-1e-3, stop_threshold=10) -> None: - self.min_loss = float('inf') - self.eps = eps - self.stop_threshold = stop_threshold - self.satis_num = 0 - - def update(self, loss): - delta = (loss - self.min_loss) / self.min_loss - if float(loss) < self.min_loss: - self.min_loss = float(loss) - update_res = True - else: - update_res = False - if delta >= self.eps: - self.satis_num += 1 - else: - self.satis_num = 0 - return update_res, self.satis_num >= self.stop_threshold - def init(smpl_layer, target, device, cfg): params = {} - params["pose_params"] = torch.rand(target.shape[0], 72) * 0.0 - params["shape_params"] = torch.rand(target.shape[0], 10) * 0.03 + params["pose_params"] = torch.zeros(target.shape[0], 72) + params["shape_params"] = torch.zeros(target.shape[0], 10) params["scale"] = torch.ones([1]) smpl_layer = smpl_layer.to(device) @@ -63,7 +43,7 @@ def init(smpl_layer, target, device, cfg): def train(smpl_layer, target, logger, writer, device, - args, cfg): + args, cfg, meters): res = [] smpl_layer, params, target, optimizer, index = \ init(smpl_layer, target, device, cfg) @@ -71,20 +51,19 @@ def train(smpl_layer, target, shape_params = params["shape_params"] scale = params["scale"] - early_stop = Early_Stop() for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)): - # for epoch in range(cfg.TRAIN.MAX_EPOCH): + # for epoch in range(cfg.TRAIN.MAX_EPOCH): verts, Jtr = smpl_layer(pose_params, th_betas=shape_params) - loss = F.smooth_l1_loss(Jtr.index_select(1, index["smpl_index"]) * 100 * scale, - target.index_select(1, index["dataset_index"]) * 100) + loss = F.smooth_l1_loss(Jtr.index_select(1, index["smpl_index"]) * 100, + target.index_select(1, index["dataset_index"]) * 100 * scale) optimizer.zero_grad() loss.backward() optimizer.step() - update_res, stop = early_stop.update(float(loss)) - if update_res: + meters.update_early_stop(float(loss)) + if meters.update_res: res = [pose_params, shape_params, verts, Jtr] - if stop: + if meters.early_stop: logger.info("Early stop at epoch {} !".format(epoch)) break @@ -95,6 +74,6 @@ def train(smpl_layer, target, writer.add_scalar('learning_rate', float( optimizer.state_dict()['param_groups'][0]['lr']), epoch) - logger.info('Train ended, min_loss = {:.9f}'.format( - float(early_stop.min_loss))) + logger.info('Train ended, min_loss = {:.4f}'.format( + float(meters.min_loss))) return res diff --git a/fit/tools/transform.py b/fit/tools/transform.py index 9530004..c36317b 100644 --- a/fit/tools/transform.py +++ b/fit/tools/transform.py @@ -3,7 +3,8 @@ import numpy as np rotate = { 'HumanAct12': [1., -1., -1.], 'CMU_Mocap': [0.05, 0.05, 0.05], - 'UTD_MHAD': [-1., 1., -1.] + 'UTD_MHAD': [-1., 1., -1.], + 'Human3.6M': [-0.001, -0.001, 0.001] } diff --git a/make_gif.py b/make_gif.py index 1bfe009..2d022d9 100644 --- a/make_gif.py +++ b/make_gif.py @@ -1,7 +1,7 @@ import matplotlib.pyplot as plt import imageio, os images = [] -filenames = sorted(fn for fn in os.listdir('./fit/output/CMU_Mocap/picture/fit/01_01') ) +filenames = sorted(fn for fn in os.listdir('./fit/output/Human3.6M/picture/fit/s_01_act_09_subact_02_ca_02') ) for filename in filenames: - images.append(imageio.imread('./fit/output/CMU_Mocap/picture/fit/01_01/'+filename)) -imageio.mimsave('fit.gif', images, duration=0.2) \ No newline at end of file + images.append(imageio.imread('./fit/output/Human3.6M/picture/fit/s_01_act_09_subact_02_ca_02/'+filename)) +imageio.mimsave('fit_mesh.gif', images, duration=0.2) \ No newline at end of file