Update deepgaitv2.py

This commit is contained in:
Chao Fan
2024-03-19 14:53:22 +08:00
committed by GitHub
parent 45db009009
commit 6dfe5681ad
+5
View File
@@ -61,6 +61,7 @@ class DeepGaitV2(BaseModel):
self.layer3 = SetBlockWrapper(self.layer3)
self.layer4 = SetBlockWrapper(self.layer4)
self.mode = mode
self.FCs = SeparateFCs(16, channels[3], channels[2])
self.BNNecks = SeparateBNNecks(16, channels[2], class_num=model_cfg['SeparateBNNecks']['class_num'])
@@ -93,6 +94,10 @@ class DeepGaitV2(BaseModel):
def forward(self, inputs):
ipts, labs, typs, vies, seqL = inputs
if not self.training and len(labs) != 1 and self.mode != '2d':
raise ValueError(
'The input size of each GPU must be 1 in testing mode, but got {}!'.format(len(labs)))
if len(ipts[0].size()) == 4:
sils = ipts[0].unsqueeze(1)
else: