SkeletonGait
This commit is contained in:
@@ -27,6 +27,7 @@ class DeepGaitV2(BaseModel):
|
||||
in_channels = model_cfg['Backbone']['in_channels']
|
||||
layers = model_cfg['Backbone']['layers']
|
||||
channels = model_cfg['Backbone']['channels']
|
||||
self.inference_use_emb2 = model_cfg['use_emb2'] if 'use_emb2' in model_cfg else False
|
||||
|
||||
if mode == '3d':
|
||||
strides = [
|
||||
@@ -92,7 +93,11 @@ class DeepGaitV2(BaseModel):
|
||||
def forward(self, inputs):
|
||||
ipts, labs, typs, vies, seqL = inputs
|
||||
|
||||
sils = ipts[0].unsqueeze(1)
|
||||
if len(ipts[0].size()) == 4:
|
||||
sils = ipts[0].unsqueeze(1)
|
||||
else:
|
||||
sils = ipts[0]
|
||||
sils = sils.transpose(1, 2).contiguous()
|
||||
assert sils.size(-1) in [44, 88]
|
||||
|
||||
del ipts
|
||||
@@ -111,7 +116,10 @@ class DeepGaitV2(BaseModel):
|
||||
embed_1 = self.FCs(feat) # [n, c, p]
|
||||
embed_2, logits = self.BNNecks(embed_1) # [n, c, p]
|
||||
|
||||
embed = embed_1
|
||||
if self.inference_use_emb2:
|
||||
embed = embed_2
|
||||
else:
|
||||
embed = embed_1
|
||||
|
||||
retval = {
|
||||
'training_feat': {
|
||||
|
||||
Reference in New Issue
Block a user