102 lines
3.5 KiB
Python
102 lines
3.5 KiB
Python
'''
|
|
Modifed from https://github.com/Gait3D/Gait3D-Benchmark/blob/72beab994c137b902d826f4b9f9e95b107bebd78/lib/modeling/models/smplgait.py
|
|
'''
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.autograd import Variable
|
|
|
|
from ..base_model import BaseModel
|
|
from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper, SeparateFCs, SeparateBNNecks
|
|
|
|
|
|
class SMPLGait(BaseModel):
|
|
def __init__(self, cfgs, is_training):
|
|
super().__init__(cfgs, is_training)
|
|
|
|
def build_network(self, model_cfg):
|
|
# Baseline
|
|
self.Backbone = self.get_backbone(model_cfg['backbone_cfg'])
|
|
self.Backbone = SetBlockWrapper(self.Backbone)
|
|
self.FCs = SeparateFCs(**model_cfg['SeparateFCs'])
|
|
self.BNNecks = SeparateBNNecks(**model_cfg['SeparateBNNecks'])
|
|
self.TP = PackSequenceWrapper(torch.max)
|
|
self.HPP = HorizontalPoolingPyramid(bin_num=model_cfg['bin_num'])
|
|
|
|
# for SMPL
|
|
self.fc1 = nn.Linear(85, 128)
|
|
self.fc2 = nn.Linear(128, 256)
|
|
self.fc3 = nn.Linear(256, 256)
|
|
self.bn1 = nn.BatchNorm1d(128)
|
|
self.bn2 = nn.BatchNorm1d(256)
|
|
self.bn3 = nn.BatchNorm1d(256)
|
|
self.dropout2 = nn.Dropout(p=0.2)
|
|
self.dropout3 = nn.Dropout(p=0.2)
|
|
|
|
def forward(self, inputs):
|
|
ipts, labs, _, _, seqL = inputs
|
|
|
|
sils = ipts[0] # [n, s, h, w]
|
|
smpls = ipts[1] # [n, s, d]
|
|
|
|
# extract SMPL features
|
|
n, s, d = smpls.size()
|
|
sps = smpls.view(-1, d)
|
|
del smpls
|
|
|
|
sps = F.relu(self.bn1(self.fc1(sps)))
|
|
sps = F.relu(self.bn2(self.dropout2(self.fc2(sps)))) # (B, 256)
|
|
sps = F.relu(self.bn3(self.dropout3(self.fc3(sps)))) # (B, 256)
|
|
sps = sps.reshape(n, 1, s, 16, 16)
|
|
iden = Variable(torch.eye(16)).unsqueeze(
|
|
0).repeat(n, 1, s, 1, 1) # [n, 1, s, 16, 16]
|
|
if sps.is_cuda:
|
|
iden = iden.cuda()
|
|
sps_trans = sps + iden # [n, 1, s, 16, 16]
|
|
|
|
if len(sils.size()) == 4:
|
|
sils = sils.unsqueeze(1)
|
|
|
|
del ipts
|
|
outs = self.Backbone(sils) # [n, c, s, h, w]
|
|
outs_n, outs_c, outs_s, outs_h, outs_w = outs.size()
|
|
|
|
zero_tensor = Variable(torch.zeros(
|
|
(outs_n, outs_c, outs_s, outs_h, outs_h-outs_w)))
|
|
if outs.is_cuda:
|
|
zero_tensor = zero_tensor.cuda()
|
|
# [n, s, c, h, h] [n, s, c, 16, 16]
|
|
outs = torch.cat([outs, zero_tensor], -1)
|
|
outs = outs.reshape(outs_n*outs_c*outs_s, outs_h,
|
|
outs_h) # [n*c*s, 16, 16]
|
|
|
|
sps = sps_trans.repeat(1, outs_c, 1, 1, 1).reshape(
|
|
outs_n * outs_c * outs_s, 16, 16)
|
|
|
|
outs_trans = torch.bmm(outs, sps)
|
|
outs_trans = outs_trans.reshape(outs_n, outs_c, outs_s, outs_h, outs_h)
|
|
|
|
# Temporal Pooling, TP
|
|
outs_trans = self.TP(outs_trans, seqL, options={"dim": 2})[
|
|
0] # [n, c, h, w]
|
|
# Horizontal Pooling Matching, HPM
|
|
feat = self.HPP(outs_trans) # [n, c, p]
|
|
embed_1 = self.FCs(feat) # [n, c, p]
|
|
|
|
embed_2, logits = self.BNNecks(embed_1) # [n, c, p]
|
|
|
|
n, _, s, h, w = sils.size()
|
|
retval = {
|
|
'training_feat': {
|
|
'triplet': {'embeddings': embed_1, 'labels': labs},
|
|
'softmax': {'logits': logits, 'labels': labs}
|
|
},
|
|
'visual_summary': {
|
|
'image/sils': sils.view(n*s, 1, h, w)
|
|
},
|
|
'inference_feat': {
|
|
'embeddings': embed_1
|
|
}
|
|
}
|
|
return retval
|