173 lines
7.1 KiB
Python
173 lines
7.1 KiB
Python
import torch
|
|
import copy
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from ..base_model import BaseModel
|
|
from ..modules import SeparateFCs, BasicConv2d, SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper
|
|
|
|
|
|
class GLN(BaseModel):
|
|
"""
|
|
http://home.ustc.edu.cn/~saihui/papers/eccv2020_gln.pdf
|
|
Gait Lateral Network: Learning Discriminative and Compact Representations for Gait Recognition
|
|
"""
|
|
|
|
def build_network(self, model_cfg):
|
|
in_channels = model_cfg['in_channels']
|
|
self.bin_num = model_cfg['bin_num']
|
|
self.hidden_dim = model_cfg['hidden_dim']
|
|
lateral_dim = model_cfg['lateral_dim']
|
|
reduce_dim = self.hidden_dim
|
|
self.pretrain = model_cfg['Lateral_pretraining']
|
|
|
|
self.sil_stage_0 = nn.Sequential(BasicConv2d(in_channels[0], in_channels[1], 5, 1, 2),
|
|
nn.LeakyReLU(inplace=True),
|
|
BasicConv2d(
|
|
in_channels[1], in_channels[1], 3, 1, 1),
|
|
nn.LeakyReLU(inplace=True))
|
|
|
|
self.sil_stage_1 = nn.Sequential(BasicConv2d(in_channels[1], in_channels[2], 3, 1, 1),
|
|
nn.LeakyReLU(inplace=True),
|
|
BasicConv2d(
|
|
in_channels[2], in_channels[2], 3, 1, 1),
|
|
nn.LeakyReLU(inplace=True))
|
|
|
|
self.sil_stage_2 = nn.Sequential(BasicConv2d(in_channels[2], in_channels[3], 3, 1, 1),
|
|
nn.LeakyReLU(inplace=True),
|
|
BasicConv2d(
|
|
in_channels[3], in_channels[3], 3, 1, 1),
|
|
nn.LeakyReLU(inplace=True))
|
|
|
|
self.set_stage_1 = copy.deepcopy(self.sil_stage_1)
|
|
self.set_stage_2 = copy.deepcopy(self.sil_stage_2)
|
|
|
|
self.set_pooling = PackSequenceWrapper(torch.max)
|
|
|
|
self.MaxP_sil = SetBlockWrapper(nn.MaxPool2d(kernel_size=2, stride=2))
|
|
self.MaxP_set = nn.MaxPool2d(kernel_size=2, stride=2)
|
|
|
|
self.sil_stage_0 = SetBlockWrapper(self.sil_stage_0)
|
|
self.sil_stage_1 = SetBlockWrapper(self.sil_stage_1)
|
|
self.sil_stage_2 = SetBlockWrapper(self.sil_stage_2)
|
|
|
|
self.lateral_layer1 = nn.Conv2d(
|
|
in_channels[1]*2, lateral_dim, kernel_size=1, stride=1, padding=0, bias=False)
|
|
self.lateral_layer2 = nn.Conv2d(
|
|
in_channels[2]*2, lateral_dim, kernel_size=1, stride=1, padding=0, bias=False)
|
|
self.lateral_layer3 = nn.Conv2d(
|
|
in_channels[3]*2, lateral_dim, kernel_size=1, stride=1, padding=0, bias=False)
|
|
|
|
self.smooth_layer1 = nn.Conv2d(
|
|
lateral_dim, lateral_dim, kernel_size=3, stride=1, padding=1, bias=False)
|
|
self.smooth_layer2 = nn.Conv2d(
|
|
lateral_dim, lateral_dim, kernel_size=3, stride=1, padding=1, bias=False)
|
|
self.smooth_layer3 = nn.Conv2d(
|
|
lateral_dim, lateral_dim, kernel_size=3, stride=1, padding=1, bias=False)
|
|
|
|
self.HPP = HorizontalPoolingPyramid()
|
|
self.Head = SeparateFCs(**model_cfg['SeparateFCs'])
|
|
|
|
if not self.pretrain:
|
|
self.encoder_bn = nn.BatchNorm1d(sum(self.bin_num)*3*self.hidden_dim)
|
|
self.encoder_bn.bias.requires_grad_(False)
|
|
|
|
self.reduce_dp = nn.Dropout(p=model_cfg['dropout'])
|
|
self.reduce_ac = nn.ReLU(inplace=True)
|
|
self.reduce_fc = nn.Linear(sum(self.bin_num)*3*self.hidden_dim, reduce_dim, bias=False)
|
|
|
|
self.reduce_bn = nn.BatchNorm1d(reduce_dim)
|
|
self.reduce_bn.bias.requires_grad_(False)
|
|
|
|
self.reduce_cls = nn.Linear(reduce_dim, model_cfg['class_num'], bias=False)
|
|
|
|
def upsample_add(self, x, y):
|
|
return F.interpolate(x, scale_factor=2, mode='nearest') + y
|
|
|
|
def forward(self, inputs):
|
|
ipts, labs, _, _, seqL = inputs
|
|
seqL = None if not self.training else seqL
|
|
sils = ipts[0] # [n, s, h, w]
|
|
del ipts
|
|
if len(sils.size()) == 4:
|
|
sils = sils.unsqueeze(2)
|
|
n, s, _, h, w = sils.size()
|
|
|
|
### stage 0 sil ###
|
|
sil_0_outs = self.sil_stage_0(sils)
|
|
stage_0_sil_set = self.set_pooling(sil_0_outs, seqL, dim=1)[0]
|
|
|
|
### stage 1 sil ###
|
|
sil_1_ipts = self.MaxP_sil(sil_0_outs)
|
|
sil_1_outs = self.sil_stage_1(sil_1_ipts)
|
|
|
|
### stage 2 sil ###
|
|
sil_2_ipts = self.MaxP_sil(sil_1_outs)
|
|
sil_2_outs = self.sil_stage_2(sil_2_ipts)
|
|
|
|
### stage 1 set ###
|
|
set_1_ipts = self.set_pooling(sil_1_ipts, seqL, dim=1)[0]
|
|
stage_1_sil_set = self.set_pooling(sil_1_outs, seqL, dim=1)[0]
|
|
set_1_outs = self.set_stage_1(set_1_ipts) + stage_1_sil_set
|
|
|
|
### stage 2 set ###
|
|
set_2_ipts = self.MaxP_set(set_1_outs)
|
|
stage_2_sil_set = self.set_pooling(sil_2_outs, seqL, dim=1)[0]
|
|
set_2_outs = self.set_stage_2(set_2_ipts) + stage_2_sil_set
|
|
|
|
set1 = torch.cat((stage_0_sil_set, stage_0_sil_set), dim=1)
|
|
set2 = torch.cat((stage_1_sil_set, set_1_outs), dim=1)
|
|
set3 = torch.cat((stage_2_sil_set, set_2_outs), dim=1)
|
|
|
|
# print(set1.shape,set2.shape,set3.shape,"***\n")
|
|
|
|
# lateral
|
|
set3 = self.lateral_layer3(set3)
|
|
set2 = self.upsample_add(set3, self.lateral_layer2(set2))
|
|
set1 = self.upsample_add(set2, self.lateral_layer1(set1))
|
|
|
|
set3 = self.smooth_layer3(set3)
|
|
set2 = self.smooth_layer2(set2)
|
|
set1 = self.smooth_layer1(set1)
|
|
|
|
set1 = self.HPP(set1)
|
|
set2 = self.HPP(set2)
|
|
set3 = self.HPP(set3)
|
|
|
|
feature = torch.cat([set1, set2, set3], -
|
|
1).permute(2, 0, 1).contiguous()
|
|
|
|
feature = self.Head(feature)
|
|
feature = feature.permute(1, 0, 2).contiguous() # n p c
|
|
|
|
# compact_bloack
|
|
if not self.pretrain:
|
|
bn_feature = self.encoder_bn(feature.view(n, -1))
|
|
bn_feature = bn_feature.view(*feature.shape).contiguous()
|
|
|
|
reduce_feature = self.reduce_dp(bn_feature)
|
|
reduce_feature = self.reduce_ac(reduce_feature)
|
|
reduce_feature = self.reduce_fc(reduce_feature.view(n, -1))
|
|
|
|
bn_reduce_feature = self.reduce_bn(reduce_feature)
|
|
logits = self.reduce_cls(bn_reduce_feature).unsqueeze(1) # n c
|
|
|
|
reduce_feature = reduce_feature.unsqueeze(1).contiguous()
|
|
bn_reduce_feature = bn_reduce_feature.unsqueeze(1).contiguous()
|
|
|
|
retval = {
|
|
'training_feat': {},
|
|
'visual_summary': {
|
|
'image/sils': sils.view(n*s, 1, h, w)
|
|
},
|
|
'inference_feat': {
|
|
'embeddings': feature # reduce_feature # bn_reduce_feature
|
|
}
|
|
}
|
|
if self.pretrain:
|
|
retval['training_feat']['triplet'] = {'embeddings': feature, 'labels': labs}
|
|
else:
|
|
retval['training_feat']['triplet'] = {'embeddings': feature, 'labels': labs}
|
|
retval['training_feat']['softmax'] = {'logits': logits, 'labels': labs}
|
|
return retval
|