Support skeleton (#155)
* pose * pose * pose * pose * 你的提交消息 * pose * pose * Delete train1.sh * pretreatment * configs * pose * reference * Update gaittr.py * naming * naming * Update transform.py * update for datasets * update README * update name and README * update * Update transform.py
This commit is contained in:
@@ -0,0 +1,186 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from ..base_model import BaseModel
|
||||
from ..modules import Graph, SpatialAttention
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
|
||||
class Mish(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def forward(self,x):
|
||||
return x * (torch.tanh(F.softplus(x)))
|
||||
|
||||
class STModule(nn.Module):
|
||||
def __init__(self,in_channels, out_channels, incidence, num_point):
|
||||
super(STModule, self).__init__()
|
||||
"""
|
||||
This class implements augmented graph spatial convolution in case of Spatial Transformer
|
||||
Fucntion adapated from: https://github.com/Chiaraplizz/ST-TR/blob/master/code/st_gcn/net/gcn_attention.py
|
||||
"""
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.incidence = incidence
|
||||
self.num_point = num_point
|
||||
self.relu = Mish()
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
self.data_bn = nn.BatchNorm1d(self.in_channels * self.num_point)
|
||||
self.attention_conv = SpatialAttention(in_channels=in_channels,out_channel=out_channels,A=self.incidence,num_point=self.num_point)
|
||||
def forward(self,x):
|
||||
N, C, T, V = x.size()
|
||||
# data normlization
|
||||
x = x.permute(0, 1, 3, 2).reshape(N, C * V, T)
|
||||
x = self.data_bn(x)
|
||||
x = x.reshape(N, C, V, T).permute(0, 1, 3, 2)
|
||||
# adjacency matrix
|
||||
self.incidence = self.incidence.cuda(x.get_device())
|
||||
# N, T, C, V > NT, C, 1, V
|
||||
xa = x.permute(0, 2, 1, 3).reshape(-1, C, 1, V)
|
||||
# spatial attention
|
||||
attn_out = self.attention_conv(xa)
|
||||
# N, T, C, V > N, C, T, V
|
||||
attn_out = attn_out.reshape(N, T, -1, V).permute(0, 2, 1, 3)
|
||||
y = attn_out
|
||||
y = self.bn(self.relu(y))
|
||||
return y
|
||||
|
||||
class UnitConv2D(nn.Module):
|
||||
'''
|
||||
This class is used in GaitTR[TCN_ST] block.
|
||||
'''
|
||||
|
||||
def __init__(self, D_in, D_out, kernel_size=9, stride=1, dropout=0.1, bias=True):
|
||||
super(UnitConv2D,self).__init__()
|
||||
pad = int((kernel_size-1)/2)
|
||||
self.conv = nn.Conv2d(D_in,D_out,kernel_size=(kernel_size,1)
|
||||
,padding=(pad,0),stride=(stride,1),bias=bias)
|
||||
self.bn = nn.BatchNorm2d(D_out)
|
||||
self.relu = Mish()
|
||||
self.dropout = nn.Dropout(dropout, inplace=False)
|
||||
#initalize
|
||||
self.conv_init(self.conv)
|
||||
|
||||
def forward(self,x):
|
||||
x = self.dropout(x)
|
||||
x = self.bn(self.relu(self.conv(x)))
|
||||
return x
|
||||
|
||||
def conv_init(self,module):
|
||||
n = module.out_channels
|
||||
for k in module.kernel_size:
|
||||
n = n*k
|
||||
module.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
|
||||
class TCN_ST(nn.Module):
|
||||
"""
|
||||
Block of GaitTR: https://arxiv.org/pdf/2204.03873.pdf
|
||||
TCN: Temporal Convolution Network
|
||||
ST: Sptail Temporal Graph Convolution Network
|
||||
"""
|
||||
def __init__(self,in_channel,out_channel,A,num_point):
|
||||
super(TCN_ST, self).__init__()
|
||||
#params
|
||||
self.in_channel = in_channel
|
||||
self.out_channel = out_channel
|
||||
self.A = A
|
||||
self.num_point = num_point
|
||||
#network
|
||||
self.tcn = UnitConv2D(D_in=self.in_channel,D_out=self.in_channel,kernel_size=9)
|
||||
self.st = STModule(in_channels=self.in_channel,out_channels=self.out_channel,incidence=self.A,num_point=self.num_point)
|
||||
self.residual = lambda x: x
|
||||
if (in_channel != out_channel):
|
||||
self.residual_s = nn.Sequential(
|
||||
nn.Conv2d(in_channel, out_channel, 1),
|
||||
nn.BatchNorm2d(out_channel),
|
||||
)
|
||||
self.down = UnitConv2D(D_in=self.in_channel,D_out=out_channel,kernel_size=1,dropout=0)
|
||||
else:
|
||||
self.residual_s = lambda x: x
|
||||
self.down = None
|
||||
|
||||
def forward(self,x):
|
||||
x0 = self.tcn(x) + self.residual(x)
|
||||
y = self.st(x0) + self.residual_s(x0)
|
||||
# skip residual
|
||||
y = y + (x if(self.down is None) else self.down(x))
|
||||
return y
|
||||
|
||||
|
||||
|
||||
class GaitTR(BaseModel):
|
||||
"""
|
||||
GaitTR: Spatial Transformer Network on Skeleton-based Gait Recognition
|
||||
Arxiv : https://arxiv.org/abs/2204.03873.pdf
|
||||
"""
|
||||
def build_network(self, model_cfg):
|
||||
|
||||
in_c = model_cfg['in_channels']
|
||||
self.num_class = model_cfg['num_class']
|
||||
self.joint_format = model_cfg['joint_format']
|
||||
self.graph = Graph(joint_format=self.joint_format,max_hop=3)
|
||||
|
||||
#### Network Define ####
|
||||
|
||||
# ajaceny matrix
|
||||
self.A = torch.from_numpy(self.graph.A.astype(np.float32))
|
||||
|
||||
#data normalization
|
||||
num_point = self.A.shape[-1]
|
||||
self.data_bn = nn.BatchNorm1d(in_c[0] * num_point)
|
||||
|
||||
#backbone
|
||||
backbone = []
|
||||
for i in range(len(in_c)-1):
|
||||
backbone.append(TCN_ST(in_channel= in_c[i],out_channel= in_c[i+1],A=self.A,num_point=num_point))
|
||||
self.backbone = nn.ModuleList(backbone)
|
||||
|
||||
self.fcn = nn.Conv1d(in_c[-1], self.num_class, kernel_size=1)
|
||||
|
||||
def forward(self, inputs):
|
||||
ipts, labs, _, _, seqL = inputs
|
||||
|
||||
x= ipts[0]
|
||||
pose = x
|
||||
# x = N, T, C, V, M -> N, C, T, V, M
|
||||
x = x.permute(0, 2, 1, 3, 4)
|
||||
N, C, T, V, M = x.size()
|
||||
if len(x.size()) == 4:
|
||||
x = x.unsqueeze(1)
|
||||
del ipts
|
||||
|
||||
x = x.permute(0, 4, 3, 1, 2).contiguous().view(N, M * V * C, T)
|
||||
|
||||
x = self.data_bn(x)
|
||||
x = x.view(N, M, V, C, T).permute(0, 1, 3, 4, 2).contiguous().view(
|
||||
N * M, C, T, V)
|
||||
#backbone
|
||||
for _,m in enumerate(self.backbone):
|
||||
x = m(x)
|
||||
# V pooling
|
||||
x = F.avg_pool2d(x, kernel_size=(1,V))
|
||||
# M pooling
|
||||
c = x.size(1)
|
||||
t = x.size(2)
|
||||
x = x.view(N, M, c, t).mean(dim=1).view(N, c, t)#[n,c,t]
|
||||
# T pooling
|
||||
x = F.avg_pool1d(x, kernel_size=x.size()[2]) #[n,c]
|
||||
# C fcn
|
||||
x = self.fcn(x) #[n,c']
|
||||
x = F.avg_pool1d(x, x.size()[2:]) # [n,c']
|
||||
x = x.view(N, self.num_class) # n,c
|
||||
embed = x.unsqueeze(-1) # n,c,1
|
||||
|
||||
retval = {
|
||||
'training_feat': {
|
||||
'triplet': {'embeddings': embed, 'labels': labs}
|
||||
},
|
||||
'visual_summary': {
|
||||
'image/pose': pose.view(N*T, M, V, C)
|
||||
},
|
||||
'inference_feat': {
|
||||
'embeddings': embed
|
||||
}
|
||||
}
|
||||
return retval
|
||||
Reference in New Issue
Block a user