from torch.nn import functional as F import torch.nn as nn from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet from ..modules import BasicConv2d block_map = {'BasicBlock': BasicBlock, 'Bottleneck': Bottleneck} class ResNet9(ResNet): def __init__(self, block, channels=[32, 64, 128, 256], in_channel=1, layers=[1, 2, 2, 1], strides=[1, 2, 2, 1], maxpool=True): if block in block_map.keys(): block = block_map[block] else: raise ValueError( "Error type for -block-Cfg-, supported: 'BasicBlock' or 'Bottleneck'.") self.maxpool_flag = maxpool super(ResNet9, self).__init__(block, layers) # Not used # self.fc = None ############ self.inplanes = channels[0] self.bn1 = nn.BatchNorm2d(self.inplanes) self.conv1 = BasicConv2d(in_channel, self.inplanes, 3, 1, 1) self.layer1 = self._make_layer( block, channels[0], layers[0], stride=strides[0], dilate=False) self.layer2 = self._make_layer( block, channels[1], layers[1], stride=strides[1], dilate=False) self.layer3 = self._make_layer( block, channels[2], layers[2], stride=strides[2], dilate=False) self.layer4 = self._make_layer( block, channels[3], layers[3], stride=strides[3], dilate=False) def _make_layer(self, block, planes, blocks, stride=1, dilate=False): if blocks >= 1: layer = super()._make_layer(block, planes, blocks, stride=stride, dilate=dilate) else: def layer(x): return x return layer def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) if self.maxpool_flag: x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) return x