update meters
This commit is contained in:
@ -40,11 +40,12 @@ The SMPL human body layer for Pytorch is from the [smplpytorch](https://github.c
|
||||
|
||||
- Download the datasets you want to fit
|
||||
|
||||
currently supported datasets:
|
||||
currently support:
|
||||
|
||||
- [HumanAct12](https://ericguo5513.github.io/action-to-motion/)
|
||||
- [CMU Mocap](https://ericguo5513.github.io/action-to-motion/)
|
||||
- [UTD-MHAD](https://personal.utdallas.edu/~kehtar/UTD-MHAD.html)
|
||||
- [Human3.6M](http://vision.imar.ro/human3.6m/description.php)
|
||||
|
||||
- Set the **DATASET.PATH** in the corresponding configuration file to the location of dataset.
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ import os
|
||||
import json
|
||||
import argparse
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Detect cross joints')
|
||||
parser.add_argument('--dataset_name', dest='dataset_name',
|
||||
@ -15,10 +16,12 @@ def parse_args():
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def create_dir_not_exist(path):
|
||||
if not os.path.exists(path):
|
||||
os.mkdir(path)
|
||||
|
||||
|
||||
def load_Jtr(file_path):
|
||||
with open(file_path, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
@ -40,15 +43,18 @@ def cross_frames(Jtr: np.ndarray):
|
||||
|
||||
|
||||
def cross_detector(dir_path):
|
||||
ans={}
|
||||
ans = {}
|
||||
for root, dirs, files in os.walk(dir_path):
|
||||
for file in files:
|
||||
file_path = os.path.join(dir_path, file)
|
||||
Jtr = load_Jtr(file_path)
|
||||
ans[file]=cross_frames(Jtr)
|
||||
ans[file] = cross_frames(Jtr)
|
||||
return ans
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args=parse_args()
|
||||
d=cross_detector(args.output_path)
|
||||
json.dump(d,open("./fit/output/cross_detection/{}.json".format(args.dataset_name),'w'))
|
||||
args = parse_args()
|
||||
d = cross_detector(args.output_path)
|
||||
json.dump(
|
||||
d, open("./fit/output/cross_detection/{}.json"
|
||||
.format(args.dataset_name), 'w'))
|
||||
@ -1166,7 +1166,4 @@ def get_label(file_name, dataset_name):
|
||||
return UTD_MHAD[key]
|
||||
elif dataset_name == 'CMU_Mocap':
|
||||
key = file_name.split('.')[0]
|
||||
if key in CMU_Mocap.keys():
|
||||
return CMU_Mocap[key]
|
||||
else:
|
||||
return ""
|
||||
return CMU_Mocap[key] if key in CMU_Mocap.keys() else ""
|
||||
@ -1,11 +1,11 @@
|
||||
import scipy.io
|
||||
import numpy as np
|
||||
import json
|
||||
|
||||
|
||||
def load(name, path):
|
||||
if name == 'UTD_MHAD':
|
||||
data = scipy.io.loadmat(path)
|
||||
arr = data['d_skel']
|
||||
arr = scipy.io.loadmat(path)['d_skel']
|
||||
new_arr = np.zeros([arr.shape[2], arr.shape[0], arr.shape[1]])
|
||||
for i in range(arr.shape[2]):
|
||||
for j in range(arr.shape[0]):
|
||||
@ -13,7 +13,9 @@ def load(name, path):
|
||||
new_arr[i][j][k] = arr[j][k][i]
|
||||
return new_arr
|
||||
elif name == 'HumanAct12':
|
||||
return np.load(path,allow_pickle=True)
|
||||
return np.load(path, allow_pickle=True)
|
||||
elif name == "CMU_Mocap":
|
||||
return np.load(path,allow_pickle=True)
|
||||
return np.load(path, allow_pickle=True)
|
||||
elif name == "Human3.6M":
|
||||
return np.load(path, allow_pickle=True)
|
||||
|
||||
|
||||
@ -9,14 +9,17 @@ import logging
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
from load import load
|
||||
from save import save_pic, save_params
|
||||
from transform import transform
|
||||
from train import train
|
||||
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
|
||||
torch.backends.cudnn.benchmark = True
|
||||
from meters import Meters
|
||||
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Fit SMPL')
|
||||
@ -95,6 +98,7 @@ if __name__ == "__main__":
|
||||
gender=cfg.MODEL.GENDER,
|
||||
model_root='smplpytorch/native/models')
|
||||
|
||||
meters=Meters()
|
||||
file_num = 0
|
||||
for root, dirs, files in os.walk(cfg.DATASET.PATH):
|
||||
for file in files:
|
||||
@ -102,10 +106,15 @@ if __name__ == "__main__":
|
||||
logger.info('Processing file: {} [{} / {}]'.format(file, file_num, len(files)))
|
||||
target = torch.from_numpy(transform(args.dataset_name,load(args.dataset_name,
|
||||
os.path.join(root, file)))).float()
|
||||
|
||||
logger.info("target shape:{}".format(target.shape))
|
||||
res = train(smpl_layer, target,
|
||||
logger, writer, device,
|
||||
args, cfg)
|
||||
args, cfg, meters)
|
||||
meters.update_avg(meters.min_loss, k=target.shape[0])
|
||||
meters.reset_early_stop()
|
||||
logger.info("avg_loss:{:.4f}".format(meters.avg))
|
||||
|
||||
# save_pic(res,smpl_layer,file,logger,args.dataset_name,target)
|
||||
save_params(res, file, logger, args.dataset_name)
|
||||
torch.cuda.empty_cache()
|
||||
logger.info("Fitting finished! Average loss: {:.9f}".format(meters.avg))
|
||||
|
||||
27
fit/tools/meters.py
Normal file
27
fit/tools/meters.py
Normal file
@ -0,0 +1,27 @@
|
||||
class Meters:
|
||||
def __init__(self, eps=-1e-3, stop_threshold=10) -> None:
|
||||
self.eps = eps
|
||||
self.stop_threshold = stop_threshold
|
||||
self.avg = 0
|
||||
self.cnt = 0
|
||||
self.reset_early_stop()
|
||||
|
||||
def reset_early_stop(self):
|
||||
self.min_loss = float('inf')
|
||||
self.satis_num = 0
|
||||
self.update_res = True
|
||||
self.early_stop = False
|
||||
|
||||
def update_avg(self, val, k=1):
|
||||
self.avg = self.avg + (val - self.avg) * k / (self.cnt + k)
|
||||
self.cnt += k
|
||||
|
||||
def update_early_stop(self, val):
|
||||
delta = (val - self.min_loss) / self.min_loss
|
||||
if float(val) < self.min_loss:
|
||||
self.min_loss = float(val)
|
||||
self.update_res = True
|
||||
else:
|
||||
self.update_res = False
|
||||
self.satis_num = self.satis_num + 1 if delta >= self.eps else 0
|
||||
self.early_stop = self.satis_num >= self.stop_threshold
|
||||
@ -19,9 +19,9 @@ def save_pic(res, smpl_layer, file, logger, dataset_name, target):
|
||||
_, _, verts, Jtr = res
|
||||
file_name = re.split('[/.]', file)[-2]
|
||||
fit_path = "fit/output/{}/picture/fit/{}".format(dataset_name, file_name)
|
||||
gt_path = "fit/output/{}/picture/gt/{}".format(dataset_name, file_name)
|
||||
# gt_path = "fit/output/{}/picture/gt/{}".format(dataset_name, file_name)
|
||||
create_dir_not_exist(fit_path)
|
||||
create_dir_not_exist(gt_path)
|
||||
# create_dir_not_exist(gt_path)
|
||||
logger.info('Saving pictures at {}'.format(fit_path))
|
||||
for i in tqdm(range(Jtr.shape[0])):
|
||||
display_model(
|
||||
@ -32,7 +32,7 @@ def save_pic(res, smpl_layer, file, logger, dataset_name, target):
|
||||
kintree_table=smpl_layer.kintree_table,
|
||||
savepath=os.path.join(fit_path+"/frame_{}".format(i)),
|
||||
batch_idx=i,
|
||||
show=False,
|
||||
show=True,
|
||||
only_joint=True)
|
||||
# display_model(
|
||||
# {'verts': verts.cpu().detach(),
|
||||
@ -59,7 +59,7 @@ def save_params(res, file, logger, dataset_name):
|
||||
Jtr = (Jtr.cpu().detach()).numpy().tolist()
|
||||
verts = (verts.cpu().detach()).numpy().tolist()
|
||||
params = {}
|
||||
params["label"] = label
|
||||
# params["label"] = label
|
||||
params["pose_params"] = pose_params
|
||||
params["shape_params"] = shape_params
|
||||
params["Jtr"] = Jtr
|
||||
|
||||
@ -8,31 +8,11 @@ from tqdm import tqdm
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
|
||||
class Early_Stop:
|
||||
def __init__(self, eps=-1e-3, stop_threshold=10) -> None:
|
||||
self.min_loss = float('inf')
|
||||
self.eps = eps
|
||||
self.stop_threshold = stop_threshold
|
||||
self.satis_num = 0
|
||||
|
||||
def update(self, loss):
|
||||
delta = (loss - self.min_loss) / self.min_loss
|
||||
if float(loss) < self.min_loss:
|
||||
self.min_loss = float(loss)
|
||||
update_res = True
|
||||
else:
|
||||
update_res = False
|
||||
if delta >= self.eps:
|
||||
self.satis_num += 1
|
||||
else:
|
||||
self.satis_num = 0
|
||||
return update_res, self.satis_num >= self.stop_threshold
|
||||
|
||||
|
||||
def init(smpl_layer, target, device, cfg):
|
||||
params = {}
|
||||
params["pose_params"] = torch.rand(target.shape[0], 72) * 0.0
|
||||
params["shape_params"] = torch.rand(target.shape[0], 10) * 0.03
|
||||
params["pose_params"] = torch.zeros(target.shape[0], 72)
|
||||
params["shape_params"] = torch.zeros(target.shape[0], 10)
|
||||
params["scale"] = torch.ones([1])
|
||||
|
||||
smpl_layer = smpl_layer.to(device)
|
||||
@ -63,7 +43,7 @@ def init(smpl_layer, target, device, cfg):
|
||||
|
||||
def train(smpl_layer, target,
|
||||
logger, writer, device,
|
||||
args, cfg):
|
||||
args, cfg, meters):
|
||||
res = []
|
||||
smpl_layer, params, target, optimizer, index = \
|
||||
init(smpl_layer, target, device, cfg)
|
||||
@ -71,20 +51,19 @@ def train(smpl_layer, target,
|
||||
shape_params = params["shape_params"]
|
||||
scale = params["scale"]
|
||||
|
||||
early_stop = Early_Stop()
|
||||
for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)):
|
||||
# for epoch in range(cfg.TRAIN.MAX_EPOCH):
|
||||
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
|
||||
loss = F.smooth_l1_loss(Jtr.index_select(1, index["smpl_index"]) * 100 * scale,
|
||||
target.index_select(1, index["dataset_index"]) * 100)
|
||||
loss = F.smooth_l1_loss(Jtr.index_select(1, index["smpl_index"]) * 100,
|
||||
target.index_select(1, index["dataset_index"]) * 100 * scale)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
update_res, stop = early_stop.update(float(loss))
|
||||
if update_res:
|
||||
meters.update_early_stop(float(loss))
|
||||
if meters.update_res:
|
||||
res = [pose_params, shape_params, verts, Jtr]
|
||||
if stop:
|
||||
if meters.early_stop:
|
||||
logger.info("Early stop at epoch {} !".format(epoch))
|
||||
break
|
||||
|
||||
@ -95,6 +74,6 @@ def train(smpl_layer, target,
|
||||
writer.add_scalar('learning_rate', float(
|
||||
optimizer.state_dict()['param_groups'][0]['lr']), epoch)
|
||||
|
||||
logger.info('Train ended, min_loss = {:.9f}'.format(
|
||||
float(early_stop.min_loss)))
|
||||
logger.info('Train ended, min_loss = {:.4f}'.format(
|
||||
float(meters.min_loss)))
|
||||
return res
|
||||
|
||||
@ -3,7 +3,8 @@ import numpy as np
|
||||
rotate = {
|
||||
'HumanAct12': [1., -1., -1.],
|
||||
'CMU_Mocap': [0.05, 0.05, 0.05],
|
||||
'UTD_MHAD': [-1., 1., -1.]
|
||||
'UTD_MHAD': [-1., 1., -1.],
|
||||
'Human3.6M': [-0.001, -0.001, 0.001]
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import imageio, os
|
||||
images = []
|
||||
filenames = sorted(fn for fn in os.listdir('./fit/output/CMU_Mocap/picture/fit/01_01') )
|
||||
filenames = sorted(fn for fn in os.listdir('./fit/output/Human3.6M/picture/fit/s_01_act_09_subact_02_ca_02') )
|
||||
for filename in filenames:
|
||||
images.append(imageio.imread('./fit/output/CMU_Mocap/picture/fit/01_01/'+filename))
|
||||
imageio.mimsave('fit.gif', images, duration=0.2)
|
||||
images.append(imageio.imread('./fit/output/Human3.6M/picture/fit/s_01_act_09_subact_02_ca_02/'+filename))
|
||||
imageio.mimsave('fit_mesh.gif', images, duration=0.2)
|
||||
Reference in New Issue
Block a user