commit 8b2e804ccc82b72ecd1fa4b355f7b87a7a2569fb Author: IamZLT Date: Mon Aug 5 11:19:19 2024 +0800 first commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f610dcc --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +checkpoint +dataset \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..0c17ad3 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 z0911k + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..d3cb49a --- /dev/null +++ b/README.md @@ -0,0 +1,60 @@ +# Deep Semantic Graph Transformer for Multi-view 3D Human Pose Estimation [AAAI 2024] + +

+ +> **Deep Semantic Graph Transformer for Multi-view 3D Human Pose Estimation**, +> Lijun Zhang, Kangkang Zhou, Feng Lu, Xiang-Dong Zhou, Yu Shi, +> *The 38th Annual AAAI Conference on Artificial Intelligence (AAAI), 2024* + +## TODO +- The paper will be released soon! +- Test code and model weights will be released soon! + +## Release +- [14/12/2023] We released the model and training code for SGraFormer. + +## Installation + +- Create a conda environment: ```conda create -n SGraFormer python=3.7``` +- Download cudatoolkit=11.0 from [here](https://developer.nvidia.com/cuda-11.0-download-archive) and install +- ```pip3 install torch==1.7.1+cu110 torchvision==0.8.2+cu110 -f https://download.pytorch.org/whl/torch_stable.html``` +- ```pip3 install -r requirements.txt``` + +## Dataset Setup + +Please download the dataset from [Human3.6M](http://vision.imar.ro/human3.6m/) website and refer to [VideoPose3D](https://github.com/facebookresearch/VideoPose3D) to set up the Human3.6M dataset ('./dataset' directory). +Or you can download the processed data from [here](https://drive.google.com/drive/folders/1F_qbuZTwLJGUSib1oBUTYfOrLB6-MKrM?usp=sharing). + +```bash +${POSE_ROOT}/ +|-- dataset +| |-- data_3d_h36m.npz +| |-- data_2d_h36m_gt.npz +| |-- data_2d_h36m_cpn_ft_h36m_dbb.npz +``` + +## Quick Start +To train a model on Human3.6M: + +```bash +python main.py --frames 27 --batch_size 1024 --nepoch 50 --lr 0.0002 +``` + +## Citation +If you find our work useful in your research, please consider citing: + + @inproceedings{ + The 38th Annual AAAI Conference on Artificial Intelligence (AAAI) + author = {Lijun Zhang, Kangkang Zhou, Feng Lu, Xiang-Dong Zhou, Yu Shi}, + title = {Deep Semantic Graph Transformer for Multi-view 3D Human Pose Estimation}, + year = {2024}, + } + + +## Acknowledgement + +Our code is extended from the following repositories. We thank the authors for releasing the codes. + +- [PoseFormer](https://github.com/zczcwh/PoseFormer) +- [VideoPose3D](https://github.com/facebookresearch/VideoPose3D) + diff --git a/common/Mydataset.py b/common/Mydataset.py new file mode 100644 index 0000000..4092c71 --- /dev/null +++ b/common/Mydataset.py @@ -0,0 +1,360 @@ +import torch +import numpy as np +import torch.utils.data as data + +from common.cameras import normalize_screen_coordinates + + +class ChunkedGenerator: + + def __init__(self, batch_size, cameras, poses_3d, poses_2d, + chunk_length=1, pad=0, causal_shift=0, + shuffle=False, random_seed=1234, + augment=False, reverse_aug=False, kps_left=None, kps_right=None, joints_left=None, joints_right=None, + endless=False, out_all=False): + assert poses_3d is None or len(poses_3d) == len(poses_2d), (len(poses_3d), len(poses_2d)) + assert cameras is None or len(cameras) == len(poses_2d) + + pairs = [] + self.saved_index = {} + start_index = 0 + + for key in poses_2d.keys(): + assert poses_3d is None or poses_2d[key].shape[0] == poses_3d[key].shape[0] + n_chunks = (poses_2d[key].shape[0] + chunk_length - 1) // chunk_length + offset = (n_chunks * chunk_length - poses_2d[key].shape[0]) // 2 + bounds = np.arange(n_chunks + 1) * chunk_length - offset + augment_vector = np.full(len(bounds - 1), False, dtype=bool) + reverse_augment_vector = np.full(len(bounds - 1), False, dtype=bool) + keys = np.tile(np.array(key).reshape([1, 2]), (len(bounds - 1), 1)) + pairs += list(zip(keys, bounds[:-1], bounds[1:], augment_vector, reverse_augment_vector)) + if reverse_aug: + pairs += list(zip(keys, bounds[:-1], bounds[1:], augment_vector, ~reverse_augment_vector)) + if augment: + if reverse_aug: + pairs += list(zip(keys, bounds[:-1], bounds[1:], ~augment_vector, ~reverse_augment_vector)) + else: + pairs += list(zip(keys, bounds[:-1], bounds[1:], ~augment_vector, reverse_augment_vector)) + end_index = start_index + poses_3d[key].shape[0] + self.saved_index[key] = [start_index, end_index] + start_index = start_index + poses_3d[key].shape[0] + + if cameras is not None: + self.batch_cam = np.empty((batch_size, cameras[key].shape[-1])) + + if poses_3d is not None: + self.batch_3d = np.empty((batch_size, chunk_length, poses_3d[key].shape[-2], poses_3d[key].shape[-1])) + self.batch_2d = np.empty( + (batch_size, chunk_length + 2 * pad, poses_2d[key].shape[-3], poses_2d[key].shape[-2], + poses_2d[key].shape[-1])) + + self.num_batches = (len(pairs) + batch_size - 1) // batch_size + self.batch_size = batch_size + self.random = np.random.RandomState(random_seed) + self.pairs = pairs + self.shuffle = shuffle + self.pad = pad + self.causal_shift = causal_shift + self.endless = endless + self.state = None + + self.cameras = cameras + if cameras is not None: + self.cameras = cameras + self.poses_3d = poses_3d + self.poses_2d = poses_2d + + self.augment = augment + self.kps_left = kps_left + self.kps_right = kps_right + self.joints_left = joints_left + self.joints_right = joints_right + self.out_all = out_all + + def num_frames(self): + return self.num_batches * self.batch_size + + def random_state(self): + return self.random + + def set_random_state(self, random): + self.random = random + + def augment_enabled(self): + return self.augment + + def next_pairs(self): + if self.state is None: + if self.shuffle: + pairs = self.random.permutation(self.pairs) + else: + pairs = self.pairs + return 0, pairs + else: + return self.state + + def get_batch(self, seq_i, start_3d, end_3d, flip, reverse): + subject, action = seq_i + seq_name = (subject, action) + start_2d = start_3d - self.pad - self.causal_shift # \u5f00\u59cb\u4f4d\u7f6e + end_2d = end_3d + self.pad - self.causal_shift + + seq_2d = self.poses_2d[seq_name].copy() + low_2d = max(start_2d, 0) + high_2d = min(end_2d, seq_2d.shape[0]) + pad_left_2d = low_2d - start_2d + pad_right_2d = end_2d - high_2d + + if pad_left_2d != 0 or pad_right_2d != 0: + self.batch_2d = np.pad(seq_2d[low_2d:high_2d], ((pad_left_2d, pad_right_2d), (0, 0), (0, 0), (0, 0)), + 'edge') + else: + self.batch_2d = seq_2d[low_2d:high_2d] + + if flip: + self.batch_2d[:, :, :, 0] *= -1 + self.batch_2d[:, :, self.kps_left + self.kps_right] = self.batch_2d[:, :, self.kps_right + self.kps_left] + + if reverse: + self.batch_2d = self.batch_2d[::-1].copy() + + if self.poses_3d is not None: + seq_3d = self.poses_3d[seq_name].copy() + if self.out_all: + low_3d = low_2d + high_3d = high_2d + pad_left_3d = pad_left_2d + pad_right_3d = pad_right_2d + else: + low_3d = max(start_3d, 0) + high_3d = min(end_3d, seq_3d.shape[0]) + pad_left_3d = low_3d - start_3d + pad_right_3d = end_3d - high_3d + if pad_left_3d != 0 or pad_right_3d != 0: + self.batch_3d = np.pad(seq_3d[low_3d:high_3d], + ((pad_left_3d, pad_right_3d), (0, 0), (0, 0)), 'edge') + else: + self.batch_3d = seq_3d[low_3d:high_3d] + + if flip: + self.batch_3d[:, :, 0] *= -1 + self.batch_3d[:, self.joints_left + self.joints_right] = \ + self.batch_3d[:, self.joints_right + self.joints_left] + if reverse: + self.batch_3d = self.batch_3d[::-1].copy() + + if self.poses_3d is None and self.cameras is None: + return None, None, self.batch_2d.copy(), action, subject + elif self.poses_3d is not None and self.cameras is None: + return np.zeros(9), self.batch_3d.copy(), self.batch_2d.copy(), action, subject, low_2d, high_2d + elif self.poses_3d is None: + return self.batch_cam, None, self.batch_2d.copy(), action, subject + else: + return self.batch_cam, self.batch_3d.copy(), self.batch_2d.copy(), action, subject + + +class Fusion(data.Dataset): + def __init__(self, opt, dataset, root_path, train=True): + self.hop1 = torch.tensor([[0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0]]) + + self.hop2 = torch.tensor([[0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]]) + + self.hop3 = torch.tensor([[0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0], + [0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1], + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]]) + + self.hop4 = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1], + [0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1], + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0], + [0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0], + [0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0]]) + + + self.data_type = opt.dataset + self.train = train + self.keypoints_name = opt.keypoints + self.root_path = root_path + + self.train_list = opt.subjects_train.split(',') + self.test_list = opt.subjects_test.split(',') + self.action_filter = None if opt.actions == '*' else opt.actions.split(',') + self.downsample = opt.downsample + self.subset = opt.subset + self.stride = opt.stride + self.crop_uv = opt.crop_uv + self.test_aug = opt.test_augmentation + self.pad = opt.pad + + if self.train: + self.keypoints = self.prepare_data(dataset, self.train_list) + self.cameras_train, self.poses_train, self.poses_train_2d = self.fetch(dataset, self.train_list, + subset=self.subset) + self.generator = ChunkedGenerator(opt.batch_size // opt.stride, self.cameras_train, self.poses_train, + self.poses_train_2d, self.stride, pad=self.pad, + augment=opt.data_augmentation, reverse_aug=opt.reverse_augmentation, + kps_left=self.kps_left, kps_right=self.kps_right, + joints_left=self.joints_left, + joints_right=self.joints_right, out_all=opt.out_all) + print('INFO: Training on {} frames'.format(self.generator.num_frames())) + else: + self.keypoints = self.prepare_data(dataset, self.test_list) + self.cameras_test, self.poses_test, self.poses_test_2d = self.fetch(dataset, self.test_list, + subset=self.subset) + self.generator = ChunkedGenerator(opt.batch_size // opt.stride, self.cameras_test, self.poses_test, + self.poses_test_2d, + pad=self.pad, augment=False, kps_left=self.kps_left, + kps_right=self.kps_right, joints_left=self.joints_left, + joints_right=self.joints_right) + self.key_index = self.generator.saved_index + print('INFO: Testing on {} frames'.format(self.generator.num_frames())) + + def prepare_data(self, dataset, folder_list): + + for subject in folder_list: + for action in dataset[subject].keys(): + dataset[subject][action]['positions'][:, 1:] -= dataset[subject][action]['positions'][:, :1] + + keypoints = np.load(self.root_path + 'data_2d_' + self.data_type + '_' + self.keypoints_name + '.npz', + allow_pickle=True) + keypoints_symmetry = keypoints['metadata'].item()['keypoints_symmetry'] + + self.kps_left, self.kps_right = list(keypoints_symmetry[0]), list(keypoints_symmetry[1]) + self.joints_left, self.joints_right = list(dataset.skeleton().joints_left()), list( + dataset.skeleton().joints_right()) + + keypoints = keypoints['positions_2d'].item() + + for subject in folder_list: + for action in dataset[subject].keys(): + mocap_length = dataset[subject][action]['positions'].shape[0] + for cam_idx in range(len(keypoints[subject][action])): + assert keypoints[subject][action][cam_idx].shape[0] >= mocap_length + if keypoints[subject][action][cam_idx].shape[0] > mocap_length: + keypoints[subject][action][cam_idx] = keypoints[subject][action][cam_idx][:mocap_length] + + for subject in keypoints.keys(): + for action in keypoints[subject]: + for cam_idx, kps in enumerate(keypoints[subject][action]): + cam = dataset.cameras()[subject][cam_idx] + if self.crop_uv == 0: + kps[..., :2] = normalize_screen_coordinates(kps[..., :2], w=cam['res_w'], h=cam['res_h']) + keypoints[subject][action][cam_idx] = kps + + for subject in folder_list: + for action in dataset[subject].keys(): + positions_2d_pairs = [] + for cam_idx in range(len(keypoints[subject][action])): + positions_2d_pairs.append(keypoints[subject][action][cam_idx]) + + keypoints[subject][action].append( + np.array(positions_2d_pairs).transpose((1, 0, 2,3))) + return keypoints + + def fetch(self, dataset, subjects, subset=1, ): + out_poses_3d = {} + out_poses_2d = {} + out_camera_params = {} + for subject in subjects: + for action in self.keypoints[subject].keys(): + poses_2d = self.keypoints[subject][action][4] + out_poses_2d[(subject, action)] = poses_2d + + poses_3d = dataset[subject][action]['positions'] + out_poses_3d[(subject, action)] = poses_3d + + if len(out_camera_params) == 0: + out_camera_params = None + + downsample = 1 + if downsample: + pass + return out_camera_params, out_poses_3d, out_poses_2d + + def hop_normalize(self, x1, x2, x3, x4): + x1 = x1 / torch.sum(x1, dim=1) + x2 = x2 / torch.sum(x1, dim=1) + x3 = x3 / torch.sum(x1, dim=1) + x4 = x4 / torch.sum(x1, dim=1) + return torch.cat((x1.unsqueeze(0), x2.unsqueeze(0), x3.unsqueeze(0), x4.unsqueeze(0)), dim=0) + + def __len__(self): + return len(self.generator.pairs) + + def __getitem__(self, index): + seq_name, start_3d, end_3d, flip, reverse = self.generator.pairs[index] + + cam, gt_3D, input_2D, action, subject, low_2d, high_2d = self.generator.get_batch(seq_name, start_3d, end_3d, + False, False) + + if self.train == False and self.test_aug: + _, _, input_2D_aug, _, _, _, _ = self.generator.get_batch(seq_name, start_3d, end_3d, flip=False, + reverse=False) + input_2D = np.concatenate((np.expand_dims(input_2D, axis=0), np.expand_dims(input_2D_aug, axis=0)), 0) + + bb_box = np.array([0, 0, 1, 1]) + input_2D_update = input_2D + + hops = self.hop_normalize(self.hop1, self.hop2, self.hop3, self.hop4) + + scale = np.float64(1.0) + + return cam, gt_3D, input_2D_update, action, subject, scale, bb_box, low_2d, high_2d, hops + diff --git a/common/__pycache__/Mydataset.cpython-37.pyc b/common/__pycache__/Mydataset.cpython-37.pyc new file mode 100644 index 0000000..fac95c8 Binary files /dev/null and b/common/__pycache__/Mydataset.cpython-37.pyc differ diff --git a/common/__pycache__/Mydataset.cpython-38.pyc b/common/__pycache__/Mydataset.cpython-38.pyc new file mode 100644 index 0000000..9f20fe5 Binary files /dev/null and b/common/__pycache__/Mydataset.cpython-38.pyc differ diff --git a/common/__pycache__/cameras.cpython-37.pyc b/common/__pycache__/cameras.cpython-37.pyc new file mode 100644 index 0000000..c0b0a2b Binary files /dev/null and b/common/__pycache__/cameras.cpython-37.pyc differ diff --git a/common/__pycache__/cameras.cpython-38.pyc b/common/__pycache__/cameras.cpython-38.pyc new file mode 100644 index 0000000..78c33d6 Binary files /dev/null and b/common/__pycache__/cameras.cpython-38.pyc differ diff --git a/common/__pycache__/h36m_dataset.cpython-37.pyc b/common/__pycache__/h36m_dataset.cpython-37.pyc new file mode 100644 index 0000000..3ad3f10 Binary files /dev/null and b/common/__pycache__/h36m_dataset.cpython-37.pyc differ diff --git a/common/__pycache__/h36m_dataset.cpython-38.pyc b/common/__pycache__/h36m_dataset.cpython-38.pyc new file mode 100644 index 0000000..c1abb31 Binary files /dev/null and b/common/__pycache__/h36m_dataset.cpython-38.pyc differ diff --git a/common/__pycache__/opt.cpython-37.pyc b/common/__pycache__/opt.cpython-37.pyc new file mode 100644 index 0000000..4973c2a Binary files /dev/null and b/common/__pycache__/opt.cpython-37.pyc differ diff --git a/common/__pycache__/opt.cpython-38.pyc b/common/__pycache__/opt.cpython-38.pyc new file mode 100644 index 0000000..7bcaf6e Binary files /dev/null and b/common/__pycache__/opt.cpython-38.pyc differ diff --git a/common/__pycache__/utils.cpython-37.pyc b/common/__pycache__/utils.cpython-37.pyc new file mode 100644 index 0000000..f28a2e4 Binary files /dev/null and b/common/__pycache__/utils.cpython-37.pyc differ diff --git a/common/__pycache__/utils.cpython-38.pyc b/common/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000..3f431d3 Binary files /dev/null and b/common/__pycache__/utils.cpython-38.pyc differ diff --git a/common/cameras.py b/common/cameras.py new file mode 100644 index 0000000..96dd745 --- /dev/null +++ b/common/cameras.py @@ -0,0 +1,258 @@ +import sys +import numpy as np +import torch + + +def normalize_screen_coordinates(X, w, h): + + assert X.shape[-1] == 2 + return X / w * 2 - [1, h / w] + + +def world_to_camera(X, R, t): # https://blog.csdn.net/Hurt_Town/article/details/125071279 + Rt = wrap(qinverse, R) + # return wrap(qrot, np.tile(Rt, (*X.shape[:-1], 1)), X - t) + return wrap(qrot, Rt.repeat(*X.shape[:-1], 1), X - t) + + +def camera_to_world(X, R, t): + return wrap(qrot, np.tile(R, (*X.shape[:-1], 1)), X) + t + + +def wrap(func, *args, unsqueeze=False): + args = list(args) + for i, arg in enumerate(args): + if type(arg) == np.ndarray: + args[i] = torch.from_numpy(arg) + if unsqueeze: + args[i] = args[i].unsqueeze(0) + + result = func(*args) + + if isinstance(result, tuple): + result = list(result) + for i, res in enumerate(result): + if type(res) == torch.Tensor: + if unsqueeze: + res = res.squeeze(0) + result[i] = res.numpy() + return tuple(result) + elif type(result) == torch.Tensor: + if unsqueeze: + result = result.squeeze(0) + # return result.numpy() + return result + else: + return result + + +def qrot(q, v): + assert q.shape[-1] == 4 + assert v.shape[-1] == 3 + assert q.shape[:-1] == v.shape[:-1] + + qvec = q[..., 1:] + uv = torch.cross(qvec, v, dim=len(q.shape) - 1) + uuv = torch.cross(qvec, uv, dim=len(q.shape) - 1) + return (v + 2 * (q[..., :1] * uv + uuv)) + + +def qinverse(q, inplace=False): + if inplace: + q[..., 1:] *= -1 + return q + else: + w = q[..., :1] + xyz = q[..., 1:] + return torch.cat((w, -xyz), dim=len(q.shape) - 1) + + +h36m_cameras_intrinsic_params = [ + { + 'id': '54138969', + 'center': [512.54150390625, 515.4514770507812], + 'focal_length': [1145.0494384765625, 1143.7811279296875], + 'radial_distortion': [-0.20709891617298126, 0.24777518212795258, -0.0030751503072679043], + 'tangential_distortion': [-0.0009756988729350269, -0.00142447161488235], + 'res_w': 1000, + 'res_h': 1002, + 'azimuth': 70, + }, + { + 'id': '55011271', + 'center': [508.8486328125, 508.0649108886719], + 'focal_length': [1149.6756591796875, 1147.5916748046875], + 'radial_distortion': [-0.1942136287689209, 0.2404085397720337, 0.006819975562393665], + 'tangential_distortion': [-0.0016190266469493508, -0.0027408944442868233], + 'res_w': 1000, + 'res_h': 1000, + 'azimuth': -70, + }, + { + 'id': '58860488', + 'center': [519.8158569335938, 501.40264892578125], + 'focal_length': [1149.1407470703125, 1148.7989501953125], + 'radial_distortion': [-0.2083381861448288, 0.25548800826072693, -0.0024604974314570427], + 'tangential_distortion': [0.0014843869721516967, -0.0007599993259645998], + 'res_w': 1000, + 'res_h': 1000, + 'azimuth': 110, + }, + { + 'id': '60457274', + 'center': [514.9682006835938, 501.88201904296875], + 'focal_length': [1145.5113525390625, 1144.77392578125], + 'radial_distortion': [-0.198384091258049, 0.21832367777824402, -0.008947807364165783], + 'tangential_distortion': [-0.0005872055771760643, -0.0018133620033040643], + 'res_w': 1000, + 'res_h': 1002, + 'azimuth': -110, + }, +] + +h36m_cameras_extrinsic_params = { + 'S1': [ + { + 'orientation': [0.1407056450843811, -0.1500701755285263, -0.755240797996521, 0.6223280429840088], + 'translation': [1841.1070556640625, 4955.28466796875, 1563.4454345703125], + }, + { + 'orientation': [0.6157187819480896, -0.764836311340332, -0.14833825826644897, 0.11794740706682205], + 'translation': [1761.278564453125, -5078.0068359375, 1606.2650146484375], + }, + { + 'orientation': [0.14651472866535187, -0.14647851884365082, 0.7653023600578308, -0.6094175577163696], + 'translation': [-1846.7777099609375, 5215.04638671875, 1491.972412109375], + }, + { + 'orientation': [0.5834008455276489, -0.7853162288665771, 0.14548823237419128, -0.14749594032764435], + 'translation': [-1794.7896728515625, -3722.698974609375, 1574.8927001953125], + }, + ], + 'S2': [ + {}, + {}, + {}, + {}, + ], + 'S3': [ + {}, + {}, + {}, + {}, + ], + 'S4': [ + {}, + {}, + {}, + {}, + ], + 'S5': [ + { + 'orientation': [0.1467377245426178, -0.162370964884758, -0.7551892995834351, 0.6178938746452332], + 'translation': [2097.3916015625, 4880.94482421875, 1605.732421875], + }, + { + 'orientation': [0.6159758567810059, -0.7626792192459106, -0.15728192031383514, 0.1189815029501915], + 'translation': [2031.7008056640625, -5167.93310546875, 1612.923095703125], + }, + { + 'orientation': [0.14291371405124664, -0.12907841801643372, 0.7678384780883789, -0.6110143065452576], + 'translation': [-1620.5948486328125, 5171.65869140625, 1496.43701171875], + }, + { + 'orientation': [0.5920479893684387, -0.7814217805862427, 0.1274748593568802, -0.15036417543888092], + 'translation': [-1637.1737060546875, -3867.3173828125, 1547.033203125], + }, + ], + 'S6': [ + { + 'orientation': [0.1337897777557373, -0.15692396461963654, -0.7571090459823608, 0.6198879480361938], + 'translation': [1935.4517822265625, 4950.24560546875, 1618.0838623046875], + }, + { + 'orientation': [0.6147197484970093, -0.7628812789916992, -0.16174767911434174, 0.11819244921207428], + 'translation': [1969.803955078125, -5128.73876953125, 1632.77880859375], + }, + { + 'orientation': [0.1529948115348816, -0.13529130816459656, 0.7646096348762512, -0.6112781167030334], + 'translation': [-1769.596435546875, 5185.361328125, 1476.993408203125], + }, + { + 'orientation': [0.5916101336479187, -0.7804774045944214, 0.12832270562648773, -0.1561593860387802], + 'translation': [-1721.668701171875, -3884.13134765625, 1540.4879150390625], + }, + ], + 'S7': [ + { + 'orientation': [0.1435241848230362, -0.1631336808204651, -0.7548328638076782, 0.6188824772834778], + 'translation': [1974.512939453125, 4926.3544921875, 1597.8326416015625], + }, + { + 'orientation': [0.6141672730445862, -0.7638262510299683, -0.1596645563840866, 0.1177929937839508], + 'translation': [1937.0584716796875, -5119.7900390625, 1631.5665283203125], + }, + { + 'orientation': [0.14550060033798218, -0.12874816358089447, 0.7660516500473022, -0.6127139329910278], + 'translation': [-1741.8111572265625, 5208.24951171875, 1464.8245849609375], + }, + { + 'orientation': [0.5912848114967346, -0.7821764349937439, 0.12445473670959473, -0.15196487307548523], + 'translation': [-1734.7105712890625, -3832.42138671875, 1548.5830078125], + }, + ], + 'S8': [ + { + 'orientation': [0.14110587537288666, -0.15589867532253265, -0.7561917304992676, 0.619644045829773], + 'translation': [2150.65185546875, 4896.1611328125, 1611.9046630859375], + }, + { + 'orientation': [0.6169601678848267, -0.7647668123245239, -0.14846350252628326, 0.11158157885074615], + 'translation': [2219.965576171875, -5148.453125, 1613.0440673828125], + }, + { + 'orientation': [0.1471444070339203, -0.13377119600772858, 0.7670128345489502, -0.6100369691848755], + 'translation': [-1571.2215576171875, 5137.0185546875, 1498.1761474609375], + }, + { + 'orientation': [0.5927824378013611, -0.7825870513916016, 0.12147816270589828, -0.14631995558738708], + 'translation': [-1476.913330078125, -3896.7412109375, 1547.97216796875], + }, + ], + 'S9': [ + { + 'orientation': [0.15540587902069092, -0.15548215806484222, -0.7532095313072205, 0.6199594736099243], + 'translation': [2044.45849609375, 4935.1171875, 1481.2275390625], + }, + { + 'orientation': [0.618784487247467, -0.7634735107421875, -0.14132238924503326, 0.11933968216180801], + 'translation': [1990.959716796875, -5123.810546875, 1568.8048095703125], + }, + { + 'orientation': [0.13357827067375183, -0.1367100477218628, 0.7689454555511475, -0.6100738644599915], + 'translation': [-1670.9921875, 5211.98583984375, 1528.387939453125], + }, + { + 'orientation': [0.5879399180412292, -0.7823407053947449, 0.1427614390850067, -0.14794869720935822], + 'translation': [-1696.04345703125, -3827.099853515625, 1591.4127197265625], + }, + ], + 'S11': [ + { + 'orientation': [0.15232472121715546, -0.15442320704460144, -0.7547563314437866, 0.6191070079803467], + 'translation': [2098.440185546875, 4926.5546875, 1500.278564453125], + }, + { + 'orientation': [0.6189449429512024, -0.7600917220115662, -0.15300633013248444, 0.1255258321762085], + 'translation': [2083.182373046875, -4912.1728515625, 1561.07861328125], + }, + { + 'orientation': [0.14943228662014008, -0.15650227665901184, 0.7681233882904053, -0.6026304364204407], + 'translation': [-1609.8153076171875, 5177.3359375, 1537.896728515625], + }, + { + 'orientation': [0.5894251465797424, -0.7818877100944519, 0.13991211354732513, -0.14715361595153809], + 'translation': [-1590.738037109375, -3854.1689453125, 1578.017578125], + }, + ], +} diff --git a/common/h36m_dataset.py b/common/h36m_dataset.py new file mode 100644 index 0000000..0b6cbbc --- /dev/null +++ b/common/h36m_dataset.py @@ -0,0 +1,171 @@ +import numpy as np +import copy + +from common.cameras import h36m_cameras_intrinsic_params, h36m_cameras_extrinsic_params, \ + normalize_screen_coordinates + + +class Skeleton: + + def __init__(self, parents, joints_left, joints_right): + assert len(joints_left) == len(joints_right) + + self._parents = np.array(parents) + self._joints_left = joints_left + self._joints_right = joints_right + self._compute_metadata() + + def num_joints(self): + return len(self._parents) + + def parents(self): + return self._parents + + def has_children(self): + return self._has_children + + def children(self): + return self._children + + def remove_joints(self, joints_to_remove): + + valid_joints = [] + for joint in range(len(self._parents)): + if joint not in joints_to_remove: + valid_joints.append(joint) + + for i in range(len(self._parents)): + while self._parents[i] in joints_to_remove: + self._parents[i] = self._parents[self._parents[i]] + + index_offsets = np.zeros(len(self._parents), dtype=int) + new_parents = [] + for i, parent in enumerate(self._parents): + if i not in joints_to_remove: + new_parents.append(parent - index_offsets[parent]) + else: + index_offsets[i:] += 1 + self._parents = np.array(new_parents) + + if self._joints_left is not None: + new_joints_left = [] + for joint in self._joints_left: + if joint in valid_joints: + new_joints_left.append(joint - index_offsets[joint]) + self._joints_left = new_joints_left + if self._joints_right is not None: + new_joints_right = [] + for joint in self._joints_right: + if joint in valid_joints: + new_joints_right.append(joint - index_offsets[joint]) + self._joints_right = new_joints_right + + self._compute_metadata() + + return valid_joints + + def joints_left(self): + return self._joints_left + + def joints_right(self): + return self._joints_right + + def _compute_metadata(self): + self._has_children = np.zeros(len(self._parents)).astype(bool) + for i, parent in enumerate(self._parents): + if parent != -1: + self._has_children[parent] = True + + self._children = [] + for i, parent in enumerate(self._parents): + self._children.append([]) + for i, parent in enumerate(self._parents): + if parent != -1: + self._children[parent].append(i) + + +h36m_skeleton = Skeleton(parents=[-1, 0, 1, 2, 3, 4, 0, 6, 7, 8, 9, 0, 11, 12, 13, 14, 12, + 16, 17, 18, 19, 20, 19, 22, 12, 24, 25, 26, 27, 28, 27, 30], # 树的双亲表示法 + joints_left=[6, 7, 8, 9, 10, 16, 17, 18, 19, 20, 21, 22, 23], + joints_right=[1, 2, 3, 4, 5, 24, 25, 26, 27, 28, 29, 30, 31]) + + +class MocapDataset: + def __init__(self, fps, skeleton): + self._skeleton = skeleton + self._fps = fps + self._data = None + self._cameras = None + + def remove_joints(self, joints_to_remove): + kept_joints = self._skeleton.remove_joints(joints_to_remove) + for subject in self._data.keys(): + for action in self._data[subject].keys(): + s = self._data[subject][action] + s['positions'] = s['positions'][:, kept_joints] + + def __getitem__(self, key): + return self._data[key] + + def subjects(self): + return self._data.keys() + + def fps(self): + return self._fps + + def skeleton(self): + return self._skeleton + + def cameras(self): + return self._cameras + + def supports_semi_supervised(self): + return False + + +class Human36mDataset(MocapDataset): + def __init__(self, path, opt, remove_static_joints=True): + super().__init__(fps=50, skeleton=h36m_skeleton) + self.train_list = ['S1', 'S5', 'S6', 'S7', 'S8'] + self.test_list = ['S9', 'S11'] + + self._cameras = copy.deepcopy(h36m_cameras_extrinsic_params) + for cameras in self._cameras.values(): + for i, cam in enumerate(cameras): + cam.update(h36m_cameras_intrinsic_params[i]) + for k, v in cam.items(): + if k not in ['id', 'res_w', 'res_h']: + cam[k] = np.array(v, dtype='float32') + + if opt.crop_uv == 0: + cam['center'] = normalize_screen_coordinates(cam['center'], w=cam['res_w'], h=cam['res_h']).astype( + 'float32') + cam['focal_length'] = cam['focal_length'] / cam['res_w'] * 2 + + if 'translation' in cam: + cam['translation'] = cam['translation'] / 1000 + + cam['intrinsic'] = np.concatenate((cam['focal_length'], + cam['center'], + cam['radial_distortion'], + cam['tangential_distortion'])) + + data = np.load(path, allow_pickle=True)['positions_3d'].item() + + self._data = {} + for subject, actions in data.items(): + self._data[subject] = {} + for action_name, positions in actions.items(): + self._data[subject][action_name] = { + 'positions': positions, + 'cameras': self._cameras[subject], + } + + if remove_static_joints: + self.remove_joints([4, 5, 9, 10, 11, 16, 20, 21, 22, 23, 24, 28, 29, 30, 31]) + + self._skeleton._parents[11] = 8 + self._skeleton._parents[14] = 8 + + def supports_semi_supervised(self): + return True \ No newline at end of file diff --git a/common/utils.py b/common/utils.py new file mode 100644 index 0000000..ca7dd16 --- /dev/null +++ b/common/utils.py @@ -0,0 +1,211 @@ +import torch +import numpy as np +import hashlib +from torch.autograd import Variable +import os + + +def deterministic_random(min_value, max_value, data): + digest = hashlib.sha256(data.encode()).digest() + raw_value = int.from_bytes(digest[:4], byteorder='little', signed=False) + return int(raw_value / (2 ** 32 - 1) * (max_value - min_value)) + min_value + + +def mpjpe_cal(predicted, target): + + assert predicted.shape == target.shape + return torch.mean(torch.norm(predicted - target, dim=len(target.shape) - 1)) + + +def test_calculation(predicted, target, action, error_sum, data_type, subject): + error_sum = mpjpe_by_action_p1(predicted, target, action, error_sum) + error_sum = mpjpe_by_action_p2(predicted, target, action, error_sum) + + return error_sum + + +def mpjpe_by_action_p1(predicted, target, action, action_error_sum): + assert predicted.shape == target.shape + num = predicted.size(0) + dist = torch.mean(torch.norm(predicted - target, dim=len(target.shape) - 1), dim=len(target.shape) - 2) + + if len(set(list(action))) == 1: + end_index = action[0].find(' ') + if end_index != -1: + action_name = action[0][:end_index] + else: + action_name = action[0] + + action_error_sum[action_name]['p1'].update(torch.mean(dist).item() * num, num) + else: + for i in range(num): + end_index = action[i].find(' ') + if end_index != -1: + action_name = action[i][:end_index] + else: + action_name = action[i] + + action_error_sum[action_name]['p1'].update(dist[i].item(), 1) + + return action_error_sum + + +def mpjpe_by_action_p2(predicted, target, action, action_error_sum): + assert predicted.shape == target.shape + num = predicted.size(0) + pred = predicted.detach().cpu().numpy().reshape(-1, predicted.shape[-2], predicted.shape[-1]) + gt = target.detach().cpu().numpy().reshape(-1, target.shape[-2], target.shape[-1]) + dist = p_mpjpe(pred, gt) + + if len(set(list(action))) == 1: + end_index = action[0].find(' ') + if end_index != -1: + action_name = action[0][:end_index] + else: + action_name = action[0] + action_error_sum[action_name]['p2'].update(np.mean(dist) * num, num) + else: + for i in range(num): + end_index = action[i].find(' ') + if end_index != -1: + action_name = action[i][:end_index] + else: + action_name = action[i] + action_error_sum[action_name]['p2'].update(np.mean(dist), 1) + + return action_error_sum + + +def p_mpjpe(predicted, target): + assert predicted.shape == target.shape + + muX = np.mean(target, axis=1, keepdims=True) + muY = np.mean(predicted, axis=1, keepdims=True) + + X0 = target - muX + Y0 = predicted - muY + + normX = np.sqrt(np.sum(X0 ** 2, axis=(1, 2), keepdims=True)) + normY = np.sqrt(np.sum(Y0 ** 2, axis=(1, 2), keepdims=True)) + + X0 /= normX + Y0 /= normY + + H = np.matmul(X0.transpose(0, 2, 1), Y0) + U, s, Vt = np.linalg.svd(H) + V = Vt.transpose(0, 2, 1) + R = np.matmul(V, U.transpose(0, 2, 1)) + + sign_detR = np.sign(np.expand_dims(np.linalg.det(R), axis=1)) + V[:, :, -1] *= sign_detR + s[:, -1] *= sign_detR.flatten() + R = np.matmul(V, U.transpose(0, 2, 1)) + + tr = np.expand_dims(np.sum(s, axis=1, keepdims=True), axis=2) + + a = tr * normX / normY + t = muX - a * np.matmul(muY, R) + + predicted_aligned = a * np.matmul(predicted, R) + t + + return np.mean(np.linalg.norm(predicted_aligned - target, axis=len(target.shape) - 1), axis=len(target.shape) - 2) + + +def define_actions(action): + actions = ["Directions", "Discussion", "Eating", "Greeting", + "Phoning", "Photo", "Posing", "Purchases", + "Sitting", "SittingDown", "Smoking", "Waiting", + "WalkDog", "Walking", "WalkTogether"] + + if action == "All" or action == "all" or action == '*': + return actions + + if not action in actions: + raise (ValueError, "Unrecognized action: %s" % action) + + return [action] + + +def define_error_list(actions): + error_sum = {} + error_sum.update({actions[i]: + {'p1': AccumLoss(), 'p2': AccumLoss()} + for i in range(len(actions))}) + return error_sum + + +class AccumLoss(object): + def __init__(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val + self.count += n + self.avg = self.sum / self.count + + +def get_varialbe(split, target): + num = len(target) + var = [] + if split == 'train': + for i in range(num): + temp = Variable(target[i], requires_grad=False).contiguous().type(torch.cuda.FloatTensor) + var.append(temp) + else: + for i in range(num): + temp = Variable(target[i]).contiguous().cuda().type(torch.cuda.FloatTensor) + var.append(temp) + + return var + + +def print_error(data_type, action_error_sum, is_train): + mean_error_p1, mean_error_p2 = print_error_action(action_error_sum, is_train) + + return mean_error_p1, mean_error_p2 + + +def print_error_action(action_error_sum, is_train): + mean_error_each = {'p1': 0.0, 'p2': 0.0} + mean_error_all = {'p1': AccumLoss(), 'p2': AccumLoss()} + + if is_train == 0: + print("{0:=^12} {1:=^10} {2:=^8}".format("Action", "p#1 mm", "p#2 mm")) + + for action, value in action_error_sum.items(): + if is_train == 0: + print("{0:<12} ".format(action), end="") + + mean_error_each['p1'] = action_error_sum[action]['p1'].avg * 1000.0 + mean_error_all['p1'].update(mean_error_each['p1'], 1) + + mean_error_each['p2'] = action_error_sum[action]['p2'].avg * 1000.0 + mean_error_all['p2'].update(mean_error_each['p2'], 1) + + if is_train == 0: + print("{0:>6.2f} {1:>10.2f}".format(mean_error_each['p1'], mean_error_each['p2'])) + + if is_train == 0: + print("{0:<12} {1:>6.2f} {2:>10.2f}".format("Average", mean_error_all['p1'].avg, mean_error_all['p2'].avg)) + + return mean_error_all['p1'].avg, mean_error_all['p2'].avg + + +def save_model(previous_name, save_dir, epoch, data_threshold, model): + if os.path.exists(previous_name): + os.remove(previous_name) + + torch.save(model.state_dict(), '%s/model_%d_%d.pth' % (save_dir, epoch, data_threshold * 100)) + + previous_name = '%s/model_%d_%d.pth' % (save_dir, epoch, data_threshold * 100) + + return previous_name + + +def save_model_epoch(save_dir, epoch, model): + torch.save(model.state_dict(), '%s/epoch_%d.pth' % (save_dir, epoch)) + diff --git a/framework.png b/framework.png new file mode 100644 index 0000000..445cdcd Binary files /dev/null and b/framework.png differ diff --git a/get_2D_skletons.py b/get_2D_skletons.py new file mode 100644 index 0000000..e69de29 diff --git a/main.py b/main.py new file mode 100644 index 0000000..02f895d --- /dev/null +++ b/main.py @@ -0,0 +1,315 @@ +import os +import torch +import logging +import random +import torch.optim as optim +from tqdm import tqdm +# from torch.utils.tensorboard import SummaryWriter + +from common.utils import * +from common.opt import opts +from common.h36m_dataset import Human36mDataset +from common.Mydataset import Fusion + +from model.SGraFormer import sgraformer + +import matplotlib.pyplot as plt +from mpl_toolkits.mplot3d import Axes3D +import numpy as np + + +# os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" +# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +# os.environ["CUDA_VISIBLE_DEVICES"] = "0" +CUDA_ID = [0] +device = torch.device("cuda") + + +def visualize_skeletons(input_2D, output_3D, gt_3D, idx=0, output_dir='./output'): + # Ensure the tensors are on the CPU and convert them to numpy arrays + input_2D = input_2D.cpu().numpy() + output_3D = output_3D.cpu().numpy() + gt_3D = gt_3D.cpu().numpy() + + # print("====> input_2D: ", input_2D[-1]) + # Get the first action and first sample from the batch + input_sample = input_2D[idx, 0] + output_sample = output_3D[idx, 0] + gt_3D_sample = gt_3D[idx, 0] + + print(f'\ninput_sample shape: {input_sample.shape}') + print(f'output_sample shape: {output_sample.shape}') + + fig = plt.figure(figsize=(25, 5)) + + # Define the connections (bones) between joints + bones = [ + (0, 1), (1, 2), (2, 3), # Left leg + (0, 4), (4, 5), (5, 6), # Right leg + (0, 7), (7, 8), (8, 9), (9, 10), # Spine + (7, 11), (11, 12), (12, 13), # Right arm + (7, 14), (14, 15), (15, 16) # Left arm + ] + + # Colors for different parts + bone_colors = { + "leg": 'green', + "spine": 'blue', + "arm": 'red' + } + + # Function to get bone color based on index + def get_bone_color(start, end): + if (start in [1, 2, 3] or end in [1, 2, 3] or + start in [4, 5, 6] or end in [4, 5, 6]): + return bone_colors["leg"] + elif start in [7, 8, 9, 10] or end in [7, 8, 9, 10]: + return bone_colors["spine"] + else: + return bone_colors["arm"] + + # Plotting 2D skeletons from different angles + for i in range(4): + ax = fig.add_subplot(1, 7, i + 1) + ax.set_title(f'2D angle {i+1}') + ax.scatter(input_sample[i, :, 0], input_sample[i, :, 1], color='blue') + + # Draw the bones + for start, end in bones: + bone_color = get_bone_color(start, end) + ax.plot([input_sample[i, start, 0], input_sample[i, end, 0]], + [input_sample[i, start, 1], input_sample[i, end, 1]], color=bone_color) + + ax.set_xlabel('X') + ax.set_ylabel('Y') + ax.set_xlim(np.min(input_sample[:, :, 0]) - 1, np.max(input_sample[:, :, 0]) + 1) + ax.set_ylim(np.min(input_sample[:, :, 1]) - 1, np.max(input_sample[:, :, 1]) + 1) + ax.grid() + + # Plotting predicted 3D skeleton + ax = fig.add_subplot(1, 7, 5, projection='3d') + ax.set_title('3D Predicted Skeleton') + ax.scatter(output_sample[:, 0], output_sample[:, 1], output_sample[:, 2], color='red', label='Predicted') + + # Draw the bones in 3D for output_sample + for start, end in bones: + bone_color = get_bone_color(start, end) + ax.plot([output_sample[start, 0], output_sample[end, 0]], + [output_sample[start, 1], output_sample[end, 1]], + [output_sample[start, 2], output_sample[end, 2]], color=bone_color) + + ax.set_xlabel('X') + ax.set_ylabel('Y') + ax.set_zlabel('Z') + ax.set_xlim(np.min(output_sample[:, 0]) - 1, np.max(output_sample[:, 0]) + 1) + ax.set_ylim(np.min(output_sample[:, 1]) - 1, np.max(output_sample[:, 1]) + 1) + ax.set_zlim(np.min(output_sample[:, 2]) - 1, np.max(output_sample[:, 2]) + 1) + ax.legend() + + # Plotting ground truth 3D skeleton + ax = fig.add_subplot(1, 7, 6, projection='3d') + ax.set_title('3D Ground Truth Skeleton') + ax.scatter(gt_3D_sample[:, 0], gt_3D_sample[:, 1], gt_3D_sample[:, 2], color='blue', label='Ground Truth') + + # Draw the bones in 3D for gt_3D_sample + for start, end in bones: + bone_color = get_bone_color(start, end) + ax.plot([gt_3D_sample[start, 0], gt_3D_sample[end, 0]], + [gt_3D_sample[start, 1], gt_3D_sample[end, 1]], + [gt_3D_sample[start, 2], gt_3D_sample[end, 2]], color=bone_color, linestyle='--') + + ax.set_xlabel('X') + ax.set_ylabel('Y') + ax.set_zlabel('Z') + ax.set_xlim(np.min(gt_3D_sample[:, 0]) - 1, np.max(gt_3D_sample[:, 0]) + 1) + ax.set_ylim(np.min(gt_3D_sample[:, 1]) - 1, np.max(gt_3D_sample[:, 1]) + 1) + ax.set_zlim(np.min(gt_3D_sample[:, 2]) - 1, np.max(gt_3D_sample[:, 2]) + 1) + ax.legend() + + plt.grid() + + # Save the figure + plt.tight_layout() + plt.savefig(f'{output_dir}/skeletons_visualization.png') + plt.show() + +def train(opt, actions, train_loader, model, optimizer, epoch, writer, adaptive_weight=None): + return step('train', opt, actions, train_loader, model, optimizer, epoch, writer, adaptive_weight) + + +def val(opt, actions, val_loader, model): + with torch.no_grad(): + return step('test', opt, actions, val_loader, model) + + +def step(split, opt, actions, dataLoader, model, optimizer=None, epoch=None, writer=None, adaptive_weight=None): + loss_all = {'loss': AccumLoss()} + action_error_sum = define_error_list(actions) + + if split == 'train': + model.train() + else: + model.eval() + + TQDM = tqdm(enumerate(dataLoader), total=len(dataLoader), ncols=100) + for i, data in TQDM: + batch_cam, gt_3D, input_2D, action, subject, scale, bb_box, start, end, hops = data + + [input_2D, gt_3D, batch_cam, scale, bb_box, hops] = get_varialbe(split, [input_2D, gt_3D, batch_cam, scale, bb_box, hops]) + + if split == 'train': + output_3D = model(input_2D, hops) + elif split == 'test': + # input_2D = input_2D.to(device) + # model = model.to(device) + # hops = hops.to(device) + input_2D, output_3D = input_augmentation(input_2D, hops, model) + + + visualize_skeletons(input_2D, output_3D, gt_3D) + + out_target = gt_3D.clone() + out_target[:, :, 0] = 0 + + if split == 'train': + loss = mpjpe_cal(output_3D, out_target) + + TQDM.set_description(f'Epoch [{epoch}/{opt.nepoch}]') + TQDM.set_postfix({"l": loss.item()}) + + N = input_2D.size(0) + loss_all['loss'].update(loss.detach().cpu().numpy() * N, N) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # writer.add_scalars(main_tag='scalars1/train_loss', + # tag_scalar_dict={'trianloss': loss.item()}, + # global_step=(epoch - 1) * len(dataLoader) + i) + + elif split == 'test': + if output_3D.shape[1] != 1: + output_3D = output_3D[:, opt.pad].unsqueeze(1) + output_3D[:, :, 1:, :] -= output_3D[:, :, :1, :] + output_3D[:, :, 0, :] = 0 + action_error_sum = test_calculation(output_3D, out_target, action, action_error_sum, opt.dataset, subject) + + if split == 'train': + return loss_all['loss'].avg + elif split == 'test': + p1, p2 = print_error(opt.dataset, action_error_sum, opt.train) + return p1, p2 + + +def input_augmentation(input_2D, hops, model): + input_2D_non_flip = input_2D[:, 0] + output_3D_non_flip = model(input_2D_non_flip, hops) + + return input_2D_non_flip, output_3D_non_flip + + +if __name__ == '__main__': + opt = opts().parse() + root_path = opt.root_path + opt.manualSeed = 1 + random.seed(opt.manualSeed) + torch.manual_seed(opt.manualSeed) + + if opt.train: + logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%Y/%m/%d %H:%M:%S', + filename=os.path.join(opt.checkpoint, 'train.log'), level=logging.INFO) + + root_path = opt.root_path + dataset_path = root_path + 'data_3d_' + opt.dataset + '.npz' + + dataset = Human36mDataset(dataset_path, opt) + actions = define_actions(opt.actions) + + if opt.train: + train_data = Fusion(opt=opt, train=True, dataset=dataset, root_path=root_path) + train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=opt.batch_size, + shuffle=True, num_workers=int(opt.workers), pin_memory=True) + + test_data = Fusion(opt=opt, train=False, dataset=dataset, root_path=root_path) + test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=opt.batch_size, + shuffle=False, num_workers=int(opt.workers), pin_memory=True) + + model = sgraformer(num_frame=opt.frames, num_joints=17, in_chans=2, embed_dim_ratio=32, depth=4, + num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None, drop_path_rate=0.1) + # model = FuseModel() + + if torch.cuda.device_count() > 1: + print("Let's use", torch.cuda.device_count(), "GPUs!") + model = torch.nn.DataParallel(model, device_ids=CUDA_ID).to(device) + else: + model = model.to(device) + + # 定义一个函数来去除 'module.' 前缀 + def remove_module_prefix(state_dict): + new_state_dict = {} + for k, v in state_dict.items(): + name = k[7:] if k.startswith('module.') else k # 去除 `module.` + new_state_dict[name] = v + return new_state_dict + + model_dict = model.state_dict() + if opt.previous_dir != '': + print('pretrained model path:', opt.previous_dir) + model_path = opt.previous_dir + pre_dict = torch.load(model_path) + # print("=====> pre_dict:", pre_dict.keys()) + # 去除 'module.' 前缀 + state_dict = remove_module_prefix(pre_dict) + # print("=====> state_dict:", state_dict.keys()) + # 只保留在模型字典中的键值对 + state_dict = {k: v for k, v in state_dict.items() if k in model_dict.keys()} + # 更新模型字典 + model_dict.update(state_dict) + # 加载更新后的模型字典 + model.load_state_dict(model_dict) + + + all_param = [] + lr = opt.lr + all_param += list(model.parameters()) + + optimizer = optim.AdamW(all_param, lr=lr, weight_decay=0.1) + + ## tensorboard + # writer = SummaryWriter("runs/nin") + writer = None + flag = 0 + + + for epoch in range(1, opt.nepoch + 1): + p1, p2 = val(opt, actions, test_dataloader, model) + print("=====> p1, p2", p1, p2) + if opt.train: + loss = train(opt, actions, train_dataloader, model, optimizer, epoch, writer) + + + if opt.train: + save_model_epoch(opt.checkpoint, epoch, model) + + if p1 < opt.previous_best_threshold: + opt.previous_name = save_model(opt.previous_name, opt.checkpoint, epoch, p1, model) + opt.previous_best_threshold = p1 + + if opt.train == 0: + print('p1: %.2f, p2: %.2f' % (p1, p2)) + break + else: + logging.info('epoch: %d, lr: %.7f, loss: %.4f, p1: %.2f, p2: %.2f' % (epoch, lr, loss, p1, p2)) + print('e: %d, lr: %.7f, loss: %.4f, p1: %.2f, p2: %.2f' % (epoch, lr, loss, p1, p2)) + + if epoch % opt.large_decay_epoch == 0: + for param_group in optimizer.param_groups: + param_group['lr'] *= opt.lr_decay_large + lr *= opt.lr_decay_large + else: + for param_group in optimizer.param_groups: + param_group['lr'] *= opt.lr_decay + lr *= opt.lr_decay + + print(opt.checkpoint) diff --git a/md5.py b/md5.py new file mode 100644 index 0000000..a1caac5 --- /dev/null +++ b/md5.py @@ -0,0 +1,19 @@ +import hashlib + +def calculate_md5(file_path): + # 创建一个新的MD5 hash对象 + md5_hash = hashlib.md5() + + # 打开文件,以二进制模式读取 + with open(file_path, "rb") as f: + # 分块读取文件,防止文件过大导致内存不足 + for chunk in iter(lambda: f.read(4096), b""): + md5_hash.update(chunk) + + # 返回MD5值,转换为十六进制格式 + return md5_hash.hexdigest() + +# 文件路径 +file_path = "/home/zlt/Documents/SGraFormer-master/checkpoint/epoch_50.pth" +md5_value = calculate_md5(file_path) +print(f"MD5: {md5_value}") \ No newline at end of file diff --git a/model/SGraFormer.py b/model/SGraFormer.py new file mode 100644 index 0000000..092ff5a --- /dev/null +++ b/model/SGraFormer.py @@ -0,0 +1,176 @@ +## Our model was revised from https://github.com/zczcwh/PoseFormer/blob/main/common/model_poseformer.py + +import torch +import torch.nn as nn +from functools import partial +from einops import rearrange +from timm.models.layers import DropPath + +from common.opt import opts + +from model.Spatial_encoder import First_view_Spatial_features, Spatial_features +from model.Temporal_encoder import Temporal__features + +opt = opts().parse() + + +####################################################################################################################### +class sgraformer(nn.Module): + def __init__(self, num_frame=9, num_joints=17, in_chans=2, embed_dim_ratio=32, depth=4, + num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.2, norm_layer=None): + """ ##########hybrid_backbone=None, representation_size=None, + Args: + num_frame (int, tuple): input frame number + num_joints (int, tuple): joints number + in_chans (int): number of input channels, 2D joints have 2 channels: (x,y) + embed_dim_ratio (int): embedding dimension ratio + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + qk_scale (float): override default qk scale of head_dim ** -0.5 if set + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + norm_layer: (nn.Module): normalization layer + """ + super().__init__() + + embed_dim = embed_dim_ratio * num_joints + out_dim = num_joints * 3 #### output dimension is num_joints * 3 + ##Spatial_features + self.SF1 = First_view_Spatial_features(num_frame, num_joints, in_chans, embed_dim_ratio, depth, + num_heads, mlp_ratio, qkv_bias, qk_scale, + drop_rate, attn_drop_rate, drop_path_rate, norm_layer) + self.SF2 = Spatial_features(num_frame, num_joints, in_chans, embed_dim_ratio, depth, + num_heads, mlp_ratio, qkv_bias, qk_scale, + drop_rate, attn_drop_rate, drop_path_rate, norm_layer) + self.SF3 = Spatial_features(num_frame, num_joints, in_chans, embed_dim_ratio, depth, + num_heads, mlp_ratio, qkv_bias, qk_scale, + drop_rate, attn_drop_rate, drop_path_rate, norm_layer) + self.SF4 = Spatial_features(num_frame, num_joints, in_chans, embed_dim_ratio, depth, + num_heads, mlp_ratio, qkv_bias, qk_scale, + drop_rate, attn_drop_rate, drop_path_rate, norm_layer) + + ## MVF + self.view_pos_embed = nn.Parameter(torch.zeros(1, 4, num_frame, embed_dim)) + self.pos_drop = nn.Dropout(p=0.) + + self.conv = nn.Sequential( + nn.BatchNorm2d(4, momentum=0.1), + nn.Conv2d(4, 1, kernel_size=opt.mvf_kernel, stride=1, padding=int(opt.mvf_kernel // 2), bias=False), + nn.ReLU(inplace=True), + ) + + + self.conv_hop = nn.Sequential( + nn.BatchNorm2d(4, momentum=0.1), + nn.Conv2d(4, 1, kernel_size=opt.mvf_kernel, stride=1, padding=int(opt.mvf_kernel // 2), bias=False), + nn.ReLU(inplace=True), + ) + + self.conv_norm = nn.LayerNorm(embed_dim) + + self.conv_hop_norm = nn.LayerNorm(embed_dim) + + + # Time Serial + self.TF = Temporal__features(num_frame, num_joints, in_chans, embed_dim_ratio, depth, + num_heads, mlp_ratio, qkv_bias, qk_scale, + drop_rate, attn_drop_rate, drop_path_rate, norm_layer) + + self.head = nn.Sequential( + nn.LayerNorm(embed_dim), + nn.Linear(embed_dim, out_dim), + ) + + self.hop_w0 = nn.Parameter(torch.ones(17, 17)) + self.hop_w1 = nn.Parameter(torch.ones(17, 17)) + self.hop_w2 = nn.Parameter(torch.ones(17, 17)) + self.hop_w3 = nn.Parameter(torch.ones(17, 17)) + self.hop_w4 = nn.Parameter(torch.ones(17, 17)) + + self.hop_global = nn.Parameter(torch.ones(17, 17)) + + self.linear_hop = nn.Linear(8, 2) + # self.max_pool = nn.MaxPool1d(2) + + self.edge_embedding = nn.Linear(17*17*4, 17*17) + + def forward(self, x, hops): + b, f, v, j, c = x.shape + + edge_embedding = self.edge_embedding(hops[0].reshape(1, -1)) + + ###############golbal feature################# + x_hop_global = x.unsqueeze(3).repeat(1, 1, 1, 17, 1, 1) + x_hop_global = x_hop_global - x_hop_global.permute(0, 1, 2, 4, 3, 5) + x_hop_global = torch.sum(x_hop_global ** 2, dim=-1) + hop_global = x_hop_global / torch.sum(x_hop_global, dim=-1).unsqueeze(-1) + hops = hops.unsqueeze(1).unsqueeze(2).repeat(1, f, v, 1, 1, 1) + hops1 = hop_global * hops[:, :, :, 0] + hops2 = hop_global * hops[:, :, :, 1] + hops3 = hop_global * hops[:, :, :, 2] + hops4 = hop_global * hops[:, :, :, 3] + # hops = torch.cat((hops1,hops2,hops3,hops4), dim=-1) + hops = torch.cat((hops1,hops2,hops3,hops4), dim=-1) + + + x1 = x[:, :, 0] + x2 = x[:, :, 1] + x3 = x[:, :, 2] + x4 = x[:, :, 3] + + x1 = x1.permute(0, 3, 1, 2) + x2 = x2.permute(0, 3, 1, 2) + x3 = x3.permute(0, 3, 1, 2) + x4 = x4.permute(0, 3, 1, 2) + + hop1 = hops[:, :, 0] + hop2 = hops[:, :, 1] + hop3 = hops[:, :, 2] + hop4 = hops[:, :, 3] + + hop1 = hop1.permute(0, 3, 1, 2) + hop2 = hop2.permute(0, 3, 1, 2) + hop3 = hop3.permute(0, 3, 1, 2) + hop4 = hop4.permute(0, 3, 1, 2) + + ### Semantic graph transformer encoder + x1, hop1, MSA1, MSA2, MSA3, MSA4 = self.SF1(x1, hop1, edge_embedding) + x2, hop2, MSA1, MSA2, MSA3, MSA4 = self.SF2(x2, hop2, MSA1, MSA2, MSA3, MSA4, edge_embedding) + x3, hop3, MSA1, MSA2, MSA3, MSA4 = self.SF3(x3, hop3, MSA1, MSA2, MSA3, MSA4, edge_embedding) + x4, hop4, MSA1, MSA2, MSA3, MSA4 = self.SF4(x4, hop4, MSA1, MSA2, MSA3, MSA4, edge_embedding) + + ### Multi-view cross-channel fusion + x = torch.cat((x1.unsqueeze(1), x2.unsqueeze(1), x3.unsqueeze(1), x4.unsqueeze(1)), dim=1) + self.view_pos_embed + x = self.pos_drop(x) + x = self.conv(x).squeeze(1) + x1 + x2 + x3 + x4 + x = self.conv_norm(x) + + hop = torch.cat((hop1.unsqueeze(1), hop2.unsqueeze(1), hop3.unsqueeze(1), hop4.unsqueeze(1)), dim=1) + self.view_pos_embed + hop = self.pos_drop(hop) + # hop = self.conv_hop(hop).squeeze(1) + hop1 + hop2 + hop3 + hop4 + # hop = self.conv_hop_norm(hop) + hop = self.conv(hop).squeeze(1) + hop1 + hop2 + hop3 + hop4 + hop = self.conv_norm(hop) + + x = x * hop + + + ### Temporal transformer encoder + x = self.TF(x) + + x = self.head(x) + x = x.view(b, opt.frames, j, -1) + + print("=============> x.shape", x.shape) + return x + + +# x = torch.rand((8, 27, 4, 17 , 2)) +# hops = torch.rand((8,4,17,17)) +# mvft = hmvformer(num_frame=opt.frames, num_joints=17, in_chans=2, embed_dim_ratio=32, depth=4, +# num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None, drop_path_rate=0.1) +# print(mvft(x, hops).shape) \ No newline at end of file diff --git a/model/Spatial_encoder.py b/model/Spatial_encoder.py new file mode 100644 index 0000000..b5a9b21 --- /dev/null +++ b/model/Spatial_encoder.py @@ -0,0 +1,343 @@ +## Our model was revised from https://github.com/zczcwh/PoseFormer/blob/main/common/model_poseformer.py + +import torch +import torch.nn as nn +from functools import partial +from einops import rearrange +from timm.models.layers import DropPath + + +####################################################################################################################### +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +####################################################################################################################### +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + + self.edge_embedding = nn.Linear(17*17, 17*17) + + def forward(self, x, edge_embedding): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + + edge_embedding = self.edge_embedding(edge_embedding) + edge_embedding = edge_embedding.reshape(1, 17, 17).unsqueeze(0).repeat(B, self.num_heads, 1, 1) + # print(edge_embedding.shape) + + attn = attn + edge_embedding + + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +####################################################################################################################### +class CVA_Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.Qnorm = nn.LayerNorm(dim) + self.Knorm = nn.LayerNorm(dim) + self.Vnorm = nn.LayerNorm(dim) + self.QLinear = nn.Linear(dim, dim) + self.KLinear = nn.Linear(dim, dim) + self.VLinear = nn.Linear(dim, dim) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + + self.edge_embedding = nn.Linear(17*17, 17*17) + + + + + def forward(self, x, CVA_input, edge_embedding): + B, N, C = x.shape + # CVA_input = self.max_pool(CVA_input) + # print(CVA_input.shape) + q = self.QLinear(self.Qnorm(CVA_input)).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + k = self.KLinear(self.Knorm(CVA_input)).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + v = self.VLinear(self.Vnorm(x)).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + attn = (q @ k.transpose(-2, -1)) * self.scale + + edge_embedding = self.edge_embedding(edge_embedding) + edge_embedding = edge_embedding.reshape(1, 17, 17).unsqueeze(0).repeat(B, self.num_heads, 1, 1) + + + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +####################################################################################################################### +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x), edge_embedding)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +####################################################################################################################### +class Multi_Out_Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + + self.norm_hop1 = norm_layer(dim) + self.norm_hop2 = norm_layer(dim) + self.mlp_hop = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, hops, edge_embedding): + MSA = self.drop_path(self.attn(self.norm1(x), edge_embedding)) + MSA = self.norm_hop1(hops) * MSA + + x = x + MSA + x = x + self.drop_path(self.mlp(self.norm2(x))) + + + hops = hops + MSA + hops = hops + self.drop_path(self.mlp_hop(self.norm_hop2(hops))) + + + return x, hops, MSA + + +####################################################################################################################### +class Multi_In_Out_Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.cva_attn = CVA_Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + # self.max_pool = nn.MaxPool1d(3, stride=1, padding=1, dilation=1, return_indices=False, ceil_mode=False) + + + self.norm_hop1 = norm_layer(dim) + self.norm_hop2 = norm_layer(dim) + self.mlp_hop = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, hops, CVA_input, edge_embedding): + MSA = self.drop_path(self.cva_attn(x, CVA_input, edge_embedding)) + MSA = self.norm_hop1(hops) * MSA + + x = x + MSA + x = x + self.drop_path(self.mlp(self.norm2(x))) + + hops = hops + MSA + hops = hops + self.drop_path(self.mlp_hop(self.norm_hop2(hops))) + return x, hops, MSA + + +####################################################################################################################### +class First_view_Spatial_features(nn.Module): + def __init__(self, num_frame=9, num_joints=17, in_chans=2, embed_dim_ratio=32, depth=4, + num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.2, norm_layer=None): + super().__init__() + + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + + ### spatial patch embedding + self.Spatial_patch_to_embedding = nn.Linear(in_chans, embed_dim_ratio) + self.Spatial_pos_embed = nn.Parameter(torch.zeros(1, num_joints, embed_dim_ratio)) + + + self.hop_to_embedding = nn.Linear(68, embed_dim_ratio) + self.hop_pos_embed = nn.Parameter(torch.zeros(1, num_joints, embed_dim_ratio)) + + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + self.block1 = Multi_Out_Block(dim=embed_dim_ratio, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, + qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[0], + norm_layer=norm_layer) + self.block2 = Multi_Out_Block(dim=embed_dim_ratio, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, + qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[1], + norm_layer=norm_layer) + self.block3 = Multi_Out_Block(dim=embed_dim_ratio, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, + qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[2], + norm_layer=norm_layer) + self.block4 = Multi_Out_Block(dim=embed_dim_ratio, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, + qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[3], + norm_layer=norm_layer) + + self.Spatial_norm = norm_layer(embed_dim_ratio) + + self.hop_norm = norm_layer(embed_dim_ratio) + + def forward(self, x, hops, edge_embedding): + b, _, f, p = x.shape ##### b is batch size, f is number of frames, p is number of joints + x = rearrange(x, 'b c f p -> (b f) p c', ) + + x = self.Spatial_patch_to_embedding(x) + x += self.Spatial_pos_embed + x = self.pos_drop(x) + + hops = rearrange(hops, 'b c f p -> (b f) p c', ) + hops = self.hop_to_embedding(hops) + hops += self.hop_pos_embed + hops = self.pos_drop(hops) + + + x, hops, MSA1 = self.block1(x, hops, edge_embedding) + x, hops, MSA2 = self.block2(x, hops, edge_embedding) + x, hops, MSA3 = self.block3(x, hops, edge_embedding) + x, hops, MSA4 = self.block4(x, hops, edge_embedding) + + x = self.Spatial_norm(x) + x = rearrange(x, '(b f) w c -> b f (w c)', f=f) + + hops = self.hop_norm(hops) + hops = rearrange(hops, '(b f) w c -> b f (w c)', f=f) + + return x, hops, MSA1, MSA2, MSA3, MSA4 + + +####################################################################################################################### +class Spatial_features(nn.Module): + def __init__(self, num_frame=9, num_joints=17, in_chans=2, embed_dim_ratio=32, depth=4, + num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.2, norm_layer=None): + super().__init__() + + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + + ### spatial patch embedding + self.Spatial_patch_to_embedding = nn.Linear(in_chans, embed_dim_ratio) + self.Spatial_pos_embed = nn.Parameter(torch.zeros(1, num_joints, embed_dim_ratio)) + + self.hop_to_embedding = nn.Linear(68, embed_dim_ratio) + self.hop_pos_embed = nn.Parameter(torch.zeros(1, num_joints, embed_dim_ratio)) + + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + self.block1 = Multi_In_Out_Block( + dim=embed_dim_ratio, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[0], norm_layer=norm_layer) + self.block2 = Multi_In_Out_Block( + dim=embed_dim_ratio, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[1], norm_layer=norm_layer) + self.block3 = Multi_In_Out_Block( + dim=embed_dim_ratio, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[2], norm_layer=norm_layer) + self.block4 = Multi_In_Out_Block( + dim=embed_dim_ratio, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[3], norm_layer=norm_layer) + + self.Spatial_norm = norm_layer(embed_dim_ratio) + + self.hop_norm = norm_layer(embed_dim_ratio) + + def forward(self, x, hops, MSA1, MSA2, MSA3, MSA4, edge_embedding): + b, _, f, p = x.shape ##### b is batch size, f is number of frames, p is number of joints + x = rearrange(x, 'b c f p -> (b f) p c', ) + + x = self.Spatial_patch_to_embedding(x) + x += self.Spatial_pos_embed + x = self.pos_drop(x) + + + hops = rearrange(hops, 'b c f p -> (b f) p c', ) + hops = self.hop_to_embedding(hops) + hops += self.hop_pos_embed + hops = self.pos_drop(hops) + + + x, hops, MSA1 = self.block1(x, hops, MSA1, edge_embedding) + x, hops, MSA2 = self.block2(x, hops, MSA2, edge_embedding) + x, hops, MSA3 = self.block3(x, hops, MSA3, edge_embedding) + x, hops, MSA4 = self.block4(x, hops, MSA4, edge_embedding) + + + x = self.Spatial_norm(x) + x = rearrange(x, '(b f) w c -> b f (w c)', f=f) + + hops = self.hop_norm(hops) + hops = rearrange(hops, '(b f) w c -> b f (w c)', f=f) + + return x, hops, MSA1, MSA2, MSA3, MSA4 \ No newline at end of file diff --git a/model/Temporal_encoder.py b/model/Temporal_encoder.py new file mode 100644 index 0000000..c5b2186 --- /dev/null +++ b/model/Temporal_encoder.py @@ -0,0 +1,159 @@ +## Our model was revised from https://github.com/zczcwh/PoseFormer/blob/main/common/model_poseformer.py + +import torch +import torch.nn as nn +from functools import partial +from einops import rearrange +from timm.models.layers import DropPath + +from common.opt import opts + +opt = opts().parse() + + +####################################################################################################################### +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +####################################################################################################################### +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +####################################################################################################################### +class CVA_Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.Qnorm = nn.LayerNorm(dim) + self.Knorm = nn.LayerNorm(dim) + self.Vnorm = nn.LayerNorm(dim) + self.QLinear = nn.Linear(dim, dim) + self.KLinear = nn.Linear(dim, dim) + self.VLinear = nn.Linear(dim, dim) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + + + + def forward(self, x, CVA_input): + B, N, C = x.shape + q = self.QLinear(self.Qnorm(CVA_input)).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + k = self.KLinear(self.Knorm(CVA_input)).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + v = self.VLinear(self.Vnorm(x)).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +####################################################################################################################### +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +####################################################################################################################### +class Temporal__features(nn.Module): + def __init__(self, num_frame=9, num_joints=17, in_chans=2, embed_dim_ratio=32, depth=4, + num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.2, norm_layer=None): + super().__init__() + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + embed_dim = embed_dim_ratio * num_joints #### temporal embed_dim is num_joints * spatial embedding dim ratio + out_dim = num_joints * 3 #### output dimension is num_joints * 3 + ### Temporal patch embedding + self.Temporal_pos_embed = nn.Parameter(torch.zeros(1, num_frame, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + + self.Temporal_norm = norm_layer(embed_dim) + ####### A easy way to implement weighted mean + self.weighted_mean = torch.nn.Conv1d(in_channels=num_frame, out_channels=1, kernel_size=1) + + def forward(self, x): + b = x.shape[0] + x += self.Temporal_pos_embed + x = self.pos_drop(x) + for blk in self.blocks: + x = blk(x) + + x = self.Temporal_norm(x) + ##### x size [b, f, emb_dim], then take weighted mean on frame dimension, we only predict 3D pose of the center frame + # x = self.weighted_mean(x) + x = x.view(b, opt.frames, -1) + return x \ No newline at end of file diff --git a/model/__pycache__/HMVFormer.cpython-37.pyc b/model/__pycache__/HMVFormer.cpython-37.pyc new file mode 100644 index 0000000..6e0ca7b Binary files /dev/null and b/model/__pycache__/HMVFormer.cpython-37.pyc differ diff --git a/model/__pycache__/HMVFormer_new.cpython-37.pyc b/model/__pycache__/HMVFormer_new.cpython-37.pyc new file mode 100644 index 0000000..14bedf0 Binary files /dev/null and b/model/__pycache__/HMVFormer_new.cpython-37.pyc differ diff --git a/model/__pycache__/SGraFormer.cpython-38.pyc b/model/__pycache__/SGraFormer.cpython-38.pyc new file mode 100644 index 0000000..cb4d2f2 Binary files /dev/null and b/model/__pycache__/SGraFormer.cpython-38.pyc differ diff --git a/model/__pycache__/Spatial_encoder.cpython-37.pyc b/model/__pycache__/Spatial_encoder.cpython-37.pyc new file mode 100644 index 0000000..e0e19db Binary files /dev/null and b/model/__pycache__/Spatial_encoder.cpython-37.pyc differ diff --git a/model/__pycache__/Spatial_encoder.cpython-38.pyc b/model/__pycache__/Spatial_encoder.cpython-38.pyc new file mode 100644 index 0000000..6e4a027 Binary files /dev/null and b/model/__pycache__/Spatial_encoder.cpython-38.pyc differ diff --git a/model/__pycache__/Spatial_encoder_new.cpython-37.pyc b/model/__pycache__/Spatial_encoder_new.cpython-37.pyc new file mode 100644 index 0000000..cf017f8 Binary files /dev/null and b/model/__pycache__/Spatial_encoder_new.cpython-37.pyc differ diff --git a/model/__pycache__/Temporal_encoder.cpython-38.pyc b/model/__pycache__/Temporal_encoder.cpython-38.pyc new file mode 100644 index 0000000..bdda422 Binary files /dev/null and b/model/__pycache__/Temporal_encoder.cpython-38.pyc differ diff --git a/model/__pycache__/Temtemporal_encoder.cpython-37.pyc b/model/__pycache__/Temtemporal_encoder.cpython-37.pyc new file mode 100644 index 0000000..c648970 Binary files /dev/null and b/model/__pycache__/Temtemporal_encoder.cpython-37.pyc differ diff --git a/output/skeletons_2D_first_action_first_sample.png b/output/skeletons_2D_first_action_first_sample.png new file mode 100644 index 0000000..edbdc71 Binary files /dev/null and b/output/skeletons_2D_first_action_first_sample.png differ diff --git a/output/skeletons_3D_first_action_first_sample.png b/output/skeletons_3D_first_action_first_sample.png new file mode 100644 index 0000000..2bf59cc Binary files /dev/null and b/output/skeletons_3D_first_action_first_sample.png differ diff --git a/output/skeletons_visualization.png b/output/skeletons_visualization.png new file mode 100644 index 0000000..c8b09d8 Binary files /dev/null and b/output/skeletons_visualization.png differ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..36d3d3f --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ +opencv-python +tqdm +yacs +numba +scikit-image +filterpy +ipython +einops +tensorboard +timm==0.4.5 +matplotlib==2.2.2 +tensorboardX \ No newline at end of file diff --git a/test.py b/test.py new file mode 100644 index 0000000..969c04f --- /dev/null +++ b/test.py @@ -0,0 +1,256 @@ +import os +import torch +import logging +import random +import torch.optim as optim +from tqdm import tqdm +# from torch.utils.tensorboard import SummaryWriter + +from common.utils import * +from common.opt import opts +from common.h36m_dataset import Human36mDataset +from common.Mydataset import Fusion + +from model.SGraFormer import sgraformer + +import matplotlib.pyplot as plt +from mpl_toolkits.mplot3d import Axes3D +import numpy as np + +os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +os.environ["CUDA_VISIBLE_DEVICES"] = "0" +CUDA_ID = [0] +device = torch.device("cuda") + + + +def visualize_skeletons(input_2D, output_3D, gt_3D, idx=5, output_dir='./output'): + # Ensure the tensors are on the CPU and convert them to numpy arrays + input_2D = input_2D.cpu().numpy() + output_3D = output_3D.cpu().numpy() + gt_3D = gt_3D.cpu().numpy() + + # Get the first action and first sample from the batch + input_sample = input_2D[idx, 0] + output_sample = output_3D[idx, 0] + gt_3D_sample = gt_3D[idx, 0] + + print(f'\ninput_sample shape: {input_sample.shape}') + print(f'output_sample shape: {output_sample.shape}') + + fig = plt.figure(figsize=(25, 5)) + + # Define the connections (bones) between joints + bones = [ + (0, 1), (1, 2), (2, 3), # Left leg + (0, 4), (4, 5), (5, 6), # Right leg + (0, 7), (7, 8), (8, 9), (9, 10), # Spine + (7, 11), (11, 12), (12, 13), # Right arm + (7, 14), (14, 15), (15, 16) # Left arm + ] + + # Colors for different parts + bone_colors = { + "leg": 'green', + "spine": 'blue', + "arm": 'red' + } + + # Function to get bone color based on index + def get_bone_color(start, end): + if (start in [1, 2, 3] or end in [1, 2, 3] or + start in [4, 5, 6] or end in [4, 5, 6]): + return bone_colors["leg"] + elif start in [7, 8, 9, 10] or end in [7, 8, 9, 10]: + return bone_colors["spine"] + else: + return bone_colors["arm"] + + # Plotting 2D skeletons from different angles + for i in range(4): + ax = fig.add_subplot(1, 7, i + 1) + ax.set_title(f'2D angle {i+1}') + ax.scatter(input_sample[i, :, 0], input_sample[i, :, 1], color='blue') + + # Draw the bones + for start, end in bones: + bone_color = get_bone_color(start, end) + ax.plot([input_sample[i, start, 0], input_sample[i, end, 0]], + [input_sample[i, start, 1], input_sample[i, end, 1]], color=bone_color) + + ax.set_xlabel('X') + ax.set_ylabel('Y') + ax.set_xlim(np.min(input_sample[:, :, 0]) - 1, np.max(input_sample[:, :, 0]) + 1) + ax.set_ylim(np.min(input_sample[:, :, 1]) - 1, np.max(input_sample[:, :, 1]) + 1) + ax.grid() + + # Plotting predicted 3D skeleton + ax = fig.add_subplot(1, 7, 5, projection='3d') + ax.set_title('3D Predicted Skeleton') + ax.scatter(output_sample[:, 0], output_sample[:, 1], output_sample[:, 2], color='red', label='Predicted') + + # Draw the bones in 3D for output_sample + for start, end in bones: + bone_color = get_bone_color(start, end) + ax.plot([output_sample[start, 0], output_sample[end, 0]], + [output_sample[start, 1], output_sample[end, 1]], + [output_sample[start, 2], output_sample[end, 2]], color=bone_color) + + ax.set_xlabel('X') + ax.set_ylabel('Y') + ax.set_zlabel('Z') + ax.set_xlim(np.min(output_sample[:, 0]) - 1, np.max(output_sample[:, 0]) + 1) + ax.set_ylim(np.min(output_sample[:, 1]) - 1, np.max(output_sample[:, 1]) + 1) + ax.set_zlim(np.min(output_sample[:, 2]) - 1, np.max(output_sample[:, 2]) + 1) + ax.legend() + + # Plotting ground truth 3D skeleton + ax = fig.add_subplot(1, 7, 6, projection='3d') + ax.set_title('3D Ground Truth Skeleton') + ax.scatter(gt_3D_sample[:, 0], gt_3D_sample[:, 1], gt_3D_sample[:, 2], color='blue', label='Ground Truth') + + # Draw the bones in 3D for gt_3D_sample + for start, end in bones: + bone_color = get_bone_color(start, end) + ax.plot([gt_3D_sample[start, 0], gt_3D_sample[end, 0]], + [gt_3D_sample[start, 1], gt_3D_sample[end, 1]], + [gt_3D_sample[start, 2], gt_3D_sample[end, 2]], color=bone_color, linestyle='--') + + ax.set_xlabel('X') + ax.set_ylabel('Y') + ax.set_zlabel('Z') + ax.set_xlim(np.min(gt_3D_sample[:, 0]) - 1, np.max(gt_3D_sample[:, 0]) + 1) + ax.set_ylim(np.min(gt_3D_sample[:, 1]) - 1, np.max(gt_3D_sample[:, 1]) + 1) + ax.set_zlim(np.min(gt_3D_sample[:, 2]) - 1, np.max(gt_3D_sample[:, 2]) + 1) + ax.legend() + + plt.grid() + + # Save the figure + plt.tight_layout() + plt.savefig(f'{output_dir}/skeletons_visualization.png') + plt.show() + +def val(opt, actions, val_loader, model): + with torch.no_grad(): + return step('test', opt, actions, val_loader, model) + +def step(split, opt, actions, dataLoader, model, optimizer=None, epoch=None, writer=None, adaptive_weight=None): + loss_all = {'loss': AccumLoss()} + action_error_sum = define_error_list(actions) + + + model.eval() + + TQDM = tqdm(enumerate(dataLoader), total=len(dataLoader), ncols=100) + for i, data in TQDM: + batch_cam, gt_3D, input_2D, action, subject, scale, bb_box, start, end, hops = data + + [input_2D, gt_3D, batch_cam, scale, bb_box, hops] = get_varialbe(split, [input_2D, gt_3D, batch_cam, scale, bb_box, hops]) + + # print("\n======> input_2D: ", input_2D.shape) + # print("======> gt_3D: ", gt_3D.shape) + + + if split == 'train': + output_3D = model(input_2D, hops) + elif split == 'test': + input_2D, output_3D = input_augmentation(input_2D, hops, model) + + out_target = gt_3D.clone() + out_target[:, :, 0] = 0 + + # print("======> output_3D: ", output_3D.shape) + # visualize_skeletons(input_2D, output_3D, gt_3D) + + if output_3D.shape[1] != 1: + output_3D = output_3D[:, opt.pad].unsqueeze(1) + output_3D[:, :, 1:, :] -= output_3D[:, :, :1, :] + output_3D[:, :, 0, :] = 0 + action_error_sum = test_calculation(output_3D, out_target, action, action_error_sum, opt.dataset, subject) + + p1, p2 = print_error(opt.dataset, action_error_sum, opt.train) + # print("======> p1, p2: ", p1, p2) + + if split == 'train': + return loss_all['loss'].avg + elif split == 'test': + p1, p2 = print_error(opt.dataset, action_error_sum, opt.train) + return p1, p2 + + +def input_augmentation(input_2D, hops, model): + input_2D_non_flip = input_2D[:, 0] + output_3D_non_flip = model(input_2D_non_flip, hops) + + # print("======> input_2D_non_flip: ", input_2D_non_flip.shape) + # print("======> output_3D_non_flip: ", output_3D_non_flip.shape) + # visualize_skeletons(input_2D_non_flip, output_3D_non_flip) + + return input_2D_non_flip, output_3D_non_flip + + + + +if __name__ == '__main__': + opt = opts().parse() + root_path = opt.root_path + opt.manualSeed = 1 + random.seed(opt.manualSeed) + torch.manual_seed(opt.manualSeed) + + + root_path = opt.root_path + dataset_path = root_path + 'data_3d_' + opt.dataset + '.npz' + + dataset = Human36mDataset(dataset_path, opt) + actions = define_actions(opt.actions) + + train_data = Fusion(opt=opt, train=True, dataset=dataset, root_path=root_path) + train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=opt.batch_size, + shuffle=True, num_workers=int(opt.workers), pin_memory=True) + + + test_data = Fusion(opt=opt, train=False, dataset=dataset, root_path=root_path) + test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=opt.batch_size, + shuffle=False, num_workers=int(opt.workers), pin_memory=True) + + model = sgraformer(num_frame=opt.frames, num_joints=17, in_chans=2, embed_dim_ratio=32, depth=4, + num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None, drop_path_rate=0.1) + # model = FuseModel() + + if torch.cuda.device_count() > 1: + print("Let's use", torch.cuda.device_count(), "GPUs!") + model = torch.nn.DataParallel(model, device_ids=CUDA_ID).to(device) + model = model.to(device) + + model_dict = model.state_dict() + + model_path = '/home/zlt/Documents/SGraFormer-master/checkpoint/epoch_50.pth' + + pre_dict = torch.load(model_path) + + model_dict = model.state_dict() + state_dict = {k: v for k, v in pre_dict.items() if k in model_dict.keys()} + model_dict.update(state_dict) + model.load_state_dict(model_dict) + + all_param = [] + lr = opt.lr + all_param += list(model.parameters()) + + optimizer = optim.AdamW(all_param, lr=lr, weight_decay=0.1) + + ## tensorboard + # writer = SummaryWriter("runs/nin") + writer = None + flag = 0 + + + + p1, p2 = val(opt, actions, test_dataloader, model) + + + print('p1: %.2f, p2: %.2f' % (p1, p2)) + diff --git a/vs b/vs new file mode 100644 index 0000000..8e0c4e0 --- /dev/null +++ b/vs @@ -0,0 +1,209 @@ +====> input_2D: [[[[-0.17325002 -0.17840627] + [-0.16224998 -0.18938544] + [-0.1595 0.00549481] + ... + [-0.12924999 -0.38701043] + [-0.12924999 -0.28819796] + [-0.23925 -0.26623955]] + + [[ 0.11524999 -0.21262503] + [ 0.06274998 -0.21262503] + [ 0.05225003 -0.02362502] + ... + [ 0.04174995 -0.40162498] + [ 0.00762498 -0.28875 ] + [ 0.07587504 -0.28875 ]] + + [[ 0.126382 -0.27275002] + [ 0.17350698 -0.28324997] + [ 0.18659711 -0.10212499] + ... + [ 0.1892153 -0.47224998] + [ 0.22063196 -0.36725003] + [ 0.16303468 -0.41974998]] + + [[ 0.11715972 -0.28818753] + [ 0.10402083 -0.27506253] + [ 0.07774305 -0.03881248] + ... + [ 0.08759713 -0.5277187 ] + [ 0.08759713 -0.37678126] + [ 0.22884035 -0.3669375 ]]] + + + [[[-0.17236805 -0.17731252] + [-0.16139579 -0.18827084] + [-0.15865278 0.00623951] + ... + [-0.12847918 -0.38552085] + [-0.12573606 -0.28415623] + [-0.23351866 -0.263042 ]] + + [[ 0.11031246 -0.21437502] + [ 0.06462502 -0.2116875 ] + [ 0.05387497 -0.02356249] + ... + [ 0.04312503 -0.4051875 ] + [ 0.01087499 -0.28693748] + [ 0.07537496 -0.2815625 ]] + + [[ 0.12379158 -0.2751823 ] + [ 0.17327082 -0.27778125] + [ 0.18889582 -0.10105205] + ... + [ 0.18889582 -0.47530204] + [ 0.21754158 -0.36354685] + [ 0.15504158 -0.41552603]] + + [[ 0.12193751 -0.28907296] + [ 0.10231245 -0.27596876] + [ 0.0826875 -0.04009375] + ... + [ 0.08595836 -0.5216719 ] + [ 0.09904158 -0.37752607] + [ 0.23314583 -0.35459378]]] + + + [[[-0.1726042 -0.17912498] + [-0.1616875 -0.17912498] + [-0.1562292 0.00645826] + ... + [-0.12893748 -0.3865417 ] + [-0.12893748 -0.28829166] + [-0.23810416 -0.2555417 ]] + + [[ 0.11325002 -0.20937502] + [ 0.063375 -0.20674998] + [ 0.05025005 -0.02037501] + ... + [ 0.03974998 -0.403625 ] + [ 0.00825 -0.28812498] + [ 0.07912505 -0.28025 ]] + + [[ 0.126382 -0.27367705] + [ 0.17350698 -0.28415626] + [ 0.18659711 -0.10339063] + ... + [ 0.1892153 -0.47278124] + [ 0.22063196 -0.36798954] + [ 0.16303468 -0.40990627]] + + [[ 0.12304163 -0.28884378] + [ 0.1034584 -0.27578124] + [ 0.07408333 -0.03739062] + ... + [ 0.08713889 -0.52070314] + [ 0.08713889 -0.37048438] + [ 0.23074996 -0.34435937]]] + + + ... + + + [[[-0.15943056 -0.18728128] + [-0.15660417 -0.19010422] + [-0.15660417 0.00185416] + ... + [-0.148125 -0.39617708] + [-0.1452986 -0.29172918] + [-0.23856944 -0.25503126]] + + [[ 0.09365964 -0.20290625] + [ 0.05642354 -0.20290625] + [ 0.05110407 -0.02192706] + ... + [ 0.03248608 -0.40251565] + [-0.01006943 -0.28008854] + [ 0.08568048 -0.2375052 ]] + + [[ 0.14960408 -0.274875 ] + [ 0.18345833 -0.27747917] + [ 0.19387496 -0.10560417] + ... + [ 0.20168746 -0.47279167] + [ 0.23554158 -0.3660208 ] + [ 0.14960408 -0.3582083 ]] + + [[ 0.10611105 -0.2872292 ] + [ 0.09300005 -0.28068748] + [ 0.07988894 -0.03210416] + ... + [ 0.10611105 -0.5292708 ] + [ 0.07988894 -0.36572918] + [ 0.20116663 -0.31012502]]] + + + [[[-0.15781945 -0.17923954] + [-0.15502083 -0.18764582] + [-0.15502083 0.00289581] + ... + [-0.14382643 -0.40060422] + [-0.13822919 -0.2997292 ] + [-0.23338199 -0.25489584]] + + [[ 0.09533334 -0.20466667] + [ 0.05533338 -0.20200002] + [ 0.04999995 -0.02066666] + ... + [ 0.03133333 -0.402 ] + [-0.01133329 -0.2793333 ] + [ 0.08466661 -0.23666668]] + + [[ 0.1532222 -0.27335936] + [ 0.18455553 -0.2785781 ] + [ 0.1976111 -0.09853125] + ... + [ 0.1976111 -0.47428125] + [ 0.22894442 -0.36990625] + [ 0.15061104 -0.35946876]] + + [[ 0.09612501 -0.29290107] + [ 0.09612501 -0.27979687] + [ 0.07977092 -0.03409376] + ... + [ 0.10593748 -0.5320521 ] + [ 0.07977092 -0.36169794] + [ 0.2007916 -0.3125573 ]]] + + + [[[-0.15511107 -0.18240628] + [-0.15511107 -0.18796876] + [-0.15233332 0.00393746] + ... + [-0.144 -0.39934376] + [-0.144 -0.30478123] + [-0.23288894 -0.25471875]] + + [[ 0.08933342 -0.20692188] + [ 0.05466664 -0.20426041] + [ 0.04666662 -0.02061981] + ... + [ 0.03333342 -0.40653127] + [-0.00666666 -0.28676564] + [ 0.08666658 -0.23619795]] + + [[ 0.15420842 -0.2762708 ] + [ 0.18285418 -0.278875 ] + [ 0.19847918 -0.10179168] + ... + [ 0.19587505 -0.4741875 ] + [ 0.22712505 -0.3648125 ] + [ 0.1456331 -0.36036688]] + + [[ 0.09231246 -0.28507295] + [ 0.09885418 -0.28507295] + [ 0.07595837 -0.02626565] + ... + [ 0.10212505 -0.530776 ] + [ 0.08577085 -0.36369792] + [ 0.20679164 -0.30800518]]]] + +=======> hops: tensor([[0.0000, 0.0014, 0.0000, ..., 0.0000, 0.0445, 0.0000], + [0.0009, 0.0000, 0.0185, ..., 0.0504, 0.0000, 0.0000], + [0.0000, 0.0094, 0.0000, ..., 0.0000, 0.0000, 0.0000], + ..., + [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], + device='cuda:0') + diff --git a/vs_3090 b/vs_3090 new file mode 100644 index 0000000..cfa0aeb --- /dev/null +++ b/vs_3090 @@ -0,0 +1,223 @@ +input_2D: [[[[-0.17325002 -0.17840627] + [-0.16224998 -0.18938544] + [-0.1595 0.00549481] + ... + [-0.12924999 -0.38701043] + [-0.12924999 -0.28819796] + [-0.23925 -0.26623955]] + + [[ 0.11524999 -0.21262503] + [ 0.06274998 -0.21262503] + [ 0.05225003 -0.02362502] + ... + [ 0.04174995 -0.40162498] + [ 0.00762498 -0.28875 ] + [ 0.07587504 -0.28875 ]] + + [[ 0.126382 -0.27275002] + [ 0.17350698 -0.28324997] + [ 0.18659711 -0.10212499] + ... + [ 0.1892153 -0.47224998] + [ 0.22063196 -0.36725003] + [ 0.16303468 -0.41974998]] + + [[ 0.11715972 -0.28818753] + [ 0.10402083 -0.27506253] + [ 0.07774305 -0.03881248] + ... + [ 0.08759713 -0.5277187 ] + [ 0.08759713 -0.37678126] + [ 0.22884035 -0.3669375 ]]] + + + [[[-0.17236805 -0.17731252] + [-0.16139579 -0.18827084] + [-0.15865278 0.00623951] + ... + [-0.12847918 -0.38552085] + [-0.12573606 -0.28415623] + [-0.23351866 -0.263042 ]] + + [[ 0.11031246 -0.21437502] + [ 0.06462502 -0.2116875 ] + [ 0.05387497 -0.02356249] + ... + [ 0.04312503 -0.4051875 ] + [ 0.01087499 -0.28693748] + [ 0.07537496 -0.2815625 ]] + + [[ 0.12379158 -0.2751823 ] + [ 0.17327082 -0.27778125] + [ 0.18889582 -0.10105205] + ... + [ 0.18889582 -0.47530204] + [ 0.21754158 -0.36354685] + [ 0.15504158 -0.41552603]] + + [[ 0.12193751 -0.28907296] + [ 0.10231245 -0.27596876] + [ 0.0826875 -0.04009375] + ... + [ 0.08595836 -0.5216719 ] + [ 0.09904158 -0.37752607] + [ 0.23314583 -0.35459378]]] + + + [[[-0.1726042 -0.17912498] + [-0.1616875 -0.17912498] + [-0.1562292 0.00645826] + ... + [-0.12893748 -0.3865417 ] + [-0.12893748 -0.28829166] + [-0.23810416 -0.2555417 ]] + + [[ 0.11325002 -0.20937502] + [ 0.063375 -0.20674998] + [ 0.05025005 -0.02037501] + ... + [ 0.03974998 -0.403625 ] + [ 0.00825 -0.28812498] + [ 0.07912505 -0.28025 ]] + + [[ 0.126382 -0.27367705] + [ 0.17350698 -0.28415626] + [ 0.18659711 -0.10339063] + ... + [ 0.1892153 -0.47278124] + [ 0.22063196 -0.36798954] + [ 0.16303468 -0.40990627]] + + [[ 0.12304163 -0.28884378] + [ 0.1034584 -0.27578124] + [ 0.07408333 -0.03739062] + ... + [ 0.08713889 -0.52070314] + [ 0.08713889 -0.37048438] + [ 0.23074996 -0.34435937]]] + + + ... + + + [[[-0.15943056 -0.18728128] + [-0.15660417 -0.19010422] + [-0.15660417 0.00185416] + ... + [-0.148125 -0.39617708] + [-0.1452986 -0.29172918] + [-0.23856944 -0.25503126]] + + [[ 0.09365964 -0.20290625] + [ 0.05642354 -0.20290625] + [ 0.05110407 -0.02192706] + ... + [ 0.03248608 -0.40251565] + [-0.01006943 -0.28008854] + [ 0.08568048 -0.2375052 ]] + + [[ 0.14960408 -0.274875 ] + [ 0.18345833 -0.27747917] + [ 0.19387496 -0.10560417] + ... + [ 0.20168746 -0.47279167] + [ 0.23554158 -0.3660208 ] + [ 0.14960408 -0.3582083 ]] + + [[ 0.10611105 -0.2872292 ] + [ 0.09300005 -0.28068748] + [ 0.07988894 -0.03210416] + ... + [ 0.10611105 -0.5292708 ] + [ 0.07988894 -0.36572918] + [ 0.20116663 -0.31012502]]] + + + [[[-0.15781945 -0.17923954] + [-0.15502083 -0.18764582] + [-0.15502083 0.00289581] + ... + [-0.14382643 -0.40060422] + [-0.13822919 -0.2997292 ] + [-0.23338199 -0.25489584]] + + [[ 0.09533334 -0.20466667] + [ 0.05533338 -0.20200002] + [ 0.04999995 -0.02066666] + ... + [ 0.03133333 -0.402 ] + [-0.01133329 -0.2793333 ] + [ 0.08466661 -0.23666668]] + + [[ 0.1532222 -0.27335936] + [ 0.18455553 -0.2785781 ] + [ 0.1976111 -0.09853125] + ... + [ 0.1976111 -0.47428125] + [ 0.22894442 -0.36990625] + [ 0.15061104 -0.35946876]] + + [[ 0.09612501 -0.29290107] + [ 0.09612501 -0.27979687] + [ 0.07977092 -0.03409376] + ... + [ 0.10593748 -0.5320521 ] + [ 0.07977092 -0.36169794] + [ 0.2007916 -0.3125573 ]]] + + + [[[-0.15511107 -0.18240628] + [-0.15511107 -0.18796876] + [-0.15233332 0.00393746] + ... + [-0.144 -0.39934376] + [-0.144 -0.30478123] + [-0.23288894 -0.25471875]] + + [[ 0.08933342 -0.20692188] + [ 0.05466664 -0.20426041] + [ 0.04666662 -0.02061981] + ... + [ 0.03333342 -0.40653127] + [-0.00666666 -0.28676564] + [ 0.08666658 -0.23619795]] + + [[ 0.15420842 -0.2762708 ] + [ 0.18285418 -0.278875 ] + [ 0.19847918 -0.10179168] + ... + [ 0.19587505 -0.4741875 ] + [ 0.22712505 -0.3648125 ] + [ 0.1456331 -0.36036688]] + + [[ 0.09231246 -0.28507295] + [ 0.09885418 -0.28507295] + [ 0.07595837 -0.02626565] + ... + [ 0.10212505 -0.530776 ] + [ 0.08577085 -0.36369792] + [ 0.20679164 -0.30800518]]]] + + +=======> hops: tensor([[0.0000, 0.0014, 0.0000, ..., 0.0000, 0.0445, 0.0000], + [0.0009, 0.0000, 0.0185, ..., 0.0504, 0.0000, 0.0000], + [0.0000, 0.0094, 0.0000, ..., 0.0000, 0.0000, 0.0000], + ..., + [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]], + device='cuda:0') + +=======> hops: tensor([[0.0000e+00, 4.9881e-05, 0.0000e+00, ..., 0.0000e+00, 6.5165e-03, + 0.0000e+00], + [3.2784e-05, 0.0000e+00, 2.7641e-02, ..., 7.3896e-02, 0.0000e+00, + 0.0000e+00], + [0.0000e+00, 1.1971e-02, 0.0000e+00, ..., 0.0000e+00, 0.0000e+00, + 0.0000e+00], + ..., + [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, 0.0000e+00, + 0.0000e+00], + [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, 0.0000e+00, + 0.0000e+00], + [0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, 0.0000e+00, + 0.0000e+00]], device='cuda:1') \ No newline at end of file