Support HAA4D
This commit is contained in:
@ -47,6 +47,7 @@ The SMPL human body layer for Pytorch is from the [smplpytorch](https://github.c
|
|||||||
- [UTD-MHAD](https://personal.utdallas.edu/~kehtar/UTD-MHAD.html)
|
- [UTD-MHAD](https://personal.utdallas.edu/~kehtar/UTD-MHAD.html)
|
||||||
- [Human3.6M](http://vision.imar.ro/human3.6m/description.php)
|
- [Human3.6M](http://vision.imar.ro/human3.6m/description.php)
|
||||||
- [NTU](https://rose1.ntu.edu.sg/dataset/actionRecognition/)
|
- [NTU](https://rose1.ntu.edu.sg/dataset/actionRecognition/)
|
||||||
|
- [HAA4D](https://cse.hkust.edu.hk/haa4d/dataset.html)
|
||||||
|
|
||||||
- Set the **DATASET.PATH** in the corresponding configuration file to the location of dataset.
|
- Set the **DATASET.PATH** in the corresponding configuration file to the location of dataset.
|
||||||
|
|
||||||
|
|||||||
5
demo.py
5
demo.py
@ -1,4 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
|
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
|
||||||
from display_utils import display_model
|
from display_utils import display_model
|
||||||
@ -15,7 +17,7 @@ if __name__ == '__main__':
|
|||||||
model_root='smplpytorch/native/models')
|
model_root='smplpytorch/native/models')
|
||||||
|
|
||||||
# Generate random pose and shape parameters
|
# Generate random pose and shape parameters
|
||||||
pose_params = torch.rand(batch_size, 72) * 0.2
|
pose_params = torch.rand(batch_size, 72) * 0.01
|
||||||
shape_params = torch.rand(batch_size, 10) * 0.03
|
shape_params = torch.rand(batch_size, 10) * 0.03
|
||||||
|
|
||||||
# GPU mode
|
# GPU mode
|
||||||
@ -26,7 +28,6 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
# Forward from the SMPL layer
|
# Forward from the SMPL layer
|
||||||
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
|
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
|
||||||
print(Jtr)
|
|
||||||
|
|
||||||
# Draw output vertices and joints
|
# Draw output vertices and joints
|
||||||
display_model(
|
display_model(
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
from xml.parsers.expat import model
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from mpl_toolkits.mplot3d import Axes3D
|
from mpl_toolkits.mplot3d import Axes3D
|
||||||
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
|
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
|
||||||
@ -21,7 +22,8 @@ def display_model(
|
|||||||
if ax is None:
|
if ax is None:
|
||||||
fig = plt.figure()
|
fig = plt.figure()
|
||||||
ax = fig.add_subplot(111, projection='3d')
|
ax = fig.add_subplot(111, projection='3d')
|
||||||
verts, joints = model_info['verts'][batch_idx], model_info['joints'][batch_idx]
|
verts = model_info['verts'][batch_idx]
|
||||||
|
joints = model_info['joints'][batch_idx]
|
||||||
if model_faces is None:
|
if model_faces is None:
|
||||||
ax.scatter(verts[:, 0], verts[:, 1], verts[:, 2], alpha=0.2)
|
ax.scatter(verts[:, 0], verts[:, 1], verts[:, 2], alpha=0.2)
|
||||||
elif not only_joint:
|
elif not only_joint:
|
||||||
|
|||||||
24
fit/configs/HAA4D.json
Normal file
24
fit/configs/HAA4D.json
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
{
|
||||||
|
"MODEL": {
|
||||||
|
"GENDER": "neutral"
|
||||||
|
},
|
||||||
|
"TRAIN": {
|
||||||
|
"LEARNING_RATE": 1e-2,
|
||||||
|
"MAX_EPOCH": 1000,
|
||||||
|
"WRITE": 10,
|
||||||
|
"OPTIMIZE_SCALE":1,
|
||||||
|
"OPTIMIZE_SHAPE":1
|
||||||
|
},
|
||||||
|
"USE_GPU": 1,
|
||||||
|
"DATASET": {
|
||||||
|
"NAME": "NTU",
|
||||||
|
"PATH": "../NTU RGB+D/result",
|
||||||
|
"TARGET_PATH": "",
|
||||||
|
"DATA_MAP": [
|
||||||
|
[0,0],[1,4],[2,1],[4,5],[5,2],[7,6],[8,3],
|
||||||
|
[12,9],[18,12],[19,15],[20,13],[21,16],
|
||||||
|
[15,10],[6,1]
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"DEBUG": 0
|
||||||
|
}
|
||||||
@ -20,19 +20,19 @@
|
|||||||
0
|
0
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
1,
|
2,
|
||||||
12
|
12
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
2,
|
1,
|
||||||
16
|
16
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
4,
|
5,
|
||||||
13
|
13
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
5,
|
4,
|
||||||
17
|
17
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
@ -40,52 +40,51 @@
|
|||||||
1
|
1
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
7,
|
8,
|
||||||
14
|
14
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
8,
|
7,
|
||||||
18
|
18
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
9,
|
9,
|
||||||
20
|
20
|
||||||
],
|
],
|
||||||
|
|
||||||
[
|
[
|
||||||
12,
|
12,
|
||||||
2
|
2
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
13,
|
14,
|
||||||
4
|
4
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
14,
|
13,
|
||||||
8
|
8
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
18,
|
19,
|
||||||
5
|
5
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
19,
|
18,
|
||||||
9
|
9
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
20,
|
21,
|
||||||
6
|
6
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
21,
|
20,
|
||||||
10
|
10
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
22,
|
23,
|
||||||
22
|
22
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
23,
|
22,
|
||||||
24
|
24
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1156,127 +1156,127 @@ CMU_Mocap = {
|
|||||||
"41_09": "Climb"
|
"41_09": "Climb"
|
||||||
}
|
}
|
||||||
|
|
||||||
NTU={
|
NTU = {
|
||||||
"1":" drink water",
|
"1":"drink water",
|
||||||
"2":" eat meal/snack",
|
"2":"eat meal/snack",
|
||||||
"3":" brushing teeth",
|
"3":"brushing teeth",
|
||||||
"4":" brushing hair",
|
"4":"brushing hair",
|
||||||
"5":" drop",
|
"5":"drop",
|
||||||
"6":" pickup",
|
"6":"pickup",
|
||||||
"7":" throw",
|
"7":"throw",
|
||||||
"8":" sitting down",
|
"8":"sitting down",
|
||||||
"9":" standing up (from sitting position)",
|
"9":"standing up (from sitting position)",
|
||||||
"10":" clapping",
|
"10":"clapping",
|
||||||
"11":" reading",
|
"11":"reading",
|
||||||
"12":" writing",
|
"12":"writing",
|
||||||
"13":" tear up paper",
|
"13":"tear up paper",
|
||||||
"14":" wear jacket",
|
"14":"wear jacket",
|
||||||
"15":" take off jacket",
|
"15":"take off jacket",
|
||||||
"16":" wear a shoe",
|
"16":"wear a shoe",
|
||||||
"17":" take off a shoe",
|
"17":"take off a shoe",
|
||||||
"18":" wear on glasses",
|
"18":"wear on glasses",
|
||||||
"19":" take off glasses",
|
"19":"take off glasses",
|
||||||
"20":" put on a hat/cap",
|
"20":"put on a hat/cap",
|
||||||
"21":" take off a hat/cap",
|
"21":"take off a hat/cap",
|
||||||
"22":" cheer up",
|
"22":"cheer up",
|
||||||
"23":" hand waving",
|
"23":"hand waving",
|
||||||
"24":" kicking something",
|
"24":"kicking something",
|
||||||
"25":" reach into pocket",
|
"25":"reach into pocket",
|
||||||
"26":" hopping (one foot jumping)",
|
"26":"hopping (one foot jumping)",
|
||||||
"27":" jump up",
|
"27":"jump up",
|
||||||
"28":" make a phone call/answer phone",
|
"28":"make a phone call/answer phone",
|
||||||
"29":" playing with phone/tablet",
|
"29":"playing with phone/tablet",
|
||||||
"30":" typing on a keyboard",
|
"30":"typing on a keyboard",
|
||||||
"31":" pointing to something with finger",
|
"31":"pointing to something with finger",
|
||||||
"32":" taking a selfie",
|
"32":"taking a selfie",
|
||||||
"33":" check time (from watch)",
|
"33":"check time (from watch)",
|
||||||
"34":" rub two hands together",
|
"34":"rub two hands together",
|
||||||
"35":" nod head/bow",
|
"35":"nod head/bow",
|
||||||
"36":" shake head",
|
"36":"shake head",
|
||||||
"37":" wipe face",
|
"37":"wipe face",
|
||||||
"38":" salute",
|
"38":"salute",
|
||||||
"39":" put the palms together",
|
"39":"put the palms together",
|
||||||
"40":" cross hands in front (say stop)",
|
"40":"cross hands in front (say stop)",
|
||||||
"41":" sneeze/cough",
|
"41":"sneeze/cough",
|
||||||
"42":" staggering",
|
"42":"staggering",
|
||||||
"43":" falling",
|
"43":"falling",
|
||||||
"44":" touch head (headache)",
|
"44":"touch head (headache)",
|
||||||
"45":" touch chest (stomachache/heart pain)",
|
"45":"touch chest (stomachache/heart pain)",
|
||||||
"46":" touch back (backache)",
|
"46":"touch back (backache)",
|
||||||
"47":" touch neck (neckache)",
|
"47":"touch neck (neckache)",
|
||||||
"48":" nausea or vomiting condition",
|
"48":"nausea or vomiting condition",
|
||||||
"49":" use a fan (with hand or paper)/feeling warm",
|
"49":"use a fan (with hand or paper)/feeling warm",
|
||||||
"50":" punching/slapping other person",
|
"50":"punching/slapping other person",
|
||||||
"51":" kicking other person",
|
"51":"kicking other person",
|
||||||
"52":" pushing other person",
|
"52":"pushing other person",
|
||||||
"53":" pat on back of other person",
|
"53":"pat on back of other person",
|
||||||
"54":" point finger at the other person",
|
"54":"point finger at the other person",
|
||||||
"55":" hugging other person",
|
"55":"hugging other person",
|
||||||
"56":" giving something to other person",
|
"56":"giving something to other person",
|
||||||
"57":" touch other person's pocket",
|
"57":"touch other person's pocket",
|
||||||
"58":" handshaking",
|
"58":"handshaking",
|
||||||
"59":" walking towards each other",
|
"59":"walking towards each other",
|
||||||
"60":" walking apart from each other",
|
"60":"walking apart from each other",
|
||||||
"61":" put on headphone",
|
"61":"put on headphone",
|
||||||
"62":" take off headphone",
|
"62":"take off headphone",
|
||||||
"63":" shoot at the basket",
|
"63":"shoot at the basket",
|
||||||
"64":" bounce ball",
|
"64":"bounce ball",
|
||||||
"65":" tennis bat swing",
|
"65":"tennis bat swing",
|
||||||
"66":" juggling table tennis balls",
|
"66":"juggling table tennis balls",
|
||||||
"67":" hush (quite)",
|
"67":"hush (quite)",
|
||||||
"68":" flick hair",
|
"68":"flick hair",
|
||||||
"69":" thumb up",
|
"69":"thumb up",
|
||||||
"70":" thumb down",
|
"70":"thumb down",
|
||||||
"71":" make ok sign",
|
"71":"make ok sign",
|
||||||
"72":" make victory sign",
|
"72":"make victory sign",
|
||||||
"73":" staple book",
|
"73":"staple book",
|
||||||
"74":" counting money",
|
"74":"counting money",
|
||||||
"75":" cutting nails",
|
"75":"cutting nails",
|
||||||
"76":" cutting paper (using scissors)",
|
"76":"cutting paper (using scissors)",
|
||||||
"77":" snapping fingers",
|
"77":"snapping fingers",
|
||||||
"78":" open bottle",
|
"78":"open bottle",
|
||||||
"79":" sniff (smell)",
|
"79":"sniff (smell)",
|
||||||
"80":" squat down",
|
"80":"squat down",
|
||||||
"81":" toss a coin",
|
"81":"toss a coin",
|
||||||
"82":" fold paper",
|
"82":"fold paper",
|
||||||
"83":" ball up paper",
|
"83":"ball up paper",
|
||||||
"84":" play magic cube",
|
"84":"play magic cube",
|
||||||
"85":" apply cream on face",
|
"85":"apply cream on face",
|
||||||
"86":" apply cream on hand back",
|
"86":"apply cream on hand back",
|
||||||
"87":" put on bag",
|
"87":"put on bag",
|
||||||
"88":" take off bag",
|
"88":"take off bag",
|
||||||
"89":" put something into a bag",
|
"89":"put something into a bag",
|
||||||
"90":" take something out of a bag",
|
"90":"take something out of a bag",
|
||||||
"91":" open a box",
|
"91":"open a box",
|
||||||
"92":" move heavy objects",
|
"92":"move heavy objects",
|
||||||
"93":" shake fist",
|
"93":"shake fist",
|
||||||
"94":" throw up cap/hat",
|
"94":"throw up cap/hat",
|
||||||
"95":" hands up (both hands)",
|
"95":"hands up (both hands)",
|
||||||
"96":" cross arms",
|
"96":"cross arms",
|
||||||
"97":" arm circles",
|
"97":"arm circles",
|
||||||
"98":" arm swings",
|
"98":"arm swings",
|
||||||
"99":" running on the spot",
|
"99":"running on the spot",
|
||||||
"100":" butt kicks (kick backward)",
|
"100":"butt kicks (kick backward)",
|
||||||
"101":" cross toe touch",
|
"101":"cross toe touch",
|
||||||
"102":" side kick",
|
"102":"side kick",
|
||||||
"103":" yawn",
|
"103":"yawn",
|
||||||
"104":" stretch oneself",
|
"104":"stretch oneself",
|
||||||
"105":" blow nose",
|
"105":"blow nose",
|
||||||
"106":" hit other person with something",
|
"106":"hit other person with something",
|
||||||
"107":" wield knife towards other person",
|
"107":"wield knife towards other person",
|
||||||
"108":" knock over other person (hit with body)",
|
"108":"knock over other person (hit with body)",
|
||||||
"109":" grab other person’s stuff",
|
"109":"grab other person’s stuff",
|
||||||
"110":" shoot at other person with a gun",
|
"110":"shoot at other person with a gun",
|
||||||
"111":" step on foot",
|
"111":"step on foot",
|
||||||
"112":" high-five",
|
"112":"high-five",
|
||||||
"113":" cheers and drink",
|
"113":"cheers and drink",
|
||||||
"114":" carry something with other person",
|
"114":"carry something with other person",
|
||||||
"115":" take a photo of other person",
|
"115":"take a photo of other person",
|
||||||
"116":" follow other person",
|
"116":"follow other person",
|
||||||
"117":" whisper in other person’s ear",
|
"117":"whisper in other person’s ear",
|
||||||
"118":" exchange things with other person",
|
"118":"exchange things with other person",
|
||||||
"119":" support somebody with hand",
|
"119":"support somebody with hand",
|
||||||
"120":" finger-guessing game (playing rock-paper-scissors)",
|
"120":"finger-guessing game (playing rock-paper-scissors)",
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_label(file_name, dataset_name):
|
def get_label(file_name, dataset_name):
|
||||||
|
|||||||
@ -19,5 +19,7 @@ def load(name, path):
|
|||||||
elif name == "Human3.6M":
|
elif name == "Human3.6M":
|
||||||
return np.load(path, allow_pickle=True)[0::5] # down_sample
|
return np.load(path, allow_pickle=True)[0::5] # down_sample
|
||||||
elif name == "NTU":
|
elif name == "NTU":
|
||||||
|
return np.load(path, allow_pickle=True)[0::2]
|
||||||
|
elif name == "HAA4D":
|
||||||
return np.load(path, allow_pickle=True)
|
return np.load(path, allow_pickle=True)
|
||||||
|
|
||||||
|
|||||||
@ -1,32 +1,33 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
sys.path.append(os.getcwd())
|
||||||
|
from meters import Meters
|
||||||
|
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
|
||||||
|
from train import train
|
||||||
|
from transform import transform
|
||||||
|
from save import save_pic, save_params
|
||||||
|
from load import load
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
from easydict import EasyDict as edict
|
from easydict import EasyDict as edict
|
||||||
import time
|
import time
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
|
|
||||||
sys.path.append(os.getcwd())
|
|
||||||
from load import load
|
|
||||||
from save import save_pic, save_params
|
|
||||||
from transform import transform
|
|
||||||
from train import train
|
|
||||||
from smplpytorch.pytorch.smpl_layer import SMPL_Layer
|
|
||||||
from meters import Meters
|
|
||||||
|
|
||||||
torch.backends.cudnn.enabled = True
|
torch.backends.cudnn.enabled = True
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description='Fit SMPL')
|
parser = argparse.ArgumentParser(description='Fit SMPL')
|
||||||
parser.add_argument('--exp', dest='exp',
|
parser.add_argument('--exp', dest='exp',
|
||||||
help='Define exp name',
|
help='Define exp name',
|
||||||
default=time.strftime('%Y-%m-%d %H-%M-%S', time.localtime(time.time())), type=str)
|
default=time.strftime('%Y-%m-%d %H-%M-%S', time.localtime(time.time())), type=str)
|
||||||
parser.add_argument('--dataset_name', dest='dataset_name',
|
parser.add_argument('--dataset_name', '-n', dest='dataset_name',
|
||||||
help='select dataset',
|
help='select dataset',
|
||||||
default='', type=str)
|
default='', type=str)
|
||||||
parser.add_argument('--dataset_path', dest='dataset_path',
|
parser.add_argument('--dataset_path', dest='dataset_path',
|
||||||
@ -98,13 +99,16 @@ if __name__ == "__main__":
|
|||||||
gender=cfg.MODEL.GENDER,
|
gender=cfg.MODEL.GENDER,
|
||||||
model_root='smplpytorch/native/models')
|
model_root='smplpytorch/native/models')
|
||||||
|
|
||||||
meters=Meters()
|
meters = Meters()
|
||||||
file_num = 0
|
file_num = 0
|
||||||
for root, dirs, files in os.walk(cfg.DATASET.PATH):
|
for root, dirs, files in os.walk(cfg.DATASET.PATH):
|
||||||
for file in files:
|
for file in sorted(files):
|
||||||
|
if not 'baseball_swing' in file:
|
||||||
|
continue
|
||||||
file_num += 1
|
file_num += 1
|
||||||
logger.info('Processing file: {} [{} / {}]'.format(file, file_num, len(files)))
|
logger.info(
|
||||||
target = torch.from_numpy(transform(args.dataset_name,load(args.dataset_name,
|
'Processing file: {} [{} / {}]'.format(file, file_num, len(files)))
|
||||||
|
target = torch.from_numpy(transform(args.dataset_name, load(args.dataset_name,
|
||||||
os.path.join(root, file)))).float()
|
os.path.join(root, file)))).float()
|
||||||
logger.info("target shape:{}".format(target.shape))
|
logger.info("target shape:{}".format(target.shape))
|
||||||
res = train(smpl_layer, target,
|
res = train(smpl_layer, target,
|
||||||
@ -115,7 +119,8 @@ if __name__ == "__main__":
|
|||||||
logger.info("avg_loss:{:.4f}".format(meters.avg))
|
logger.info("avg_loss:{:.4f}".format(meters.avg))
|
||||||
|
|
||||||
save_params(res, file, logger, args.dataset_name)
|
save_params(res, file, logger, args.dataset_name)
|
||||||
# save_pic(res,smpl_layer,file,logger,args.dataset_name,target)
|
save_pic(res, smpl_layer, file, logger, args.dataset_name, target)
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
logger.info("Fitting finished! Average loss: {:.9f}".format(meters.avg))
|
logger.info(
|
||||||
|
"Fitting finished! Average loss: {:.9f}".format(meters.avg))
|
||||||
|
|||||||
@ -19,7 +19,7 @@ def save_pic(res, smpl_layer, file, logger, dataset_name, target):
|
|||||||
_, _, verts, Jtr = res
|
_, _, verts, Jtr = res
|
||||||
file_name = re.split('[/.]', file)[-2]
|
file_name = re.split('[/.]', file)[-2]
|
||||||
fit_path = "fit/output/{}/picture/{}".format(dataset_name, file_name)
|
fit_path = "fit/output/{}/picture/{}".format(dataset_name, file_name)
|
||||||
create_dir_not_exist(fit_path)
|
os.makedirs(fit_path,exist_ok=True)
|
||||||
logger.info('Saving pictures at {}'.format(fit_path))
|
logger.info('Saving pictures at {}'.format(fit_path))
|
||||||
for i in tqdm(range(Jtr.shape[0])):
|
for i in tqdm(range(Jtr.shape[0])):
|
||||||
display_model(
|
display_model(
|
||||||
@ -28,7 +28,7 @@ def save_pic(res, smpl_layer, file, logger, dataset_name, target):
|
|||||||
model_faces=smpl_layer.th_faces,
|
model_faces=smpl_layer.th_faces,
|
||||||
with_joints=True,
|
with_joints=True,
|
||||||
kintree_table=smpl_layer.kintree_table,
|
kintree_table=smpl_layer.kintree_table,
|
||||||
savepath=os.path.join(fit_path+"/frame_{}".format(i)),
|
savepath=os.path.join(fit_path+"/frame_{:0>4d}".format(i)),
|
||||||
batch_idx=i,
|
batch_idx=i,
|
||||||
show=False,
|
show=False,
|
||||||
only_joint=True)
|
only_joint=True)
|
||||||
@ -55,3 +55,21 @@ def save_params(res, file, logger, dataset_name):
|
|||||||
with open(os.path.join((fit_path),
|
with open(os.path.join((fit_path),
|
||||||
"{}_params.pkl".format(file_name)), 'wb') as f:
|
"{}_params.pkl".format(file_name)), 'wb') as f:
|
||||||
pickle.dump(params, f)
|
pickle.dump(params, f)
|
||||||
|
|
||||||
|
|
||||||
|
def save_single_pic(res, smpl_layer, epoch, logger, dataset_name, target):
|
||||||
|
_, _, verts, Jtr = res
|
||||||
|
fit_path = "fit/output/{}/picture".format(dataset_name)
|
||||||
|
create_dir_not_exist(fit_path)
|
||||||
|
logger.info('Saving pictures at {}'.format(fit_path))
|
||||||
|
display_model(
|
||||||
|
{'verts': verts.cpu().detach(),
|
||||||
|
'joints': Jtr.cpu().detach()},
|
||||||
|
model_faces=smpl_layer.th_faces,
|
||||||
|
with_joints=True,
|
||||||
|
kintree_table=smpl_layer.kintree_table,
|
||||||
|
savepath=fit_path+"/epoch_{:0>4d}".format(epoch),
|
||||||
|
batch_idx=60,
|
||||||
|
show=False,
|
||||||
|
only_joint=False)
|
||||||
|
logger.info('Picture saved')
|
||||||
@ -6,6 +6,7 @@ import os
|
|||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
sys.path.append(os.getcwd())
|
sys.path.append(os.getcwd())
|
||||||
|
from save import save_single_pic
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -25,8 +26,10 @@ def init(smpl_layer, target, device, cfg):
|
|||||||
params["shape_params"].requires_grad = bool(cfg.TRAIN.OPTIMIZE_SHAPE)
|
params["shape_params"].requires_grad = bool(cfg.TRAIN.OPTIMIZE_SHAPE)
|
||||||
params["scale"].requires_grad = bool(cfg.TRAIN.OPTIMIZE_SCALE)
|
params["scale"].requires_grad = bool(cfg.TRAIN.OPTIMIZE_SCALE)
|
||||||
|
|
||||||
optimizer = optim.Adam([params["pose_params"], params["shape_params"], params["scale"]],
|
optim_params = [{'params': params["pose_params"], 'lr': cfg.TRAIN.LEARNING_RATE},
|
||||||
lr=cfg.TRAIN.LEARNING_RATE)
|
{'params': params["shape_params"], 'lr': cfg.TRAIN.LEARNING_RATE},
|
||||||
|
{'params': params["scale"], 'lr': cfg.TRAIN.LEARNING_RATE*10},]
|
||||||
|
optimizer = optim.Adam(optim_params)
|
||||||
|
|
||||||
index = {}
|
index = {}
|
||||||
smpl_index = []
|
smpl_index = []
|
||||||
@ -51,11 +54,14 @@ def train(smpl_layer, target,
|
|||||||
shape_params = params["shape_params"]
|
shape_params = params["shape_params"]
|
||||||
scale = params["scale"]
|
scale = params["scale"]
|
||||||
|
|
||||||
for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)):
|
with torch.no_grad():
|
||||||
# for epoch in range(cfg.TRAIN.MAX_EPOCH):
|
|
||||||
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
|
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
|
||||||
loss = F.smooth_l1_loss(Jtr.index_select(1, index["smpl_index"]) * 100,
|
params["scale"]*=(torch.max(torch.abs(target))/torch.max(torch.abs(Jtr)))
|
||||||
target.index_select(1, index["dataset_index"]) * 100 * scale)
|
|
||||||
|
for epoch in tqdm(range(cfg.TRAIN.MAX_EPOCH)):
|
||||||
|
verts, Jtr = smpl_layer(pose_params, th_betas=shape_params)
|
||||||
|
loss = F.smooth_l1_loss(scale*Jtr.index_select(1, index["smpl_index"]),
|
||||||
|
target.index_select(1, index["dataset_index"]))
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
@ -67,12 +73,15 @@ def train(smpl_layer, target,
|
|||||||
logger.info("Early stop at epoch {} !".format(epoch))
|
logger.info("Early stop at epoch {} !".format(epoch))
|
||||||
break
|
break
|
||||||
|
|
||||||
if epoch % cfg.TRAIN.WRITE == 0:
|
if epoch % cfg.TRAIN.WRITE == 0 or epoch<10:
|
||||||
# logger.info("Epoch {}, lossPerBatch={:.6f}, scale={:.4f}".format(
|
# logger.info("Epoch {}, lossPerBatch={:.6f}, scale={:.4f}".format(
|
||||||
# epoch, float(loss),float(scale)))
|
# epoch, float(loss),float(scale)))
|
||||||
|
print("Epoch {}, lossPerBatch={:.6f}, scale={:.4f}".format(
|
||||||
|
epoch, float(loss),float(scale)))
|
||||||
writer.add_scalar('loss', float(loss), epoch)
|
writer.add_scalar('loss', float(loss), epoch)
|
||||||
writer.add_scalar('learning_rate', float(
|
writer.add_scalar('learning_rate', float(
|
||||||
optimizer.state_dict()['param_groups'][0]['lr']), epoch)
|
optimizer.state_dict()['param_groups'][0]['lr']), epoch)
|
||||||
|
# save_single_pic(res,smpl_layer,epoch,logger,args.dataset_name,target)
|
||||||
|
|
||||||
logger.info('Train ended, min_loss = {:.4f}'.format(
|
logger.info('Train ended, min_loss = {:.4f}'.format(
|
||||||
float(meters.min_loss)))
|
float(meters.min_loss)))
|
||||||
|
|||||||
@ -5,7 +5,8 @@ rotate = {
|
|||||||
'CMU_Mocap': [0.05, 0.05, 0.05],
|
'CMU_Mocap': [0.05, 0.05, 0.05],
|
||||||
'UTD_MHAD': [-1., 1., -1.],
|
'UTD_MHAD': [-1., 1., -1.],
|
||||||
'Human3.6M': [-0.001, -0.001, 0.001],
|
'Human3.6M': [-0.001, -0.001, 0.001],
|
||||||
'NTU': [-1., 1., -1.]
|
'NTU': [1., 1., -1.],
|
||||||
|
'HAA4D': [1., -1., -1.],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import imageio, os
|
import imageio, os
|
||||||
images = []
|
images = []
|
||||||
filenames = sorted(fn for fn in os.listdir('./fit/output/Human3.6M/picture/fit/s_01_act_09_subact_02_ca_02') )
|
filenames = sorted(fn for fn in os.listdir('D:/OneDrive - sjtu.edu.cn/MVIG/Action-Dataset/Pose_to_SMPL/fit/output/NTU/picture') )
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
images.append(imageio.imread('./fit/output/Human3.6M/picture/fit/s_01_act_09_subact_02_ca_02/'+filename))
|
images.append(imageio.imread('D:/OneDrive - sjtu.edu.cn/MVIG/Action-Dataset/Pose_to_SMPL/fit/output/NTU/picture/'+filename))
|
||||||
imageio.mimsave('fit_mesh.gif', images, duration=0.2)
|
imageio.mimsave('clapping_example.gif', images, duration=0.2)
|
||||||
Reference in New Issue
Block a user