638c657763
* add resnet9 backbone and regular da ops * add gait3d config * fix invalid path CASIA-B* in windows * add gaitbase config for all datasets * rm unused OpenGait transform
59 lines
1.9 KiB
Python
59 lines
1.9 KiB
Python
from torch.nn import functional as F
|
|
import torch.nn as nn
|
|
from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet
|
|
from ..modules import BasicConv2d
|
|
|
|
|
|
block_map = {'BasicBlock': BasicBlock,
|
|
'Bottleneck': Bottleneck}
|
|
|
|
|
|
class ResNet9(ResNet):
|
|
def __init__(self, block, channels=[32, 64, 128, 256], in_channel=1, layers=[1, 2, 2, 1], strides=[1, 2, 2, 1], maxpool=True):
|
|
if block in block_map.keys():
|
|
block = block_map[block]
|
|
else:
|
|
raise ValueError(
|
|
"Error type for -block-Cfg-, supported: 'BasicBlock' or 'Bottleneck'.")
|
|
self.maxpool_flag = maxpool
|
|
super(ResNet9, self).__init__(block, layers)
|
|
|
|
# Not used #
|
|
self.fc = None
|
|
############
|
|
self.inplanes = channels[0]
|
|
self.bn1 = nn.BatchNorm2d(self.inplanes)
|
|
|
|
self.conv1 = BasicConv2d(in_channel, self.inplanes, 3, 1, 1)
|
|
|
|
self.layer1 = self._make_layer(
|
|
block, channels[0], layers[0], stride=strides[0], dilate=False)
|
|
|
|
self.layer2 = self._make_layer(
|
|
block, channels[1], layers[1], stride=strides[1], dilate=False)
|
|
self.layer3 = self._make_layer(
|
|
block, channels[2], layers[2], stride=strides[2], dilate=False)
|
|
self.layer4 = self._make_layer(
|
|
block, channels[3], layers[3], stride=strides[3], dilate=False)
|
|
|
|
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
|
if blocks >= 1:
|
|
layer = super()._make_layer(block, planes, blocks, stride=stride, dilate=dilate)
|
|
else:
|
|
def layer(x): return x
|
|
return layer
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
x = self.bn1(x)
|
|
x = self.relu(x)
|
|
if self.maxpool_flag:
|
|
x = self.maxpool(x)
|
|
|
|
x = self.layer1(x)
|
|
x = self.layer2(x)
|
|
x = self.layer3(x)
|
|
x = self.layer4(x)
|
|
return x
|
|
|