Update multigait++.py

This commit is contained in:
Dongyang Jin
2025-02-12 14:18:34 +08:00
committed by GitHub
parent bc9b17e6c5
commit 85887ad93b
+4 -4
View File
@@ -54,7 +54,7 @@ class MultiGaitpp(BaseModel):
self.part2_layer3 = copy.deepcopy(self.part1_layer3) self.part2_layer3 = copy.deepcopy(self.part1_layer3)
self.layer3 = copy.deepcopy(self.part1_layer3) self.layer3 = copy.deepcopy(self.part1_layer3)
self.layer4 = self.make_layer(BasicBlockP3D, 256 * C, stride=[1, 1], blocks_num=B[3], mode='p3d') self.layer4 = self.make_layer(BasicBlockP3D, 256 * C, stride=[1, 1], blocks_num=B[3], mode='p3d')
self.crossattn1 = CrossAttention(64) self.csquare = CSquare(64)
self.FCs = SeparateFCs(16, 256*C, 128*C) self.FCs = SeparateFCs(16, 256*C, 128*C)
@@ -101,7 +101,7 @@ class MultiGaitpp(BaseModel):
part2 = self.part2_layer1(part2) part2 = self.part2_layer1(part2)
part1 = self.part1_layer0(part1) part1 = self.part1_layer0(part1)
part1 = self.part1_layer1(part1) part1 = self.part1_layer1(part1)
out, attn1, attn2, attn_co = self.crossattn1(part2,part1) out, attn1, attn2, attn_co = self.csquare(part2,part1)
part2 = self.part2_layer2(part2*attn1) part2 = self.part2_layer2(part2*attn1)
part1 = self.part1_layer2(part1*attn2) part1 = self.part1_layer2(part1*attn2)
@@ -157,9 +157,9 @@ class CatFusion(nn.Module):
return retun return retun
class CrossAttention(nn.Module): class CSquare(nn.Module):
def __init__(self, in_channels=64, squeeze_ratio=16, h=32, w=22): def __init__(self, in_channels=64, squeeze_ratio=16, h=32, w=22):
super(CrossAttention, self).__init__() super(CSquare, self).__init__()
hidden_dim = int(in_channels / squeeze_ratio) hidden_dim = int(in_channels / squeeze_ratio)
self.TP_mean = PackSequenceWrapper(torch.mean) self.TP_mean = PackSequenceWrapper(torch.mean)
self.conv2 = SetBlockWrapper(nn.Sequential( self.conv2 = SetBlockWrapper(nn.Sequential(