2c29afadf3
* pose * pose * pose * pose * 你的提交消息 * pose * pose * Delete train1.sh * pretreatment * configs * pose * reference * Update gaittr.py * naming * naming * Update transform.py * update for datasets * update README * update name and README * update * Update transform.py
187 lines
6.2 KiB
Python
187 lines
6.2 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from ..base_model import BaseModel
|
|
from ..modules import Graph, SpatialAttention
|
|
import numpy as np
|
|
import math
|
|
|
|
|
|
class Mish(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
def forward(self,x):
|
|
return x * (torch.tanh(F.softplus(x)))
|
|
|
|
class STModule(nn.Module):
|
|
def __init__(self,in_channels, out_channels, incidence, num_point):
|
|
super(STModule, self).__init__()
|
|
"""
|
|
This class implements augmented graph spatial convolution in case of Spatial Transformer
|
|
Fucntion adapated from: https://github.com/Chiaraplizz/ST-TR/blob/master/code/st_gcn/net/gcn_attention.py
|
|
"""
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.incidence = incidence
|
|
self.num_point = num_point
|
|
self.relu = Mish()
|
|
self.bn = nn.BatchNorm2d(out_channels)
|
|
self.data_bn = nn.BatchNorm1d(self.in_channels * self.num_point)
|
|
self.attention_conv = SpatialAttention(in_channels=in_channels,out_channel=out_channels,A=self.incidence,num_point=self.num_point)
|
|
def forward(self,x):
|
|
N, C, T, V = x.size()
|
|
# data normlization
|
|
x = x.permute(0, 1, 3, 2).reshape(N, C * V, T)
|
|
x = self.data_bn(x)
|
|
x = x.reshape(N, C, V, T).permute(0, 1, 3, 2)
|
|
# adjacency matrix
|
|
self.incidence = self.incidence.cuda(x.get_device())
|
|
# N, T, C, V > NT, C, 1, V
|
|
xa = x.permute(0, 2, 1, 3).reshape(-1, C, 1, V)
|
|
# spatial attention
|
|
attn_out = self.attention_conv(xa)
|
|
# N, T, C, V > N, C, T, V
|
|
attn_out = attn_out.reshape(N, T, -1, V).permute(0, 2, 1, 3)
|
|
y = attn_out
|
|
y = self.bn(self.relu(y))
|
|
return y
|
|
|
|
class UnitConv2D(nn.Module):
|
|
'''
|
|
This class is used in GaitTR[TCN_ST] block.
|
|
'''
|
|
|
|
def __init__(self, D_in, D_out, kernel_size=9, stride=1, dropout=0.1, bias=True):
|
|
super(UnitConv2D,self).__init__()
|
|
pad = int((kernel_size-1)/2)
|
|
self.conv = nn.Conv2d(D_in,D_out,kernel_size=(kernel_size,1)
|
|
,padding=(pad,0),stride=(stride,1),bias=bias)
|
|
self.bn = nn.BatchNorm2d(D_out)
|
|
self.relu = Mish()
|
|
self.dropout = nn.Dropout(dropout, inplace=False)
|
|
#initalize
|
|
self.conv_init(self.conv)
|
|
|
|
def forward(self,x):
|
|
x = self.dropout(x)
|
|
x = self.bn(self.relu(self.conv(x)))
|
|
return x
|
|
|
|
def conv_init(self,module):
|
|
n = module.out_channels
|
|
for k in module.kernel_size:
|
|
n = n*k
|
|
module.weight.data.normal_(0, math.sqrt(2. / n))
|
|
|
|
class TCN_ST(nn.Module):
|
|
"""
|
|
Block of GaitTR: https://arxiv.org/pdf/2204.03873.pdf
|
|
TCN: Temporal Convolution Network
|
|
ST: Sptail Temporal Graph Convolution Network
|
|
"""
|
|
def __init__(self,in_channel,out_channel,A,num_point):
|
|
super(TCN_ST, self).__init__()
|
|
#params
|
|
self.in_channel = in_channel
|
|
self.out_channel = out_channel
|
|
self.A = A
|
|
self.num_point = num_point
|
|
#network
|
|
self.tcn = UnitConv2D(D_in=self.in_channel,D_out=self.in_channel,kernel_size=9)
|
|
self.st = STModule(in_channels=self.in_channel,out_channels=self.out_channel,incidence=self.A,num_point=self.num_point)
|
|
self.residual = lambda x: x
|
|
if (in_channel != out_channel):
|
|
self.residual_s = nn.Sequential(
|
|
nn.Conv2d(in_channel, out_channel, 1),
|
|
nn.BatchNorm2d(out_channel),
|
|
)
|
|
self.down = UnitConv2D(D_in=self.in_channel,D_out=out_channel,kernel_size=1,dropout=0)
|
|
else:
|
|
self.residual_s = lambda x: x
|
|
self.down = None
|
|
|
|
def forward(self,x):
|
|
x0 = self.tcn(x) + self.residual(x)
|
|
y = self.st(x0) + self.residual_s(x0)
|
|
# skip residual
|
|
y = y + (x if(self.down is None) else self.down(x))
|
|
return y
|
|
|
|
|
|
|
|
class GaitTR(BaseModel):
|
|
"""
|
|
GaitTR: Spatial Transformer Network on Skeleton-based Gait Recognition
|
|
Arxiv : https://arxiv.org/abs/2204.03873.pdf
|
|
"""
|
|
def build_network(self, model_cfg):
|
|
|
|
in_c = model_cfg['in_channels']
|
|
self.num_class = model_cfg['num_class']
|
|
self.joint_format = model_cfg['joint_format']
|
|
self.graph = Graph(joint_format=self.joint_format,max_hop=3)
|
|
|
|
#### Network Define ####
|
|
|
|
# ajaceny matrix
|
|
self.A = torch.from_numpy(self.graph.A.astype(np.float32))
|
|
|
|
#data normalization
|
|
num_point = self.A.shape[-1]
|
|
self.data_bn = nn.BatchNorm1d(in_c[0] * num_point)
|
|
|
|
#backbone
|
|
backbone = []
|
|
for i in range(len(in_c)-1):
|
|
backbone.append(TCN_ST(in_channel= in_c[i],out_channel= in_c[i+1],A=self.A,num_point=num_point))
|
|
self.backbone = nn.ModuleList(backbone)
|
|
|
|
self.fcn = nn.Conv1d(in_c[-1], self.num_class, kernel_size=1)
|
|
|
|
def forward(self, inputs):
|
|
ipts, labs, _, _, seqL = inputs
|
|
|
|
x= ipts[0]
|
|
pose = x
|
|
# x = N, T, C, V, M -> N, C, T, V, M
|
|
x = x.permute(0, 2, 1, 3, 4)
|
|
N, C, T, V, M = x.size()
|
|
if len(x.size()) == 4:
|
|
x = x.unsqueeze(1)
|
|
del ipts
|
|
|
|
x = x.permute(0, 4, 3, 1, 2).contiguous().view(N, M * V * C, T)
|
|
|
|
x = self.data_bn(x)
|
|
x = x.view(N, M, V, C, T).permute(0, 1, 3, 4, 2).contiguous().view(
|
|
N * M, C, T, V)
|
|
#backbone
|
|
for _,m in enumerate(self.backbone):
|
|
x = m(x)
|
|
# V pooling
|
|
x = F.avg_pool2d(x, kernel_size=(1,V))
|
|
# M pooling
|
|
c = x.size(1)
|
|
t = x.size(2)
|
|
x = x.view(N, M, c, t).mean(dim=1).view(N, c, t)#[n,c,t]
|
|
# T pooling
|
|
x = F.avg_pool1d(x, kernel_size=x.size()[2]) #[n,c]
|
|
# C fcn
|
|
x = self.fcn(x) #[n,c']
|
|
x = F.avg_pool1d(x, x.size()[2:]) # [n,c']
|
|
x = x.view(N, self.num_class) # n,c
|
|
embed = x.unsqueeze(-1) # n,c,1
|
|
|
|
retval = {
|
|
'training_feat': {
|
|
'triplet': {'embeddings': embed, 'labels': labs}
|
|
},
|
|
'visual_summary': {
|
|
'image/pose': pose.view(N*T, M, V, C)
|
|
},
|
|
'inference_feat': {
|
|
'embeddings': embed
|
|
}
|
|
}
|
|
return retval
|