OpenGait release(pre-beta version).
This commit is contained in:
@@ -0,0 +1,17 @@
|
||||
from inspect import isclass
|
||||
from pkgutil import iter_modules
|
||||
from pathlib import Path
|
||||
from importlib import import_module
|
||||
|
||||
# iterate through the modules in the current package
|
||||
package_dir = Path(__file__).resolve().parent
|
||||
for (_, module_name, _) in iter_modules([package_dir]):
|
||||
|
||||
# import the module and iterate through its attributes
|
||||
module = import_module(f"{__name__}.{module_name}")
|
||||
for attribute_name in dir(module):
|
||||
attribute = getattr(module, attribute_name)
|
||||
|
||||
if isclass(attribute):
|
||||
# Add the class to this package's variables
|
||||
globals()[attribute_name] = attribute
|
||||
@@ -0,0 +1,56 @@
|
||||
import torch
|
||||
|
||||
from ..base_model import BaseModel
|
||||
from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper, SeparateFCs, SeparateBNNecks
|
||||
|
||||
|
||||
class Baseline(BaseModel):
|
||||
def __init__(self, cfgs, is_training):
|
||||
super().__init__(cfgs, is_training)
|
||||
|
||||
def build_network(self, model_cfg):
|
||||
self.Backbone = self.get_backbone(model_cfg)
|
||||
self.Backbone = SetBlockWrapper(self.Backbone)
|
||||
self.FCs = SeparateFCs(**model_cfg['SeparateFCs'])
|
||||
self.BNNecks = SeparateBNNecks(**model_cfg['SeparateBNNecks'])
|
||||
self.TP = PackSequenceWrapper(torch.max)
|
||||
self.HPP = HorizontalPoolingPyramid(bin_num=model_cfg['bin_num'])
|
||||
|
||||
def forward(self, inputs):
|
||||
ipts, labs, _, _, seqL = inputs
|
||||
|
||||
sils = ipts[0]
|
||||
if len(sils.size()) == 4:
|
||||
sils = sils.unsqueeze(2)
|
||||
|
||||
del ipts
|
||||
outs = self.Backbone(sils) # [n, s, c, h, w]
|
||||
|
||||
# Temporal Pooling, TP
|
||||
outs = self.TP(outs, seqL, dim=1)[0] # [n, c, h, w]
|
||||
# Horizontal Pooling Matching, HPM
|
||||
feat = self.HPP(outs) # [n, c, p]
|
||||
feat = feat.permute(2, 0, 1).contiguous() # [p, n, c]
|
||||
|
||||
embed_1 = self.FCs(feat) # [p, n, c]
|
||||
embed_2, logits = self.BNNecks(embed_1) # [p, n, c]
|
||||
|
||||
embed_1 = embed_1.permute(1, 0, 2).contiguous() # [n, p, c]
|
||||
embed_2 = embed_2.permute(1, 0, 2).contiguous() # [n, p, c]
|
||||
logits = logits.permute(1, 0, 2).contiguous() # [n, p, c]
|
||||
embed = embed_1
|
||||
|
||||
n, s, _, h, w = sils.size()
|
||||
retval = {
|
||||
'training_feat': {
|
||||
'triplet': {'embeddings': embed_1, 'labels': labs},
|
||||
'softmax': {'logits': logits, 'labels': labs}
|
||||
},
|
||||
'visual_summary': {
|
||||
'image/sils': sils.view(n*s, 1, h, w)
|
||||
},
|
||||
'inference_feat': {
|
||||
'embeddings': embed
|
||||
}
|
||||
}
|
||||
return retval
|
||||
@@ -0,0 +1,151 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..base_model import BaseModel
|
||||
from ..modules import SeparateFCs, BasicConv3d, PackSequenceWrapper
|
||||
|
||||
|
||||
class GLConv(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, halving, fm_sign=False, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False, **kwargs):
|
||||
super(GLConv, self).__init__()
|
||||
self.halving = halving
|
||||
self.fm_sign = fm_sign
|
||||
self.global_conv3d = BasicConv3d(
|
||||
in_channels, out_channels, kernel_size, stride, padding, bias, **kwargs)
|
||||
self.local_conv3d = BasicConv3d(
|
||||
in_channels, out_channels, kernel_size, stride, padding, bias, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
'''
|
||||
x: [n, c, s, h, w]
|
||||
'''
|
||||
gob_feat = self.global_conv3d(x)
|
||||
if self.halving == 0:
|
||||
lcl_feat = self.local_conv3d(x)
|
||||
else:
|
||||
h = x.size(3)
|
||||
split_size = int(h // 2**self.halving)
|
||||
lcl_feat = x.split(split_size, 3)
|
||||
lcl_feat = torch.cat([self.local_conv3d(_) for _ in lcl_feat], 3)
|
||||
|
||||
if not self.fm_sign:
|
||||
feat = F.leaky_relu(gob_feat) + F.leaky_relu(lcl_feat)
|
||||
else:
|
||||
feat = F.leaky_relu(torch.cat([gob_feat, lcl_feat], dim=3))
|
||||
return feat
|
||||
|
||||
|
||||
class GeMHPP(nn.Module):
|
||||
def __init__(self, bin_num=[64], p=6.5, eps=1.0e-6):
|
||||
super(GeMHPP, self).__init__()
|
||||
self.bin_num = bin_num
|
||||
self.p = nn.Parameter(
|
||||
torch.ones(1)*p)
|
||||
self.eps = eps
|
||||
|
||||
def gem(self, ipts):
|
||||
return F.avg_pool2d(ipts.clamp(min=self.eps).pow(self.p), (1, ipts.size(-1))).pow(1. / self.p)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x : [n, c, h, w]
|
||||
ret: [n, c, p]
|
||||
"""
|
||||
n, c = x.size()[:2]
|
||||
features = []
|
||||
for b in self.bin_num:
|
||||
z = x.view(n, c, b, -1)
|
||||
z = self.gem(z).squeeze(-1)
|
||||
features.append(z)
|
||||
return torch.cat(features, -1)
|
||||
|
||||
|
||||
class GaitGL(BaseModel):
|
||||
"""
|
||||
GaitGL: Gait Recognition via Effective Global-Local Feature Representation and Local Temporal Aggregation
|
||||
Arxiv : https://arxiv.org/pdf/2011.01461.pdf
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kargs):
|
||||
super(GaitGL, self).__init__(*args, **kargs)
|
||||
|
||||
def build_network(self, model_cfg):
|
||||
in_c = model_cfg['channels']
|
||||
class_num = model_cfg['class_num']
|
||||
|
||||
# For CASIA-B
|
||||
self.conv3d = nn.Sequential(
|
||||
BasicConv3d(1, in_c[0], kernel_size=(3, 3, 3),
|
||||
stride=(1, 1, 1), padding=(1, 1, 1)),
|
||||
nn.LeakyReLU(inplace=True)
|
||||
)
|
||||
self.LTA = nn.Sequential(
|
||||
BasicConv3d(in_c[0], in_c[0], kernel_size=(
|
||||
3, 1, 1), stride=(3, 1, 1), padding=(0, 0, 0)),
|
||||
nn.LeakyReLU(inplace=True)
|
||||
)
|
||||
|
||||
self.GLConvA0 = GLConv(in_c[0], in_c[1], halving=3, fm_sign=False, kernel_size=(
|
||||
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
|
||||
self.MaxPool0 = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))
|
||||
|
||||
self.GLConvA1 = GLConv(in_c[1], in_c[2], halving=3, fm_sign=False, kernel_size=(
|
||||
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
|
||||
self.GLConvB2 = GLConv(in_c[2], in_c[2], halving=3, fm_sign=True, kernel_size=(
|
||||
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
|
||||
|
||||
self.Head0 = SeparateFCs(64, in_c[2], in_c[2])
|
||||
self.Bn = nn.BatchNorm1d(in_c[2])
|
||||
self.Head1 = SeparateFCs(64, in_c[2], class_num)
|
||||
|
||||
self.TP = PackSequenceWrapper(torch.max)
|
||||
self.HPP = GeMHPP()
|
||||
|
||||
def forward(self, inputs):
|
||||
ipts, labs, _, _, seqL = inputs
|
||||
seqL = None if not self.training else seqL
|
||||
|
||||
sils = ipts[0].unsqueeze(1)
|
||||
del ipts
|
||||
n, _, s, h, w = sils.size()
|
||||
if s < 3:
|
||||
repeat = 3 if s == 1 else 2
|
||||
sils = sils.repeat(1, 1, repeat, 1, 1)
|
||||
|
||||
outs = self.conv3d(sils)
|
||||
outs = self.LTA(outs)
|
||||
|
||||
outs = self.GLConvA0(outs)
|
||||
outs = self.MaxPool0(outs)
|
||||
|
||||
outs = self.GLConvA1(outs)
|
||||
outs = self.GLConvB2(outs) # [n, c, s, h, w]
|
||||
|
||||
outs = self.TP(outs, dim=2, seq_dim=2, seqL=seqL)[0] # [n, c, h, w]
|
||||
outs = self.HPP(outs) # [n, c, p]
|
||||
outs = outs.permute(2, 0, 1).contiguous() # [p, n, c]
|
||||
|
||||
gait = self.Head0(outs) # [p, n, c]
|
||||
gait = gait.permute(1, 2, 0).contiguous() # [n, c, p]
|
||||
bnft = self.Bn(gait) # [n, c, p]
|
||||
logi = self.Head1(bnft.permute(2, 0, 1).contiguous()) # [p, n, c]
|
||||
|
||||
gait = gait.permute(0, 2, 1).contiguous() # [n, p, c]
|
||||
bnft = bnft.permute(0, 2, 1).contiguous() # [n, p, c]
|
||||
logi = logi.permute(1, 0, 2).contiguous() # [n, p, c]
|
||||
|
||||
n, _, s, h, w = sils.size()
|
||||
retval = {
|
||||
'training_feat': {
|
||||
'triplet': {'embeddings': bnft, 'labels': labs},
|
||||
'softmax': {'logits': logi, 'labels': labs}
|
||||
},
|
||||
'visual_summary': {
|
||||
'image/sils': sils.view(n*s, 1, h, w)
|
||||
},
|
||||
'inference_feat': {
|
||||
'embeddings': bnft
|
||||
}
|
||||
}
|
||||
return retval
|
||||
@@ -0,0 +1,127 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ..base_model import BaseModel
|
||||
from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper, SeparateFCs
|
||||
from utils import clones
|
||||
|
||||
|
||||
class BasicConv1d(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
|
||||
super(BasicConv1d, self).__init__()
|
||||
self.conv = nn.Conv1d(in_channels, out_channels,
|
||||
kernel_size, bias=False, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
ret = self.conv(x)
|
||||
return ret
|
||||
|
||||
|
||||
class TemporalFeatureAggregator(nn.Module):
|
||||
def __init__(self, in_channels, squeeze=4, parts_num=16):
|
||||
super(TemporalFeatureAggregator, self).__init__()
|
||||
hidden_dim = int(in_channels // squeeze)
|
||||
self.parts_num = parts_num
|
||||
|
||||
# MTB1
|
||||
conv3x1 = nn.Sequential(
|
||||
BasicConv1d(in_channels, hidden_dim, 3, padding=1),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
BasicConv1d(hidden_dim, in_channels, 1))
|
||||
self.conv1d3x1 = clones(conv3x1, parts_num)
|
||||
self.avg_pool3x1 = nn.AvgPool1d(3, stride=1, padding=1)
|
||||
self.max_pool3x1 = nn.MaxPool1d(3, stride=1, padding=1)
|
||||
|
||||
# MTB1
|
||||
conv3x3 = nn.Sequential(
|
||||
BasicConv1d(in_channels, hidden_dim, 3, padding=1),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
BasicConv1d(hidden_dim, in_channels, 3, padding=1))
|
||||
self.conv1d3x3 = clones(conv3x3, parts_num)
|
||||
self.avg_pool3x3 = nn.AvgPool1d(5, stride=1, padding=2)
|
||||
self.max_pool3x3 = nn.MaxPool1d(5, stride=1, padding=2)
|
||||
|
||||
# Temporal Pooling, TP
|
||||
self.TP = torch.max
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Input: x, [n, s, c, p]
|
||||
Output: ret, [n, p, c]
|
||||
"""
|
||||
n, s, c, p = x.size()
|
||||
x = x.permute(3, 0, 2, 1).contiguous() # [p, n, c, s]
|
||||
feature = x.split(1, 0) # [[n, c, s], ...]
|
||||
x = x.view(-1, c, s)
|
||||
|
||||
# MTB1: ConvNet1d & Sigmoid
|
||||
logits3x1 = torch.cat([conv(_.squeeze(0)).unsqueeze(0)
|
||||
for conv, _ in zip(self.conv1d3x1, feature)], 0)
|
||||
scores3x1 = torch.sigmoid(logits3x1)
|
||||
# MTB1: Template Function
|
||||
feature3x1 = self.avg_pool3x1(x) + self.max_pool3x1(x)
|
||||
feature3x1 = feature3x1.view(p, n, c, s)
|
||||
feature3x1 = feature3x1 * scores3x1
|
||||
|
||||
# MTB2: ConvNet1d & Sigmoid
|
||||
logits3x3 = torch.cat([conv(_.squeeze(0)).unsqueeze(0)
|
||||
for conv, _ in zip(self.conv1d3x3, feature)], 0)
|
||||
scores3x3 = torch.sigmoid(logits3x3)
|
||||
# MTB2: Template Function
|
||||
feature3x3 = self.avg_pool3x3(x) + self.max_pool3x3(x)
|
||||
feature3x3 = feature3x3.view(p, n, c, s)
|
||||
feature3x3 = feature3x3 * scores3x3
|
||||
|
||||
# Temporal Pooling
|
||||
ret = self.TP(feature3x1 + feature3x3, dim=-1)[0] # [p, n, c]
|
||||
ret = ret.permute(1, 0, 2).contiguous() # [n, p, c]
|
||||
return ret
|
||||
|
||||
|
||||
class GaitPart(BaseModel):
|
||||
def __init__(self, *args, **kargs):
|
||||
super(GaitPart, self).__init__(*args, **kargs)
|
||||
"""
|
||||
GaitPart: Temporal Part-based Model for Gait Recognition
|
||||
Paper: https://openaccess.thecvf.com/content_CVPR_2020/papers/Fan_GaitPart_Temporal_Part-Based_Model_for_Gait_Recognition_CVPR_2020_paper.pdf
|
||||
Github: https://github.com/ChaoFan96/GaitPart
|
||||
"""
|
||||
|
||||
def build_network(self, model_cfg):
|
||||
|
||||
self.Backbone = self.get_backbone(model_cfg)
|
||||
head_cfg = model_cfg['SeparateFCs']
|
||||
self.Head = SeparateFCs(**model_cfg['SeparateFCs'])
|
||||
self.Backbone = SetBlockWrapper(self.Backbone)
|
||||
self.HPP = SetBlockWrapper(
|
||||
HorizontalPoolingPyramid(bin_num=model_cfg['bin_num']))
|
||||
self.TFA = PackSequenceWrapper(TemporalFeatureAggregator(
|
||||
in_channels=head_cfg['in_channels'], parts_num=head_cfg['parts_num']))
|
||||
|
||||
def forward(self, inputs):
|
||||
ipts, labs, _, _, seqL = inputs
|
||||
|
||||
sils = ipts[0]
|
||||
if len(sils.size()) == 4:
|
||||
sils = sils.unsqueeze(2)
|
||||
|
||||
del ipts
|
||||
out = self.Backbone(sils) # [n, s, c, h, w]
|
||||
out = self.HPP(out) # [n, s, c, p]
|
||||
out = self.TFA(out, seqL) # [n, p, c]
|
||||
|
||||
embs = self.Head(out.permute(1, 0, 2).contiguous()) # [p, n, c]
|
||||
embs = embs.permute(1, 0, 2).contiguous() # [n, p, c]
|
||||
|
||||
n, s, _, h, w = sils.size()
|
||||
retval = {
|
||||
'training_feat': {
|
||||
'triplet': {'embeddings': embs, 'labels': labs}
|
||||
},
|
||||
'visual_summary': {
|
||||
'image/sils': sils.view(n*s, 1, h, w)
|
||||
},
|
||||
'inference_feat': {
|
||||
'embeddings': embs
|
||||
}
|
||||
}
|
||||
return retval
|
||||
@@ -0,0 +1,87 @@
|
||||
import torch
|
||||
import copy
|
||||
import torch.nn as nn
|
||||
|
||||
from ..base_model import BaseModel
|
||||
from ..modules import SeparateFCs, BasicConv2d, SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper
|
||||
|
||||
|
||||
class GaitSet(BaseModel):
|
||||
"""
|
||||
GaitSet: Regarding Gait as a Set for Cross-View Gait Recognition
|
||||
Arxiv: https://arxiv.org/abs/1811.06186
|
||||
Github: https://github.com/AbnerHqC/GaitSet
|
||||
"""
|
||||
|
||||
def build_network(self, model_cfg):
|
||||
in_c = model_cfg['in_channels']
|
||||
self.set_block1 = nn.Sequential(BasicConv2d(in_c[0], in_c[1], 5, 1, 2),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
BasicConv2d(in_c[1], in_c[1], 3, 1, 1),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2))
|
||||
|
||||
self.set_block2 = nn.Sequential(BasicConv2d(in_c[1], in_c[2], 3, 1, 1),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
BasicConv2d(in_c[2], in_c[2], 3, 1, 1),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2))
|
||||
|
||||
self.set_block3 = nn.Sequential(BasicConv2d(in_c[2], in_c[3], 3, 1, 1),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
BasicConv2d(in_c[3], in_c[3], 3, 1, 1),
|
||||
nn.LeakyReLU(inplace=True))
|
||||
|
||||
self.gl_block2 = copy.deepcopy(self.set_block2)
|
||||
self.gl_block3 = copy.deepcopy(self.set_block3)
|
||||
|
||||
self.set_block1 = SetBlockWrapper(self.set_block1)
|
||||
self.set_block2 = SetBlockWrapper(self.set_block2)
|
||||
self.set_block3 = SetBlockWrapper(self.set_block3)
|
||||
|
||||
self.set_pooling = PackSequenceWrapper(torch.max)
|
||||
|
||||
self.Head = SeparateFCs(**model_cfg['SeparateFCs'])
|
||||
|
||||
self.HPP = HorizontalPoolingPyramid(bin_num=model_cfg['bin_num'])
|
||||
|
||||
def forward(self, inputs):
|
||||
ipts, labs, _, _, seqL = inputs
|
||||
sils = ipts[0] # [n, s, h, w]
|
||||
if len(sils.size()) == 4:
|
||||
sils = sils.unsqueeze(2)
|
||||
|
||||
del ipts
|
||||
outs = self.set_block1(sils)
|
||||
gl = self.set_pooling(outs, seqL, dim=1)[0]
|
||||
gl = self.gl_block2(gl)
|
||||
|
||||
outs = self.set_block2(outs)
|
||||
gl = gl + self.set_pooling(outs, seqL, dim=1)[0]
|
||||
gl = self.gl_block3(gl)
|
||||
|
||||
outs = self.set_block3(outs)
|
||||
outs = self.set_pooling(outs, seqL, dim=1)[0]
|
||||
gl = gl + outs
|
||||
|
||||
# Horizontal Pooling Matching, HPM
|
||||
feature1 = self.HPP(outs) # [n, c, p]
|
||||
feature2 = self.HPP(gl) # [n, c, p]
|
||||
feature = torch.cat([feature1, feature2], -1) # [n, c, p]
|
||||
feature = feature.permute(2, 0, 1).contiguous() # [p, n, c]
|
||||
embs = self.Head(feature)
|
||||
embs = embs.permute(1, 0, 2).contiguous() # [n, p, c]
|
||||
|
||||
n, s, _, h, w = sils.size()
|
||||
retval = {
|
||||
'training_feat': {
|
||||
'triplet': {'embeddings': embs, 'labels': labs}
|
||||
},
|
||||
'visual_summary': {
|
||||
'image/sils': sils.view(n*s, 1, h, w)
|
||||
},
|
||||
'inference_feat': {
|
||||
'embeddings': embs
|
||||
}
|
||||
}
|
||||
return retval
|
||||
@@ -0,0 +1,172 @@
|
||||
import torch
|
||||
import copy
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..base_model import BaseModel
|
||||
from ..modules import SeparateFCs, BasicConv2d, SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper
|
||||
|
||||
|
||||
class GLN(BaseModel):
|
||||
"""
|
||||
http://home.ustc.edu.cn/~saihui/papers/eccv2020_gln.pdf
|
||||
Gait Lateral Network: Learning Discriminative and Compact Representations for Gait Recognition
|
||||
"""
|
||||
|
||||
def build_network(self, model_cfg):
|
||||
in_channels = model_cfg['in_channels']
|
||||
self.bin_num = model_cfg['bin_num']
|
||||
self.hidden_dim = model_cfg['hidden_dim']
|
||||
lateral_dim = model_cfg['lateral_dim']
|
||||
reduce_dim = self.hidden_dim
|
||||
self.pretrain = model_cfg['Lateral_pretraining']
|
||||
|
||||
self.sil_stage_0 = nn.Sequential(BasicConv2d(in_channels[0], in_channels[1], 5, 1, 2),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
BasicConv2d(
|
||||
in_channels[1], in_channels[1], 3, 1, 1),
|
||||
nn.LeakyReLU(inplace=True))
|
||||
|
||||
self.sil_stage_1 = nn.Sequential(BasicConv2d(in_channels[1], in_channels[2], 3, 1, 1),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
BasicConv2d(
|
||||
in_channels[2], in_channels[2], 3, 1, 1),
|
||||
nn.LeakyReLU(inplace=True))
|
||||
|
||||
self.sil_stage_2 = nn.Sequential(BasicConv2d(in_channels[2], in_channels[3], 3, 1, 1),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
BasicConv2d(
|
||||
in_channels[3], in_channels[3], 3, 1, 1),
|
||||
nn.LeakyReLU(inplace=True))
|
||||
|
||||
self.set_stage_1 = copy.deepcopy(self.sil_stage_1)
|
||||
self.set_stage_2 = copy.deepcopy(self.sil_stage_2)
|
||||
|
||||
self.set_pooling = PackSequenceWrapper(torch.max)
|
||||
|
||||
self.MaxP_sil = SetBlockWrapper(nn.MaxPool2d(kernel_size=2, stride=2))
|
||||
self.MaxP_set = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
|
||||
self.sil_stage_0 = SetBlockWrapper(self.sil_stage_0)
|
||||
self.sil_stage_1 = SetBlockWrapper(self.sil_stage_1)
|
||||
self.sil_stage_2 = SetBlockWrapper(self.sil_stage_2)
|
||||
|
||||
self.lateral_layer1 = nn.Conv2d(
|
||||
in_channels[1]*2, lateral_dim, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.lateral_layer2 = nn.Conv2d(
|
||||
in_channels[2]*2, lateral_dim, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.lateral_layer3 = nn.Conv2d(
|
||||
in_channels[3]*2, lateral_dim, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
|
||||
self.smooth_layer1 = nn.Conv2d(
|
||||
lateral_dim, lateral_dim, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.smooth_layer2 = nn.Conv2d(
|
||||
lateral_dim, lateral_dim, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.smooth_layer3 = nn.Conv2d(
|
||||
lateral_dim, lateral_dim, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
|
||||
self.HPP = HorizontalPoolingPyramid()
|
||||
self.Head = SeparateFCs(**model_cfg['SeparateFCs'])
|
||||
|
||||
if not self.pretrain:
|
||||
self.encoder_bn = nn.BatchNorm1d(sum(self.bin_num)*3*self.hidden_dim)
|
||||
self.encoder_bn.bias.requires_grad_(False)
|
||||
|
||||
self.reduce_dp = nn.Dropout(p=model_cfg['dropout'])
|
||||
self.reduce_ac = nn.ReLU(inplace=True)
|
||||
self.reduce_fc = nn.Linear(sum(self.bin_num)*3*self.hidden_dim, reduce_dim, bias=False)
|
||||
|
||||
self.reduce_bn = nn.BatchNorm1d(reduce_dim)
|
||||
self.reduce_bn.bias.requires_grad_(False)
|
||||
|
||||
self.reduce_cls = nn.Linear(reduce_dim, model_cfg['class_num'], bias=False)
|
||||
|
||||
def upsample_add(self, x, y):
|
||||
return F.interpolate(x, scale_factor=2, mode='nearest') + y
|
||||
|
||||
def forward(self, inputs):
|
||||
ipts, labs, _, _, seqL = inputs
|
||||
seqL = None if not self.training else seqL
|
||||
sils = ipts[0] # [n, s, h, w]
|
||||
del ipts
|
||||
if len(sils.size()) == 4:
|
||||
sils = sils.unsqueeze(2)
|
||||
n, s, _, h, w = sils.size()
|
||||
|
||||
### stage 0 sil ###
|
||||
sil_0_outs = self.sil_stage_0(sils)
|
||||
stage_0_sil_set = self.set_pooling(sil_0_outs, seqL, dim=1)[0]
|
||||
|
||||
### stage 1 sil ###
|
||||
sil_1_ipts = self.MaxP_sil(sil_0_outs)
|
||||
sil_1_outs = self.sil_stage_1(sil_1_ipts)
|
||||
|
||||
### stage 2 sil ###
|
||||
sil_2_ipts = self.MaxP_sil(sil_1_outs)
|
||||
sil_2_outs = self.sil_stage_2(sil_2_ipts)
|
||||
|
||||
### stage 1 set ###
|
||||
set_1_ipts = self.set_pooling(sil_1_ipts, seqL, dim=1)[0]
|
||||
stage_1_sil_set = self.set_pooling(sil_1_outs, seqL, dim=1)[0]
|
||||
set_1_outs = self.set_stage_1(set_1_ipts) + stage_1_sil_set
|
||||
|
||||
### stage 2 set ###
|
||||
set_2_ipts = self.MaxP_set(set_1_outs)
|
||||
stage_2_sil_set = self.set_pooling(sil_2_outs, seqL, dim=1)[0]
|
||||
set_2_outs = self.set_stage_2(set_2_ipts) + stage_2_sil_set
|
||||
|
||||
set1 = torch.cat((stage_0_sil_set, stage_0_sil_set), dim=1)
|
||||
set2 = torch.cat((stage_1_sil_set, set_1_outs), dim=1)
|
||||
set3 = torch.cat((stage_2_sil_set, set_2_outs), dim=1)
|
||||
|
||||
# print(set1.shape,set2.shape,set3.shape,"***\n")
|
||||
|
||||
# lateral
|
||||
set3 = self.lateral_layer3(set3)
|
||||
set2 = self.upsample_add(set3, self.lateral_layer2(set2))
|
||||
set1 = self.upsample_add(set2, self.lateral_layer1(set1))
|
||||
|
||||
set3 = self.smooth_layer3(set3)
|
||||
set2 = self.smooth_layer2(set2)
|
||||
set1 = self.smooth_layer1(set1)
|
||||
|
||||
set1 = self.HPP(set1)
|
||||
set2 = self.HPP(set2)
|
||||
set3 = self.HPP(set3)
|
||||
|
||||
feature = torch.cat([set1, set2, set3], -
|
||||
1).permute(2, 0, 1).contiguous()
|
||||
|
||||
feature = self.Head(feature)
|
||||
feature = feature.permute(1, 0, 2).contiguous() # n p c
|
||||
|
||||
# compact_bloack
|
||||
if not self.pretrain:
|
||||
bn_feature = self.encoder_bn(feature.view(n, -1))
|
||||
bn_feature = bn_feature.view(*feature.shape).contiguous()
|
||||
|
||||
reduce_feature = self.reduce_dp(bn_feature)
|
||||
reduce_feature = self.reduce_ac(reduce_feature)
|
||||
reduce_feature = self.reduce_fc(reduce_feature.view(n, -1))
|
||||
|
||||
bn_reduce_feature = self.reduce_bn(reduce_feature)
|
||||
logits = self.reduce_cls(bn_reduce_feature).unsqueeze(1) # n c
|
||||
|
||||
reduce_feature = reduce_feature.unsqueeze(1).contiguous()
|
||||
bn_reduce_feature = bn_reduce_feature.unsqueeze(1).contiguous()
|
||||
|
||||
retval = {
|
||||
'training_feat': {},
|
||||
'visual_summary': {
|
||||
'image/sils': sils.view(n*s, 1, h, w)
|
||||
},
|
||||
'inference_feat': {
|
||||
'embeddings': feature # reduce_feature # bn_reduce_feature
|
||||
}
|
||||
}
|
||||
if self.pretrain:
|
||||
retval['training_feat']['triplet'] = {'embeddings': feature, 'labels': labs}
|
||||
else:
|
||||
retval['training_feat']['triplet'] = {'embeddings': feature, 'labels': labs}
|
||||
retval['training_feat']['softmax'] = {'logits': logits, 'labels': labs}
|
||||
return retval
|
||||
Reference in New Issue
Block a user