update meters
This commit is contained in:
@ -40,11 +40,12 @@ The SMPL human body layer for Pytorch is from the [smplpytorch](https://github.c
|
|||||||
|
|
||||||
- Download the datasets you want to fit
|
- Download the datasets you want to fit
|
||||||
|
|
||||||
currently supported datasets:
|
currently support:
|
||||||
|
|
||||||
- [HumanAct12](https://ericguo5513.github.io/action-to-motion/)
|
- [HumanAct12](https://ericguo5513.github.io/action-to-motion/)
|
||||||
- [CMU Mocap](https://ericguo5513.github.io/action-to-motion/)
|
- [CMU Mocap](https://ericguo5513.github.io/action-to-motion/)
|
||||||
- [UTD-MHAD](https://personal.utdallas.edu/~kehtar/UTD-MHAD.html)
|
- [UTD-MHAD](https://personal.utdallas.edu/~kehtar/UTD-MHAD.html)
|
||||||
|
- [Human3.6M](http://vision.imar.ro/human3.6m/description.php)
|
||||||
|
|
||||||
- Set the **DATASET.PATH** in the corresponding configuration file to the location of dataset.
|
- Set the **DATASET.PATH** in the corresponding configuration file to the location of dataset.
|
||||||
|
|
||||||
|
|||||||
@ -4,6 +4,7 @@ import os
|
|||||||
import json
|
import json
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description='Detect cross joints')
|
parser = argparse.ArgumentParser(description='Detect cross joints')
|
||||||
parser.add_argument('--dataset_name', dest='dataset_name',
|
parser.add_argument('--dataset_name', dest='dataset_name',
|
||||||
@ -15,10 +16,12 @@ def parse_args():
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
def create_dir_not_exist(path):
|
def create_dir_not_exist(path):
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
os.mkdir(path)
|
os.mkdir(path)
|
||||||
|
|
||||||
|
|
||||||
def load_Jtr(file_path):
|
def load_Jtr(file_path):
|
||||||
with open(file_path, 'rb') as f:
|
with open(file_path, 'rb') as f:
|
||||||
data = pickle.load(f)
|
data = pickle.load(f)
|
||||||
@ -48,7 +51,10 @@ def cross_detector(dir_path):
|
|||||||
ans[file] = cross_frames(Jtr)
|
ans[file] = cross_frames(Jtr)
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
d = cross_detector(args.output_path)
|
d = cross_detector(args.output_path)
|
||||||
json.dump(d,open("./fit/output/cross_detection/{}.json".format(args.dataset_name),'w'))
|
json.dump(
|
||||||
|
d, open("./fit/output/cross_detection/{}.json"
|
||||||
|
.format(args.dataset_name), 'w'))
|
||||||
@ -1166,7 +1166,4 @@ def get_label(file_name, dataset_name):
|
|||||||
return UTD_MHAD[key]
|
return UTD_MHAD[key]
|
||||||
elif dataset_name == 'CMU_Mocap':
|
elif dataset_name == 'CMU_Mocap':
|
||||||
key = file_name.split('.')[0]
|
key = file_name.split('.')[0]
|
||||||
if key in CMU_Mocap.keys():
|
return CMU_Mocap[key] if key in CMU_Mocap.keys() else ""
|
||||||
return CMU_Mocap[key]
|
|
||||||
else:
|
|
||||||
return ""
|
|
||||||
@ -1,11 +1,11 @@
|
|||||||
import scipy.io
|
import scipy.io
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
def load(name, path):
|
def load(name, path):
|
||||||
if name == 'UTD_MHAD':
|
if name == 'UTD_MHAD':
|
||||||
data = scipy.io.loadmat(path)
|
arr = scipy.io.loadmat(path)['d_skel']
|
||||||
arr = data['d_skel']
|
|
||||||
new_arr = np.zeros([arr.shape[2], arr.shape[0], arr.shape[1]])
|
new_arr = np.zeros([arr.shape[2], arr.shape[0], arr.shape[1]])
|
||||||
for i in range(arr.shape[2]):
|
for i in range(arr.shape[2]):
|
||||||
for j in range(arr.shape[0]):
|
for j in range(arr.shape[0]):
|
||||||
@ -16,4 +16,6 @@ def load(name, path):
|
|||||||
return np.load(path, allow_pickle=True)
|
return np.load(path, allow_pickle=True)
|
||||||
elif name == "CMU_Mocap":
|
elif name == "CMU_Mocap":
|
||||||
return np.load(path, allow_pickle=True)
|
return np.load(path, allow_pickle=True)
|
||||||
|
elif name == "Human3.6M":
|
||||||
|
return np.load(path, allow_pickle=True)
|
||||||
|
|
||||||
|
|||||||
@ -9,14 +9,17 @@ import logging
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
|
|
||||||
sys.path.append(os.getcwd())
|
sys.path.append(os.getcwd())
|
||||||
from load import load
|
from load import load
|
||||||
from save import save_pic, save_params
|
from save import save_pic, save_params
|
||||||
from transform import transform
|
from transform import transform
|
||||||
from train import train
|
from train import train
|
||||||
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
|
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
|
||||||
torch.backends.cudnn.benchmark = True
|
from meters import Meters
|
||||||
|
|
||||||
|
torch.backends.cudnn.enabled = True
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description='Fit SMPL')
|
parser = argparse.ArgumentParser(description='Fit SMPL')
|
||||||
@ -95,6 +98,7 @@ if __name__ == "__main__":
|
|||||||
gender=cfg.MODEL.GENDER,
|
gender=cfg.MODEL.GENDER,
|
||||||
model_root='smplpytorch/native/models')
|
model_root='smplpytorch/native/models')
|
||||||
|
|
||||||
|
meters=Meters()
|
||||||
file_num = 0
|
file_num = 0
|
||||||
for root, dirs, files in os.walk(cfg.DATASET.PATH):
|
for root, dirs, files in os.walk(cfg.DATASET.PATH):
|
||||||
for file in files:
|
for file in files:
|
||||||
@ -102,10 +106,15 @@ if __name__ == "__main__":
|
|||||||
logger.info('Processing file: {} [{} / {}]'.format(file, file_num, len(files)))
|
logger.info('Processing file: {} [{} / {}]'.format(file, file_num, len(files)))
|
||||||
target = torch.from_numpy(transform(args.dataset_name,load(args.dataset_name,
|
target = torch.from_numpy(transform(args.dataset_name,load(args.dataset_name,
|
||||||
os.path.join(root, file)))).float()
|
os.path.join(root, file)))).float()
|
||||||
|
logger.info("target shape:{}".format(target.shape))
|
||||||
res = train(smpl_layer, target,
|
res = train(smpl_layer, target,
|
||||||
logger, writer, device,
|
logger, writer, device,
|
||||||
args, cfg)
|
args, cfg, meters)
|
||||||
|
meters.update_avg(meters.min_loss, k=target.shape[0])
|
||||||
|
meters.reset_early_stop()
|
||||||
|
logger.info("avg_loss:{:.4f}".format(meters.avg))
|
||||||
|
|
||||||
# save_pic(res,smpl_layer,file,logger,args.dataset_name,target)
|
# save_pic(res,smpl_layer,file,logger,args.dataset_name,target)
|
||||||
save_params(res, file, logger, args.dataset_name)
|
save_params(res, file, logger, args.dataset_name)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
logger.info("Fitting finished! Average loss: {:.9f}".format(meters.avg))
|
||||||
|
|||||||
27
fit/tools/meters.py
Normal file
27
fit/tools/meters.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
class Meters:
|
||||||
|
def __init__(self, eps=-1e-3, stop_threshold=10) -> None:
|
||||||
|
self.eps = eps
|
||||||
|
self.stop_threshold = stop_threshold
|
||||||
|
self.avg = 0
|
||||||
|
self.cnt = 0
|
||||||
|
self.reset_early_stop()
|
||||||
|
|
||||||
|
def reset_early_stop(self):
|
||||||
|
self.min_loss = float('inf')
|
||||||
|
self.satis_num = 0
|
||||||
|
self.update_res = True
|
||||||
|
self.early_stop = False
|
||||||
|
|
||||||
|
def update_avg(self, val, k=1):
|
||||||
|
self.avg = self.avg + (val - self.avg) * k / (self.cnt + k)
|
||||||
|
self.cnt += k
|
||||||
|
|
||||||
|
def update_early_stop(self, val):
|
||||||
|
delta = (val - self.min_loss) / self.min_loss
|
||||||
|
if float(val) < self.min_loss:
|
||||||
|
self.min_loss = float(val)
|
||||||
|
self.update_res = True
|
||||||
|
else:
|
||||||
|
self.update_res = False
|
||||||
|
self.satis_num = self.satis_num + 1 if delta >= self.eps else 0
|
||||||
|
self.early_stop = self.satis_num >= self.stop_threshold
|
||||||
@ -19,9 +19,9 @@ def save_pic(res, smpl_layer, file, logger, dataset_name, target):
|
|||||||
_, _, verts, Jtr = res
|
_, _, verts, Jtr = res
|
||||||
file_name = re.split('[/.]', file)[-2]
|
file_name = re.split('[/.]', file)[-2]
|
||||||
fit_path = "fit/output/{}/picture/fit/{}".format(dataset_name, file_name)
|
fit_path = "fit/output/{}/picture/fit/{}".format(dataset_name, file_name)
|
||||||
gt_path = "fit/output/{}/picture/gt/{}".format(dataset_name, file_name)
|
# gt_path = "fit/output/{}/picture/gt/{}".format(dataset_name, file_name)
|
||||||
create_dir_not_exist(fit_path)
|
create_dir_not_exist(fit_path)
|
||||||
create_dir_not_exist(gt_path)
|
# create_dir_not_exist(gt_path)
|
||||||
logger.info('Saving pictures at {}'.format(fit_path))
|
logger.info('Saving pictures at {}'.format(fit_path))
|
||||||
for i in tqdm(range(Jtr.shape[0])):
|
for i in tqdm(range(Jtr.shape[0])):
|
||||||
display_model(
|
display_model(
|
||||||
@ -32,7 +32,7 @@ def save_pic(res, smpl_layer, file, logger, dataset_name, target):
|
|||||||
kintree_table=smpl_layer.kintree_table,
|
kintree_table=smpl_layer.kintree_table,
|
||||||
savepath=os.path.join(fit_path+"/frame_{}".format(i)),
|
savepath=os.path.join(fit_path+"/frame_{}".format(i)),
|
||||||
batch_idx=i,
|
batch_idx=i,
|
||||||
show=False,
|
show=True,
|
||||||
only_joint=True)
|
only_joint=True)
|
||||||
# display_model(
|
# display_model(
|
||||||
# {'verts': verts.cpu().detach(),
|
# {'verts': verts.cpu().detach(),
|
||||||
@ -59,7 +59,7 @@ def save_params(res, file, logger, dataset_name):
|
|||||||
Jtr = (Jtr.cpu().detach()).numpy().tolist()
|
Jtr = (Jtr.cpu().detach()).numpy().tolist()
|
||||||
verts = (verts.cpu().detach()).numpy().tolist()
|
verts = (verts.cpu().detach()).numpy().tolist()
|
||||||
params = {}
|
params = {}
|
||||||
params["label"] = label
|
# params["label"] = label
|
||||||
params["pose_params"] = pose_params
|
params["pose_params"] = pose_params
|
||||||
params["shape_params"] = shape_params
|
params["shape_params"] = shape_params
|
||||||
params["Jtr"] = Jtr
|
params["Jtr"] = Jtr
|
||||||
|
|||||||
@ -8,31 +8,11 @@ from tqdm import tqdm
|
|||||||
sys.path.append(os.getcwd())
|
sys.path.append(os.getcwd())
|
||||||
|
|
||||||
|
|
||||||
class Early_Stop:
|
|
||||||
def __init__(self, eps=-1e-3, stop_threshold=10) -> None:
|
|
||||||
self.min_loss = float('inf')
|
|
||||||
self.eps = eps
|
|
||||||
self.stop_threshold = stop_threshold
|
|
||||||
self.satis_num = 0
|
|
||||||
|
|
||||||
def update(self, loss):
|
|
||||||
delta = (loss - self.min_loss) / self.min_loss
|
|
||||||
if float(loss) < self.min_loss:
|
|
||||||
self.min_loss = float(loss)
|
|
||||||
update_res = True
|
|
||||||
else:
|
|
||||||
update_res = False
|
|
||||||
if delta >= self.eps:
|
|
||||||
self.satis_num += 1
|
|
||||||
else:
|
|
||||||
self.satis_num = 0
|
|
||||||
return update_res, self.satis_num >= self.stop_threshold
|
|
||||||
|
|
||||||
|
|
||||||
def init(smpl_layer, target, device, cfg):
|
def init(smpl_layer, target, device, cfg):
|
||||||
params = {}
|
params = {}
|
||||||
params["pose_params"] = torch.rand(target.shape[0], 72) * 0.0
|
params["pose_params"] = torch.zeros(target.shape[0], 72)
|
||||||
params["shape_params"] = torch.rand(target.shape[0], 10) * 0.03
|
params["shape_params"] = torch.zeros(target.shape[0], 10)
|
||||||
params["scale"] = torch.ones([1])
|
params["scale"] = torch.ones([1])
|
||||||
|
|
||||||
smpl_layer = smpl_layer.to(device)
|
smpl_layer = smpl_layer.to(device)
|
||||||
@ -63,7 +43,7 @@ def init(smpl_layer, target, device, cfg):
|
|||||||
|
|
||||||
def train(smpl_layer, target,
|
def train(smpl_layer, target,
|
||||||
logger, writer, device,
|
logger, writer, device,
|
||||||
args, cfg):
|
args, cfg, meters):
|
||||||
res = []
|
res = []
|
||||||
smpl_layer, params, target, optimizer, index = \
|
smpl_layer, params, target, optimizer, index = \
|
||||||
init(smpl_layer, target, device, cfg)
|
init(smpl_layer, target, device, cfg)
|
||||||
@ -71,20 +51,19 @@ def train(smpl_layer, target,
|
|||||||
shape_params = params["shape_params"]
|
shape_params = params["shape_params"]
|
||||||
scale = params["scale"]
|
scale = params["scale"]
|
||||||
|
|
||||||
early_stop = Early_Stop()
|
|
||||||
for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)):
|
for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)):
|
||||||
# for epoch in range(cfg.TRAIN.MAX_EPOCH):
|
# for epoch in range(cfg.TRAIN.MAX_EPOCH):
|
||||||
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
|
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
|
||||||
loss = F.smooth_l1_loss(Jtr.index_select(1, index["smpl_index"]) * 100 * scale,
|
loss = F.smooth_l1_loss(Jtr.index_select(1, index["smpl_index"]) * 100,
|
||||||
target.index_select(1, index["dataset_index"]) * 100)
|
target.index_select(1, index["dataset_index"]) * 100 * scale)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
update_res, stop = early_stop.update(float(loss))
|
meters.update_early_stop(float(loss))
|
||||||
if update_res:
|
if meters.update_res:
|
||||||
res = [pose_params, shape_params, verts, Jtr]
|
res = [pose_params, shape_params, verts, Jtr]
|
||||||
if stop:
|
if meters.early_stop:
|
||||||
logger.info("Early stop at epoch {} !".format(epoch))
|
logger.info("Early stop at epoch {} !".format(epoch))
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -95,6 +74,6 @@ def train(smpl_layer, target,
|
|||||||
writer.add_scalar('learning_rate', float(
|
writer.add_scalar('learning_rate', float(
|
||||||
optimizer.state_dict()['param_groups'][0]['lr']), epoch)
|
optimizer.state_dict()['param_groups'][0]['lr']), epoch)
|
||||||
|
|
||||||
logger.info('Train ended, min_loss = {:.9f}'.format(
|
logger.info('Train ended, min_loss = {:.4f}'.format(
|
||||||
float(early_stop.min_loss)))
|
float(meters.min_loss)))
|
||||||
return res
|
return res
|
||||||
|
|||||||
@ -3,7 +3,8 @@ import numpy as np
|
|||||||
rotate = {
|
rotate = {
|
||||||
'HumanAct12': [1., -1., -1.],
|
'HumanAct12': [1., -1., -1.],
|
||||||
'CMU_Mocap': [0.05, 0.05, 0.05],
|
'CMU_Mocap': [0.05, 0.05, 0.05],
|
||||||
'UTD_MHAD': [-1., 1., -1.]
|
'UTD_MHAD': [-1., 1., -1.],
|
||||||
|
'Human3.6M': [-0.001, -0.001, 0.001]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import imageio, os
|
import imageio, os
|
||||||
images = []
|
images = []
|
||||||
filenames = sorted(fn for fn in os.listdir('./fit/output/CMU_Mocap/picture/fit/01_01') )
|
filenames = sorted(fn for fn in os.listdir('./fit/output/Human3.6M/picture/fit/s_01_act_09_subact_02_ca_02') )
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
images.append(imageio.imread('./fit/output/CMU_Mocap/picture/fit/01_01/'+filename))
|
images.append(imageio.imread('./fit/output/Human3.6M/picture/fit/s_01_act_09_subact_02_ca_02/'+filename))
|
||||||
imageio.mimsave('fit.gif', images, duration=0.2)
|
imageio.mimsave('fit_mesh.gif', images, duration=0.2)
|
||||||
Reference in New Issue
Block a user