OpenGait release(pre-beta version).
This commit is contained in:
@@ -0,0 +1,127 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ..base_model import BaseModel
|
||||
from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper, SeparateFCs
|
||||
from utils import clones
|
||||
|
||||
|
||||
class BasicConv1d(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
|
||||
super(BasicConv1d, self).__init__()
|
||||
self.conv = nn.Conv1d(in_channels, out_channels,
|
||||
kernel_size, bias=False, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
ret = self.conv(x)
|
||||
return ret
|
||||
|
||||
|
||||
class TemporalFeatureAggregator(nn.Module):
|
||||
def __init__(self, in_channels, squeeze=4, parts_num=16):
|
||||
super(TemporalFeatureAggregator, self).__init__()
|
||||
hidden_dim = int(in_channels // squeeze)
|
||||
self.parts_num = parts_num
|
||||
|
||||
# MTB1
|
||||
conv3x1 = nn.Sequential(
|
||||
BasicConv1d(in_channels, hidden_dim, 3, padding=1),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
BasicConv1d(hidden_dim, in_channels, 1))
|
||||
self.conv1d3x1 = clones(conv3x1, parts_num)
|
||||
self.avg_pool3x1 = nn.AvgPool1d(3, stride=1, padding=1)
|
||||
self.max_pool3x1 = nn.MaxPool1d(3, stride=1, padding=1)
|
||||
|
||||
# MTB1
|
||||
conv3x3 = nn.Sequential(
|
||||
BasicConv1d(in_channels, hidden_dim, 3, padding=1),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
BasicConv1d(hidden_dim, in_channels, 3, padding=1))
|
||||
self.conv1d3x3 = clones(conv3x3, parts_num)
|
||||
self.avg_pool3x3 = nn.AvgPool1d(5, stride=1, padding=2)
|
||||
self.max_pool3x3 = nn.MaxPool1d(5, stride=1, padding=2)
|
||||
|
||||
# Temporal Pooling, TP
|
||||
self.TP = torch.max
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Input: x, [n, s, c, p]
|
||||
Output: ret, [n, p, c]
|
||||
"""
|
||||
n, s, c, p = x.size()
|
||||
x = x.permute(3, 0, 2, 1).contiguous() # [p, n, c, s]
|
||||
feature = x.split(1, 0) # [[n, c, s], ...]
|
||||
x = x.view(-1, c, s)
|
||||
|
||||
# MTB1: ConvNet1d & Sigmoid
|
||||
logits3x1 = torch.cat([conv(_.squeeze(0)).unsqueeze(0)
|
||||
for conv, _ in zip(self.conv1d3x1, feature)], 0)
|
||||
scores3x1 = torch.sigmoid(logits3x1)
|
||||
# MTB1: Template Function
|
||||
feature3x1 = self.avg_pool3x1(x) + self.max_pool3x1(x)
|
||||
feature3x1 = feature3x1.view(p, n, c, s)
|
||||
feature3x1 = feature3x1 * scores3x1
|
||||
|
||||
# MTB2: ConvNet1d & Sigmoid
|
||||
logits3x3 = torch.cat([conv(_.squeeze(0)).unsqueeze(0)
|
||||
for conv, _ in zip(self.conv1d3x3, feature)], 0)
|
||||
scores3x3 = torch.sigmoid(logits3x3)
|
||||
# MTB2: Template Function
|
||||
feature3x3 = self.avg_pool3x3(x) + self.max_pool3x3(x)
|
||||
feature3x3 = feature3x3.view(p, n, c, s)
|
||||
feature3x3 = feature3x3 * scores3x3
|
||||
|
||||
# Temporal Pooling
|
||||
ret = self.TP(feature3x1 + feature3x3, dim=-1)[0] # [p, n, c]
|
||||
ret = ret.permute(1, 0, 2).contiguous() # [n, p, c]
|
||||
return ret
|
||||
|
||||
|
||||
class GaitPart(BaseModel):
|
||||
def __init__(self, *args, **kargs):
|
||||
super(GaitPart, self).__init__(*args, **kargs)
|
||||
"""
|
||||
GaitPart: Temporal Part-based Model for Gait Recognition
|
||||
Paper: https://openaccess.thecvf.com/content_CVPR_2020/papers/Fan_GaitPart_Temporal_Part-Based_Model_for_Gait_Recognition_CVPR_2020_paper.pdf
|
||||
Github: https://github.com/ChaoFan96/GaitPart
|
||||
"""
|
||||
|
||||
def build_network(self, model_cfg):
|
||||
|
||||
self.Backbone = self.get_backbone(model_cfg)
|
||||
head_cfg = model_cfg['SeparateFCs']
|
||||
self.Head = SeparateFCs(**model_cfg['SeparateFCs'])
|
||||
self.Backbone = SetBlockWrapper(self.Backbone)
|
||||
self.HPP = SetBlockWrapper(
|
||||
HorizontalPoolingPyramid(bin_num=model_cfg['bin_num']))
|
||||
self.TFA = PackSequenceWrapper(TemporalFeatureAggregator(
|
||||
in_channels=head_cfg['in_channels'], parts_num=head_cfg['parts_num']))
|
||||
|
||||
def forward(self, inputs):
|
||||
ipts, labs, _, _, seqL = inputs
|
||||
|
||||
sils = ipts[0]
|
||||
if len(sils.size()) == 4:
|
||||
sils = sils.unsqueeze(2)
|
||||
|
||||
del ipts
|
||||
out = self.Backbone(sils) # [n, s, c, h, w]
|
||||
out = self.HPP(out) # [n, s, c, p]
|
||||
out = self.TFA(out, seqL) # [n, p, c]
|
||||
|
||||
embs = self.Head(out.permute(1, 0, 2).contiguous()) # [p, n, c]
|
||||
embs = embs.permute(1, 0, 2).contiguous() # [n, p, c]
|
||||
|
||||
n, s, _, h, w = sils.size()
|
||||
retval = {
|
||||
'training_feat': {
|
||||
'triplet': {'embeddings': embs, 'labels': labs}
|
||||
},
|
||||
'visual_summary': {
|
||||
'image/sils': sils.view(n*s, 1, h, w)
|
||||
},
|
||||
'inference_feat': {
|
||||
'embeddings': embs
|
||||
}
|
||||
}
|
||||
return retval
|
||||
Reference in New Issue
Block a user