first commit
This commit is contained in:
315
main.py
Normal file
315
main.py
Normal file
@ -0,0 +1,315 @@
|
||||
import os
|
||||
import torch
|
||||
import logging
|
||||
import random
|
||||
import torch.optim as optim
|
||||
from tqdm import tqdm
|
||||
# from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from common.utils import *
|
||||
from common.opt import opts
|
||||
from common.h36m_dataset import Human36mDataset
|
||||
from common.Mydataset import Fusion
|
||||
|
||||
from model.SGraFormer import sgraformer
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from mpl_toolkits.mplot3d import Axes3D
|
||||
import numpy as np
|
||||
|
||||
|
||||
# os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
||||
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
CUDA_ID = [0]
|
||||
device = torch.device("cuda")
|
||||
|
||||
|
||||
def visualize_skeletons(input_2D, output_3D, gt_3D, idx=0, output_dir='./output'):
|
||||
# Ensure the tensors are on the CPU and convert them to numpy arrays
|
||||
input_2D = input_2D.cpu().numpy()
|
||||
output_3D = output_3D.cpu().numpy()
|
||||
gt_3D = gt_3D.cpu().numpy()
|
||||
|
||||
# print("====> input_2D: ", input_2D[-1])
|
||||
# Get the first action and first sample from the batch
|
||||
input_sample = input_2D[idx, 0]
|
||||
output_sample = output_3D[idx, 0]
|
||||
gt_3D_sample = gt_3D[idx, 0]
|
||||
|
||||
print(f'\ninput_sample shape: {input_sample.shape}')
|
||||
print(f'output_sample shape: {output_sample.shape}')
|
||||
|
||||
fig = plt.figure(figsize=(25, 5))
|
||||
|
||||
# Define the connections (bones) between joints
|
||||
bones = [
|
||||
(0, 1), (1, 2), (2, 3), # Left leg
|
||||
(0, 4), (4, 5), (5, 6), # Right leg
|
||||
(0, 7), (7, 8), (8, 9), (9, 10), # Spine
|
||||
(7, 11), (11, 12), (12, 13), # Right arm
|
||||
(7, 14), (14, 15), (15, 16) # Left arm
|
||||
]
|
||||
|
||||
# Colors for different parts
|
||||
bone_colors = {
|
||||
"leg": 'green',
|
||||
"spine": 'blue',
|
||||
"arm": 'red'
|
||||
}
|
||||
|
||||
# Function to get bone color based on index
|
||||
def get_bone_color(start, end):
|
||||
if (start in [1, 2, 3] or end in [1, 2, 3] or
|
||||
start in [4, 5, 6] or end in [4, 5, 6]):
|
||||
return bone_colors["leg"]
|
||||
elif start in [7, 8, 9, 10] or end in [7, 8, 9, 10]:
|
||||
return bone_colors["spine"]
|
||||
else:
|
||||
return bone_colors["arm"]
|
||||
|
||||
# Plotting 2D skeletons from different angles
|
||||
for i in range(4):
|
||||
ax = fig.add_subplot(1, 7, i + 1)
|
||||
ax.set_title(f'2D angle {i+1}')
|
||||
ax.scatter(input_sample[i, :, 0], input_sample[i, :, 1], color='blue')
|
||||
|
||||
# Draw the bones
|
||||
for start, end in bones:
|
||||
bone_color = get_bone_color(start, end)
|
||||
ax.plot([input_sample[i, start, 0], input_sample[i, end, 0]],
|
||||
[input_sample[i, start, 1], input_sample[i, end, 1]], color=bone_color)
|
||||
|
||||
ax.set_xlabel('X')
|
||||
ax.set_ylabel('Y')
|
||||
ax.set_xlim(np.min(input_sample[:, :, 0]) - 1, np.max(input_sample[:, :, 0]) + 1)
|
||||
ax.set_ylim(np.min(input_sample[:, :, 1]) - 1, np.max(input_sample[:, :, 1]) + 1)
|
||||
ax.grid()
|
||||
|
||||
# Plotting predicted 3D skeleton
|
||||
ax = fig.add_subplot(1, 7, 5, projection='3d')
|
||||
ax.set_title('3D Predicted Skeleton')
|
||||
ax.scatter(output_sample[:, 0], output_sample[:, 1], output_sample[:, 2], color='red', label='Predicted')
|
||||
|
||||
# Draw the bones in 3D for output_sample
|
||||
for start, end in bones:
|
||||
bone_color = get_bone_color(start, end)
|
||||
ax.plot([output_sample[start, 0], output_sample[end, 0]],
|
||||
[output_sample[start, 1], output_sample[end, 1]],
|
||||
[output_sample[start, 2], output_sample[end, 2]], color=bone_color)
|
||||
|
||||
ax.set_xlabel('X')
|
||||
ax.set_ylabel('Y')
|
||||
ax.set_zlabel('Z')
|
||||
ax.set_xlim(np.min(output_sample[:, 0]) - 1, np.max(output_sample[:, 0]) + 1)
|
||||
ax.set_ylim(np.min(output_sample[:, 1]) - 1, np.max(output_sample[:, 1]) + 1)
|
||||
ax.set_zlim(np.min(output_sample[:, 2]) - 1, np.max(output_sample[:, 2]) + 1)
|
||||
ax.legend()
|
||||
|
||||
# Plotting ground truth 3D skeleton
|
||||
ax = fig.add_subplot(1, 7, 6, projection='3d')
|
||||
ax.set_title('3D Ground Truth Skeleton')
|
||||
ax.scatter(gt_3D_sample[:, 0], gt_3D_sample[:, 1], gt_3D_sample[:, 2], color='blue', label='Ground Truth')
|
||||
|
||||
# Draw the bones in 3D for gt_3D_sample
|
||||
for start, end in bones:
|
||||
bone_color = get_bone_color(start, end)
|
||||
ax.plot([gt_3D_sample[start, 0], gt_3D_sample[end, 0]],
|
||||
[gt_3D_sample[start, 1], gt_3D_sample[end, 1]],
|
||||
[gt_3D_sample[start, 2], gt_3D_sample[end, 2]], color=bone_color, linestyle='--')
|
||||
|
||||
ax.set_xlabel('X')
|
||||
ax.set_ylabel('Y')
|
||||
ax.set_zlabel('Z')
|
||||
ax.set_xlim(np.min(gt_3D_sample[:, 0]) - 1, np.max(gt_3D_sample[:, 0]) + 1)
|
||||
ax.set_ylim(np.min(gt_3D_sample[:, 1]) - 1, np.max(gt_3D_sample[:, 1]) + 1)
|
||||
ax.set_zlim(np.min(gt_3D_sample[:, 2]) - 1, np.max(gt_3D_sample[:, 2]) + 1)
|
||||
ax.legend()
|
||||
|
||||
plt.grid()
|
||||
|
||||
# Save the figure
|
||||
plt.tight_layout()
|
||||
plt.savefig(f'{output_dir}/skeletons_visualization.png')
|
||||
plt.show()
|
||||
|
||||
def train(opt, actions, train_loader, model, optimizer, epoch, writer, adaptive_weight=None):
|
||||
return step('train', opt, actions, train_loader, model, optimizer, epoch, writer, adaptive_weight)
|
||||
|
||||
|
||||
def val(opt, actions, val_loader, model):
|
||||
with torch.no_grad():
|
||||
return step('test', opt, actions, val_loader, model)
|
||||
|
||||
|
||||
def step(split, opt, actions, dataLoader, model, optimizer=None, epoch=None, writer=None, adaptive_weight=None):
|
||||
loss_all = {'loss': AccumLoss()}
|
||||
action_error_sum = define_error_list(actions)
|
||||
|
||||
if split == 'train':
|
||||
model.train()
|
||||
else:
|
||||
model.eval()
|
||||
|
||||
TQDM = tqdm(enumerate(dataLoader), total=len(dataLoader), ncols=100)
|
||||
for i, data in TQDM:
|
||||
batch_cam, gt_3D, input_2D, action, subject, scale, bb_box, start, end, hops = data
|
||||
|
||||
[input_2D, gt_3D, batch_cam, scale, bb_box, hops] = get_varialbe(split, [input_2D, gt_3D, batch_cam, scale, bb_box, hops])
|
||||
|
||||
if split == 'train':
|
||||
output_3D = model(input_2D, hops)
|
||||
elif split == 'test':
|
||||
# input_2D = input_2D.to(device)
|
||||
# model = model.to(device)
|
||||
# hops = hops.to(device)
|
||||
input_2D, output_3D = input_augmentation(input_2D, hops, model)
|
||||
|
||||
|
||||
visualize_skeletons(input_2D, output_3D, gt_3D)
|
||||
|
||||
out_target = gt_3D.clone()
|
||||
out_target[:, :, 0] = 0
|
||||
|
||||
if split == 'train':
|
||||
loss = mpjpe_cal(output_3D, out_target)
|
||||
|
||||
TQDM.set_description(f'Epoch [{epoch}/{opt.nepoch}]')
|
||||
TQDM.set_postfix({"l": loss.item()})
|
||||
|
||||
N = input_2D.size(0)
|
||||
loss_all['loss'].update(loss.detach().cpu().numpy() * N, N)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# writer.add_scalars(main_tag='scalars1/train_loss',
|
||||
# tag_scalar_dict={'trianloss': loss.item()},
|
||||
# global_step=(epoch - 1) * len(dataLoader) + i)
|
||||
|
||||
elif split == 'test':
|
||||
if output_3D.shape[1] != 1:
|
||||
output_3D = output_3D[:, opt.pad].unsqueeze(1)
|
||||
output_3D[:, :, 1:, :] -= output_3D[:, :, :1, :]
|
||||
output_3D[:, :, 0, :] = 0
|
||||
action_error_sum = test_calculation(output_3D, out_target, action, action_error_sum, opt.dataset, subject)
|
||||
|
||||
if split == 'train':
|
||||
return loss_all['loss'].avg
|
||||
elif split == 'test':
|
||||
p1, p2 = print_error(opt.dataset, action_error_sum, opt.train)
|
||||
return p1, p2
|
||||
|
||||
|
||||
def input_augmentation(input_2D, hops, model):
|
||||
input_2D_non_flip = input_2D[:, 0]
|
||||
output_3D_non_flip = model(input_2D_non_flip, hops)
|
||||
|
||||
return input_2D_non_flip, output_3D_non_flip
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
opt = opts().parse()
|
||||
root_path = opt.root_path
|
||||
opt.manualSeed = 1
|
||||
random.seed(opt.manualSeed)
|
||||
torch.manual_seed(opt.manualSeed)
|
||||
|
||||
if opt.train:
|
||||
logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%Y/%m/%d %H:%M:%S',
|
||||
filename=os.path.join(opt.checkpoint, 'train.log'), level=logging.INFO)
|
||||
|
||||
root_path = opt.root_path
|
||||
dataset_path = root_path + 'data_3d_' + opt.dataset + '.npz'
|
||||
|
||||
dataset = Human36mDataset(dataset_path, opt)
|
||||
actions = define_actions(opt.actions)
|
||||
|
||||
if opt.train:
|
||||
train_data = Fusion(opt=opt, train=True, dataset=dataset, root_path=root_path)
|
||||
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=opt.batch_size,
|
||||
shuffle=True, num_workers=int(opt.workers), pin_memory=True)
|
||||
|
||||
test_data = Fusion(opt=opt, train=False, dataset=dataset, root_path=root_path)
|
||||
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=opt.batch_size,
|
||||
shuffle=False, num_workers=int(opt.workers), pin_memory=True)
|
||||
|
||||
model = sgraformer(num_frame=opt.frames, num_joints=17, in_chans=2, embed_dim_ratio=32, depth=4,
|
||||
num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None, drop_path_rate=0.1)
|
||||
# model = FuseModel()
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
||||
model = torch.nn.DataParallel(model, device_ids=CUDA_ID).to(device)
|
||||
else:
|
||||
model = model.to(device)
|
||||
|
||||
# 定义一个函数来去除 'module.' 前缀
|
||||
def remove_module_prefix(state_dict):
|
||||
new_state_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
name = k[7:] if k.startswith('module.') else k # 去除 `module.`
|
||||
new_state_dict[name] = v
|
||||
return new_state_dict
|
||||
|
||||
model_dict = model.state_dict()
|
||||
if opt.previous_dir != '':
|
||||
print('pretrained model path:', opt.previous_dir)
|
||||
model_path = opt.previous_dir
|
||||
pre_dict = torch.load(model_path)
|
||||
# print("=====> pre_dict:", pre_dict.keys())
|
||||
# 去除 'module.' 前缀
|
||||
state_dict = remove_module_prefix(pre_dict)
|
||||
# print("=====> state_dict:", state_dict.keys())
|
||||
# 只保留在模型字典中的键值对
|
||||
state_dict = {k: v for k, v in state_dict.items() if k in model_dict.keys()}
|
||||
# 更新模型字典
|
||||
model_dict.update(state_dict)
|
||||
# 加载更新后的模型字典
|
||||
model.load_state_dict(model_dict)
|
||||
|
||||
|
||||
all_param = []
|
||||
lr = opt.lr
|
||||
all_param += list(model.parameters())
|
||||
|
||||
optimizer = optim.AdamW(all_param, lr=lr, weight_decay=0.1)
|
||||
|
||||
## tensorboard
|
||||
# writer = SummaryWriter("runs/nin")
|
||||
writer = None
|
||||
flag = 0
|
||||
|
||||
|
||||
for epoch in range(1, opt.nepoch + 1):
|
||||
p1, p2 = val(opt, actions, test_dataloader, model)
|
||||
print("=====> p1, p2", p1, p2)
|
||||
if opt.train:
|
||||
loss = train(opt, actions, train_dataloader, model, optimizer, epoch, writer)
|
||||
|
||||
|
||||
if opt.train:
|
||||
save_model_epoch(opt.checkpoint, epoch, model)
|
||||
|
||||
if p1 < opt.previous_best_threshold:
|
||||
opt.previous_name = save_model(opt.previous_name, opt.checkpoint, epoch, p1, model)
|
||||
opt.previous_best_threshold = p1
|
||||
|
||||
if opt.train == 0:
|
||||
print('p1: %.2f, p2: %.2f' % (p1, p2))
|
||||
break
|
||||
else:
|
||||
logging.info('epoch: %d, lr: %.7f, loss: %.4f, p1: %.2f, p2: %.2f' % (epoch, lr, loss, p1, p2))
|
||||
print('e: %d, lr: %.7f, loss: %.4f, p1: %.2f, p2: %.2f' % (epoch, lr, loss, p1, p2))
|
||||
|
||||
if epoch % opt.large_decay_epoch == 0:
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] *= opt.lr_decay_large
|
||||
lr *= opt.lr_decay_large
|
||||
else:
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] *= opt.lr_decay
|
||||
lr *= opt.lr_decay
|
||||
|
||||
print(opt.checkpoint)
|
||||
Reference in New Issue
Block a user