GaitSSB@Pretrain release
This commit is contained in:
@@ -690,4 +690,19 @@ class SpatialAttention(nn.Module):
|
||||
batch, Nh, dv, T, V = x.size()
|
||||
ret_shape = (batch, Nh * dv, T, V)
|
||||
return torch.reshape(x, ret_shape)
|
||||
|
||||
|
||||
from einops import rearrange
|
||||
class ParallelBN1d(nn.Module):
|
||||
def __init__(self, parts_num, in_channels, **kwargs):
|
||||
super(ParallelBN1d, self).__init__()
|
||||
self.parts_num = parts_num
|
||||
self.bn1d = nn.BatchNorm1d(in_channels * parts_num, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
'''
|
||||
x: [n, c, p]
|
||||
'''
|
||||
x = rearrange(x, 'n c p -> n (c p)')
|
||||
x = self.bn1d(x)
|
||||
x = rearrange(x, 'n (c p) -> n c p', p=self.parts_num)
|
||||
return x
|
||||
Reference in New Issue
Block a user