Files
SGraFormer/test.py
2024-08-05 11:19:19 +08:00

257 lines
8.9 KiB
Python

import os
import torch
import logging
import random
import torch.optim as optim
from tqdm import tqdm
# from torch.utils.tensorboard import SummaryWriter
from common.utils import *
from common.opt import opts
from common.h36m_dataset import Human36mDataset
from common.Mydataset import Fusion
from model.SGraFormer import sgraformer
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
CUDA_ID = [0]
device = torch.device("cuda")
def visualize_skeletons(input_2D, output_3D, gt_3D, idx=5, output_dir='./output'):
# Ensure the tensors are on the CPU and convert them to numpy arrays
input_2D = input_2D.cpu().numpy()
output_3D = output_3D.cpu().numpy()
gt_3D = gt_3D.cpu().numpy()
# Get the first action and first sample from the batch
input_sample = input_2D[idx, 0]
output_sample = output_3D[idx, 0]
gt_3D_sample = gt_3D[idx, 0]
print(f'\ninput_sample shape: {input_sample.shape}')
print(f'output_sample shape: {output_sample.shape}')
fig = plt.figure(figsize=(25, 5))
# Define the connections (bones) between joints
bones = [
(0, 1), (1, 2), (2, 3), # Left leg
(0, 4), (4, 5), (5, 6), # Right leg
(0, 7), (7, 8), (8, 9), (9, 10), # Spine
(7, 11), (11, 12), (12, 13), # Right arm
(7, 14), (14, 15), (15, 16) # Left arm
]
# Colors for different parts
bone_colors = {
"leg": 'green',
"spine": 'blue',
"arm": 'red'
}
# Function to get bone color based on index
def get_bone_color(start, end):
if (start in [1, 2, 3] or end in [1, 2, 3] or
start in [4, 5, 6] or end in [4, 5, 6]):
return bone_colors["leg"]
elif start in [7, 8, 9, 10] or end in [7, 8, 9, 10]:
return bone_colors["spine"]
else:
return bone_colors["arm"]
# Plotting 2D skeletons from different angles
for i in range(4):
ax = fig.add_subplot(1, 7, i + 1)
ax.set_title(f'2D angle {i+1}')
ax.scatter(input_sample[i, :, 0], input_sample[i, :, 1], color='blue')
# Draw the bones
for start, end in bones:
bone_color = get_bone_color(start, end)
ax.plot([input_sample[i, start, 0], input_sample[i, end, 0]],
[input_sample[i, start, 1], input_sample[i, end, 1]], color=bone_color)
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_xlim(np.min(input_sample[:, :, 0]) - 1, np.max(input_sample[:, :, 0]) + 1)
ax.set_ylim(np.min(input_sample[:, :, 1]) - 1, np.max(input_sample[:, :, 1]) + 1)
ax.grid()
# Plotting predicted 3D skeleton
ax = fig.add_subplot(1, 7, 5, projection='3d')
ax.set_title('3D Predicted Skeleton')
ax.scatter(output_sample[:, 0], output_sample[:, 1], output_sample[:, 2], color='red', label='Predicted')
# Draw the bones in 3D for output_sample
for start, end in bones:
bone_color = get_bone_color(start, end)
ax.plot([output_sample[start, 0], output_sample[end, 0]],
[output_sample[start, 1], output_sample[end, 1]],
[output_sample[start, 2], output_sample[end, 2]], color=bone_color)
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_xlim(np.min(output_sample[:, 0]) - 1, np.max(output_sample[:, 0]) + 1)
ax.set_ylim(np.min(output_sample[:, 1]) - 1, np.max(output_sample[:, 1]) + 1)
ax.set_zlim(np.min(output_sample[:, 2]) - 1, np.max(output_sample[:, 2]) + 1)
ax.legend()
# Plotting ground truth 3D skeleton
ax = fig.add_subplot(1, 7, 6, projection='3d')
ax.set_title('3D Ground Truth Skeleton')
ax.scatter(gt_3D_sample[:, 0], gt_3D_sample[:, 1], gt_3D_sample[:, 2], color='blue', label='Ground Truth')
# Draw the bones in 3D for gt_3D_sample
for start, end in bones:
bone_color = get_bone_color(start, end)
ax.plot([gt_3D_sample[start, 0], gt_3D_sample[end, 0]],
[gt_3D_sample[start, 1], gt_3D_sample[end, 1]],
[gt_3D_sample[start, 2], gt_3D_sample[end, 2]], color=bone_color, linestyle='--')
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_xlim(np.min(gt_3D_sample[:, 0]) - 1, np.max(gt_3D_sample[:, 0]) + 1)
ax.set_ylim(np.min(gt_3D_sample[:, 1]) - 1, np.max(gt_3D_sample[:, 1]) + 1)
ax.set_zlim(np.min(gt_3D_sample[:, 2]) - 1, np.max(gt_3D_sample[:, 2]) + 1)
ax.legend()
plt.grid()
# Save the figure
plt.tight_layout()
plt.savefig(f'{output_dir}/skeletons_visualization.png')
plt.show()
def val(opt, actions, val_loader, model):
with torch.no_grad():
return step('test', opt, actions, val_loader, model)
def step(split, opt, actions, dataLoader, model, optimizer=None, epoch=None, writer=None, adaptive_weight=None):
loss_all = {'loss': AccumLoss()}
action_error_sum = define_error_list(actions)
model.eval()
TQDM = tqdm(enumerate(dataLoader), total=len(dataLoader), ncols=100)
for i, data in TQDM:
batch_cam, gt_3D, input_2D, action, subject, scale, bb_box, start, end, hops = data
[input_2D, gt_3D, batch_cam, scale, bb_box, hops] = get_varialbe(split, [input_2D, gt_3D, batch_cam, scale, bb_box, hops])
# print("\n======> input_2D: ", input_2D.shape)
# print("======> gt_3D: ", gt_3D.shape)
if split == 'train':
output_3D = model(input_2D, hops)
elif split == 'test':
input_2D, output_3D = input_augmentation(input_2D, hops, model)
out_target = gt_3D.clone()
out_target[:, :, 0] = 0
# print("======> output_3D: ", output_3D.shape)
# visualize_skeletons(input_2D, output_3D, gt_3D)
if output_3D.shape[1] != 1:
output_3D = output_3D[:, opt.pad].unsqueeze(1)
output_3D[:, :, 1:, :] -= output_3D[:, :, :1, :]
output_3D[:, :, 0, :] = 0
action_error_sum = test_calculation(output_3D, out_target, action, action_error_sum, opt.dataset, subject)
p1, p2 = print_error(opt.dataset, action_error_sum, opt.train)
# print("======> p1, p2: ", p1, p2)
if split == 'train':
return loss_all['loss'].avg
elif split == 'test':
p1, p2 = print_error(opt.dataset, action_error_sum, opt.train)
return p1, p2
def input_augmentation(input_2D, hops, model):
input_2D_non_flip = input_2D[:, 0]
output_3D_non_flip = model(input_2D_non_flip, hops)
# print("======> input_2D_non_flip: ", input_2D_non_flip.shape)
# print("======> output_3D_non_flip: ", output_3D_non_flip.shape)
# visualize_skeletons(input_2D_non_flip, output_3D_non_flip)
return input_2D_non_flip, output_3D_non_flip
if __name__ == '__main__':
opt = opts().parse()
root_path = opt.root_path
opt.manualSeed = 1
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
root_path = opt.root_path
dataset_path = root_path + 'data_3d_' + opt.dataset + '.npz'
dataset = Human36mDataset(dataset_path, opt)
actions = define_actions(opt.actions)
train_data = Fusion(opt=opt, train=True, dataset=dataset, root_path=root_path)
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=opt.batch_size,
shuffle=True, num_workers=int(opt.workers), pin_memory=True)
test_data = Fusion(opt=opt, train=False, dataset=dataset, root_path=root_path)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=opt.batch_size,
shuffle=False, num_workers=int(opt.workers), pin_memory=True)
model = sgraformer(num_frame=opt.frames, num_joints=17, in_chans=2, embed_dim_ratio=32, depth=4,
num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None, drop_path_rate=0.1)
# model = FuseModel()
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model = torch.nn.DataParallel(model, device_ids=CUDA_ID).to(device)
model = model.to(device)
model_dict = model.state_dict()
model_path = '/home/zlt/Documents/SGraFormer-master/checkpoint/epoch_50.pth'
pre_dict = torch.load(model_path)
model_dict = model.state_dict()
state_dict = {k: v for k, v in pre_dict.items() if k in model_dict.keys()}
model_dict.update(state_dict)
model.load_state_dict(model_dict)
all_param = []
lr = opt.lr
all_param += list(model.parameters())
optimizer = optim.AdamW(all_param, lr=lr, weight_decay=0.1)
## tensorboard
# writer = SummaryWriter("runs/nin")
writer = None
flag = 0
p1, p2 = val(opt, actions, test_dataloader, model)
print('p1: %.2f, p2: %.2f' % (p1, p2))