GaitSSB@Pretrain release

This commit is contained in:
jdyjjj
2023-11-20 20:28:09 +08:00
parent 476c4adbe3
commit b24e797486
6 changed files with 506 additions and 4 deletions
+16 -1
View File
@@ -690,4 +690,19 @@ class SpatialAttention(nn.Module):
batch, Nh, dv, T, V = x.size()
ret_shape = (batch, Nh * dv, T, V)
return torch.reshape(x, ret_shape)
from einops import rearrange
class ParallelBN1d(nn.Module):
def __init__(self, parts_num, in_channels, **kwargs):
super(ParallelBN1d, self).__init__()
self.parts_num = parts_num
self.bn1d = nn.BatchNorm1d(in_channels * parts_num, **kwargs)
def forward(self, x):
'''
x: [n, c, p]
'''
x = rearrange(x, 'n c p -> n (c p)')
x = self.bn1d(x)
x = rearrange(x, 'n (c p) -> n c p', p=self.parts_num)
return x