diff --git a/configs/skeletongait/README.md b/configs/skeletongait/README.md new file mode 100644 index 0000000..7563a68 --- /dev/null +++ b/configs/skeletongait/README.md @@ -0,0 +1,89 @@ +# SkeletonGait: Gait Recognition Using Skeleton Maps + +This [paper](https://arxiv.org/abs/2311.13444) has been accepted by AAAI 2024. + +## Generating Heatmap and Training Steps + +### Step 1: Generating Heatmap +Leveraging the power of Distributed Data Parallel (DDP), we've streamlined the heatmap generation process. Below is the script to initiate the generation: +``` +CUDA_VISIBLE_DEVICES=0,1,2,3 \ +python -m torch.distributed.launch \ +--nproc_per_node=4 \ +datasets/pretreatment_heatmap.py \ +--pose_data_path= \ +--save_root= \ +--dataset_name= +``` + +Parameter Guide: +- `--pose_data_path`: Specifies the directory containing the pose data files (`.pkl`, ID-Level). This is **required**. +- `--save_root`: Designates the root directory for storing the generated heatmap files (`.pkl`, ID-Level). This is **required**. +- `--dataset_name`: The name of the dataset undergoing preprocessing. This is required. +- `--ext_name`: An **optional** suffix for the 'save_root' directory to facilitate identification. Defaults to an empty string. +- `--heatmap_cfg_path`: Path to the configuration file of the heatmap generator. The default setting is `configs/skeletongait/pretreatment_heatmap.yaml`. + +Note: If your pose data follows the COCO 18 format (for instance, OU-MVLP pose data or data extracted using [OpenPose](https://github.com/CMU-Perceptual-Computing-Lab/openpose) in COCO format), ensure to set `transfer_to_coco17` to True in the configuration file `configs/skeletongait/pretreatment_heatmap.yaml`. + + +**Optional** + +### Step 2: Creating Symbolic Links for Heatmap and Silhouette Data + +The script to symlink heatmaps and silouettes is as follows: + +``` +python datasets/ln_sil_heatmap.py \ +--heatmap_data_path= \ +--silhouette_data_path= \ +--output_path= +``` + +Parameter Guide: +- `--heatmap_data_path`: The **absolute** path to your heatmap data. This is **required**. +- `--silhouette_data_path`: The **absolute** path to your silhouette data. This is **required**. +- `--output_path`: Designates the directory for linked output data. This is **required**. +- `--dataset_pkl_ext_name`: An **optional** parameter to specify the extension for `.pkl` silhouette files. Defaults to `.pkl`. CCPG is `aligned-sils.pkl`, SUSTech-1K is `Camera-Sils_aligned.pkl`, and other is `.pkl`. + +### Step3: Training SkeletonGait or SkeletonGait++ + +The script to SkeletonGait is as follows: + +``` +CUDA_VISIBLE_DEVICES=0,1,2,3 \ + python -m torch.distributed.launch \ + --nproc_per_node=4 opengait/main.py \ + --cfgs ./configs/skeletongait/skeletongait_Gait3D.yaml \ + --phase train --log_to_file +``` + +The script to SkeletonGait++ is as follows: + +``` +CUDA_VISIBLE_DEVICES=0,1,2,3 \ + python -m torch.distributed.launch \ + --nproc_per_node=4 opengait/main.py \ + --cfgs ./configs/skeletongait/skeletongait++_Gait3D.yaml \ + --phase train --log_to_file +``` + +## Performance for SkeletonGait and SkeletonGait++ + +### SkeletonGait +| Datasets | `Rank1` | Configuration | +|---------------------|---------|----------------------------------------------| +| CCPG | CL: 52.4, UP: 65.4, DN: 72.8, BG: 80.9 | [skeletongait_CCPG.yaml](./skeletongait_CCPG.yaml) | +| OU-MVLP (AlphaPose) | TODO | [skeletongait_OUMVLP.yaml](./skeletongait_OUMVLP.yaml) | +| SUSTech-1K | Normal: 54.2, Bag: 51.7, Clothing: 21.34, Carrying: 51.59, Umberalla: 44.5, Uniform: 53.37, Occlusion: 67.07, Night: 44.15, Overall: 51.46 | [skeletongait_SUSTech1K.yaml](./skeletongait_SUSTech1K.yaml) | +| Gait3D | 38.1 | [skeletongait_Gait3D.yaml](./skeletongait_Gait3D.yaml) | +| GREW | TODO | [skeletongait_GREW.yaml](./skeletongait_GREW.yaml) | + +### SkeletonGait++ +| Datasets | `Rank1` | Configuration | +|---------------------|---------|-------------------------------------------------| +| CCPG | CL: 90.1, UP: 95.0, DN: 92.9, BG: 97.0 | [skeletongait++_CCPG.yaml](./skeletongait++_CCPG.yaml) | +| SUSTech-1K | Normal: 85.09, Bag: 82.90, Clothing: 46.53, Carrying: 81.88, Umberalla: 80.76, Uniform: 82.50, Occlusion: 86.16, Night: 47.48, Overall: 81.33 | [skeletongait++_SUSTech1K.yaml](./skeletongait++_SUSTech1K.yaml) | +| Gait3D | 77.40 | [skeletongait++_Gait3D.yaml](./skeletongait++_Gait3D.yaml) | +| GREW | 87.04 | [skeletongait++_GREW.yaml](./skeletongait++_GREW.yaml) | + + diff --git a/configs/skeletongait/pretreatment_heatmap.yaml b/configs/skeletongait/pretreatment_heatmap.yaml new file mode 100644 index 0000000..8f8c7bf --- /dev/null +++ b/configs/skeletongait/pretreatment_heatmap.yaml @@ -0,0 +1,25 @@ +coco18tococo17_args: + transfer_to_coco17: False # OU-MVLP and CCPG is True, Other is False + +padkeypoints_args: + pad_method: knn # knn or simple + use_conf: True # Indicates whether confidence scores. + +norm_args: + pose_format: coco # coco or openpose-x where 'x' can be either 18 or 25, indicating the number of keypoints used by the OpenPose model + use_conf: ${padkeypoints_args.use_conf} + heatmap_image_height: 128 # Sets the height (in pixels) for the heatmap images that will be normlization + +heatmap_generator_args: + sigma: 8.0 # The standard deviation of the Gaussian kernel used to generate the heatmaps + use_score: ${padkeypoints_args.use_conf} + img_h: ${norm_args.heatmap_image_height} + img_w: ${norm_args.heatmap_image_height} + with_limb: null # this auto set in the code + with_kp: null # this auto set in the code + +align_args: + align: True # Indicates whether the images will be aligned + final_img_size: 64 # Sets the size (in pixels) for the final images + offset: 0 + heatmap_image_size: ${norm_args.heatmap_image_height} \ No newline at end of file diff --git a/configs/skeletongait/skeletongait++_CCPG.yaml b/configs/skeletongait/skeletongait++_CCPG.yaml new file mode 100644 index 0000000..c051c05 --- /dev/null +++ b/configs/skeletongait/skeletongait++_CCPG.yaml @@ -0,0 +1,98 @@ +data_cfg: + dataset_name: CCPG + dataset_root: your_path + dataset_partition: ./datasets/CCPG/CCPG.json + num_workers: 1 + data_in_use: [True, True] # heatmap, sil + remove_no_gallery: false # Remove probe if no gallery for it + test_dataset_name: CCPG + +evaluator_cfg: + enable_float16: true + restore_ckpt_strict: true + restore_hint: 60000 + save_name: SkeletonGaitPP + eval_func: evaluate_CCPG + sampler: + batch_shuffle: false + batch_size: 4 + sample_type: all_ordered # all indicates whole sequence used to test, while ordered means input sequence by its natural order; Other options: fixed_unordered + frames_all_limit: 720 # limit the number of sampled frames to prevent out of memory + metric: euc # cos + transform: + - type: BaseSilCuttingTransform + +loss_cfg: + - loss_term_weights: 1.0 + margin: 0.2 + type: TripletLoss + log_prefix: triplet + - loss_term_weights: 1.0 + scale: 16 + type: CrossEntropyLoss + log_prefix: softmax + log_accuracy: true + +model_cfg: + model: SkeletonGaitPP + Backbone: + in_channels: 3 + blocks: + - 1 + - 1 + - 1 + - 1 + C: 2 + SeparateBNNecks: + class_num: 100 + +optimizer_cfg: + lr: 0.1 + momentum: 0.9 + solver: SGD + weight_decay: 0.0005 + +scheduler_cfg: + gamma: 0.1 + milestones: # Learning Rate Reduction at each milestones + - 20000 + - 30000 + - 40000 + scheduler: MultiStepLR + +trainer_cfg: + enable_float16: true # half_percesion float for memory reduction and speedup + fix_BN: false + log_iter: 100 + restore_ckpt_strict: true + restore_hint: 30000 + save_iter: 10000 + save_name: DeepGaitV2_P3D_GaitMap_B1C2_Sigma-8.0_Hot_False_Align-True_OpenGaitDA-True_ML_LowLevel + sync_BN: true + total_iter: 60000 + sampler: + batch_shuffle: true + batch_size: + - 8 # TripletSampler, batch_size[0] indicates Number of Identity + - 16 # batch_size[1] indicates Samples sequqnce for each Identity + frames_num_fixed: 30 # fixed frames number for training + frames_skip_num: 4 + sample_type: fixed_ordered # fixed control input frames number, unordered for controlling order of input tensor; Other options: unfixed_ordered or all_ordered + type: TripletSampler + transform: + # - type: Compose + # trf_cfg: + # - type: BaseSilCuttingTransform + # - type: RandomHorizontalFlip + # prob: 0.5 + - type: Compose + trf_cfg: + - type: RandomPerspective + prob: 0.2 + - type: BaseSilCuttingTransform + - type: RandomHorizontalFlip + prob: 0.2 + - type: RandomRotate + prob: 0.2 + + diff --git a/configs/skeletongait/skeletongait++_GREW.yaml b/configs/skeletongait/skeletongait++_GREW.yaml new file mode 100644 index 0000000..0cbe899 --- /dev/null +++ b/configs/skeletongait/skeletongait++_GREW.yaml @@ -0,0 +1,92 @@ +data_cfg: + dataset_name: GREW + dataset_root: your_path + dataset_partition: ./datasets/GREW/GREW.json + num_workers: 1 + data_in_use: [True, True] # heatmap, sil + remove_no_gallery: false # Remove probe if no gallery for it + test_dataset_name: GREW + +evaluator_cfg: + enable_float16: true + restore_ckpt_strict: true + restore_hint: 180000 + save_name: SkeletonGaitPP + eval_func: GREW_submission + sampler: + batch_shuffle: false + batch_size: 4 + sample_type: all_ordered # all indicates whole sequence used to test, while ordered means input sequence by its natural order; Other options: fixed_unordered + frames_all_limit: 720 # limit the number of sampled frames to prevent out of memory + metric: euc # cos + transform: + - type: BaseSilCuttingTransform + +loss_cfg: + - loss_term_weights: 1.0 + margin: 0.2 + type: TripletLoss + log_prefix: triplet + - loss_term_weights: 1.0 + scale: 16 + type: CrossEntropyLoss + log_prefix: softmax + log_accuracy: true + +model_cfg: + model: SkeletonGaitPP + Backbone: + in_channels: 3 + blocks: + - 1 + - 4 + - 4 + - 1 + C: 2 + SeparateBNNecks: + class_num: 20000 + +optimizer_cfg: + lr: 0.1 + momentum: 0.9 + solver: SGD + weight_decay: 0.0005 + +scheduler_cfg: + gamma: 0.1 + milestones: # Learning Rate Reduction at each milestones + - 80000 + - 120000 + - 150000 + scheduler: MultiStepLR + +trainer_cfg: + enable_float16: true # half_percesion float for memory reduction and speedup + fix_BN: false + log_iter: 100 + restore_ckpt_strict: true + restore_hint: 0 + save_iter: 30000 + save_name: SkeletonGaitPP + sync_BN: true + total_iter: 180000 + sampler: + batch_shuffle: true + batch_size: + - 32 # TripletSampler, batch_size[0] indicates Number of Identity + - 4 # batch_size[1] indicates Samples sequqnce for each Identity + frames_num_fixed: 30 # fixed frames number for training + frames_skip_num: 4 + sample_type: fixed_ordered # fixed control input frames number, unordered for controlling order of input tensor; Other options: unfixed_ordered or all_ordered + type: TripletSampler + transform: + - type: Compose + trf_cfg: + - type: RandomPerspective + prob: 0.2 + - type: BaseSilCuttingTransform + - type: RandomHorizontalFlip + prob: 0.2 + - type: RandomRotate + prob: 0.2 + diff --git a/configs/skeletongait/skeletongait++_Gait3D.yaml b/configs/skeletongait/skeletongait++_Gait3D.yaml new file mode 100644 index 0000000..8873952 --- /dev/null +++ b/configs/skeletongait/skeletongait++_Gait3D.yaml @@ -0,0 +1,93 @@ +data_cfg: + dataset_name: Gait3D + dataset_root: your_path + dataset_partition: ./datasets/Gait3D/Gait3D.json + num_workers: 1 + data_in_use: [True, True] # heatmap, sil + remove_no_gallery: false # Remove probe if no gallery for it + test_dataset_name: Gait3D + +evaluator_cfg: + enable_float16: true + restore_ckpt_strict: true + restore_hint: 60000 + save_name: SkeletonGaitPP # LowLevel + eval_func: evaluate_Gait3D + sampler: + batch_shuffle: false + batch_size: 4 + sample_type: all_ordered # all indicates whole sequence used to test, while ordered means input sequence by its natural order; Other options: fixed_unordered + frames_all_limit: 720 # limit the number of sampled frames to prevent out of memory + metric: euc # cos + transform: + - type: BaseSilTransform + +loss_cfg: + - loss_term_weights: 1.0 + margin: 0.2 + type: TripletLoss + log_prefix: triplet + - loss_term_weights: 1.0 + scale: 16 + type: CrossEntropyLoss + log_prefix: softmax + log_accuracy: true + +model_cfg: + model: SkeletonGaitPP + Backbone: + in_channels: 3 + blocks: + - 1 + - 4 + - 4 + - 1 + C: 2 + SeparateBNNecks: + class_num: 3000 + +optimizer_cfg: + lr: 0.1 + momentum: 0.9 + solver: SGD + weight_decay: 0.0005 + +scheduler_cfg: + gamma: 0.1 + milestones: # Learning Rate Reduction at each milestones + - 20000 + - 30000 + - 40000 + scheduler: MultiStepLR + +trainer_cfg: + enable_float16: true # half_percesion float for memory reduction and speedup + fix_BN: false + log_iter: 100 + restore_ckpt_strict: true + restore_hint: 0 + save_iter: 10000 + save_name: SkeletonGaitPP # LowLevel + sync_BN: true + total_iter: 60000 + sampler: + batch_shuffle: true + batch_size: + - 32 # TripletSampler, batch_size[0] indicates Number of Identity + - 4 # batch_size[1] indicates Samples sequqnce for each Identity + frames_num_fixed: 30 # fixed frames number for training + frames_skip_num: 4 + sample_type: fixed_ordered # fixed control input frames number, unordered for controlling order of input tensor; Other options: unfixed_ordered or all_ordered + type: TripletSampler + transform: + - type: Compose + trf_cfg: + - type: RandomPerspective + prob: 0.2 + - type: BaseSilTransform + - type: RandomHorizontalFlip + prob: 0.2 + - type: RandomRotate + prob: 0.2 + + diff --git a/configs/skeletongait/skeletongait++_SUSTech1K.yaml b/configs/skeletongait/skeletongait++_SUSTech1K.yaml new file mode 100644 index 0000000..65be0c8 --- /dev/null +++ b/configs/skeletongait/skeletongait++_SUSTech1K.yaml @@ -0,0 +1,92 @@ +data_cfg: + dataset_name: SUSTech1K + dataset_root: your_path + dataset_partition: ./datasets/SUSTech1K/SUSTech1K.json + num_workers: 4 + data_in_use: [True, True] # heatmap, sil + remove_no_gallery: false # Remove probe if no gallery for it + test_dataset_name: SUSTech1K + +evaluator_cfg: + enable_float16: true + restore_ckpt_strict: true + restore_hint: 50000 + save_name: SkeletonGaitPP + eval_func: evaluate_indoor_dataset #evaluate_Gait3D + sampler: + batch_shuffle: false + batch_size: 4 + sample_type: all_ordered # all indicates whole sequence used to test, while ordered means input sequence by its natural order; Other options: fixed_unordered + frames_all_limit: 720 # limit the number of sampled frames to prevent out of memory + metric: euc # cos + transform: + - type: BaseSilCuttingTransform + +loss_cfg: + - loss_term_weight: 1.0 + margin: 0.2 + type: TripletLoss + log_prefix: triplet + - loss_term_weight: 1.0 + scale: 16 + type: CrossEntropyLoss + log_prefix: softmax + log_accuracy: true + +model_cfg: + model: SkeletonGaitPP + Backbone: + in_channels: 3 + blocks: + - 1 + - 1 + - 1 + - 1 + C: 2 + SeparateBNNecks: + class_num: 250 + +optimizer_cfg: + lr: 0.1 + momentum: 0.9 + solver: SGD + weight_decay: 0.0005 + +scheduler_cfg: + gamma: 0.1 + milestones: # Learning Rate Reduction at each milestones + - 20000 + - 30000 + - 40000 + scheduler: MultiStepLR + +trainer_cfg: + enable_float16: true # half_percesion float for memory reduction and speedup + fix_BN: false + with_test: false #true + log_iter: 100 + restore_ckpt_strict: true + restore_hint: 20000 + save_iter: 10000 + save_name: SkeletonGaitPP + sync_BN: true + total_iter: 50000 + sampler: + batch_shuffle: true + batch_size: + - 8 # TripletSampler, batch_size[0] indicates Number of Identity + - 8 # batch_size[1] indicates Samples sequqnce for each Identity + frames_num_fixed: 10 # fixed frames number for training + frames_skip_num: 4 + sample_type: fixed_ordered # fixed control input frames number, unordered for controlling order of input tensor; Other options: unfixed_ordered or all_ordered + type: TripletSampler + transform: + - type: Compose + trf_cfg: + - type: RandomPerspective + prob: 0.2 + - type: BaseSilCuttingTransform + - type: RandomHorizontalFlip + prob: 0.2 + - type: RandomRotate + prob: 0.2 diff --git a/configs/skeletongait/skeletongait_CCPG.yaml b/configs/skeletongait/skeletongait_CCPG.yaml new file mode 100644 index 0000000..d16ef6a --- /dev/null +++ b/configs/skeletongait/skeletongait_CCPG.yaml @@ -0,0 +1,99 @@ +data_cfg: + dataset_name: CCPG + dataset_root: your_path + dataset_partition: ./datasets/CCPG/CCPG.json + num_workers: 1 + data_in_use: [True, False] # heatmap, sil + remove_no_gallery: false # Remove probe if no gallery for it + test_dataset_name: CCPG + +evaluator_cfg: + enable_float16: true + restore_ckpt_strict: true + restore_hint: 60000 + save_name: SkeletonGait + eval_func: evaluate_CCPG + sampler: + batch_shuffle: false + batch_size: 4 + sample_type: all_ordered # all indicates whole sequence used to test, while ordered means input sequence by its natural order; Other options: fixed_unordered + frames_all_limit: 720 # limit the number of sampled frames to prevent out of memory + metric: euc # cos + transform: + - type: BaseSilCuttingTransform + +loss_cfg: + - loss_term_weights: 1.0 + margin: 0.2 + type: TripletLoss + log_prefix: triplet + - loss_term_weights: 1.0 + scale: 16 + type: CrossEntropyLoss + log_prefix: softmax + log_accuracy: true + +model_cfg: + model: DeepGaitV2 + Backbone: + in_channels: 2 + mode: p3d + layers: + - 1 + - 1 + - 1 + - 1 + channels: + - 64 + - 128 + - 256 + - 512 + SeparateBNNecks: + class_num: 100 + use_emb2: true + +optimizer_cfg: + lr: 0.1 + momentum: 0.9 + solver: SGD + weight_decay: 0.0005 + +scheduler_cfg: + gamma: 0.1 + milestones: # Learning Rate Reduction at each milestones + - 20000 + - 30000 + - 40000 + scheduler: MultiStepLR + +trainer_cfg: + enable_float16: true # half_percesion float for memory reduction and speedup + fix_BN: false + log_iter: 100 + restore_ckpt_strict: true + restore_hint: 0 + save_iter: 10000 + save_name: SkeletonGait + sync_BN: true + total_iter: 60000 + sampler: + batch_shuffle: true + batch_size: + - 8 # TripletSampler, batch_size[0] indicates Number of Identity + - 16 # batch_size[1] indicates Samples sequqnce for each Identity + frames_num_fixed: 30 # fixed frames number for training + frames_skip_num: 4 + sample_type: fixed_ordered # fixed control input frames number, unordered for controlling order of input tensor; Other options: unfixed_ordered or all_ordered + type: TripletSampler + transform: + - type: Compose + trf_cfg: + - type: RandomPerspective + prob: 0.2 + - type: BaseSilCuttingTransform + - type: RandomHorizontalFlip + prob: 0.2 + - type: RandomRotate + prob: 0.2 + + diff --git a/configs/skeletongait/skeletongait_GREW.yaml b/configs/skeletongait/skeletongait_GREW.yaml new file mode 100644 index 0000000..dcc93a8 --- /dev/null +++ b/configs/skeletongait/skeletongait_GREW.yaml @@ -0,0 +1,97 @@ +data_cfg: + dataset_name: GREW + dataset_root: your_path + dataset_partition: ./datasets/GREW/GREW.json + num_workers: 1 + data_in_use: [True, False] # heatmap, sil + remove_no_gallery: false # Remove probe if no gallery for it + test_dataset_name: GREW + +evaluator_cfg: + enable_float16: true + restore_ckpt_strict: true + restore_hint: 180000 + save_name: SkeletonGait + eval_func: GREW_submission + sampler: + batch_shuffle: false + batch_size: 4 + sample_type: all_ordered # all indicates whole sequence used to test, while ordered means input sequence by its natural order; Other options: fixed_unordered + frames_all_limit: 720 # limit the number of sampled frames to prevent out of memory + metric: euc # cos + transform: + - type: BaseSilCuttingTransform + +loss_cfg: + - loss_term_weights: 1.0 + margin: 0.2 + type: TripletLoss + log_prefix: triplet + - loss_term_weights: 1.0 + scale: 16 + type: CrossEntropyLoss + log_prefix: softmax + log_accuracy: true + +model_cfg: + model: DeepGaitV2 + Backbone: + in_channels: 2 + mode: p3d + layers: + - 1 + - 4 + - 4 + - 1 + channels: + - 64 + - 128 + - 256 + - 512 + SeparateBNNecks: + class_num: 20000 + +optimizer_cfg: + lr: 0.1 + momentum: 0.9 + solver: SGD + weight_decay: 0.0005 + +scheduler_cfg: + gamma: 0.1 + milestones: # Learning Rate Reduction at each milestones + - 80000 + - 120000 + - 150000 + scheduler: MultiStepLR + +trainer_cfg: + enable_float16: true # half_percesion float for memory reduction and speedup + fix_BN: false + log_iter: 100 + restore_ckpt_strict: true + restore_hint: 90000 + save_iter: 30000 + save_name: SkeletonGait + sync_BN: true + total_iter: 180000 + sampler: + batch_shuffle: true + batch_size: + - 32 # TripletSampler, batch_size[0] indicates Number of Identity + - 4 # batch_size[1] indicates Samples sequqnce for each Identity + frames_num_fixed: 30 # fixed frames number for training + frames_skip_num: 4 + sample_type: fixed_ordered # fixed control input frames number, unordered for controlling order of input tensor; Other options: unfixed_ordered or all_ordered + type: TripletSampler + transform: + - type: Compose + trf_cfg: + - type: RandomPerspective + prob: 0.2 + - type: BaseSilCuttingTransform + - type: RandomHorizontalFlip + prob: 0.2 + - type: RandomRotate + prob: 0.2 + diff --git a/configs/skeletongait/skeletongait_Gait3D.yaml b/configs/skeletongait/skeletongait_Gait3D.yaml new file mode 100644 index 0000000..027f683 --- /dev/null +++ b/configs/skeletongait/skeletongait_Gait3D.yaml @@ -0,0 +1,97 @@ +data_cfg: + dataset_name: Gait3D + dataset_root: your_path + dataset_partition: ./datasets/Gait3D/Gait3D.json + num_workers: 1 + data_in_use: [True, False] # heatmap, sil + remove_no_gallery: false # Remove probe if no gallery for it + test_dataset_name: Gait3D + +evaluator_cfg: + enable_float16: true + restore_ckpt_strict: true + restore_hint: 60000 + save_name: SkeletonGait + eval_func: evaluate_Gait3D + sampler: + batch_shuffle: false + batch_size: 4 + sample_type: all_ordered # all indicates whole sequence used to test, while ordered means input sequence by its natural order; Other options: fixed_unordered + frames_all_limit: 720 # limit the number of sampled frames to prevent out of memory + metric: euc # cos + transform: + - type: BaseSilCuttingTransform + +loss_cfg: + - loss_term_weights: 1.0 + margin: 0.2 + type: TripletLoss + log_prefix: triplet + - loss_term_weights: 1.0 + scale: 16 + type: CrossEntropyLoss + log_prefix: softmax + log_accuracy: true + +model_cfg: + model: DeepGaitV2 + Backbone: + in_channels: 2 + mode: p3d + layers: + - 1 + - 4 + - 4 + - 1 + channels: + - 64 + - 128 + - 256 + - 512 + SeparateBNNecks: + class_num: 3000 + +optimizer_cfg: + lr: 0.1 + momentum: 0.9 + solver: SGD + weight_decay: 0.0005 + +scheduler_cfg: + gamma: 0.1 + milestones: # Learning Rate Reduction at each milestones + - 20000 + - 30000 + - 40000 + scheduler: MultiStepLR + +trainer_cfg: + enable_float16: true # half_percesion float for memory reduction and speedup + fix_BN: false + log_iter: 100 + restore_ckpt_strict: true + restore_hint: 0 + save_iter: 10000 + save_name: SkeletonGait + sync_BN: true + total_iter: 60000 + sampler: + batch_shuffle: true + batch_size: + - 32 # TripletSampler, batch_size[0] indicates Number of Identity + - 4 # batch_size[1] indicates Samples sequqnce for each Identity + frames_num_fixed: 30 # fixed frames number for training + frames_skip_num: 4 + sample_type: fixed_ordered # fixed control input frames number, unordered for controlling order of input tensor; Other options: unfixed_ordered or all_ordered + type: TripletSampler + transform: + - type: Compose + trf_cfg: + - type: RandomPerspective + prob: 0.2 + - type: BaseSilCuttingTransform + - type: RandomHorizontalFlip + prob: 0.2 + - type: RandomRotate + prob: 0.2 + diff --git a/configs/skeletongait/skeletongait_OUMVLP.yaml b/configs/skeletongait/skeletongait_OUMVLP.yaml new file mode 100644 index 0000000..b02b2dc --- /dev/null +++ b/configs/skeletongait/skeletongait_OUMVLP.yaml @@ -0,0 +1,92 @@ +data_cfg: + dataset_name: OUMVLP + dataset_root: your_path + dataset_partition: ./datasets/OUMVLP/OUMVLP.json + num_workers: 1 + data_in_use: [True, False] # heatmap, sil + remove_no_gallery: false # Remove probe if no gallery for it + test_dataset_name: OUMVLP + +evaluator_cfg: + enable_float16: true + restore_ckpt_strict: true + restore_hint: 120000 + save_name: SkeletonGait + sampler: + batch_shuffle: false + batch_size: 4 + sample_type: all_ordered # all indicates whole sequence used to test, while ordered means input sequence by its natural order; Other options: fixed_unordered + frames_all_limit: 720 # limit the number of sampled frames to prevent out of memory + metric: euc # cos + transform: + - type: BaseSilCuttingTransform + +loss_cfg: + - loss_term_weights: 1.0 + margin: 0.2 + type: TripletLoss + log_prefix: triplet + - loss_term_weights: 1.0 + scale: 16 + type: CrossEntropyLoss + log_prefix: softmax + log_accuracy: true + +model_cfg: + model: DeepGaitV2 + Backbone: + in_channels: 2 + mode: p3d + layers: + - 1 + - 1 + - 1 + - 1 + channels: + - 64 + - 128 + - 256 + - 512 + SeparateBNNecks: + class_num: 5153 + +optimizer_cfg: + lr: 0.1 + momentum: 0.9 + solver: SGD + weight_decay: 0.0005 + +scheduler_cfg: + gamma: 0.1 + milestones: # Learning Rate Reduction at each milestones + - 60000 + - 80000 + - 100000 + scheduler: MultiStepLR + +trainer_cfg: + enable_float16: true # half_percesion float for memory reduction and speedup + fix_BN: false + log_iter: 100 + restore_ckpt_strict: true + restore_hint: 0 + save_iter: 20000 + save_name: SkeletonGait + sync_BN: true + total_iter: 120000 + sampler: + batch_shuffle: true + batch_size: + - 32 # TripletSampler, batch_size[0] indicates Number of Identity + - 8 # batch_size[1] indicates Samples sequqnce for each Identity + frames_num_fixed: 30 # fixed frames number for training + frames_skip_num: 4 + sample_type: fixed_ordered # fixed control input frames number, unordered for controlling order of input tensor; Other options: unfixed_ordered or all_ordered + type: TripletSampler + transform: + - type: Compose + trf_cfg: + - type: BaseSilCuttingTransform + - type: RandomHorizontalFlip + prob: 0.5 + diff --git a/configs/skeletongait/skeletongait_SUSTech1K.yaml b/configs/skeletongait/skeletongait_SUSTech1K.yaml new file mode 100644 index 0000000..76d51ec --- /dev/null +++ b/configs/skeletongait/skeletongait_SUSTech1K.yaml @@ -0,0 +1,96 @@ +data_cfg: + dataset_name: SUSTech1K + dataset_root: your_path + dataset_partition: ./datasets/SUSTech1K/SUSTech1K.json + num_workers: 4 + data_in_use: [True, False] # heatmap, sil + remove_no_gallery: false # Remove probe if no gallery for it + test_dataset_name: SUSTech1K + +evaluator_cfg: + enable_float16: true + restore_ckpt_strict: true + restore_hint: 50000 + save_name: SkeletonGait + eval_func: evaluate_indoor_dataset #evaluate_Gait3D + sampler: + batch_shuffle: false + batch_size: 4 + sample_type: all_ordered # all indicates whole sequence used to test, while ordered means input sequence by its natural order; Other options: fixed_unordered + frames_all_limit: 720 # limit the number of sampled frames to prevent out of memory + metric: euc # cos + transform: + - type: BaseSilCuttingTransform + +loss_cfg: + - loss_term_weight: 1.0 + margin: 0.2 + type: TripletLoss + log_prefix: triplet + - loss_term_weight: 1.0 + scale: 16 + type: CrossEntropyLoss + log_prefix: softmax + log_accuracy: true + +model_cfg: + model: DeepGaitV2 + Backbone: + in_channels: 2 + mode: p3d + layers: + - 1 + - 1 + - 1 + - 1 + channels: + - 64 + - 128 + - 256 + - 512 + SeparateBNNecks: + class_num: 250 + +optimizer_cfg: + lr: 0.1 + momentum: 0.9 + solver: SGD + weight_decay: 0.0005 + +scheduler_cfg: + gamma: 0.1 + milestones: # Learning Rate Reduction at each milestones + - 20000 + - 30000 + - 40000 + scheduler: MultiStepLR + +trainer_cfg: + enable_float16: true # half_percesion float for memory reduction and speedup + fix_BN: false + with_test: true #true + log_iter: 100 + restore_ckpt_strict: true + restore_hint: 0 + save_iter: 10000 + save_name: SkeletonGait + sync_BN: true + total_iter: 50000 + sampler: + batch_shuffle: true + batch_size: + - 8 # TripletSampler, batch_size[0] indicates Number of Identity + - 8 # batch_size[1] indicates Samples sequqnce for each Identity + frames_num_fixed: 10 # fixed frames number for training + sample_type: fixed_unordered # fixed control input frames number, unordered for controlling order of input tensor; Other options: unfixed_ordered or all_ordered + type: TripletSampler + transform: + - type: Compose + trf_cfg: + - type: RandomPerspective + prob: 0.2 + - type: BaseSilCuttingTransform + - type: RandomHorizontalFlip + prob: 0.2 + - type: RandomRotate + prob: 0.2 diff --git a/datasets/ln_sil_heatmap.py b/datasets/ln_sil_heatmap.py new file mode 100644 index 0000000..7348443 --- /dev/null +++ b/datasets/ln_sil_heatmap.py @@ -0,0 +1,77 @@ +import os +import sys +import argparse +from tqdm import tqdm +from glob import glob + +def get_args(): + parser = argparse.ArgumentParser(description='Symlink silouette data and pose data into the same folder for SkeletonGait++ training.') + parser.add_argument('--heatmap_data_path', type=str, required=True, help="path of heatmap data, must be the absolute path.") + parser.add_argument('--silhouette_data_path', type=str, required=True, help="path of silouette data, must be the absolute path.") + parser.add_argument('--dataset_pkl_ext_name', type=str, default='.pkl', help="The extent name for .pkl files of silouettes data.") + parser.add_argument('--output_path', type=str, required=True, help="path of output data") + opt = parser.parse_args() + return opt + +def main(): + opt = get_args() + heatmap_data_path = opt.heatmap_data_path + silhouette_data_path = opt.silhouette_data_path + if not os.path.exists(heatmap_data_path): + print(f"heatmap data path {heatmap_data_path} does not exist.") + sys.exit(1) + if not os.path.exists(silhouette_data_path): + print(f"silouette data path {silhouette_data_path} does not exist.") + sys.exit(1) + + all_heatmap_files = sorted(glob(os.path.join(heatmap_data_path, "*/*/*/*.pkl"))) + all_silouette_files = sorted(glob(os.path.join(silhouette_data_path, f"*/*/*/*{opt.dataset_pkl_ext_name}"))) + # print(len(all_heatmap_files), len(all_silouette_files)) + # assert len(all_heatmap_files) == len(all_silouette_files), "The number of heatmap files and silouette files are not equal." + + if len(all_heatmap_files) >= len(all_silouette_files): + for heatmap_file in tqdm(all_heatmap_files): + tmp_list = heatmap_file.split('/') + sil_folder = os.path.join(silhouette_data_path, *tmp_list[-4:-1]) + if not os.path.exists(sil_folder): + print(f"silouette folder {sil_folder} does not exist.") + continue + else: + silouette_file = sorted(glob(os.path.join(sil_folder, f"*{opt.dataset_pkl_ext_name}")))[0] + + output_file = os.path.join(opt.output_path, *tmp_list[-4:-1]) + os.makedirs(output_file, exist_ok=True) + os.system(f"ln -s {silouette_file} {output_file}/1_sil.pkl") + os.system(f"ln -s {heatmap_file} {output_file}/0_heatmap.pkl") + else: + for silouette_file in tqdm(all_silouette_files): + tmp_list = silouette_file.split('/') + heatmap_folder = os.path.join(heatmap_data_path, *tmp_list[-4:-1]) + if not os.path.exists(heatmap_folder): + print(f"heatmap folder {heatmap_folder} does not exist.") + continue + else: + heatmap_file = sorted(glob(os.path.join(heatmap_folder, "*.pkl")))[0] + + output_file = os.path.join(opt.output_path, *tmp_list[-4:-1]) + os.makedirs(output_file, exist_ok=True) + os.system(f"ln -s {silouette_file} {output_file}/1_sil.pkl") + os.system(f"ln -s {heatmap_file} {output_file}/0_heatmap.pkl") + + print("Done! Output data is in ", opt.output_path) + + # for tmp_file in tqdm(iter_files): + # heatmap_file = all_heatmap_files[i] + # silouette_file = all_silouette_files[i] + # sil_tmp_list = silouette_file.split('/') + # heatmap_tmp_list = heatmap_file.split('/') + # if + + # output_file = os.path.join(opt.output_path, *tmp_list[-4:-1]) + # os.makedirs(output_file, exist_ok=True) + + # os.system(f"ln -s {silouette_file} {output_file}/1_sil.pkl") + # os.system(f"ln -s {heatmap_file} {output_file}/0_heatmap.pkl") + +if __name__ == "__main__": + main() diff --git a/datasets/pretreatment_heatmap.py b/datasets/pretreatment_heatmap.py new file mode 100644 index 0000000..effa922 --- /dev/null +++ b/datasets/pretreatment_heatmap.py @@ -0,0 +1,712 @@ +import os +import cv2 +import yaml +import math +import torch +import random +import pickle +import argparse +import numpy as np +from glob import glob +from tqdm import tqdm +import matplotlib.cm as cm +import torch.distributed as dist +from torchvision import transforms as T +from torch.utils.data import Dataset, DataLoader +from sklearn.impute import KNNImputer, SimpleImputer + +torch.manual_seed(347) +random.seed(347) + +######################################################################################################### +# The following code is the base class code for generating heatmap. +######################################################################################################### + +class GeneratePoseTarget: + """Generate pseudo heatmaps based on joint coordinates and confidence. + Required keys are "keypoint", "img_shape", "keypoint_score" (optional), + added or modified keys are "imgs". + Args: + sigma (float): The sigma of the generated gaussian map. Default: 0.6. + use_score (bool): Use the confidence score of keypoints as the maximum + of the gaussian maps. Default: True. + with_kp (bool): Generate pseudo heatmaps for keypoints. Default: True. + with_limb (bool): Generate pseudo heatmaps for limbs. At least one of + 'with_kp' and 'with_limb' should be True. Default: False. + skeletons (tuple[tuple]): The definition of human skeletons. + Default: ((0, 1), (0, 2), (1, 3), (2, 4), (0, 5), (5, 7), (7, 9), + (0, 6), (6, 8), (8, 10), (5, 11), (11, 13), (13, 15), + (6, 12), (12, 14), (14, 16), (11, 12)), + which is the definition of COCO-17p skeletons. + double (bool): Output both original heatmaps and flipped heatmaps. + Default: False. + left_kp (tuple[int]): Indexes of left keypoints, which is used when + flipping heatmaps. Default: (1, 3, 5, 7, 9, 11, 13, 15), + which is left keypoints in COCO-17p. + right_kp (tuple[int]): Indexes of right keypoints, which is used when + flipping heatmaps. Default: (2, 4, 6, 8, 10, 12, 14, 16), + which is right keypoints in COCO-17p. + left_limb (tuple[int]): Indexes of left limbs, which is used when + flipping heatmaps. Default: (1, 3, 5, 7, 9, 11, 13, 15), + which is left limbs of skeletons we defined for COCO-17p. + right_limb (tuple[int]): Indexes of right limbs, which is used when + flipping heatmaps. Default: (2, 4, 6, 8, 10, 12, 14, 16), + which is right limbs of skeletons we defined for COCO-17p. + """ + + def __init__(self, + sigma=0.6, + use_score=True, + with_kp=True, + with_limb=False, + skeletons=((0, 1), (0, 2), (1, 3), (2, 4), (0, 5), (5, 7), + (7, 9), (0, 6), (6, 8), (8, 10), (5, 11), (11, 13), + (13, 15), (6, 12), (12, 14), (14, 16), (11, 12)), + double=False, + left_kp=(1, 3, 5, 7, 9, 11, 13, 15), + right_kp=(2, 4, 6, 8, 10, 12, 14, 16), + left_limb=(0, 2, 4, 5, 6, 10, 11, 12), + right_limb=(1, 3, 7, 8, 9, 13, 14, 15), + scaling=1., + eps= 1e-3, + img_h=64, + img_w = 64): + + self.sigma = sigma + self.use_score = use_score + self.with_kp = with_kp + self.with_limb = with_limb + self.double = double + self.eps = eps + + assert self.with_kp + self.with_limb == 1, ('One of "with_limb" and "with_kp" should be set as True.') + self.left_kp = left_kp + self.right_kp = right_kp + self.skeletons = skeletons + self.left_limb = left_limb + self.right_limb = right_limb + self.scaling = scaling + self.img_h = img_h + self.img_w = img_w + + def generate_a_heatmap(self, arr, centers, max_values, point_center): + """Generate pseudo heatmap for one keypoint in one frame. + Args: + arr (np.ndarray): The array to store the generated heatmaps. Shape: img_h * img_w. + centers (np.ndarray): The coordinates of corresponding keypoints (of multiple persons). Shape: 1 * 2. + max_values (np.ndarray): The max values of each keypoint. Shape: (1, ). + point_center: Shape: (1, 2) + Returns: + np.ndarray: The generated pseudo heatmap. + """ + + sigma = self.sigma + img_h, img_w = arr.shape + + for center, max_value in zip(centers, max_values): + if max_value < self.eps: + continue + + mu_x, mu_y = center[0], center[1] + + tmp_st_x = int(mu_x - 3 * sigma) + tmp_ed_x = int(mu_x + 3 * sigma) + tmp_st_y = int(mu_y - 3 * sigma) + tmp_ed_y = int(mu_y + 3 * sigma) + + st_x = max(tmp_st_x, 0) + ed_x = min(tmp_ed_x + 1, img_w) + st_y = max(tmp_st_y, 0) + ed_y = min(tmp_ed_y + 1, img_h) + x = np.arange(st_x, ed_x, 1, np.float32) + y = np.arange(st_y, ed_y, 1, np.float32) + + # if the keypoint not in the heatmap coordinate system + if not (len(x) and len(y)): + continue + y = y[:, None] + + patch = np.exp(-((x - mu_x)**2 + (y - mu_y)**2) / 2 / sigma**2) + patch = patch * max_value + + arr[st_y:ed_y, st_x:ed_x] = np.maximum(arr[st_y:ed_y, st_x:ed_x], patch) + + def generate_a_limb_heatmap(self, arr, starts, ends, start_values, end_values, point_center): + """Generate pseudo heatmap for one limb in one frame. + Args: + arr (np.ndarray): The array to store the generated heatmaps. Shape: img_h * img_w. + starts (np.ndarray): The coordinates of one keypoint in the corresponding limbs. Shape: 1 * 2. + ends (np.ndarray): The coordinates of the other keypoint in the corresponding limbs. Shape: 1 * 2. + start_values (np.ndarray): The max values of one keypoint in the corresponding limbs. Shape: (1, ). + end_values (np.ndarray): The max values of the other keypoint in the corresponding limbs. Shape: (1, ). + Returns: + np.ndarray: The generated pseudo heatmap. + """ + + sigma = self.sigma + img_h, img_w = arr.shape + + for start, end, start_value, end_value in zip(starts, ends, start_values, end_values): + value_coeff = min(start_value, end_value) + if value_coeff < self.eps: + continue + + min_x, max_x = min(start[0], end[0]), max(start[0], end[0]) + min_y, max_y = min(start[1], end[1]), max(start[1], end[1]) + + + + tmp_min_x = int(min_x - 3 * sigma) + tmp_max_x = int(max_x + 3 * sigma) + tmp_min_y = int(min_y - 3 * sigma) + tmp_max_y = int(max_y + 3 * sigma) + + min_x = max(tmp_min_x, 0) + max_x = min(tmp_max_x + 1, img_w) + min_y = max(tmp_min_y, 0) + max_y = min(tmp_max_y + 1, img_h) + + x = np.arange(min_x, max_x, 1, np.float32) + y = np.arange(min_y, max_y, 1, np.float32) + + if not (len(x) and len(y)): + continue + + y = y[:, None] + x_0 = np.zeros_like(x) + y_0 = np.zeros_like(y) + + # distance to start keypoints + d2_start = ((x - start[0])**2 + (y - start[1])**2) + + # distance to end keypoints + d2_end = ((x - end[0])**2 + (y - end[1])**2) + + # the distance between start and end keypoints. + d2_ab = ((start[0] - end[0])**2 + (start[1] - end[1])**2) + + if d2_ab < 1: + self.generate_a_heatmap(arr, start[None], start_value[None], point_center) + continue + + coeff = (d2_start - d2_end + d2_ab) / 2. / d2_ab + + a_dominate = coeff <= 0 + b_dominate = coeff >= 1 + seg_dominate = 1 - a_dominate - b_dominate + + position = np.stack([x + y_0, y + x_0], axis=-1) + projection = start + np.stack([coeff, coeff], axis=-1) * (end - start) + d2_line = position - projection + d2_line = d2_line[:, :, 0]**2 + d2_line[:, :, 1]**2 + d2_seg = a_dominate * d2_start + b_dominate * d2_end + seg_dominate * d2_line + + patch = np.exp(-d2_seg / 2. / sigma**2) + patch = patch * value_coeff + + arr[min_y:max_y, min_x:max_x] = np.maximum(arr[min_y:max_y, min_x:max_x], patch) + def generate_heatmap(self, arr, kps, max_values): + """Generate pseudo heatmap for all keypoints and limbs in one frame (if + needed). + Args: + arr (np.ndarray): The array to store the generated heatmaps. Shape: V * img_h * img_w. + kps (np.ndarray): The coordinates of keypoints in this frame. Shape: 1 * V * 2. + max_values (np.ndarray): The confidence score of each keypoint. Shape: 1 * V. + Returns: + np.ndarray: The generated pseudo heatmap. + """ + + point_center = kps.mean(1) + + if self.with_kp: + num_kp = kps.shape[1] + for i in range(num_kp): + self.generate_a_heatmap(arr[i], kps[:, i], max_values[:, i], point_center) + + if self.with_limb: + for i, limb in enumerate(self.skeletons): + start_idx, end_idx = limb + starts = kps[:, start_idx] + ends = kps[:, end_idx] + + start_values = max_values[:, start_idx] + end_values = max_values[:, end_idx] + self.generate_a_limb_heatmap(arr[i], starts, ends, start_values, end_values, point_center) + + def gen_an_aug(self, pose_data): + """Generate pseudo heatmaps for all frames. + Args: + pose_data (array): [1, T, V, C] + Returns: + list[np.ndarray]: The generated pseudo heatmaps. + """ + + all_kps = pose_data[..., :2] + kp_shape = pose_data.shape # [1, T, V, 2] + + if pose_data.shape[-1] == 3: + all_kpscores = pose_data[..., -1] # [1, T, V] + else: + all_kpscores = np.ones(kp_shape[:-1], dtype=np.float32) + + + + # scale img_h, img_w and kps + img_h = int(self.img_h * self.scaling + 0.5) + img_w = int(self.img_w * self.scaling + 0.5) + all_kps[..., :2] *= self.scaling + + num_frame = kp_shape[1] + num_c = 0 + if self.with_kp: + num_c += all_kps.shape[2] + if self.with_limb: + num_c += len(self.skeletons) + ret = np.zeros([num_frame, num_c, img_h, img_w], dtype=np.float32) + + for i in range(num_frame): + # 1, V, C + kps = all_kps[:, i] + # 1, V + kpscores = all_kpscores[:, i] if self.use_score else np.ones_like(all_kpscores[:, i]) + + self.generate_heatmap(ret[i], kps, kpscores) + return ret + + def __call__(self, pose_data): + """ + pose_data: (T, V, C=3/2) + 1: means person number + """ + pose_data = pose_data[None,...] # (1, T, V, C=3/2) + + heatmap = self.gen_an_aug(pose_data) + + if self.double: + indices = np.arange(heatmap.shape[1], dtype=np.int64) + left, right = (self.left_kp, self.right_kp) if self.with_kp else (self.left_limb, self.right_limb) + for l, r in zip(left, right): # noqa: E741 + indices[l] = r + indices[r] = l + heatmap_flip = heatmap[..., ::-1][:, indices] + heatmap = np.concatenate([heatmap, heatmap_flip]) + return heatmap + + def __repr__(self): + repr_str = (f'{self.__class__.__name__}(' + f'sigma={self.sigma}, ' + f'use_score={self.use_score}, ' + f'with_kp={self.with_kp}, ' + f'with_limb={self.with_limb}, ' + f'skeletons={self.skeletons}, ' + f'double={self.double}, ' + f'left_kp={self.left_kp}, ' + f'right_kp={self.right_kp})') + return repr_str + +class HeatmapToImage: + """ + Convert the heatmap data to image data. + """ + def __init__(self) -> None: + self.cmap = cm.gray + + def __call__(self, heatmaps): + """ + heatmaps: (T, 17, H, W) + return images: (T, 1, H, W) + """ + heatmaps = [x.transpose(1, 2, 0) for x in heatmaps] + h, w, _ = heatmaps[0].shape + newh, neww = int(h), int(w) + heatmaps = [np.max(x, axis=-1) for x in heatmaps] + heatmaps = [(self.cmap(x)[..., :3] * 255).astype(np.uint8) for x in heatmaps] + heatmaps = [cv2.resize(x, (neww, newh)) for x in heatmaps] + return np.ascontiguousarray(np.mean(np.array(heatmaps), axis=-1, keepdims=True).transpose(0,3,1,2)) + +class CenterAndScaleNormalizer: + + def __init__(self, pose_format="coco", use_conf=True, heatmap_image_height=128) -> None: + """ + Parameters: + - pose_format (str): Specifies the format of the keypoints. + This parameter determines how the keypoints are structured and indexed. + The supported formats are "coco" or "openpose-x" where 'x' can be either 18 or 25, indicating the number of keypoints used by the OpenPose model. + - use_conf (bool): Indicates whether confidence scores. + - heatmap_image_height (int): Sets the height (in pixels) for the heatmap images that will be normlization. + """ + self.pose_format = pose_format + self.use_conf = use_conf + self.heatmap_image_height = heatmap_image_height + + def __call__(self, data): + """ + Implements step (a) from Figure 2 in the SkeletonGait paper. + data: (T, V, C) + - T: number of frames + - V: number of joints + - C: dimensionality, where 2 indicates joint coordinates and 1 indicates the confidence score + return data: (T, V, C) + """ + + if self.use_conf: + pose_seq = data[..., :-1] + score = np.expand_dims(data[..., -1], axis=-1) + else: + pose_seq = data[..., :-1] + + # Hip as the center point + if self.pose_format.lower() == "coco": + hip = (pose_seq[:, 11] + pose_seq[:, 12]) / 2. # [t, 2] + elif self.pose_format.split('-')[0].lower() == "openpose": + hip = (pose_seq[:, 9] + pose_seq[:, 12]) / 2. # [t, 2] + else: + raise ValueError(f"Error value for pose_format: {self.pose_format} in CenterAndScale Class.") + + # Center-normalization + pose_seq = pose_seq - hip[:, np.newaxis, :] + + # Scale-normalization + y_max = np.max(pose_seq[:, :, 1], axis=-1) # [t] + y_min = np.min(pose_seq[:, :, 1], axis=-1) # [t] + pose_seq *= ((self.heatmap_image_height // 1.5) / (y_max - y_min)[:, np.newaxis, np.newaxis]) # [t, v, 2] + + pose_seq += self.heatmap_image_height // 2 + + if self.use_conf: + pose_seq = np.concatenate([pose_seq, score], axis=-1) + return pose_seq + +class PadKeypoints: + """ + Pad the keypoints with missing values. + """ + + def __init__(self, pad_method="knn", use_conf=True) -> None: + """ + pad_method (str): Specifies the method used to pad the missing values. + The supported methods are "knn" and "simple". + use_conf (bool): Indicates whether confidence scores. + """ + self.use_conf = use_conf + if pad_method.lower() == "knn": + self.imputer = KNNImputer(missing_values=0.0, n_neighbors=4, weights="distance", add_indicator=False) + elif pad_method.lower() == "simple": + self.imputer = SimpleImputer(missing_values=0.0, strategy='mean',add_indicator=True) + else: + raise ValueError(f"Error value for padding method: {pad_method}") + + def __call__(self, raw_data): + """ + raw_data: (T, V, C) + - T: number of frames + - V: number of joints + - C: dimensionality, where 2 indicates joint coordinates and 1 indicates the confidence score + return padded_data: (T, V, C) + """ + T, V, C = raw_data.shape + if self.use_conf: + data = raw_data[..., :-1] + score = np.expand_dims(raw_data[..., -1], axis=-1) + C = C - 1 + else: + data = raw_data[..., :-1] + data = data.reshape((T, V*C)) + padded_data = self.imputer.fit_transform(data) + try: + padded_data = padded_data.reshape((T, V, C)) + except: + padded_data = data.reshape((T, V, C)) + if self.use_conf: + padded_data = np.concatenate([padded_data, score], axis=-1) + return padded_data + +class COCO18toCOCO17: + """ + Transfer COCO18 format (Openpose extracted) to COCO17 format + """ + + def __init__(self, transfer_to_coco17=True): + """ + transfer_to_coco17 (bool): Indicates whether to transfer the keypoints from COCO18 to COCO17 format. + """ + self.map_dict = { + 0: 0,# "nose", + 1: 15,# "left_eye", + 2: 14,# "right_eye", + 3: 17,# "left_ear", + 4: 16,# "right_ear", + 5: 5,# "left_shoulder", + 6: 2,# "right_shoulder", + 7: 6,# "left_elbow", + 8: 3,# "right_elbow", + 9: 7,# "left_wrist", + 10: 4,# "right_wrist", + 11: 11,# "left_hip", + 12: 8,# "right_hip", + 13: 12,# "left_knee", + 14: 9,# "right_knee", + 15: 13,# "left_ankle", + 16: 10,# "right_ankle" + } + self.transfer = transfer_to_coco17 + + def __call__(self, data): + + """ + data: (T, 18, C) + - T: number of frames + - 18: number of joints of COCO18 format + - C: dimensionality, where 2 indicates joint coordinates and 1 indicates the confidence score + return data: (T, 17, C) + """ + + if self.transfer: + """ + input data [T, 18, C] coco18 format + return data [T, 17, C] coco17 format + """ + T, _, C = data.shape + coco17_pkl_data = np.zeros((T, 17, C)) + for i in range(17): + coco17_pkl_data[:,i,:] = data[:,self.map_dict[i],:] + return coco17_pkl_data + else: + return data + +class GatherTransform(object): + """ + Gather the different transforms. + """ + def __init__(self, base_transform, transform_bone, transform_joint): + + """ + base_transform: Some common transform, e.g., COCO18toCOCO17, PadKeypoints, CenterAndScale + transform_bone: GeneratePoseTarget for generate bone heatmap + transform_joint: GeneratePoseTarget for generate joint heatmap + """ + self.base_transform = base_transform + self.transform_bone = transform_bone + self.transform_joint = transform_joint + + def __call__(self, pose_data): + x = self.base_transform(pose_data) + heatmap_bone = self.transform_bone(x) # [T, 1, H, W] + heatmap_joint = self.transform_joint(x) # [T, 1, H, W] + heatmap = np.concatenate([heatmap_bone, heatmap_joint], axis=1) + return heatmap + +class HeatmapAlignment(): + def __init__(self, align=True, final_img_size=64, offset=0, heatmap_image_size=128) -> None: + self.align = align + self.final_img_size = final_img_size + self.offset = offset + self.heatmap_image_size = heatmap_image_size + + def center_crop(self, heatmap): + """ + Input: [1, heatmap_image_size, heatmap_image_size] + Output: [1, final_img_size, final_img_size] + """ + raw_heatmap = heatmap[0] + if self.align: + y_sum = raw_heatmap.sum(axis=1) + y_top = (y_sum != 0).argmax(axis=0) + y_btm = (y_sum != 0).cumsum(axis=0).argmax(axis=0) + height = y_btm - y_top + 1 + raw_heatmap = raw_heatmap[y_top - self.offset: y_btm + 1 + self.offset, (self.heatmap_image_size // 2) - (height // 2) : (self.heatmap_image_size // 2) + (height // 2) + 1] + raw_heatmap = cv2.resize(raw_heatmap, (self.final_img_size, self.final_img_size), interpolation=cv2.INTER_AREA) + return raw_heatmap[np.newaxis, :, :] # [1, final_img_size, final_img_size] + + def __call__(self, heatmap_imgs): + """ + heatmap_imgs: (T, 1, raw_size, raw_size) + return (T, 1, final_img_size, final_img_size) + """ + heatmap_imgs = heatmap_imgs / 255. + heatmap_imgs = np.array([self.center_crop(heatmap_img) for heatmap_img in heatmap_imgs]) + return (heatmap_imgs * 255).astype('uint8') + +def GenerateHeatmapTransform( + coco18tococo17_args, + padkeypoints_args, + norm_args, + heatmap_generator_args, + align_args +): + + base_transform = T.Compose([ + COCO18toCOCO17(**coco18tococo17_args), + PadKeypoints(**padkeypoints_args), + CenterAndScaleNormalizer(**norm_args), + ]) + + heatmap_generator_args["with_limb"] = True + heatmap_generator_args["with_kp"] = False + transform_bone = T.Compose([ + GeneratePoseTarget(**heatmap_generator_args), + HeatmapToImage(), + HeatmapAlignment(**align_args) + ]) + + heatmap_generator_args["with_limb"] = False + heatmap_generator_args["with_kp"] = True + transform_joint = T.Compose([ + GeneratePoseTarget(**heatmap_generator_args), + HeatmapToImage(), + HeatmapAlignment(**align_args) + ]) + + transform = T.Compose([ + GatherTransform(base_transform, transform_bone, transform_joint) # [T, 2, H, W] + ]) + + return transform + +######################################################################################################### +# The following code is DDP progress codes. +######################################################################################################### +class SequentialDistributedSampler(torch.utils.data.sampler.Sampler): + """ + Distributed Sampler that subsamples indicies sequentially, + making it easier to collate all results at the end. + Even though we only use this sampler for eval and predict (no training), + which means that the model params won't have to be synced (i.e. will not hang + for synchronization even if varied number of forward passes), we still add extra + samples to the sampler to make it evenly divisible (like in `DistributedSampler`) + to make it easy to `gather` or `reduce` resulting tensors at the end of the loop. + """ + + def __init__(self, dataset, batch_size, rank=None, num_replicas=None): + if num_replicas is None: + if not torch.distributed.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = torch.distributed.get_world_size() + if rank is None: + if not torch.distributed.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = torch.distributed.get_rank() + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.batch_size = batch_size + self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.batch_size / self.num_replicas)) * self.batch_size + self.total_size = self.num_samples * self.num_replicas + + def __iter__(self): + indices = list(range(len(self.dataset))) + # add extra samples to make it evenly divisible + indices += [indices[-1]] * (self.total_size - len(indices)) + # subsample + indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples] + return iter(indices) + + def __len__(self): + return self.num_samples + + +class TransferDataset(Dataset): + def __init__(self, args, generate_heatemap_cfgs) -> None: + super().__init__() + pose_root = args.pose_data_path + sigma = generate_heatemap_cfgs['heatmap_generator_args']['sigma'] + self.dataset_name = args.dataset_name + assert self.dataset_name.lower() in ["sustech1k", "grew", "ccpg", "oumvlp", "ou-mvlp", "gait3d", "casiab", "casiae"], f"Invalid dataset name: {self.dataset_name}" + self.save_root = os.path.join(args.save_root, f"{self.dataset_name}_sigma_{sigma}_{args.ext_name}") + os.makedirs(self.save_root, exist_ok=True) + + self.heatmap_transform = GenerateHeatmapTransform(**generate_heatemap_cfgs) + + if self.dataset_name.lower() == "sustech1k": + self.all_ps_data_paths = sorted(glob(os.path.join(pose_root, "*/*/*/03*.pkl"))) + else: + self.all_ps_data_paths = sorted(glob(os.path.join(pose_root, "*/*/*/*.pkl"))) + + def __len__(self): + return len(self.all_ps_data_paths) + + def __getitem__(self, index): + pose_path = self.all_ps_data_paths[index] + with open(pose_path, "rb") as f: + pose_data = pickle.load(f) + if self.dataset_name.lower() == "grew": + # print(pose_data.shape) + pose_data = pose_data[:,2:].reshape(-1, 17, 3) + + tmp_split = pose_path.split('/') + + heatmap_img = self.heatmap_transform(pose_data) # [T, 2, H, W] + + save_path_pkl = os.path.join(self.save_root, 'pkl', *tmp_split[-4:-1]) + os.makedirs(save_path_pkl, exist_ok=True) + + # save some visualization + if index < 10: + # save images + save_path_img = os.path.join(self.save_root, 'images', *tmp_split[-4:-1]) + os.makedirs(save_path_img, exist_ok=True) + # save_heatemapimg_index = random.choice(list(range(heatmap_img.shape[0]))) + for save_heatemapimg_index in range(heatmap_img.shape[0]): + cv2.imwrite(os.path.join(save_path_img, f'bone_{save_heatemapimg_index}.jpg'), heatmap_img[save_heatemapimg_index, 0]) + cv2.imwrite(os.path.join(save_path_img, f'pose_{save_heatemapimg_index}.jpg'), heatmap_img[save_heatemapimg_index, 1]) + + pickle.dump(heatmap_img, open(os.path.join(save_path_pkl, tmp_split[-1]), 'wb')) + return None + +def mycollate(_): + return None + + +def get_args(): + parser = argparse.ArgumentParser(description='Utility for generating heatmaps from pose data.') + parser.add_argument('--pose_data_path', type=str, required=True, help="Path to the root directory containing pose data (.pkl files, ID-level) files.") + parser.add_argument('--save_root', type=str, required=True, help="Root directory where generated heatmap .pkl files will be saved (ID-level).") + parser.add_argument('--ext_name', type=str, default='', help="Extension name to be appended to the 'save_root' for identification.") + parser.add_argument('--dataset_name', type=str, required=True, help="Name of the dataset being preprocessed.") + parser.add_argument('--heatemap_cfg_path', type=str, default='configs/skeletongait/pretreatment_heatmap.yaml', help="Path to the heatmap generator configuration file.") + parser.add_argument("--local_rank", type=int, default=0, help="Local rank for distributed processing, defaults to 0 for non-distributed setups.") + opt = parser.parse_args() + return opt + +def replace_variables(data, context=None): + if context is None: + context = {} + + if isinstance(data, dict): + for key, value in data.items(): + data[key] = replace_variables(value, context) + elif isinstance(data, list): + data = [replace_variables(item, context) for item in data] + elif isinstance(data, str): + if data.startswith('${') and data.endswith('}'): + var_path = data[2:-1].split('.') + var_value = context + try: + for part in var_path: + var_value = var_value[part] + return var_value + except KeyError: + raise ValueError(f"Variable {data} not found in context") + return data + +if __name__ == "__main__": + dist.init_process_group("nccl", init_method='env://') + local_rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + + args = get_args() + + # Load the heatmap generator configuration + with open(args.heatemap_cfg_path, 'r') as stream: + generate_heatemap_cfgs = yaml.safe_load(stream) + generate_heatemap_cfgs = replace_variables(generate_heatemap_cfgs, generate_heatemap_cfgs) + # Create the dataset + dataset = TransferDataset(args, generate_heatemap_cfgs) + + # Create the dataloader + dist_sampler = SequentialDistributedSampler(dataset, batch_size=1, rank=local_rank, num_replicas=world_size) + dataloader = DataLoader(dataset=dataset, batch_size=1, sampler=dist_sampler, num_workers=8, collate_fn=mycollate) + for _, tmp in tqdm(enumerate(dataloader), total=len(dataloader)): + pass + + diff --git a/opengait/data/transform.py b/opengait/data/transform.py index 3c8162a..cf6d882 100644 --- a/opengait/data/transform.py +++ b/opengait/data/transform.py @@ -132,15 +132,19 @@ class RandomRotate(object): if random.uniform(0, 1) >= self.prob: return seq else: - _, dh, dw = seq.shape + dh, dw = seq.shape[-2:] # rotation degree = random.uniform(-self.degree, self.degree) M1 = cv2.getRotationMatrix2D((dh // 2, dw // 2), degree, 1) # affine + if len(seq.shape) == 4: + seq = seq.transpose(0, 2, 3, 1) seq = [cv2.warpAffine(_[0, ...], M1, (dw, dh)) - for _ in np.split(seq, seq.shape[0], axis=0)] + for _ in np.split(seq, seq.shape[0], axis=0)] seq = np.concatenate([np.array(_)[np.newaxis, ...] for _ in seq], 0) + if len(seq.shape) == 4: + seq = seq.transpose(0, 3, 1, 2) return seq @@ -152,7 +156,7 @@ class RandomPerspective(object): if random.uniform(0, 1) >= self.prob: return seq else: - _, h, w = seq.shape + h, w = seq.shape[-2:] cutting = int(w // 44) * 10 x_left = list(range(0, cutting)) x_right = list(range(w - cutting, w)) @@ -164,10 +168,14 @@ class RandomPerspective(object): canvasPoints = np.float32([[0, 0], [w, 0], [w, h], [0, h]]) perspectiveMatrix = cv2.getPerspectiveTransform( np.array(srcPoints), np.array(canvasPoints)) + if len(seq.shape) == 4: + seq = seq.transpose(0, 2, 3, 1) seq = [cv2.warpPerspective(_[0, ...], perspectiveMatrix, (w, h)) for _ in np.split(seq, seq.shape[0], axis=0)] seq = np.concatenate([np.array(_)[np.newaxis, ...] for _ in seq], 0) + if len(seq.shape) == 4: + seq = seq.transpose(0, 3, 1, 2) return seq @@ -180,7 +188,7 @@ class RandomAffine(object): if random.uniform(0, 1) >= self.prob: return seq else: - _, dh, dw = seq.shape + dh, dw = seq.shape[-2:] # rotation max_shift = int(dh // 64 * 10) shift_range = list(range(0, max_shift)) @@ -190,10 +198,14 @@ class RandomAffine(object): dh-random.choice(shift_range), random.choice(shift_range)], [random.choice(shift_range), dw-random.choice(shift_range)]]) M1 = cv2.getAffineTransform(pts1, pts2) # affine + if len(seq.shape) == 4: + seq = seq.transpose(0, 2, 3, 1) seq = [cv2.warpAffine(_[0, ...], M1, (dw, dh)) for _ in np.split(seq, seq.shape[0], axis=0)] seq = np.concatenate([np.array(_)[np.newaxis, ...] for _ in seq], 0) + if len(seq.shape) == 4: + seq = seq.transpose(0, 3, 1, 2) return seq # ****************************************** diff --git a/opengait/modeling/models/deepgaitv2.py b/opengait/modeling/models/deepgaitv2.py index 4dc3493..3e66a62 100644 --- a/opengait/modeling/models/deepgaitv2.py +++ b/opengait/modeling/models/deepgaitv2.py @@ -27,6 +27,7 @@ class DeepGaitV2(BaseModel): in_channels = model_cfg['Backbone']['in_channels'] layers = model_cfg['Backbone']['layers'] channels = model_cfg['Backbone']['channels'] + self.inference_use_emb2 = model_cfg['use_emb2'] if 'use_emb2' in model_cfg else False if mode == '3d': strides = [ @@ -92,7 +93,11 @@ class DeepGaitV2(BaseModel): def forward(self, inputs): ipts, labs, typs, vies, seqL = inputs - sils = ipts[0].unsqueeze(1) + if len(ipts[0].size()) == 4: + sils = ipts[0].unsqueeze(1) + else: + sils = ipts[0] + sils = sils.transpose(1, 2).contiguous() assert sils.size(-1) in [44, 88] del ipts @@ -111,7 +116,10 @@ class DeepGaitV2(BaseModel): embed_1 = self.FCs(feat) # [n, c, p] embed_2, logits = self.BNNecks(embed_1) # [n, c, p] - embed = embed_1 + if self.inference_use_emb2: + embed = embed_2 + else: + embed = embed_1 retval = { 'training_feat': { diff --git a/opengait/modeling/models/skeletongait++.py b/opengait/modeling/models/skeletongait++.py new file mode 100644 index 0000000..9c169d6 --- /dev/null +++ b/opengait/modeling/models/skeletongait++.py @@ -0,0 +1,191 @@ +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +from ..base_model import BaseModel +from ..modules import HorizontalPoolingPyramid, PackSequenceWrapper, SeparateFCs, SeparateBNNecks, SetBlockWrapper, conv3x3, conv1x1, BasicBlock2D, BasicBlockP3D + +from einops import rearrange + +import copy + +class SkeletonGaitPP(BaseModel): + + def build_network(self, model_cfg): + #B, C = [1, 4, 4, 1], 2 + in_C, B, C = model_cfg['Backbone']['in_channels'], model_cfg['Backbone']['blocks'], model_cfg['Backbone']['C'] + self.inference_use_emb = model_cfg['use_emb2'] if 'use_emb2' in model_cfg else False + + self.inplanes = 32 * C + self.sil_layer0 = SetBlockWrapper(nn.Sequential( + conv3x3(1, self.inplanes, 1), + nn.BatchNorm2d(self.inplanes), + nn.ReLU(inplace=True) + )) + + self.map_layer0 = SetBlockWrapper(nn.Sequential( + conv3x3(2, self.inplanes, 1), + nn.BatchNorm2d(self.inplanes), + nn.ReLU(inplace=True) + )) + + self.sil_layer1 = SetBlockWrapper(self.make_layer(BasicBlock2D, 32 * C, stride=[1, 1], blocks_num=B[0], mode='2d')) + self.map_layer1 = copy.deepcopy(self.sil_layer1) + self.fusion = AttentionFusion(32 * C) + + self.layer2 = self.make_layer(BasicBlockP3D, 64 * C, stride=[2, 2], blocks_num=B[1], mode='p3d') + self.layer3 = self.make_layer(BasicBlockP3D, 128 * C, stride=[2, 2], blocks_num=B[2], mode='p3d') + self.layer4 = self.make_layer(BasicBlockP3D, 256 * C, stride=[1, 1], blocks_num=B[3], mode='p3d') + + self.FCs = SeparateFCs(16, 256*C, 128*C) + self.BNNecks = SeparateBNNecks(16, 128*C, class_num=model_cfg['SeparateBNNecks']['class_num']) + + self.TP = PackSequenceWrapper(torch.max) + self.HPP = HorizontalPoolingPyramid(bin_num=[16]) + + def make_layer(self, block, planes, stride, blocks_num, mode='2d'): + + if max(stride) > 1 or self.inplanes != planes * block.expansion: + if mode == '3d': + downsample = nn.Sequential(nn.Conv3d(self.inplanes, planes * block.expansion, kernel_size=[1, 1, 1], stride=stride, padding=[0, 0, 0], bias=False), nn.BatchNorm3d(planes * block.expansion)) + elif mode == '2d': + downsample = nn.Sequential(conv1x1(self.inplanes, planes * block.expansion, stride=stride), nn.BatchNorm2d(planes * block.expansion)) + elif mode == 'p3d': + downsample = nn.Sequential(nn.Conv3d(self.inplanes, planes * block.expansion, kernel_size=[1, 1, 1], stride=[1, *stride], padding=[0, 0, 0], bias=False), nn.BatchNorm3d(planes * block.expansion)) + else: + raise TypeError('xxx') + else: + downsample = lambda x: x + + layers = [block(self.inplanes, planes, stride=stride, downsample=downsample)] + self.inplanes = planes * block.expansion + s = [1, 1] if mode in ['2d', 'p3d'] else [1, 1, 1] + for i in range(1, blocks_num): + layers.append( + block(self.inplanes, planes, stride=s) + ) + return nn.Sequential(*layers) + + def inputs_pretreament(self, inputs): + ### Ensure the same data augmentation for heatmap and silhouette + pose_sils = inputs[0] + new_data_list = [] + for pose, sil in zip(pose_sils[0], pose_sils[1]): + sil = sil[:, np.newaxis, ...] # [T, 1, H, W] + pose_h, pose_w = pose.shape[-2], pose.shape[-1] + sil_h, sil_w = sil.shape[-2], sil.shape[-1] + if sil_h != sil_w and pose_h == pose_w: + cutting = (sil_h - sil_w) // 2 + pose = pose[..., cutting:-cutting] + cat_data = np.concatenate([pose, sil], axis=1) # [T, 3, H, W] + new_data_list.append(cat_data) + new_inputs = [[new_data_list], inputs[1], inputs[2], inputs[3], inputs[4]] + return super().inputs_pretreament(new_inputs) + + def forward(self, inputs): + ipts, labs, _, _, seqL = inputs + + pose = ipts[0] + pose = pose.transpose(1, 2).contiguous() + assert pose.size(-1) in [44, 48, 88, 96] + maps = pose[:, :2, ...] + sils = pose[:, -1, ...].unsqueeze(1) + + del ipts + map0 = self.map_layer0(maps) + map1 = self.map_layer1(map0) + + sil0 = self.sil_layer0(sils) + sil1 = self.sil_layer1(sil0) + + out1 = self.fusion(sil1, map1) + out2 = self.layer2(out1) + out3 = self.layer3(out2) + out4 = self.layer4(out3) # [n, c, s, h, w] + + # Temporal Pooling, TP + outs = self.TP(out4, seqL, options={"dim": 2})[0] # [n, c, h, w] + n, c, h, w = outs.size() + + # Horizontal Pooling Matching, HPM + feat = self.HPP(outs) # [n, c, p] + + embed_1 = self.FCs(feat) # [n, c, p] + embed_2, logits = self.BNNecks(embed_1) # [n, c, p] + + if self.inference_use_emb: + embed = embed_2 + else: + embed = embed_1 + + retval = { + 'training_feat': { + 'triplet': {'embeddings': embed_1, 'labels': labs}, + 'softmax': {'logits': logits, 'labels': labs} + }, + 'visual_summary': { + 'image/sils': rearrange(pose * 255., 'n c s h w -> (n s) c h w'), + }, + 'inference_feat': { + 'embeddings': embed + } + } + return retval + +class AttentionFusion(nn.Module): + def __init__(self, in_channels=64, squeeze_ratio=16): + super(AttentionFusion, self).__init__() + hidden_dim = int(in_channels / squeeze_ratio) + self.conv = SetBlockWrapper( + nn.Sequential( + conv1x1(in_channels * 2, hidden_dim), + nn.BatchNorm2d(hidden_dim), + nn.ReLU(inplace=True), + conv3x3(hidden_dim, hidden_dim), + nn.BatchNorm2d(hidden_dim), + nn.ReLU(inplace=True), + conv1x1(hidden_dim, in_channels * 2), + ) + ) + + def forward(self, sil_feat, map_feat): + ''' + sil_feat: [n, c, s, h, w] + map_feat: [n, c, s, h, w] + ''' + c = sil_feat.size(1) + feats = torch.cat([sil_feat, map_feat], dim=1) + score = self.conv(feats) # [n, 2 * c, s, h, w] + score = rearrange(score, 'n (d c) s h w -> n d c s h w', d=2) + score = F.softmax(score, dim=1) + retun = sil_feat * score[:, 0] + map_feat * score[:, 1] + return retun + +class CatFusion(nn.Module): + def __init__(self, in_channels=64): + super(CatFusion, self).__init__() + self.conv = SetBlockWrapper( + nn.Sequential( + conv1x1(in_channels * 2, in_channels), + ) + ) + + def forward(self, sil_feat, map_feat): + ''' + sil_feat: [n, c, s, h, w] + map_feat: [n, c, s, h, w] + ''' + feats = torch.cat([sil_feat, map_feat]) + retun = self.conv(feats) + return retun + +class PlusFusion(nn.Module): + def __init__(self): + super(PlusFusion, self).__init__() + + def forward(self, sil_feat, map_feat): + ''' + sil_feat: [n, c, s, h, w] + map_feat: [n, c, s, h, w] + ''' + return sil_feat + map_feat