Update multigait++.py
This commit is contained in:
@@ -54,7 +54,7 @@ class MultiGaitpp(BaseModel):
|
||||
self.part2_layer3 = copy.deepcopy(self.part1_layer3)
|
||||
self.layer3 = copy.deepcopy(self.part1_layer3)
|
||||
self.layer4 = self.make_layer(BasicBlockP3D, 256 * C, stride=[1, 1], blocks_num=B[3], mode='p3d')
|
||||
self.crossattn1 = CrossAttention(64)
|
||||
self.csquare = CSquare(64)
|
||||
|
||||
|
||||
self.FCs = SeparateFCs(16, 256*C, 128*C)
|
||||
@@ -101,7 +101,7 @@ class MultiGaitpp(BaseModel):
|
||||
part2 = self.part2_layer1(part2)
|
||||
part1 = self.part1_layer0(part1)
|
||||
part1 = self.part1_layer1(part1)
|
||||
out, attn1, attn2, attn_co = self.crossattn1(part2,part1)
|
||||
out, attn1, attn2, attn_co = self.csquare(part2,part1)
|
||||
|
||||
part2 = self.part2_layer2(part2*attn1)
|
||||
part1 = self.part1_layer2(part1*attn2)
|
||||
@@ -157,9 +157,9 @@ class CatFusion(nn.Module):
|
||||
return retun
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
class CSquare(nn.Module):
|
||||
def __init__(self, in_channels=64, squeeze_ratio=16, h=32, w=22):
|
||||
super(CrossAttention, self).__init__()
|
||||
super(CSquare, self).__init__()
|
||||
hidden_dim = int(in_channels / squeeze_ratio)
|
||||
self.TP_mean = PackSequenceWrapper(torch.mean)
|
||||
self.conv2 = SetBlockWrapper(nn.Sequential(
|
||||
|
||||
Reference in New Issue
Block a user