From 7fe5d1b26adbb03f6ecfd51fd0f76ad0e53deef1 Mon Sep 17 00:00:00 2001 From: Dongyang Jin <1410234026@qq.com> Date: Sun, 10 Dec 2023 13:54:59 +0800 Subject: [PATCH] deepgaitv2 --- configs/deepgaitv2/DeepGaitV2_casiab.yaml | 95 +++++++++++ configs/deepgaitv2/DeepGaitV2_ccpg.yaml | 97 ++++++++++++ configs/deepgaitv2/DeepGaitV2_gait3d.yaml | 96 ++++++++++++ configs/deepgaitv2/DeepGaitV2_grew.yaml | 96 ++++++++++++ configs/deepgaitv2/DeepGaitV2_oumvlp.yaml | 95 +++++++++++ configs/deepgaitv2/DeepGaitV2_sustech1k.yaml | 96 ++++++++++++ opengait/modeling/models/deepgaitv2.py | 129 +++++++++++++++ opengait/modeling/modules.py | 157 ++++++++++++++++++- 8 files changed, 860 insertions(+), 1 deletion(-) create mode 100644 configs/deepgaitv2/DeepGaitV2_casiab.yaml create mode 100644 configs/deepgaitv2/DeepGaitV2_ccpg.yaml create mode 100644 configs/deepgaitv2/DeepGaitV2_gait3d.yaml create mode 100644 configs/deepgaitv2/DeepGaitV2_grew.yaml create mode 100644 configs/deepgaitv2/DeepGaitV2_oumvlp.yaml create mode 100644 configs/deepgaitv2/DeepGaitV2_sustech1k.yaml create mode 100644 opengait/modeling/models/deepgaitv2.py diff --git a/configs/deepgaitv2/DeepGaitV2_casiab.yaml b/configs/deepgaitv2/DeepGaitV2_casiab.yaml new file mode 100644 index 0000000..f41df22 --- /dev/null +++ b/configs/deepgaitv2/DeepGaitV2_casiab.yaml @@ -0,0 +1,95 @@ +data_cfg: + dataset_name: CASIA-B + dataset_root: your_path + dataset_partition: ./datasets/CASIA-B/CASIA-B.json + num_workers: 1 + remove_no_gallery: false # Remove probe if no gallery for it + test_dataset_name: CASIA-B + +evaluator_cfg: + enable_float16: true + restore_ckpt_strict: true + restore_hint: 60000 + save_name: DeepGaitV2 + sampler: + batch_shuffle: false + batch_size: 4 + sample_type: all_ordered # all indicates whole sequence used to test, while ordered means input sequence by its natural order; Other options: fixed_unordered + frames_all_limit: 720 # limit the number of sampled frames to prevent out of memory + metric: euc # cos + transform: + - type: BaseSilCuttingTransform + +loss_cfg: + - loss_term_weight: 1.0 + margin: 0.2 + type: TripletLoss + log_prefix: triplet + - loss_term_weight: 1.0 + scale: 16 + type: CrossEntropyLoss + log_prefix: softmax + log_accuracy: true + +model_cfg: + model: DeepGaitV2 + Backbone: + mode: p3d + in_channels: 1 + layers: + - 1 + - 1 + - 1 + - 1 + channels: + - 64 + - 128 + - 256 + - 512 + SeparateBNNecks: + class_num: 74 + +optimizer_cfg: + lr: 0.1 + momentum: 0.9 + solver: SGD + weight_decay: 0.0005 + +scheduler_cfg: + gamma: 0.1 + milestones: # Learning Rate Reduction at each milestones + - 20000 + - 40000 + - 50000 + scheduler: MultiStepLR + +trainer_cfg: + enable_float16: true # half_percesion float for memory reduction and speedup + fix_BN: false + log_iter: 100 + with_test: false + restore_ckpt_strict: true + restore_hint: 0 + save_iter: 30000 + save_name: DeepGaitV2 + sync_BN: true + total_iter: 60000 + sampler: + batch_shuffle: true + batch_size: + - 8 # TripletSampler, batch_size[0] indicates Number of Identity + - 16 # batch_size[1] indicates Samples sequqnce for each Identity + frames_num_fixed: 30 # fixed frames number for training + frames_skip_num: 4 + sample_type: fixed_ordered # fixed control input frames number, unordered for controlling order of input tensor; Other options: unfixed_ordered or all_ordered + type: TripletSampler + transform: + - type: Compose + trf_cfg: + - type: RandomPerspective + prob: 0.2 + - type: BaseSilCuttingTransform + - type: RandomHorizontalFlip + prob: 0.2 + - type: RandomRotate + prob: 0.2 diff --git a/configs/deepgaitv2/DeepGaitV2_ccpg.yaml b/configs/deepgaitv2/DeepGaitV2_ccpg.yaml new file mode 100644 index 0000000..81e29dc --- /dev/null +++ b/configs/deepgaitv2/DeepGaitV2_ccpg.yaml @@ -0,0 +1,97 @@ +data_cfg: + dataset_name: CCPG + dataset_root: your_path + dataset_partition: ./datasets/CCPG/CCPG.json + num_workers: 1 + data_in_use: [True, False, False, False] + remove_no_gallery: false # Remove probe if no gallery for it + test_dataset_name: CCPG + +evaluator_cfg: + enable_float16: true + restore_ckpt_strict: true + restore_hint: 60000 + save_name: DeepGaitV2 + sampler: + batch_shuffle: false + batch_size: 4 + sample_type: all_ordered # all indicates whole sequence used to test, while ordered means input sequence by its natural order; Other options: fixed_unordered + frames_all_limit: 720 # limit the number of sampled frames to prevent out of memory + eval_func: evaluate_CCPG + metric: euc # cos + transform: + - type: BaseSilCuttingTransform + +loss_cfg: + - loss_term_weight: 1.0 + margin: 0.2 + type: TripletLoss + log_prefix: triplet + - loss_term_weight: 1.0 + scale: 16 + type: CrossEntropyLoss + log_prefix: softmax + log_accuracy: true + +model_cfg: + model: DeepGaitV2 + Backbone: + mode: p3d + in_channels: 1 + layers: + - 1 + - 1 + - 1 + - 1 + channels: + - 64 + - 128 + - 256 + - 512 + SeparateBNNecks: + class_num: 100 + +optimizer_cfg: + lr: 0.1 + momentum: 0.9 + solver: SGD + weight_decay: 0.0005 + +scheduler_cfg: + gamma: 0.1 + milestones: # Learning Rate Reduction at each milestones + - 20000 + - 40000 + - 50000 + scheduler: MultiStepLR + +trainer_cfg: + enable_float16: true # half_percesion float for memory reduction and speedup + fix_BN: false + log_iter: 100 + with_test: false + restore_ckpt_strict: true + restore_hint: 0 + save_iter: 30000 + save_name: DeepGaitV2 + sync_BN: true + total_iter: 60000 + sampler: + batch_shuffle: true + batch_size: + - 8 # TripletSampler, batch_size[0] indicates Number of Identity + - 16 # batch_size[1] indicates Samples sequqnce for each Identity + frames_num_fixed: 30 # fixed frames number for training + frames_skip_num: 4 + sample_type: fixed_ordered # fixed control input frames number, unordered for controlling order of input tensor; Other options: unfixed_ordered or all_ordered + type: TripletSampler + transform: + - type: Compose + trf_cfg: + - type: RandomPerspective + prob: 0.2 + - type: BaseSilCuttingTransform + - type: RandomHorizontalFlip + prob: 0.2 + - type: RandomRotate + prob: 0.2 diff --git a/configs/deepgaitv2/DeepGaitV2_gait3d.yaml b/configs/deepgaitv2/DeepGaitV2_gait3d.yaml new file mode 100644 index 0000000..c6acc8d --- /dev/null +++ b/configs/deepgaitv2/DeepGaitV2_gait3d.yaml @@ -0,0 +1,96 @@ +data_cfg: + dataset_name: Gait3D + dataset_root: your_path + dataset_partition: ./datasets/Gait3D/Gait3D.json + num_workers: 1 + remove_no_gallery: false # Remove probe if no gallery for it + test_dataset_name: Gait3D + +evaluator_cfg: + enable_float16: true + restore_ckpt_strict: true + restore_hint: 60000 + save_name: DeepGaitV2 + eval_func: evaluate_Gait3D + sampler: + batch_shuffle: false + batch_size: 4 + sample_type: all_ordered # all indicates whole sequence used to test, while ordered means input sequence by its natural order; Other options: fixed_unordered + frames_all_limit: 720 # limit the number of sampled frames to prevent out of memory + metric: euc # cos + transform: + - type: BaseSilTransform + +loss_cfg: + - loss_term_weight: 1.0 + margin: 0.2 + type: TripletLoss + log_prefix: triplet + - loss_term_weight: 1.0 + scale: 16 + type: CrossEntropyLoss + log_prefix: softmax + log_accuracy: true + +model_cfg: + model: DeepGaitV2 + Backbone: + in_channels: 1 + mode: p3d + layers: + - 1 + - 4 + - 4 + - 1 + channels: + - 64 + - 128 + - 256 + - 512 + SeparateBNNecks: + class_num: 3000 + +optimizer_cfg: + lr: 0.1 + momentum: 0.9 + solver: SGD + weight_decay: 0.0005 + +scheduler_cfg: + gamma: 0.1 + milestones: # Learning Rate Reduction at each milestones + - 20000 + - 40000 + - 50000 + scheduler: MultiStepLR + +trainer_cfg: + enable_float16: true # half_percesion float for memory reduction and speedup + fix_BN: false + log_iter: 100 + with_test: false + restore_ckpt_strict: true + restore_hint: 0 + save_iter: 30000 + save_name: DeepGaitV2 + sync_BN: true + total_iter: 60000 + sampler: + batch_shuffle: true + batch_size: + - 32 # TripletSampler, batch_size[0] indicates Number of Identity + - 4 # batch_size[1] indicates Samples sequqnce for each Identity + frames_num_fixed: 30 # fixed frames number for training + frames_skip_num: 4 + sample_type: fixed_ordered # fixed control input frames number, unordered for controlling order of input tensor; Other options: unfixed_ordered or all_ordered + type: TripletSampler + transform: + - type: Compose + trf_cfg: + - type: RandomPerspective + prob: 0.2 + - type: BaseSilCuttingTransform + - type: RandomHorizontalFlip + prob: 0.2 + - type: RandomRotate + prob: 0.2 \ No newline at end of file diff --git a/configs/deepgaitv2/DeepGaitV2_grew.yaml b/configs/deepgaitv2/DeepGaitV2_grew.yaml new file mode 100644 index 0000000..71cac91 --- /dev/null +++ b/configs/deepgaitv2/DeepGaitV2_grew.yaml @@ -0,0 +1,96 @@ +data_cfg: + dataset_name: GREW + dataset_root: your_path + dataset_partition: ./datasets/GREW/GREW.json + num_workers: 1 + remove_no_gallery: false # Remove probe if no gallery for it + test_dataset_name: GREW + +evaluator_cfg: + enable_float16: true + restore_ckpt_strict: true + restore_hint: 180000 + save_name: DeepGaitV2 + eval_func: GREW_submission + sampler: + batch_shuffle: false + batch_size: 4 + sample_type: all_ordered # all indicates whole sequence used to test, while ordered means input sequence by its natural order; Other options: fixed_unordered + frames_all_limit: 720 # limit the number of sampled frames to prevent out of memory + metric: euc # cos + transform: + - type: BaseSilCuttingTransform + +loss_cfg: + - loss_term_weight: 1.0 + margin: 0.2 + type: TripletLoss + log_prefix: triplet + - loss_term_weight: 1.0 + scale: 16 + type: CrossEntropyLoss + log_prefix: softmax + log_accuracy: true + +model_cfg: + model: DeepGaitV2 + Backbone: + in_channels: 1 + mode: p3d + layers: + - 1 + - 4 + - 4 + - 1 + channels: + - 64 + - 128 + - 256 + - 512 + SeparateBNNecks: + class_num: 20000 + +optimizer_cfg: + lr: 0.1 + momentum: 0.9 + solver: SGD + weight_decay: 0.0005 + +scheduler_cfg: + gamma: 0.1 + milestones: # Learning Rate Reduction at each milestones + - 80000 + - 120000 + - 150000 + scheduler: MultiStepLR + +trainer_cfg: + enable_float16: true # half_percesion float for memory reduction and speedup + fix_BN: false + log_iter: 100 + with_test: false + restore_ckpt_strict: true + restore_hint: 0 + save_iter: 30000 + save_name: DeepGaitV2 + sync_BN: true + total_iter: 180000 + sampler: + batch_shuffle: true + batch_size: + - 32 # TripletSampler, batch_size[0] indicates Number of Identity + - 4 # batch_size[1] indicates Samples sequqnce for each Identity + frames_num_fixed: 30 # fixed frames number for training + frames_skip_num: 4 + sample_type: fixed_ordered # fixed control input frames number, unordered for controlling order of input tensor; Other options: unfixed_ordered or all_ordered + type: TripletSampler + transform: + - type: Compose + trf_cfg: + - type: RandomPerspective + prob: 0.2 + - type: BaseSilCuttingTransform + - type: RandomHorizontalFlip + prob: 0.2 + - type: RandomRotate + prob: 0.2 diff --git a/configs/deepgaitv2/DeepGaitV2_oumvlp.yaml b/configs/deepgaitv2/DeepGaitV2_oumvlp.yaml new file mode 100644 index 0000000..808ec0c --- /dev/null +++ b/configs/deepgaitv2/DeepGaitV2_oumvlp.yaml @@ -0,0 +1,95 @@ +data_cfg: + dataset_name: OUMVLP + dataset_root: your_path + dataset_partition: ./datasets/OUMVLP/OUMVLP.json + num_workers: 1 + remove_no_gallery: false # Remove probe if no gallery for it + test_dataset_name: OUMVLP + +evaluator_cfg: + enable_float16: true + restore_ckpt_strict: true + restore_hint: 120000 + save_name: DeepGaitV2 + sampler: + batch_shuffle: false + batch_size: 4 + sample_type: all_ordered # all indicates whole sequence used to test, while ordered means input sequence by its natural order; Other options: fixed_unordered + frames_all_limit: 720 # limit the number of sampled frames to prevent out of memory + metric: euc # cos + transform: + - type: BaseSilCuttingTransform + +loss_cfg: + - loss_term_weight: 1.0 + margin: 0.2 + type: TripletLoss + log_prefix: triplet + - loss_term_weight: 1.0 + scale: 16 + type: CrossEntropyLoss + log_prefix: softmax + log_accuracy: true + +model_cfg: + model: DeepGaitV2 + Backbone: + in_channels: 1 + mode: p3d + layers: + - 1 + - 1 + - 1 + - 1 + channels: + - 64 + - 128 + - 256 + - 512 + SeparateBNNecks: + class_num: 5153 + +optimizer_cfg: + lr: 0.1 + momentum: 0.9 + solver: SGD + weight_decay: 0.0005 + +scheduler_cfg: + gamma: 0.1 + milestones: # Learning Rate Reduction at each milestones + - 60000 + - 80000 + - 100000 + scheduler: MultiStepLR + +trainer_cfg: + enable_float16: true # half_percesion float for memory reduction and speedup + fix_BN: false + log_iter: 100 + with_test: false + restore_ckpt_strict: true + restore_hint: 0 + save_iter: 30000 + save_name: DeepGaitV2 + sync_BN: true + total_iter: 120000 + sampler: + batch_shuffle: true + batch_size: + - 32 # TripletSampler, batch_size[0] indicates Number of Identity + - 8 # batch_size[1] indicates Samples sequqnce for each Identity + frames_num_fixed: 30 # fixed frames number for training + frames_skip_num: 4 + sample_type: fixed_ordered # fixed control input frames number, unordered for controlling order of input tensor; Other options: unfixed_ordered or all_ordered + type: TripletSampler + transform: + - type: Compose + trf_cfg: + - type: RandomPerspective + prob: 0.2 + - type: BaseSilCuttingTransform + - type: RandomHorizontalFlip + prob: 0.2 + - type: RandomRotate + prob: 0.2 diff --git a/configs/deepgaitv2/DeepGaitV2_sustech1k.yaml b/configs/deepgaitv2/DeepGaitV2_sustech1k.yaml new file mode 100644 index 0000000..52fed25 --- /dev/null +++ b/configs/deepgaitv2/DeepGaitV2_sustech1k.yaml @@ -0,0 +1,96 @@ +data_cfg: + dataset_name: SUSTech1K + dataset_root: your_path + dataset_partition: ./datasets/SUSTech1K/SUSTech1K.json + num_workers: 4 + data_in_use: [false, false, false, false, false, false, true, false, false, false, false, false, false, false, false, false] + remove_no_gallery: false # Remove probe if no gallery for it + test_dataset_name: SUSTech1K + +evaluator_cfg: + enable_float16: true + restore_ckpt_strict: true + restore_hint: 50000 + save_name: DeepGaitV2 + eval_func: evaluate_indoor_dataset #evaluate_Gait3D + sampler: + batch_shuffle: false + batch_size: 4 + sample_type: all_ordered # all indicates whole sequence used to test, while ordered means input sequence by its natural order; Other options: fixed_unordered + frames_all_limit: 720 # limit the number of sampled frames to prevent out of memory + metric: euc # cos + transform: + - type: BaseSilCuttingTransform + +loss_cfg: + - loss_term_weight: 1.0 + margin: 0.2 + type: TripletLoss + log_prefix: triplet + - loss_term_weight: 1.0 + scale: 16 + type: CrossEntropyLoss + log_prefix: softmax + log_accuracy: true + +model_cfg: + model: DeepGaitV2 + Backbone: + mode: p3d + in_channels: 1 + layers: + - 1 + - 1 + - 1 + - 1 + channels: + - 64 + - 128 + - 256 + - 512 + SeparateBNNecks: + class_num: 250 + +optimizer_cfg: + lr: 0.1 + momentum: 0.9 + solver: SGD + weight_decay: 0.0005 + +scheduler_cfg: + gamma: 0.1 + milestones: # Learning Rate Reduction at each milestones + - 20000 + - 30000 + - 40000 + scheduler: MultiStepLR + +trainer_cfg: + enable_float16: true # half_percesion float for memory reduction and speedup + fix_BN: false + with_test: true #true + log_iter: 100 + restore_ckpt_strict: true + restore_hint: 0 + save_iter: 10000 + save_name: DeepGaitV2 + sync_BN: true + total_iter: 50000 + sampler: + batch_shuffle: true + batch_size: + - 8 # TripletSampler, batch_size[0] indicates Number of Identity + - 8 # batch_size[1] indicates Samples sequqnce for each Identity + frames_num_fixed: 10 # fixed frames number for training + sample_type: fixed_unordered # fixed control input frames number, unordered for controlling order of input tensor; Other options: unfixed_ordered or all_ordered + type: TripletSampler + transform: + - type: Compose + trf_cfg: + - type: RandomPerspective + prob: 0.2 + - type: BaseSilCuttingTransform + - type: RandomHorizontalFlip + prob: 0.2 + - type: RandomRotate + prob: 0.2 diff --git a/opengait/modeling/models/deepgaitv2.py b/opengait/modeling/models/deepgaitv2.py new file mode 100644 index 0000000..4dc3493 --- /dev/null +++ b/opengait/modeling/models/deepgaitv2.py @@ -0,0 +1,129 @@ +import torch +import torch.nn as nn + +import os +import numpy as np +import os.path as osp +import matplotlib.pyplot as plt + +from ..base_model import BaseModel +from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper, SeparateFCs, SeparateBNNecks, conv1x1, conv3x3, BasicBlock2D, BasicBlockP3D, BasicBlock3D + +from einops import rearrange + +blocks_map = { + '2d': BasicBlock2D, + 'p3d': BasicBlockP3D, + '3d': BasicBlock3D +} + +class DeepGaitV2(BaseModel): + + def build_network(self, model_cfg): + mode = model_cfg['Backbone']['mode'] + assert mode in blocks_map.keys() + block = blocks_map[mode] + + in_channels = model_cfg['Backbone']['in_channels'] + layers = model_cfg['Backbone']['layers'] + channels = model_cfg['Backbone']['channels'] + + if mode == '3d': + strides = [ + [1, 1], + [1, 2, 2], + [1, 2, 2], + [1, 1, 1] + ] + else: + strides = [ + [1, 1], + [2, 2], + [2, 2], + [1, 1] + ] + + self.inplanes = channels[0] + self.layer0 = SetBlockWrapper(nn.Sequential( + conv3x3(in_channels, self.inplanes, 1), + nn.BatchNorm2d(self.inplanes), + nn.ReLU(inplace=True) + )) + self.layer1 = SetBlockWrapper(self.make_layer(BasicBlock2D, channels[0], strides[0], blocks_num=layers[0], mode=mode)) + + self.layer2 = self.make_layer(block, channels[1], strides[1], blocks_num=layers[1], mode=mode) + self.layer3 = self.make_layer(block, channels[2], strides[2], blocks_num=layers[2], mode=mode) + self.layer4 = self.make_layer(block, channels[3], strides[3], blocks_num=layers[3], mode=mode) + + if mode == '2d': + self.layer2 = SetBlockWrapper(self.layer2) + self.layer3 = SetBlockWrapper(self.layer3) + self.layer4 = SetBlockWrapper(self.layer4) + + self.FCs = SeparateFCs(16, channels[3], channels[2]) + self.BNNecks = SeparateBNNecks(16, channels[2], class_num=model_cfg['SeparateBNNecks']['class_num']) + + self.TP = PackSequenceWrapper(torch.max) + self.HPP = HorizontalPoolingPyramid(bin_num=[16]) + + def make_layer(self, block, planes, stride, blocks_num, mode='2d'): + + if max(stride) > 1 or self.inplanes != planes * block.expansion: + if mode == '3d': + downsample = nn.Sequential(nn.Conv3d(self.inplanes, planes * block.expansion, kernel_size=[1, 1, 1], stride=stride, padding=[0, 0, 0], bias=False), nn.BatchNorm3d(planes * block.expansion)) + elif mode == '2d': + downsample = nn.Sequential(conv1x1(self.inplanes, planes * block.expansion, stride=stride), nn.BatchNorm2d(planes * block.expansion)) + elif mode == 'p3d': + downsample = nn.Sequential(nn.Conv3d(self.inplanes, planes * block.expansion, kernel_size=[1, 1, 1], stride=[1, *stride], padding=[0, 0, 0], bias=False), nn.BatchNorm3d(planes * block.expansion)) + else: + raise TypeError('xxx') + else: + downsample = lambda x: x + + layers = [block(self.inplanes, planes, stride=stride, downsample=downsample)] + self.inplanes = planes * block.expansion + s = [1, 1] if mode in ['2d', 'p3d'] else [1, 1, 1] + for i in range(1, blocks_num): + layers.append( + block(self.inplanes, planes, stride=s) + ) + return nn.Sequential(*layers) + + def forward(self, inputs): + ipts, labs, typs, vies, seqL = inputs + + sils = ipts[0].unsqueeze(1) + assert sils.size(-1) in [44, 88] + + del ipts + out0 = self.layer0(sils) + out1 = self.layer1(out0) + out2 = self.layer2(out1) + out3 = self.layer3(out2) + out4 = self.layer4(out3) # [n, c, s, h, w] + + # Temporal Pooling, TP + outs = self.TP(out4, seqL, options={"dim": 2})[0] # [n, c, h, w] + + # Horizontal Pooling Matching, HPM + feat = self.HPP(outs) # [n, c, p] + + embed_1 = self.FCs(feat) # [n, c, p] + embed_2, logits = self.BNNecks(embed_1) # [n, c, p] + + embed = embed_1 + + retval = { + 'training_feat': { + 'triplet': {'embeddings': embed_1, 'labels': labs}, + 'softmax': {'logits': logits, 'labels': labs} + }, + 'visual_summary': { + 'image/sils': rearrange(sils, 'n c s h w -> (n s) c h w'), + }, + 'inference_feat': { + 'embeddings': embed + } + } + + return retval diff --git a/opengait/modeling/modules.py b/opengait/modeling/modules.py index 1cc2309..e1ee2b0 100644 --- a/opengait/modeling/modules.py +++ b/opengait/modeling/modules.py @@ -705,4 +705,159 @@ class ParallelBN1d(nn.Module): x = rearrange(x, 'n c p -> n (c p)') x = self.bn1d(x) x = rearrange(x, 'n (c p) -> n c p', p=self.parts_num) - return x \ No newline at end of file + return x + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + +class BasicBlock2D(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(BasicBlock2D, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError( + 'BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError( + "Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + +class BasicBlockP3D(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(BasicBlockP3D, self).__init__() + if norm_layer is None: + norm_layer2d = nn.BatchNorm2d + norm_layer3d = nn.BatchNorm3d + if groups != 1 or base_width != 64: + raise ValueError( + 'BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError( + "Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.relu = nn.ReLU(inplace=True) + + self.conv1 = SetBlockWrapper( + nn.Sequential( + conv3x3(inplanes, planes, stride), + norm_layer2d(planes), + nn.ReLU(inplace=True) + ) + ) + + self.conv2 = SetBlockWrapper( + nn.Sequential( + conv3x3(planes, planes), + norm_layer2d(planes), + ) + ) + + self.shortcut3d = nn.Conv3d(planes, planes, (3, 1, 1), (1, 1, 1), (1, 0, 0), bias=False) + self.sbn = norm_layer3d(planes) + + self.downsample = downsample + + def forward(self, x): + ''' + x: [n, c, s, h, w] + ''' + identity = x + + out = self.conv1(x) + out = self.relu(out + self.sbn(self.shortcut3d(out))) + out = self.conv2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + +class BasicBlock3D(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=[1, 1, 1], downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(BasicBlock3D, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm3d + if groups != 1 or base_width != 64: + raise ValueError( + 'BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError( + "Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + assert stride[0] in [1, 2, 3] + if stride[0] in [1, 2]: + tp = 1 + else: + tp = 0 + self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=(3, 3, 3), stride=stride, padding=[tp, 1, 1], bias=False) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv3d(planes, planes, kernel_size=(3, 3, 3), stride=[1, 1, 1], padding=[1, 1, 1], bias=False) + self.bn2 = norm_layer(planes) + self.downsample = downsample + + def forward(self, x): + ''' + x: [n, c, s, h, w] + ''' + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out