316 lines
12 KiB
Python
316 lines
12 KiB
Python
import os
|
|
import torch
|
|
import logging
|
|
import random
|
|
import torch.optim as optim
|
|
from tqdm import tqdm
|
|
# from torch.utils.tensorboard import SummaryWriter
|
|
|
|
from common.utils import *
|
|
from common.opt import opts
|
|
from common.h36m_dataset import Human36mDataset
|
|
from common.Mydataset import Fusion
|
|
|
|
from model.SGraFormer import sgraformer
|
|
|
|
import matplotlib.pyplot as plt
|
|
from mpl_toolkits.mplot3d import Axes3D
|
|
import numpy as np
|
|
|
|
|
|
# os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
|
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
|
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|
CUDA_ID = [0]
|
|
device = torch.device("cuda")
|
|
|
|
|
|
def visualize_skeletons(input_2D, output_3D, gt_3D, idx=0, output_dir='./output'):
|
|
# Ensure the tensors are on the CPU and convert them to numpy arrays
|
|
input_2D = input_2D.cpu().numpy()
|
|
output_3D = output_3D.cpu().numpy()
|
|
gt_3D = gt_3D.cpu().numpy()
|
|
|
|
# print("====> input_2D: ", input_2D[-1])
|
|
# Get the first action and first sample from the batch
|
|
input_sample = input_2D[idx, 0]
|
|
output_sample = output_3D[idx, 0]
|
|
gt_3D_sample = gt_3D[idx, 0]
|
|
|
|
print(f'\ninput_sample shape: {input_sample.shape}')
|
|
print(f'output_sample shape: {output_sample.shape}')
|
|
|
|
fig = plt.figure(figsize=(25, 5))
|
|
|
|
# Define the connections (bones) between joints
|
|
bones = [
|
|
(0, 1), (1, 2), (2, 3), # Left leg
|
|
(0, 4), (4, 5), (5, 6), # Right leg
|
|
(0, 7), (7, 8), (8, 9), (9, 10), # Spine
|
|
(7, 11), (11, 12), (12, 13), # Right arm
|
|
(7, 14), (14, 15), (15, 16) # Left arm
|
|
]
|
|
|
|
# Colors for different parts
|
|
bone_colors = {
|
|
"leg": 'green',
|
|
"spine": 'blue',
|
|
"arm": 'red'
|
|
}
|
|
|
|
# Function to get bone color based on index
|
|
def get_bone_color(start, end):
|
|
if (start in [1, 2, 3] or end in [1, 2, 3] or
|
|
start in [4, 5, 6] or end in [4, 5, 6]):
|
|
return bone_colors["leg"]
|
|
elif start in [7, 8, 9, 10] or end in [7, 8, 9, 10]:
|
|
return bone_colors["spine"]
|
|
else:
|
|
return bone_colors["arm"]
|
|
|
|
# Plotting 2D skeletons from different angles
|
|
for i in range(4):
|
|
ax = fig.add_subplot(1, 7, i + 1)
|
|
ax.set_title(f'2D angle {i+1}')
|
|
ax.scatter(input_sample[i, :, 0], input_sample[i, :, 1], color='blue')
|
|
|
|
# Draw the bones
|
|
for start, end in bones:
|
|
bone_color = get_bone_color(start, end)
|
|
ax.plot([input_sample[i, start, 0], input_sample[i, end, 0]],
|
|
[input_sample[i, start, 1], input_sample[i, end, 1]], color=bone_color)
|
|
|
|
ax.set_xlabel('X')
|
|
ax.set_ylabel('Y')
|
|
ax.set_xlim(np.min(input_sample[:, :, 0]) - 1, np.max(input_sample[:, :, 0]) + 1)
|
|
ax.set_ylim(np.min(input_sample[:, :, 1]) - 1, np.max(input_sample[:, :, 1]) + 1)
|
|
ax.grid()
|
|
|
|
# Plotting predicted 3D skeleton
|
|
ax = fig.add_subplot(1, 7, 5, projection='3d')
|
|
ax.set_title('3D Predicted Skeleton')
|
|
ax.scatter(output_sample[:, 0], output_sample[:, 1], output_sample[:, 2], color='red', label='Predicted')
|
|
|
|
# Draw the bones in 3D for output_sample
|
|
for start, end in bones:
|
|
bone_color = get_bone_color(start, end)
|
|
ax.plot([output_sample[start, 0], output_sample[end, 0]],
|
|
[output_sample[start, 1], output_sample[end, 1]],
|
|
[output_sample[start, 2], output_sample[end, 2]], color=bone_color)
|
|
|
|
ax.set_xlabel('X')
|
|
ax.set_ylabel('Y')
|
|
ax.set_zlabel('Z')
|
|
ax.set_xlim(np.min(output_sample[:, 0]) - 1, np.max(output_sample[:, 0]) + 1)
|
|
ax.set_ylim(np.min(output_sample[:, 1]) - 1, np.max(output_sample[:, 1]) + 1)
|
|
ax.set_zlim(np.min(output_sample[:, 2]) - 1, np.max(output_sample[:, 2]) + 1)
|
|
ax.legend()
|
|
|
|
# Plotting ground truth 3D skeleton
|
|
ax = fig.add_subplot(1, 7, 6, projection='3d')
|
|
ax.set_title('3D Ground Truth Skeleton')
|
|
ax.scatter(gt_3D_sample[:, 0], gt_3D_sample[:, 1], gt_3D_sample[:, 2], color='blue', label='Ground Truth')
|
|
|
|
# Draw the bones in 3D for gt_3D_sample
|
|
for start, end in bones:
|
|
bone_color = get_bone_color(start, end)
|
|
ax.plot([gt_3D_sample[start, 0], gt_3D_sample[end, 0]],
|
|
[gt_3D_sample[start, 1], gt_3D_sample[end, 1]],
|
|
[gt_3D_sample[start, 2], gt_3D_sample[end, 2]], color=bone_color, linestyle='--')
|
|
|
|
ax.set_xlabel('X')
|
|
ax.set_ylabel('Y')
|
|
ax.set_zlabel('Z')
|
|
ax.set_xlim(np.min(gt_3D_sample[:, 0]) - 1, np.max(gt_3D_sample[:, 0]) + 1)
|
|
ax.set_ylim(np.min(gt_3D_sample[:, 1]) - 1, np.max(gt_3D_sample[:, 1]) + 1)
|
|
ax.set_zlim(np.min(gt_3D_sample[:, 2]) - 1, np.max(gt_3D_sample[:, 2]) + 1)
|
|
ax.legend()
|
|
|
|
plt.grid()
|
|
|
|
# Save the figure
|
|
plt.tight_layout()
|
|
plt.savefig(f'{output_dir}/skeletons_visualization.png')
|
|
plt.show()
|
|
|
|
def train(opt, actions, train_loader, model, optimizer, epoch, writer, adaptive_weight=None):
|
|
return step('train', opt, actions, train_loader, model, optimizer, epoch, writer, adaptive_weight)
|
|
|
|
|
|
def val(opt, actions, val_loader, model):
|
|
with torch.no_grad():
|
|
return step('test', opt, actions, val_loader, model)
|
|
|
|
|
|
def step(split, opt, actions, dataLoader, model, optimizer=None, epoch=None, writer=None, adaptive_weight=None):
|
|
loss_all = {'loss': AccumLoss()}
|
|
action_error_sum = define_error_list(actions)
|
|
|
|
if split == 'train':
|
|
model.train()
|
|
else:
|
|
model.eval()
|
|
|
|
TQDM = tqdm(enumerate(dataLoader), total=len(dataLoader), ncols=100)
|
|
for i, data in TQDM:
|
|
batch_cam, gt_3D, input_2D, action, subject, scale, bb_box, start, end, hops = data
|
|
|
|
[input_2D, gt_3D, batch_cam, scale, bb_box, hops] = get_varialbe(split, [input_2D, gt_3D, batch_cam, scale, bb_box, hops])
|
|
|
|
if split == 'train':
|
|
output_3D = model(input_2D, hops)
|
|
elif split == 'test':
|
|
# input_2D = input_2D.to(device)
|
|
# model = model.to(device)
|
|
# hops = hops.to(device)
|
|
input_2D, output_3D = input_augmentation(input_2D, hops, model)
|
|
|
|
|
|
visualize_skeletons(input_2D, output_3D, gt_3D)
|
|
|
|
out_target = gt_3D.clone()
|
|
out_target[:, :, 0] = 0
|
|
|
|
if split == 'train':
|
|
loss = mpjpe_cal(output_3D, out_target)
|
|
|
|
TQDM.set_description(f'Epoch [{epoch}/{opt.nepoch}]')
|
|
TQDM.set_postfix({"l": loss.item()})
|
|
|
|
N = input_2D.size(0)
|
|
loss_all['loss'].update(loss.detach().cpu().numpy() * N, N)
|
|
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
# writer.add_scalars(main_tag='scalars1/train_loss',
|
|
# tag_scalar_dict={'trianloss': loss.item()},
|
|
# global_step=(epoch - 1) * len(dataLoader) + i)
|
|
|
|
elif split == 'test':
|
|
if output_3D.shape[1] != 1:
|
|
output_3D = output_3D[:, opt.pad].unsqueeze(1)
|
|
output_3D[:, :, 1:, :] -= output_3D[:, :, :1, :]
|
|
output_3D[:, :, 0, :] = 0
|
|
action_error_sum = test_calculation(output_3D, out_target, action, action_error_sum, opt.dataset, subject)
|
|
|
|
if split == 'train':
|
|
return loss_all['loss'].avg
|
|
elif split == 'test':
|
|
p1, p2 = print_error(opt.dataset, action_error_sum, opt.train)
|
|
return p1, p2
|
|
|
|
|
|
def input_augmentation(input_2D, hops, model):
|
|
input_2D_non_flip = input_2D[:, 0]
|
|
output_3D_non_flip = model(input_2D_non_flip, hops)
|
|
|
|
return input_2D_non_flip, output_3D_non_flip
|
|
|
|
|
|
if __name__ == '__main__':
|
|
opt = opts().parse()
|
|
root_path = opt.root_path
|
|
opt.manualSeed = 1
|
|
random.seed(opt.manualSeed)
|
|
torch.manual_seed(opt.manualSeed)
|
|
|
|
if opt.train:
|
|
logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%Y/%m/%d %H:%M:%S',
|
|
filename=os.path.join(opt.checkpoint, 'train.log'), level=logging.INFO)
|
|
|
|
root_path = opt.root_path
|
|
dataset_path = root_path + 'data_3d_' + opt.dataset + '.npz'
|
|
|
|
dataset = Human36mDataset(dataset_path, opt)
|
|
actions = define_actions(opt.actions)
|
|
|
|
if opt.train:
|
|
train_data = Fusion(opt=opt, train=True, dataset=dataset, root_path=root_path)
|
|
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=opt.batch_size,
|
|
shuffle=True, num_workers=int(opt.workers), pin_memory=True)
|
|
|
|
test_data = Fusion(opt=opt, train=False, dataset=dataset, root_path=root_path)
|
|
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=opt.batch_size,
|
|
shuffle=False, num_workers=int(opt.workers), pin_memory=True)
|
|
|
|
model = sgraformer(num_frame=opt.frames, num_joints=17, in_chans=2, embed_dim_ratio=32, depth=4,
|
|
num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None, drop_path_rate=0.1)
|
|
# model = FuseModel()
|
|
|
|
if torch.cuda.device_count() > 1:
|
|
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
|
model = torch.nn.DataParallel(model, device_ids=CUDA_ID).to(device)
|
|
else:
|
|
model = model.to(device)
|
|
|
|
# 定义一个函数来去除 'module.' 前缀
|
|
def remove_module_prefix(state_dict):
|
|
new_state_dict = {}
|
|
for k, v in state_dict.items():
|
|
name = k[7:] if k.startswith('module.') else k # 去除 `module.`
|
|
new_state_dict[name] = v
|
|
return new_state_dict
|
|
|
|
model_dict = model.state_dict()
|
|
if opt.previous_dir != '':
|
|
print('pretrained model path:', opt.previous_dir)
|
|
model_path = opt.previous_dir
|
|
pre_dict = torch.load(model_path)
|
|
# print("=====> pre_dict:", pre_dict.keys())
|
|
# 去除 'module.' 前缀
|
|
state_dict = remove_module_prefix(pre_dict)
|
|
# print("=====> state_dict:", state_dict.keys())
|
|
# 只保留在模型字典中的键值对
|
|
state_dict = {k: v for k, v in state_dict.items() if k in model_dict.keys()}
|
|
# 更新模型字典
|
|
model_dict.update(state_dict)
|
|
# 加载更新后的模型字典
|
|
model.load_state_dict(model_dict)
|
|
|
|
|
|
all_param = []
|
|
lr = opt.lr
|
|
all_param += list(model.parameters())
|
|
|
|
optimizer = optim.AdamW(all_param, lr=lr, weight_decay=0.1)
|
|
|
|
## tensorboard
|
|
# writer = SummaryWriter("runs/nin")
|
|
writer = None
|
|
flag = 0
|
|
|
|
|
|
for epoch in range(1, opt.nepoch + 1):
|
|
p1, p2 = val(opt, actions, test_dataloader, model)
|
|
print("=====> p1, p2", p1, p2)
|
|
if opt.train:
|
|
loss = train(opt, actions, train_dataloader, model, optimizer, epoch, writer)
|
|
|
|
|
|
if opt.train:
|
|
save_model_epoch(opt.checkpoint, epoch, model)
|
|
|
|
if p1 < opt.previous_best_threshold:
|
|
opt.previous_name = save_model(opt.previous_name, opt.checkpoint, epoch, p1, model)
|
|
opt.previous_best_threshold = p1
|
|
|
|
if opt.train == 0:
|
|
print('p1: %.2f, p2: %.2f' % (p1, p2))
|
|
break
|
|
else:
|
|
logging.info('epoch: %d, lr: %.7f, loss: %.4f, p1: %.2f, p2: %.2f' % (epoch, lr, loss, p1, p2))
|
|
print('e: %d, lr: %.7f, loss: %.4f, p1: %.2f, p2: %.2f' % (epoch, lr, loss, p1, p2))
|
|
|
|
if epoch % opt.large_decay_epoch == 0:
|
|
for param_group in optimizer.param_groups:
|
|
param_group['lr'] *= opt.lr_decay_large
|
|
lr *= opt.lr_decay_large
|
|
else:
|
|
for param_group in optimizer.param_groups:
|
|
param_group['lr'] *= opt.lr_decay
|
|
lr *= opt.lr_decay
|
|
|
|
print(opt.checkpoint)
|