update meters

This commit is contained in:
Iridoudou
2021-08-19 12:40:54 +08:00
parent e0ee13d6d6
commit 8573a09b84
10 changed files with 81 additions and 59 deletions

View File

@ -40,11 +40,12 @@ The SMPL human body layer for Pytorch is from the [smplpytorch](https://github.c
- Download the datasets you want to fit
currently supported datasets:
currently support:
- [HumanAct12](https://ericguo5513.github.io/action-to-motion/)
- [CMU Mocap](https://ericguo5513.github.io/action-to-motion/)
- [UTD-MHAD](https://personal.utdallas.edu/~kehtar/UTD-MHAD.html)
- [Human3.6M](http://vision.imar.ro/human3.6m/description.php)
- Set the **DATASET.PATH** in the corresponding configuration file to the location of dataset.

View File

@ -4,6 +4,7 @@ import os
import json
import argparse
def parse_args():
parser = argparse.ArgumentParser(description='Detect cross joints')
parser.add_argument('--dataset_name', dest='dataset_name',
@ -15,10 +16,12 @@ def parse_args():
args = parser.parse_args()
return args
def create_dir_not_exist(path):
if not os.path.exists(path):
os.mkdir(path)
def load_Jtr(file_path):
with open(file_path, 'rb') as f:
data = pickle.load(f)
@ -40,15 +43,18 @@ def cross_frames(Jtr: np.ndarray):
def cross_detector(dir_path):
ans={}
ans = {}
for root, dirs, files in os.walk(dir_path):
for file in files:
file_path = os.path.join(dir_path, file)
Jtr = load_Jtr(file_path)
ans[file]=cross_frames(Jtr)
ans[file] = cross_frames(Jtr)
return ans
if __name__ == "__main__":
args=parse_args()
d=cross_detector(args.output_path)
json.dump(d,open("./fit/output/cross_detection/{}.json".format(args.dataset_name),'w'))
args = parse_args()
d = cross_detector(args.output_path)
json.dump(
d, open("./fit/output/cross_detection/{}.json"
.format(args.dataset_name), 'w'))

View File

@ -1166,7 +1166,4 @@ def get_label(file_name, dataset_name):
return UTD_MHAD[key]
elif dataset_name == 'CMU_Mocap':
key = file_name.split('.')[0]
if key in CMU_Mocap.keys():
return CMU_Mocap[key]
else:
return ""
return CMU_Mocap[key] if key in CMU_Mocap.keys() else ""

View File

@ -1,11 +1,11 @@
import scipy.io
import numpy as np
import json
def load(name, path):
if name == 'UTD_MHAD':
data = scipy.io.loadmat(path)
arr = data['d_skel']
arr = scipy.io.loadmat(path)['d_skel']
new_arr = np.zeros([arr.shape[2], arr.shape[0], arr.shape[1]])
for i in range(arr.shape[2]):
for j in range(arr.shape[0]):
@ -13,7 +13,9 @@ def load(name, path):
new_arr[i][j][k] = arr[j][k][i]
return new_arr
elif name == 'HumanAct12':
return np.load(path,allow_pickle=True)
return np.load(path, allow_pickle=True)
elif name == "CMU_Mocap":
return np.load(path,allow_pickle=True)
return np.load(path, allow_pickle=True)
elif name == "Human3.6M":
return np.load(path, allow_pickle=True)

View File

@ -9,14 +9,17 @@ import logging
import argparse
import json
sys.path.append(os.getcwd())
from load import load
from save import save_pic, save_params
from transform import transform
from train import train
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
torch.backends.cudnn.benchmark = True
from meters import Meters
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
def parse_args():
parser = argparse.ArgumentParser(description='Fit SMPL')
@ -95,6 +98,7 @@ if __name__ == "__main__":
gender=cfg.MODEL.GENDER,
model_root='smplpytorch/native/models')
meters=Meters()
file_num = 0
for root, dirs, files in os.walk(cfg.DATASET.PATH):
for file in files:
@ -102,10 +106,15 @@ if __name__ == "__main__":
logger.info('Processing file: {} [{} / {}]'.format(file, file_num, len(files)))
target = torch.from_numpy(transform(args.dataset_name,load(args.dataset_name,
os.path.join(root, file)))).float()
logger.info("target shape:{}".format(target.shape))
res = train(smpl_layer, target,
logger, writer, device,
args, cfg)
args, cfg, meters)
meters.update_avg(meters.min_loss, k=target.shape[0])
meters.reset_early_stop()
logger.info("avg_loss:{:.4f}".format(meters.avg))
# save_pic(res,smpl_layer,file,logger,args.dataset_name,target)
save_params(res, file, logger, args.dataset_name)
torch.cuda.empty_cache()
logger.info("Fitting finished! Average loss: {:.9f}".format(meters.avg))

27
fit/tools/meters.py Normal file
View File

@ -0,0 +1,27 @@
class Meters:
def __init__(self, eps=-1e-3, stop_threshold=10) -> None:
self.eps = eps
self.stop_threshold = stop_threshold
self.avg = 0
self.cnt = 0
self.reset_early_stop()
def reset_early_stop(self):
self.min_loss = float('inf')
self.satis_num = 0
self.update_res = True
self.early_stop = False
def update_avg(self, val, k=1):
self.avg = self.avg + (val - self.avg) * k / (self.cnt + k)
self.cnt += k
def update_early_stop(self, val):
delta = (val - self.min_loss) / self.min_loss
if float(val) < self.min_loss:
self.min_loss = float(val)
self.update_res = True
else:
self.update_res = False
self.satis_num = self.satis_num + 1 if delta >= self.eps else 0
self.early_stop = self.satis_num >= self.stop_threshold

View File

@ -19,9 +19,9 @@ def save_pic(res, smpl_layer, file, logger, dataset_name, target):
_, _, verts, Jtr = res
file_name = re.split('[/.]', file)[-2]
fit_path = "fit/output/{}/picture/fit/{}".format(dataset_name, file_name)
gt_path = "fit/output/{}/picture/gt/{}".format(dataset_name, file_name)
# gt_path = "fit/output/{}/picture/gt/{}".format(dataset_name, file_name)
create_dir_not_exist(fit_path)
create_dir_not_exist(gt_path)
# create_dir_not_exist(gt_path)
logger.info('Saving pictures at {}'.format(fit_path))
for i in tqdm(range(Jtr.shape[0])):
display_model(
@ -32,7 +32,7 @@ def save_pic(res, smpl_layer, file, logger, dataset_name, target):
kintree_table=smpl_layer.kintree_table,
savepath=os.path.join(fit_path+"/frame_{}".format(i)),
batch_idx=i,
show=False,
show=True,
only_joint=True)
# display_model(
# {'verts': verts.cpu().detach(),
@ -59,7 +59,7 @@ def save_params(res, file, logger, dataset_name):
Jtr = (Jtr.cpu().detach()).numpy().tolist()
verts = (verts.cpu().detach()).numpy().tolist()
params = {}
params["label"] = label
# params["label"] = label
params["pose_params"] = pose_params
params["shape_params"] = shape_params
params["Jtr"] = Jtr

View File

@ -8,31 +8,11 @@ from tqdm import tqdm
sys.path.append(os.getcwd())
class Early_Stop:
def __init__(self, eps=-1e-3, stop_threshold=10) -> None:
self.min_loss = float('inf')
self.eps = eps
self.stop_threshold = stop_threshold
self.satis_num = 0
def update(self, loss):
delta = (loss - self.min_loss) / self.min_loss
if float(loss) < self.min_loss:
self.min_loss = float(loss)
update_res = True
else:
update_res = False
if delta >= self.eps:
self.satis_num += 1
else:
self.satis_num = 0
return update_res, self.satis_num >= self.stop_threshold
def init(smpl_layer, target, device, cfg):
params = {}
params["pose_params"] = torch.rand(target.shape[0], 72) * 0.0
params["shape_params"] = torch.rand(target.shape[0], 10) * 0.03
params["pose_params"] = torch.zeros(target.shape[0], 72)
params["shape_params"] = torch.zeros(target.shape[0], 10)
params["scale"] = torch.ones([1])
smpl_layer = smpl_layer.to(device)
@ -63,7 +43,7 @@ def init(smpl_layer, target, device, cfg):
def train(smpl_layer, target,
logger, writer, device,
args, cfg):
args, cfg, meters):
res = []
smpl_layer, params, target, optimizer, index = \
init(smpl_layer, target, device, cfg)
@ -71,20 +51,19 @@ def train(smpl_layer, target,
shape_params = params["shape_params"]
scale = params["scale"]
early_stop = Early_Stop()
for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)):
# for epoch in range(cfg.TRAIN.MAX_EPOCH):
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
loss = F.smooth_l1_loss(Jtr.index_select(1, index["smpl_index"]) * 100 * scale,
target.index_select(1, index["dataset_index"]) * 100)
loss = F.smooth_l1_loss(Jtr.index_select(1, index["smpl_index"]) * 100,
target.index_select(1, index["dataset_index"]) * 100 * scale)
optimizer.zero_grad()
loss.backward()
optimizer.step()
update_res, stop = early_stop.update(float(loss))
if update_res:
meters.update_early_stop(float(loss))
if meters.update_res:
res = [pose_params, shape_params, verts, Jtr]
if stop:
if meters.early_stop:
logger.info("Early stop at epoch {} !".format(epoch))
break
@ -95,6 +74,6 @@ def train(smpl_layer, target,
writer.add_scalar('learning_rate', float(
optimizer.state_dict()['param_groups'][0]['lr']), epoch)
logger.info('Train ended, min_loss = {:.9f}'.format(
float(early_stop.min_loss)))
logger.info('Train ended, min_loss = {:.4f}'.format(
float(meters.min_loss)))
return res

View File

@ -3,7 +3,8 @@ import numpy as np
rotate = {
'HumanAct12': [1., -1., -1.],
'CMU_Mocap': [0.05, 0.05, 0.05],
'UTD_MHAD': [-1., 1., -1.]
'UTD_MHAD': [-1., 1., -1.],
'Human3.6M': [-0.001, -0.001, 0.001]
}

View File

@ -1,7 +1,7 @@
import matplotlib.pyplot as plt
import imageio, os
images = []
filenames = sorted(fn for fn in os.listdir('./fit/output/CMU_Mocap/picture/fit/01_01') )
filenames = sorted(fn for fn in os.listdir('./fit/output/Human3.6M/picture/fit/s_01_act_09_subact_02_ca_02') )
for filename in filenames:
images.append(imageio.imread('./fit/output/CMU_Mocap/picture/fit/01_01/'+filename))
imageio.mimsave('fit.gif', images, duration=0.2)
images.append(imageio.imread('./fit/output/Human3.6M/picture/fit/s_01_act_09_subact_02_ca_02/'+filename))
imageio.mimsave('fit_mesh.gif', images, duration=0.2)