From 2b3e65e2a2e6ef2bafe947ad4b97a75f773cde24 Mon Sep 17 00:00:00 2001 From: Iridoudou <2534936416@qq.com> Date: Sat, 7 Aug 2021 21:19:21 +0800 Subject: [PATCH] add HumanAct12, UTD_MHAD --- fit/configs/HumanAct12.json | 115 +++++++++++++++++++++++++++++++++++ fit/configs/UTD_MHAD.json | 83 ++++++++++++++++++++++++++ fit/configs/config.json | 2 +- fit/tools/load.py | 2 +- fit/tools/main.py | 59 ++++++------------ fit/tools/save.py | 39 +++++------- fit/tools/train.py | 116 ++++++++++++++++++++++++------------ fit/tools/transform.py | 28 ++++++--- make_gif.py | 2 +- 9 files changed, 329 insertions(+), 117 deletions(-) create mode 100644 fit/configs/HumanAct12.json create mode 100644 fit/configs/UTD_MHAD.json diff --git a/fit/configs/HumanAct12.json b/fit/configs/HumanAct12.json new file mode 100644 index 0000000..cbe1207 --- /dev/null +++ b/fit/configs/HumanAct12.json @@ -0,0 +1,115 @@ +{ + "MODEL": { + "GENDER": "neutral" + }, + "TRAIN": { + "LEARNING_RATE": 2e-2, + "MAX_EPOCH": 500, + "WRITE": 1 + }, + "USE_GPU": 1, + "DATASET": { + "NAME": "UTD-MHAD", + "PATH": "../Action2Motion/HumanAct12/HumanAct12/", + "TARGET_PATH": "", + "DATA_MAP": [ + [ + 0, + 0 + ], + [ + 1, + 1 + ], + [ + 2, + 2 + ], + [ + 3, + 3 + ], + [ + 4, + 4 + ], + [ + 5, + 5 + ], + [ + 6, + 6 + ], + [ + 7, + 7 + ], + [ + 8, + 8 + ], + [ + 9, + 9 + ], + [ + 10, + 10 + ], + [ + 11, + 11 + ], + [ + 12, + 12 + ], + [ + 13, + 13 + ], + [ + 14, + 14 + ], + [ + 15, + 15 + ], + [ + 16, + 16 + ], + [ + 17, + 17 + ], + [ + 18, + 18 + ], + [ + 19, + 19 + ], + [ + 20, + 20 + ], + [ + 21, + 21 + ], + [ + 22, + 22 + ], + [ + 23, + 23 + ] + ] + }, + "DEBUG": 0 +} \ No newline at end of file diff --git a/fit/configs/UTD_MHAD.json b/fit/configs/UTD_MHAD.json new file mode 100644 index 0000000..fc91d77 --- /dev/null +++ b/fit/configs/UTD_MHAD.json @@ -0,0 +1,83 @@ +{ + "MODEL": { + "GENDER": "neutral" + }, + "TRAIN": { + "LEARNING_RATE": 2e-2, + "MAX_EPOCH": 500, + "WRITE": 1 + }, + "USE_GPU": 1, + "DATASET": { + "NAME": "UTD-MHAD", + "PATH": "../UTD-MHAD/Skeleton/Skeleton/", + "TARGET_PATH": "", + "DATA_MAP": [ + [ + 12, + 1 + ], + [ + 0, + 3 + ], + [ + 16, + 4 + ], + [ + 18, + 5 + ], + [ + 20, + 6 + ], + [ + 22, + 7 + ], + [ + 17, + 8 + ], + [ + 19, + 9 + ], + [ + 21, + 10 + ], + [ + 23, + 11 + ], + [ + 1, + 12 + ], + [ + 4, + 13 + ], + [ + 7, + 14 + ], + [ + 2, + 16 + ], + [ + 5, + 17 + ], + [ + 8, + 18 + ] + ] + }, + "DEBUG": 0 +} \ No newline at end of file diff --git a/fit/configs/config.json b/fit/configs/config.json index 478e8ea..5319de1 100644 --- a/fit/configs/config.json +++ b/fit/configs/config.json @@ -54,5 +54,5 @@ ] } }, - "DEBUG": 1 + "DEBUG": 0 } \ No newline at end of file diff --git a/fit/tools/load.py b/fit/tools/load.py index a3522f5..91889c3 100644 --- a/fit/tools/load.py +++ b/fit/tools/load.py @@ -13,4 +13,4 @@ def load(name, path): new_arr[i][j][k] = arr[j][k][i] return new_arr elif name == 'HumanAct12': - return np.load(path) + return np.load(path,allow_pickle=True) diff --git a/fit/tools/main.py b/fit/tools/main.py index f89cfaf..c0021d5 100644 --- a/fit/tools/main.py +++ b/fit/tools/main.py @@ -1,14 +1,4 @@ -import matplotlib as plt import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.nn.modules import module -from torch.optim import lr_scheduler -import torch.optim as optim -from torch.utils.data import DataLoader -from torch.utils.data import sampler -import torchvision.datasets as dset -import torchvision.transforms as T import numpy as np from tensorboardX import SummaryWriter from easydict import EasyDict as edict @@ -35,17 +25,15 @@ def parse_args(): parser.add_argument('--exp', dest='exp', help='Define exp name', default=time.strftime('%Y-%m-%d %H-%M-%S', time.localtime(time.time())), type=str) - parser.add_argument('--config_path', dest='config_path', - help='Select configuration file', - default='fit/configs/config.json', type=str) - parser.add_argument('--dataset_path', dest='dataset_path', + parser.add_argument('--dataset_name', dest='dataset_name', help='select dataset', default='', type=str) args = parser.parse_args() return args def get_config(args): - with open(args.config_path, 'r') as f: + config_path='fit/configs/{}.json'.format(args.dataset_name) + with open(config_path, 'r') as f: data = json.load(f) cfg = edict(data.copy()) return cfg @@ -100,31 +88,18 @@ if __name__ == "__main__": gender=cfg.MODEL.GENDER, model_root='smplpytorch/native/models') - if not cfg.DEBUG: - for root,dirs,files in os.walk(cfg.DATASET_PATH): - for file in files: - logger.info('Processing file: {}'.format(file)) - target_path=os.path.join(root,file) - - target = np.array(transform(np.load(target_path))) - logger.info('File shape: {}'.format(target.shape)) - target = torch.from_numpy(target).float() - - - res = train(smpl_layer,target, - logger,writer,device, - args,cfg) - - # save_pic(target,res,smpl_layer,file,logger) - save_params(res,file,logger) - else: - target = np.array(transform(load('UTD_MHAD',cfg.DATASET.TARGET_PATH), - rotate=[-1,1,-1])) - target = torch.from_numpy(target).float() - data_map_dataset=torch.tensor(cfg.DATASET.DATA_MAP.UTD_MHAD[1]) - target = target.index_select(1, data_map_dataset) - print(target.shape) - res = train(smpl_layer,target, - logger,writer,device, - args,cfg) + for root,dirs,files in os.walk(cfg.DATASET.PATH): + for file in files: + logger.info('Processing file: {}'.format(file)) + target = torch.from_numpy(transform(args.dataset_name, + load(args.dataset_name, + os.path.join(root,file)))).float() + + + res = train(smpl_layer,target, + logger,writer,device, + args,cfg) + + # save_pic(res,smpl_layer,file,logger,args.dataset_name) + save_params(res,file,logger, args.dataset_name) \ No newline at end of file diff --git a/fit/tools/save.py b/fit/tools/save.py index d2eeadc..8d11ac8 100644 --- a/fit/tools/save.py +++ b/fit/tools/save.py @@ -15,25 +15,13 @@ def create_dir_not_exist(path): os.mkdir(path) -def save_pic(target, res, smpl_layer, file, logger): - pose_params, shape_params, verts, Jtr = res - name = re.split('[/.]', file)[-2] - gt_path = "fit/output/HumanAct12/picture/gt/{}".format(name) - fit_path = "fit/output/HumanAct12/picture/fit/{}".format(name) - create_dir_not_exist(gt_path) +def save_pic(res, smpl_layer, file, logger, dataset_name): + _, _, verts, Jtr = res + file_name = re.split('[/.]', file)[-2] + fit_path = "fit/output/{}/picture/fit/{}".format(dataset_name,file_name) create_dir_not_exist(fit_path) - logger.info('Saving pictures at {} and {}'.format(gt_path, fit_path)) - for i in tqdm(range(target.shape[0])): - display_model( - {'verts': verts.cpu().detach(), - 'joints': target.cpu().detach()}, - model_faces=smpl_layer.th_faces, - with_joints=True, - kintree_table=smpl_layer.kintree_table, - savepath=os.path.join(gt_path+"/frame_{}".format(i)), - batch_idx=i, - show=False, - only_joint=True) + logger.info('Saving pictures at {}'.format(fit_path)) + for i in tqdm(range(Jtr.shape[0])): display_model( {'verts': verts.cpu().detach(), 'joints': Jtr.cpu().detach()}, @@ -42,14 +30,15 @@ def save_pic(target, res, smpl_layer, file, logger): kintree_table=smpl_layer.kintree_table, savepath=os.path.join(fit_path+"/frame_{}".format(i)), batch_idx=i, - show=False) + show=False, + only_joint=False) logger.info('Pictures saved') -def save_params(res, file, logger): +def save_params(res, file, logger, dataset_name): pose_params, shape_params, verts, Jtr = res - name = re.split('[/.]', file)[-2] - fit_path = "fit/output/HumanAct12/params/" + file_name = re.split('[/.]', file)[-2] + fit_path = "fit/output/{}/params/".format(dataset_name) create_dir_not_exist(fit_path) logger.info('Saving params at {}'.format(fit_path)) pose_params = pose_params.cpu().detach() @@ -58,11 +47,13 @@ def save_params(res, file, logger): shape_params = shape_params.numpy().tolist() Jtr = Jtr.cpu().detach() Jtr = Jtr.numpy().tolist() + verts = verts.cpu().detach() + verts = verts.numpy().tolist() params = {} params["pose_params"] = pose_params params["shape_params"] = shape_params params["Jtr"] = Jtr + params["mesh"] = verts f = open(os.path.join((fit_path), - "{}_params.json".format(name)), 'w') + "{}_params.json".format(file_name)), 'w') json.dump(params, f) - logger.info('Params saved') diff --git a/fit/tools/train.py b/fit/tools/train.py index 7819713..3c0a74b 100644 --- a/fit/tools/train.py +++ b/fit/tools/train.py @@ -1,3 +1,4 @@ +from fit.tools.save import save_pic import matplotlib as plt from matplotlib.pyplot import show import torch @@ -27,57 +28,94 @@ from smplpytorch.pytorch.smpl_layer import SMPL_Layer from display_utils import display_model from map import mapping + +class Early_Stop: + def __init__(self, eps = -1e-3, stop_threshold = 10) -> None: + self.min_loss=float('inf') + self.eps=eps + self.stop_threshold=stop_threshold + self.satis_num=0 + + def update(self, loss): + delta = (loss - self.min_loss) / self.min_loss + if float(loss) < self.min_loss: + self.min_loss = float(loss) + update_res=True + else: + update_res=False + if delta >= self.eps: + self.satis_num += 1 + else: + self.satis_num = max(0,self.satis_num-1) + return update_res, self.satis_num >= self.stop_threshold + + +def init(smpl_layer, target, device, cfg): + params={} + params["pose_params"] = torch.rand(target.shape[0], 72) * 0.0 + params["shape_params"] = torch.rand(target.shape[0], 10) * 0.03 + params["scale"] = torch.ones([1]) + + smpl_layer = smpl_layer.to(device) + params["pose_params"] = params["pose_params"].to(device) + params["shape_params"] = params["shape_params"].to(device) + target = target.to(device) + params["scale"] = params["scale"].to(device) + + params["pose_params"].requires_grad = True + params["shape_params"].requires_grad = True + params["scale"].requires_grad = False + + optimizer = optim.Adam([params["pose_params"], params["shape_params"]], + lr=cfg.TRAIN.LEARNING_RATE) + + index={} + smpl_index=[] + dataset_index=[] + for tp in cfg.DATASET.DATA_MAP: + smpl_index.append(tp[0]) + dataset_index.append(tp[1]) + + index["smpl_index"]=torch.tensor(smpl_index).to(device) + index["dataset_index"]=torch.tensor(dataset_index).to(device) + + return smpl_layer, params,target, optimizer, index + + def train(smpl_layer, target, logger, writer, device, args, cfg): res = [] - pose_params = torch.rand(target.shape[0], 72) * 0.0 - shape_params = torch.rand(target.shape[0], 10) * 0.03 - scale = torch.ones([1]) - - smpl_layer = smpl_layer.to(device) - pose_params = pose_params.to(device) - shape_params = shape_params.to(device) - target = target.to(device) - scale = scale.to(device) - - pose_params.requires_grad = True - shape_params.requires_grad = True - scale.requires_grad = False - smpl_layer.requires_grad = False - - optimizer = optim.Adam([pose_params, shape_params], - lr=cfg.TRAIN.LEARNING_RATE) + smpl_layer, params,target, optimizer, index = \ + init(smpl_layer, target, device, cfg) + pose_params = params["pose_params"] + shape_params = params["shape_params"] + scale = params["scale"] - min_loss = float('inf') - data_map=torch.tensor(cfg.DATASET.DATA_MAP.UTD_MHAD)[0].to(device) - # for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)): - for epoch in range(cfg.TRAIN.MAX_EPOCH): + early_stop = Early_Stop() + for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)): + # for epoch in range(cfg.TRAIN.MAX_EPOCH): verts, Jtr = smpl_layer(pose_params, th_betas=shape_params) - loss = F.smooth_l1_loss(Jtr.index_select(1, data_map) * 100, target * 100) + loss = F.smooth_l1_loss(Jtr.index_select(1, index["smpl_index"]) * 100, + target.index_select(1, index["dataset_index"]) * 100) optimizer.zero_grad() loss.backward() optimizer.step() - if float(loss) < min_loss: - min_loss = float(loss) + + update_res, stop = early_stop.update(float(loss)) + if update_res: res = [pose_params, shape_params, verts, Jtr] + if stop: + logger.info("Early stop at epoch {} !".format(epoch)) + break + if epoch % cfg.TRAIN.WRITE == 0: - logger.info("Epoch {}, lossPerBatch={:.9f}, scale={:.6f}".format( - epoch, float(loss), float(scale))) + # logger.info("Epoch {}, lossPerBatch={:.6f}, EarlyStopSatis: {}".format( + # epoch, float(loss), early_stop.satis_num)) writer.add_scalar('loss', float(loss), epoch) writer.add_scalar('learning_rate', float( optimizer.state_dict()['param_groups'][0]['lr']), epoch) - if epoch % cfg.TRAIN.SAVE == 0 and epoch > 0: - for i in tqdm(range(Jtr.shape[0])): - display_model( - {'verts': verts.cpu().detach(), - 'joints': Jtr.cpu().detach()}, - model_faces=smpl_layer.th_faces, - with_joints=True, - kintree_table=smpl_layer.kintree_table, - savepath="fit/output/UTD_MHAD/picture/frame_{}".format(str(i).zfill(4)), - batch_idx=i, - show=True, - only_joint=True) - logger.info('Train ended, min_loss = {:.9f}'.format(float(min_loss))) + + + logger.info('Train ended, min_loss = {:.9f}'.format(float(early_stop.min_loss))) return res diff --git a/fit/tools/transform.py b/fit/tools/transform.py index 9a59c78..f062f94 100644 --- a/fit/tools/transform.py +++ b/fit/tools/transform.py @@ -1,13 +1,23 @@ import numpy as np -def transform(arr: np.ndarray, rotate=[1.,-1.,-1.]): - for i in range(arr.shape[0]): - origin = arr[i][3].copy() - for j in range(arr.shape[1]): - arr[i][j] -= origin - for k in range(3): - arr[i][j][k] *= rotate[k] - arr[i][3] = [0.0, 0.0, 0.0] - print(arr[0]) +def transform(name, arr: np.ndarray): + if name == 'HumanAct12': + rotate = [1., -1., -1.] + for i in range(arr.shape[0]): + origin = arr[i][0].copy() + for j in range(arr.shape[1]): + arr[i][j] -= origin + for k in range(3): + arr[i][j][k] *= rotate[k] + arr[i][0] = [0.0, 0.0, 0.0] + elif name == 'UTD_MHAD': + rotate = [-1., 1.,-1.] + for i in range(arr.shape[0]): + origin = arr[i][3].copy() + for j in range(arr.shape[1]): + arr[i][j] -= origin + for k in range(3): + arr[i][j][k] *= rotate[k] + arr[i][3] = [0.0, 0.0, 0.0] return arr diff --git a/make_gif.py b/make_gif.py index 82db709..b8d4623 100644 --- a/make_gif.py +++ b/make_gif.py @@ -3,5 +3,5 @@ import imageio, os images = [] filenames = sorted(fn for fn in os.listdir('./fit/output/UTD_MHAD/picture/') ) for filename in filenames: - images.append(imageio.imread('./fit/output/UTD_MHAD/picture/'+filename)) + images.append(imageio.imread('./fit/output/UTD_MHAD/picture/fit/a10_s1_t1_skeleton/'+filename)) imageio.mimsave('./fit.gif', images, duration=0.3) \ No newline at end of file