Files
Dongyang Jin 2c29afadf3 Support skeleton (#155)
* pose

* pose

* pose

* pose

* 你的提交消息

* pose

* pose

* Delete train1.sh

* pretreatment

* configs

* pose

* reference

* Update gaittr.py

* naming

* naming

* Update transform.py

* update for datasets

* update README

* update name and README

* update

* Update transform.py
2023-09-27 16:20:00 +08:00

187 lines
6.2 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
from ..base_model import BaseModel
from ..modules import Graph, SpatialAttention
import numpy as np
import math
class Mish(nn.Module):
def __init__(self):
super().__init__()
def forward(self,x):
return x * (torch.tanh(F.softplus(x)))
class STModule(nn.Module):
def __init__(self,in_channels, out_channels, incidence, num_point):
super(STModule, self).__init__()
"""
This class implements augmented graph spatial convolution in case of Spatial Transformer
Fucntion adapated from: https://github.com/Chiaraplizz/ST-TR/blob/master/code/st_gcn/net/gcn_attention.py
"""
self.in_channels = in_channels
self.out_channels = out_channels
self.incidence = incidence
self.num_point = num_point
self.relu = Mish()
self.bn = nn.BatchNorm2d(out_channels)
self.data_bn = nn.BatchNorm1d(self.in_channels * self.num_point)
self.attention_conv = SpatialAttention(in_channels=in_channels,out_channel=out_channels,A=self.incidence,num_point=self.num_point)
def forward(self,x):
N, C, T, V = x.size()
# data normlization
x = x.permute(0, 1, 3, 2).reshape(N, C * V, T)
x = self.data_bn(x)
x = x.reshape(N, C, V, T).permute(0, 1, 3, 2)
# adjacency matrix
self.incidence = self.incidence.cuda(x.get_device())
# N, T, C, V > NT, C, 1, V
xa = x.permute(0, 2, 1, 3).reshape(-1, C, 1, V)
# spatial attention
attn_out = self.attention_conv(xa)
# N, T, C, V > N, C, T, V
attn_out = attn_out.reshape(N, T, -1, V).permute(0, 2, 1, 3)
y = attn_out
y = self.bn(self.relu(y))
return y
class UnitConv2D(nn.Module):
'''
This class is used in GaitTR[TCN_ST] block.
'''
def __init__(self, D_in, D_out, kernel_size=9, stride=1, dropout=0.1, bias=True):
super(UnitConv2D,self).__init__()
pad = int((kernel_size-1)/2)
self.conv = nn.Conv2d(D_in,D_out,kernel_size=(kernel_size,1)
,padding=(pad,0),stride=(stride,1),bias=bias)
self.bn = nn.BatchNorm2d(D_out)
self.relu = Mish()
self.dropout = nn.Dropout(dropout, inplace=False)
#initalize
self.conv_init(self.conv)
def forward(self,x):
x = self.dropout(x)
x = self.bn(self.relu(self.conv(x)))
return x
def conv_init(self,module):
n = module.out_channels
for k in module.kernel_size:
n = n*k
module.weight.data.normal_(0, math.sqrt(2. / n))
class TCN_ST(nn.Module):
"""
Block of GaitTR: https://arxiv.org/pdf/2204.03873.pdf
TCN: Temporal Convolution Network
ST: Sptail Temporal Graph Convolution Network
"""
def __init__(self,in_channel,out_channel,A,num_point):
super(TCN_ST, self).__init__()
#params
self.in_channel = in_channel
self.out_channel = out_channel
self.A = A
self.num_point = num_point
#network
self.tcn = UnitConv2D(D_in=self.in_channel,D_out=self.in_channel,kernel_size=9)
self.st = STModule(in_channels=self.in_channel,out_channels=self.out_channel,incidence=self.A,num_point=self.num_point)
self.residual = lambda x: x
if (in_channel != out_channel):
self.residual_s = nn.Sequential(
nn.Conv2d(in_channel, out_channel, 1),
nn.BatchNorm2d(out_channel),
)
self.down = UnitConv2D(D_in=self.in_channel,D_out=out_channel,kernel_size=1,dropout=0)
else:
self.residual_s = lambda x: x
self.down = None
def forward(self,x):
x0 = self.tcn(x) + self.residual(x)
y = self.st(x0) + self.residual_s(x0)
# skip residual
y = y + (x if(self.down is None) else self.down(x))
return y
class GaitTR(BaseModel):
"""
GaitTR: Spatial Transformer Network on Skeleton-based Gait Recognition
Arxiv : https://arxiv.org/abs/2204.03873.pdf
"""
def build_network(self, model_cfg):
in_c = model_cfg['in_channels']
self.num_class = model_cfg['num_class']
self.joint_format = model_cfg['joint_format']
self.graph = Graph(joint_format=self.joint_format,max_hop=3)
#### Network Define ####
# ajaceny matrix
self.A = torch.from_numpy(self.graph.A.astype(np.float32))
#data normalization
num_point = self.A.shape[-1]
self.data_bn = nn.BatchNorm1d(in_c[0] * num_point)
#backbone
backbone = []
for i in range(len(in_c)-1):
backbone.append(TCN_ST(in_channel= in_c[i],out_channel= in_c[i+1],A=self.A,num_point=num_point))
self.backbone = nn.ModuleList(backbone)
self.fcn = nn.Conv1d(in_c[-1], self.num_class, kernel_size=1)
def forward(self, inputs):
ipts, labs, _, _, seqL = inputs
x= ipts[0]
pose = x
# x = N, T, C, V, M -> N, C, T, V, M
x = x.permute(0, 2, 1, 3, 4)
N, C, T, V, M = x.size()
if len(x.size()) == 4:
x = x.unsqueeze(1)
del ipts
x = x.permute(0, 4, 3, 1, 2).contiguous().view(N, M * V * C, T)
x = self.data_bn(x)
x = x.view(N, M, V, C, T).permute(0, 1, 3, 4, 2).contiguous().view(
N * M, C, T, V)
#backbone
for _,m in enumerate(self.backbone):
x = m(x)
# V pooling
x = F.avg_pool2d(x, kernel_size=(1,V))
# M pooling
c = x.size(1)
t = x.size(2)
x = x.view(N, M, c, t).mean(dim=1).view(N, c, t)#[n,c,t]
# T pooling
x = F.avg_pool1d(x, kernel_size=x.size()[2]) #[n,c]
# C fcn
x = self.fcn(x) #[n,c']
x = F.avg_pool1d(x, x.size()[2:]) # [n,c']
x = x.view(N, self.num_class) # n,c
embed = x.unsqueeze(-1) # n,c,1
retval = {
'training_feat': {
'triplet': {'embeddings': embed, 'labels': labs}
},
'visual_summary': {
'image/pose': pose.view(N*T, M, V, C)
},
'inference_feat': {
'embeddings': embed
}
}
return retval