OpenGait release(pre-beta version).
This commit is contained in:
@@ -0,0 +1,87 @@
|
||||
import torch
|
||||
import copy
|
||||
import torch.nn as nn
|
||||
|
||||
from ..base_model import BaseModel
|
||||
from ..modules import SeparateFCs, BasicConv2d, SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper
|
||||
|
||||
|
||||
class GaitSet(BaseModel):
|
||||
"""
|
||||
GaitSet: Regarding Gait as a Set for Cross-View Gait Recognition
|
||||
Arxiv: https://arxiv.org/abs/1811.06186
|
||||
Github: https://github.com/AbnerHqC/GaitSet
|
||||
"""
|
||||
|
||||
def build_network(self, model_cfg):
|
||||
in_c = model_cfg['in_channels']
|
||||
self.set_block1 = nn.Sequential(BasicConv2d(in_c[0], in_c[1], 5, 1, 2),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
BasicConv2d(in_c[1], in_c[1], 3, 1, 1),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2))
|
||||
|
||||
self.set_block2 = nn.Sequential(BasicConv2d(in_c[1], in_c[2], 3, 1, 1),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
BasicConv2d(in_c[2], in_c[2], 3, 1, 1),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2))
|
||||
|
||||
self.set_block3 = nn.Sequential(BasicConv2d(in_c[2], in_c[3], 3, 1, 1),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
BasicConv2d(in_c[3], in_c[3], 3, 1, 1),
|
||||
nn.LeakyReLU(inplace=True))
|
||||
|
||||
self.gl_block2 = copy.deepcopy(self.set_block2)
|
||||
self.gl_block3 = copy.deepcopy(self.set_block3)
|
||||
|
||||
self.set_block1 = SetBlockWrapper(self.set_block1)
|
||||
self.set_block2 = SetBlockWrapper(self.set_block2)
|
||||
self.set_block3 = SetBlockWrapper(self.set_block3)
|
||||
|
||||
self.set_pooling = PackSequenceWrapper(torch.max)
|
||||
|
||||
self.Head = SeparateFCs(**model_cfg['SeparateFCs'])
|
||||
|
||||
self.HPP = HorizontalPoolingPyramid(bin_num=model_cfg['bin_num'])
|
||||
|
||||
def forward(self, inputs):
|
||||
ipts, labs, _, _, seqL = inputs
|
||||
sils = ipts[0] # [n, s, h, w]
|
||||
if len(sils.size()) == 4:
|
||||
sils = sils.unsqueeze(2)
|
||||
|
||||
del ipts
|
||||
outs = self.set_block1(sils)
|
||||
gl = self.set_pooling(outs, seqL, dim=1)[0]
|
||||
gl = self.gl_block2(gl)
|
||||
|
||||
outs = self.set_block2(outs)
|
||||
gl = gl + self.set_pooling(outs, seqL, dim=1)[0]
|
||||
gl = self.gl_block3(gl)
|
||||
|
||||
outs = self.set_block3(outs)
|
||||
outs = self.set_pooling(outs, seqL, dim=1)[0]
|
||||
gl = gl + outs
|
||||
|
||||
# Horizontal Pooling Matching, HPM
|
||||
feature1 = self.HPP(outs) # [n, c, p]
|
||||
feature2 = self.HPP(gl) # [n, c, p]
|
||||
feature = torch.cat([feature1, feature2], -1) # [n, c, p]
|
||||
feature = feature.permute(2, 0, 1).contiguous() # [p, n, c]
|
||||
embs = self.Head(feature)
|
||||
embs = embs.permute(1, 0, 2).contiguous() # [n, p, c]
|
||||
|
||||
n, s, _, h, w = sils.size()
|
||||
retval = {
|
||||
'training_feat': {
|
||||
'triplet': {'embeddings': embs, 'labels': labs}
|
||||
},
|
||||
'visual_summary': {
|
||||
'image/sils': sils.view(n*s, 1, h, w)
|
||||
},
|
||||
'inference_feat': {
|
||||
'embeddings': embs
|
||||
}
|
||||
}
|
||||
return retval
|
||||
Reference in New Issue
Block a user