deepgaitv2
This commit is contained in:
@@ -705,4 +705,159 @@ class ParallelBN1d(nn.Module):
|
||||
x = rearrange(x, 'n c p -> n (c p)')
|
||||
x = self.bn1d(x)
|
||||
x = rearrange(x, 'n (c p) -> n c p', p=self.parts_num)
|
||||
return x
|
||||
return x
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
class BasicBlock2D(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||
base_width=64, dilation=1, norm_layer=None):
|
||||
super(BasicBlock2D, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError(
|
||||
'BasicBlock only supports groups=1 and base_width=64')
|
||||
if dilation > 1:
|
||||
raise NotImplementedError(
|
||||
"Dilation > 1 not supported in BasicBlock")
|
||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = norm_layer(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = norm_layer(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
class BasicBlockP3D(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||
base_width=64, dilation=1, norm_layer=None):
|
||||
super(BasicBlockP3D, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer2d = nn.BatchNorm2d
|
||||
norm_layer3d = nn.BatchNorm3d
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError(
|
||||
'BasicBlock only supports groups=1 and base_width=64')
|
||||
if dilation > 1:
|
||||
raise NotImplementedError(
|
||||
"Dilation > 1 not supported in BasicBlock")
|
||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
self.conv1 = SetBlockWrapper(
|
||||
nn.Sequential(
|
||||
conv3x3(inplanes, planes, stride),
|
||||
norm_layer2d(planes),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
)
|
||||
|
||||
self.conv2 = SetBlockWrapper(
|
||||
nn.Sequential(
|
||||
conv3x3(planes, planes),
|
||||
norm_layer2d(planes),
|
||||
)
|
||||
)
|
||||
|
||||
self.shortcut3d = nn.Conv3d(planes, planes, (3, 1, 1), (1, 1, 1), (1, 0, 0), bias=False)
|
||||
self.sbn = norm_layer3d(planes)
|
||||
|
||||
self.downsample = downsample
|
||||
|
||||
def forward(self, x):
|
||||
'''
|
||||
x: [n, c, s, h, w]
|
||||
'''
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.relu(out + self.sbn(self.shortcut3d(out)))
|
||||
out = self.conv2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
class BasicBlock3D(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=[1, 1, 1], downsample=None, groups=1,
|
||||
base_width=64, dilation=1, norm_layer=None):
|
||||
super(BasicBlock3D, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm3d
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError(
|
||||
'BasicBlock only supports groups=1 and base_width=64')
|
||||
if dilation > 1:
|
||||
raise NotImplementedError(
|
||||
"Dilation > 1 not supported in BasicBlock")
|
||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||
assert stride[0] in [1, 2, 3]
|
||||
if stride[0] in [1, 2]:
|
||||
tp = 1
|
||||
else:
|
||||
tp = 0
|
||||
self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=(3, 3, 3), stride=stride, padding=[tp, 1, 1], bias=False)
|
||||
self.bn1 = norm_layer(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = nn.Conv3d(planes, planes, kernel_size=(3, 3, 3), stride=[1, 1, 1], padding=[1, 1, 1], bias=False)
|
||||
self.bn2 = norm_layer(planes)
|
||||
self.downsample = downsample
|
||||
|
||||
def forward(self, x):
|
||||
'''
|
||||
x: [n, c, s, h, w]
|
||||
'''
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user