Support skeleton (#155)

* pose

* pose

* pose

* pose

* 你的提交消息

* pose

* pose

* Delete train1.sh

* pretreatment

* configs

* pose

* reference

* Update gaittr.py

* naming

* naming

* Update transform.py

* update for datasets

* update README

* update name and README

* update

* Update transform.py
This commit is contained in:
Dongyang Jin
2023-09-27 16:20:00 +08:00
committed by GitHub
parent 853bb1821d
commit 2c29afadf3
41 changed files with 4251 additions and 12 deletions
+135
View File
@@ -0,0 +1,135 @@
import torch
import torch.nn as nn
from ..modules import TemporalBasicBlock, TemporalBottleneckBlock, SpatialBasicBlock, SpatialBottleneckBlock
class ResGCNModule(nn.Module):
"""
ResGCNModule
Arxiv: https://arxiv.org/abs/2010.09978
Github: https://github.com/Thomas-yx/ResGCNv1
https://github.com/BNU-IVC/FastPoseGait
"""
def __init__(self, in_channels, out_channels, block, A, stride=1, kernel_size=[9,2],reduction=4, get_res=False,is_main=False):
super(ResGCNModule, self).__init__()
if not len(kernel_size) == 2:
logging.info('')
logging.error('Error: Please check whether len(kernel_size) == 2')
raise ValueError()
if not kernel_size[0] % 2 == 1:
logging.info('')
logging.error('Error: Please check whether kernel_size[0] % 2 == 1')
raise ValueError()
temporal_window_size, max_graph_distance = kernel_size
if block == 'initial':
module_res, block_res = False, False
elif block == 'Basic':
module_res, block_res = True, False
else:
module_res, block_res = False, True
if not module_res:
self.residual = lambda x: 0
elif stride == 1 and in_channels == out_channels:
self.residual = lambda x: x
else:
# stride =2
self.residual = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, (stride,1)),
nn.BatchNorm2d(out_channels),
)
if block in ['Basic','initial']:
spatial_block = SpatialBasicBlock
temporal_block = TemporalBasicBlock
if block == 'Bottleneck':
spatial_block = SpatialBottleneckBlock
temporal_block = TemporalBottleneckBlock
self.scn = spatial_block(in_channels, out_channels, max_graph_distance, block_res,reduction)
if in_channels == out_channels and is_main:
tcn_stride =True
else:
tcn_stride = False
self.tcn = temporal_block(out_channels, temporal_window_size, stride, block_res,reduction,get_res=get_res,tcn_stride=tcn_stride)
self.edge = nn.Parameter(torch.ones_like(A))
def forward(self, x, A):
A = A.cuda(x.get_device())
return self.tcn(self.scn(x, A*self.edge), self.residual(x))
class ResGCNInputBranch(nn.Module):
"""
ResGCNInputBranch_Module
Arxiv: https://arxiv.org/abs/2010.09978
Github: https://github.com/Thomas-yx/ResGCNv1
"""
def __init__(self, input_branch, block, A, input_num , reduction = 4):
super(ResGCNInputBranch, self).__init__()
self.register_buffer('A', A)
module_list = []
for i in range(len(input_branch)-1):
if i==0:
module_list.append(ResGCNModule(input_branch[i],input_branch[i+1],'initial',A, reduction=reduction))
else:
module_list.append(ResGCNModule(input_branch[i],input_branch[i+1],block,A,reduction=reduction))
self.bn = nn.BatchNorm2d(input_branch[0])
self.layers = nn.ModuleList(module_list)
def forward(self, x):
x = self.bn(x)
for layer in self.layers:
x = layer(x, self.A)
return x
class ResGCN(nn.Module):
"""
ResGCN
Arxiv: https://arxiv.org/abs/2010.09978
"""
def __init__(self, input_num, input_branch, main_stream,num_class, reduction, block, graph):
super(ResGCN, self).__init__()
self.graph = graph
self.head= nn.ModuleList(
ResGCNInputBranch(input_branch, block, graph, input_num ,reduction)
for _ in range(input_num)
)
main_stream_list = []
for i in range(len(main_stream)-1):
if main_stream[i]==main_stream[i+1]:
stride = 1
else:
stride = 2
if i ==0:
main_stream_list.append(ResGCNModule(main_stream[i]*input_num,main_stream[i+1],block,graph,stride=1,reduction = reduction,get_res=True,is_main=True))
else:
main_stream_list.append(ResGCNModule(main_stream[i],main_stream[i+1],block,graph,stride = stride, reduction = reduction,is_main=True))
self.backbone = nn.ModuleList(main_stream_list)
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.fcn = nn.Linear(256, num_class)
def forward(self, x):
# input branch
x_cat = []
for i, branch in enumerate(self.head):
x_cat.append(branch(x[:, i]))
x = torch.cat(x_cat, dim=1)
# main stream
for layer in self.backbone:
x = layer(x, self.graph)
# output
x = self.global_pooling(x)
x = x.squeeze(-1)
x = self.fcn(x.squeeze((-1)))
return x
+2 -2
View File
@@ -144,6 +144,7 @@ class BaseModel(MetaModel, nn.Module):
self.build_network(cfgs['model_cfg'])
self.init_parameters()
self.seq_trfs = get_transform(self.engine_cfg['transform'])
self.msg_mgr.log_info(cfgs['data_cfg'])
if training:
@@ -299,8 +300,7 @@ class BaseModel(MetaModel, nn.Module):
tuple: training data including inputs, labels, and some meta data.
"""
seqs_batch, labs_batch, typs_batch, vies_batch, seqL_batch = inputs
trf_cfgs = self.engine_cfg['transform']
seq_trfs = get_transform(trf_cfgs)
seq_trfs = self.seq_trfs
if len(seqs_batch) != len(seq_trfs):
raise ValueError(
"The number of types of input data and transform should be same. But got {} and {}".format(len(seqs_batch), len(seq_trfs)))
+107
View File
@@ -0,0 +1,107 @@
'''
Modifed fromhttps://github.com/BNU-IVC/FastPoseGait/blob/main/fastposegait/modeling/losses/supconloss.py
'''
import torch.nn as nn
import torch
from .base import BaseLoss, gather_and_scale_wrapper
class SupConLoss_Re(BaseLoss):
def __init__(self, temperature=0.01):
super(SupConLoss_Re, self).__init__()
self.train_loss = SupConLoss(temperature=temperature)
@gather_and_scale_wrapper
def forward(self, features, labels=None, mask=None):
loss = self.train_loss(features,labels)
self.info.update({
'loss': loss.detach().clone()})
return loss, self.info
class SupConLoss(nn.Module):
"""Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
It also supports the unsupervised contrastive loss in SimCLR"""
def __init__(self, temperature=0.01, contrast_mode='all',
base_temperature=0.07):
super(SupConLoss, self).__init__()
self.temperature = temperature
self.contrast_mode = contrast_mode
self.base_temperature = base_temperature
def forward(self, features, labels=None, mask=None):
"""Compute loss for model. If both `labels` and `mask` are None,
it degenerates to SimCLR unsupervised loss:
https://arxiv.org/pdf/2002.05709.pdf
Args:
features: hidden vector of shape [bsz, n_views, ...].
labels: ground truth of shape [bsz].
mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
has the same class as sample i. Can be asymmetric.
Returns:
A loss scalar.
"""
device = (torch.device('cuda')
if features.is_cuda
else torch.device('cpu'))
if len(features.shape) < 3:
raise ValueError('`features` needs to be [bsz, n_views, ...],'
'at least 3 dimensions are required')
if len(features.shape) > 3:
features = features.view(features.shape[0], features.shape[1], -1)
batch_size = features.shape[0]
if labels is not None and mask is not None:
raise ValueError('Cannot define both `labels` and `mask`')
elif labels is None and mask is None:
mask = torch.eye(batch_size, dtype=torch.float32).to(device)
elif labels is not None:
labels = labels.contiguous().view(-1, 1)
if labels.shape[0] != batch_size:
raise ValueError('Num of labels does not match num of features')
mask = torch.eq(labels, labels.T).float().to(device)
else:
mask = mask.float().to(device)
contrast_count = features.shape[1]
contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
if self.contrast_mode == 'one':
anchor_feature = features[:, 0]
anchor_count = 1
elif self.contrast_mode == 'all':
anchor_feature = contrast_feature
anchor_count = contrast_count
else:
raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
# compute logits
anchor_dot_contrast = torch.div(
torch.matmul(anchor_feature, contrast_feature.T),
self.temperature)
# for numerical stability
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()
# tile mask
mask = mask.repeat(anchor_count, contrast_count)
# mask-out self-contrast cases
logits_mask = torch.scatter(
torch.ones_like(mask),
1,
torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
0
)
mask = mask * logits_mask
# compute log_prob
exp_logits = torch.exp(logits) * logits_mask
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
# compute mean of log-likelihood over positive
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
# loss
loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
loss = loss.view(anchor_count, batch_size).mean()
return loss
+19
View File
@@ -0,0 +1,19 @@
'''
Modifed fromhttps://github.com/BNU-IVC/FastPoseGait/blob/main/fastposegait/modeling/losses/supconloss_Lp.py
'''
from .base import BaseLoss, gather_and_scale_wrapper
from pytorch_metric_learning import losses, distances
class SupConLoss_Lp(BaseLoss):
def __init__(self, temperature=0.01):
super(SupConLoss_Lp, self).__init__()
self.distance = distances.LpDistance()
self.train_loss = losses.SupConLoss(temperature=temperature, distance=self.distance)
@gather_and_scale_wrapper
def forward(self, features, labels=None, mask=None):
loss = self.train_loss(features,labels)
self.info.update({
'loss': loss.detach().clone()})
return loss, self.info
+75
View File
@@ -0,0 +1,75 @@
import torch
from ..base_model import BaseModel
from ..backbones.resgcn import ResGCN
from ..modules import Graph
import torch.nn.functional as F
class GaitGraph1(BaseModel):
"""
GaitGraph1: Gaitgraph: Graph Convolutional Network for Skeleton-Based Gait Recognition
Paper: https://ieeexplore.ieee.org/document/9506717
Github: https://github.com/tteepe/GaitGraph
"""
def build_network(self, model_cfg):
self.joint_format = model_cfg['joint_format']
self.input_num = model_cfg['input_num']
self.block = model_cfg['block']
self.input_branch = model_cfg['input_branch']
self.main_stream = model_cfg['main_stream']
self.num_class = model_cfg['num_class']
self.reduction = model_cfg['reduction']
self.tta = model_cfg['tta']
## Graph Init ##
self.graph = Graph(joint_format=self.joint_format,max_hop=3)
self.A = torch.tensor(self.graph.A, dtype=torch.float32, requires_grad=False)
## Network ##
self.ResGCN = ResGCN(input_num=self.input_num, input_branch=self.input_branch,
main_stream=self.main_stream, num_class=self.num_class,
reduction=self.reduction, block=self.block,graph=self.A)
def forward(self, inputs):
ipts, labs, type_, view_, seqL = inputs
x_input = ipts[0] # N T C V I
# x = N, T, C, V, M -> N, C, T, V, M
x_input = x_input.permute(0, 2, 3, 4, 1).contiguous()
N, T, V, I, C = x_input.size()
pose = x_input
if self.training:
x_input = torch.cat([x_input[:,:int(T/2),...],x_input[:,int(T/2):,...]],dim=0) #[8, 60, 17, 1, 3]
elif self.tta:
data_flipped = torch.flip(x_input,dims=[1])
x_input = torch.cat([x_input,data_flipped], dim=0)
x = x_input.permute(0, 3, 4, 1, 2).contiguous()
# resgcn
x = self.ResGCN(x)
x = F.normalize(x, dim=1, p=2) # norm #only for GaitGraph1 # Remove from GaitGraph2
if self.training:
f1, f2 = torch.split(x, [N, N], dim=0)
embed = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1) #[4, 2, 128]
elif self.tta:
f1, f2 = torch.split(x, [N, N], dim=0)
embed = torch.mean(torch.stack([f1, f2]), dim=0)
embed = embed.unsqueeze(-1)
else:
embed = embed.unsqueeze(-1)
retval = {
'training_feat': {
'SupConLoss': {'features': embed , 'labels': labs}, # loss
},
'visual_summary': {
'image/pose': pose.view(N*T, 1, I*V, C).contiguous() # visualization
},
'inference_feat': {
'embeddings': embed # for metric
}
}
return retval
+110
View File
@@ -0,0 +1,110 @@
import torch
import torch.nn as nn
from ..base_model import BaseModel
from ..backbones.resgcn import ResGCN
from ..modules import Graph
import numpy as np
class GaitGraph2(BaseModel):
"""
GaitGraph2: Towards a Deeper Understanding of Skeleton-based Gait Recognition
Paper: https://openaccess.thecvf.com/content/CVPR2022W/Biometrics/papers/Teepe_Towards_a_Deeper_Understanding_of_Skeleton-Based_Gait_Recognition_CVPRW_2022_paper
Github: https://github.com/tteepe/GaitGraph2
"""
def build_network(self, model_cfg):
self.joint_format = model_cfg['joint_format']
self.input_num = model_cfg['input_num']
self.block = model_cfg['block']
self.input_branch = model_cfg['input_branch']
self.main_stream = model_cfg['main_stream']
self.num_class = model_cfg['num_class']
self.reduction = model_cfg['reduction']
self.tta = model_cfg['tta']
## Graph Init ##
self.graph = Graph(joint_format=self.joint_format,max_hop=3)
self.A = torch.tensor(self.graph.A, dtype=torch.float32, requires_grad=False)
## Network ##
self.ResGCN = ResGCN(input_num=self.input_num, input_branch=self.input_branch,
main_stream=self.main_stream, num_class=self.num_class,
reduction=self.reduction, block=self.block,graph=self.A)
def forward(self, inputs):
ipts, labs, type_, view_, seqL = inputs
x_input = ipts[0]
N, T, V, I, C = x_input.size()
pose = x_input
flip_idx = self.graph.flip_idx
if not self.training and self.tta:
multi_input = MultiInput(self.graph.connect_joint, self.graph.center)
x1 = []
x2 = []
for i in range(N):
x1.append(multi_input(x_input[i,:,:,0,:3].flip(0)))
x2.append(multi_input(x_input[i,:,flip_idx,0,:3]))
x_input = torch.cat([x_input, torch.stack(x1,0), torch.stack(x2,0)], dim=0)
x = x_input.permute(0, 3, 4, 1, 2).contiguous()
# resgcn
x = self.ResGCN(x)
if not self.training and self.tta:
f1, f2, f3 = torch.split(x, [N, N, N], dim=0)
x = torch.cat((f1, f2, f3), dim=1)
embed = torch.unsqueeze(x,-1)
retval = {
'training_feat': {
'SupConLoss': {'features': x , 'labels': labs}, # loss
},
'visual_summary': {
'image/pose': pose.view(N*T, 1, I*V, C).contiguous() # visualization
},
'inference_feat': {
'embeddings': embed # for metric
}
}
return retval
class MultiInput:
def __init__(self, connect_joint, center):
self.connect_joint = connect_joint
self.center = center
def __call__(self, data):
# T, V, C -> T, V, I=3, C + 2
T, V, C = data.shape
x_new = torch.zeros((T, V, 3, C + 2), device=data.device)
# Joints
x = data
x_new[:, :, 0, :C] = x
for i in range(V):
x_new[:, i, 0, C:] = x[:, i, :2] - x[:, self.center, :2]
# Velocity
for i in range(T - 2):
x_new[i, :, 1, :2] = x[i + 1, :, :2] - x[i, :, :2]
x_new[i, :, 1, 3:] = x[i + 2, :, :2] - x[i, :, :2]
x_new[:, :, 1, 3] = x[:, :, 2]
# Bones
for i in range(V):
x_new[:, i, 2, :2] = x[:, i, :2] - x[:, self.connect_joint[i], :2]
bone_length = 0
for i in range(C - 1):
bone_length += torch.pow(x_new[:, :, 2, i], 2)
bone_length = torch.sqrt(bone_length) + 0.0001
for i in range(C - 1):
x_new[:, :, 2, C+i] = torch.acos(x_new[:, :, 2, i] / bone_length)
x_new[:, :, 2, 3] = x[:, :, 2]
data = x_new
return data
+186
View File
@@ -0,0 +1,186 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..base_model import BaseModel
from ..modules import Graph, SpatialAttention
import numpy as np
import math
class Mish(nn.Module):
def __init__(self):
super().__init__()
def forward(self,x):
return x * (torch.tanh(F.softplus(x)))
class STModule(nn.Module):
def __init__(self,in_channels, out_channels, incidence, num_point):
super(STModule, self).__init__()
"""
This class implements augmented graph spatial convolution in case of Spatial Transformer
Fucntion adapated from: https://github.com/Chiaraplizz/ST-TR/blob/master/code/st_gcn/net/gcn_attention.py
"""
self.in_channels = in_channels
self.out_channels = out_channels
self.incidence = incidence
self.num_point = num_point
self.relu = Mish()
self.bn = nn.BatchNorm2d(out_channels)
self.data_bn = nn.BatchNorm1d(self.in_channels * self.num_point)
self.attention_conv = SpatialAttention(in_channels=in_channels,out_channel=out_channels,A=self.incidence,num_point=self.num_point)
def forward(self,x):
N, C, T, V = x.size()
# data normlization
x = x.permute(0, 1, 3, 2).reshape(N, C * V, T)
x = self.data_bn(x)
x = x.reshape(N, C, V, T).permute(0, 1, 3, 2)
# adjacency matrix
self.incidence = self.incidence.cuda(x.get_device())
# N, T, C, V > NT, C, 1, V
xa = x.permute(0, 2, 1, 3).reshape(-1, C, 1, V)
# spatial attention
attn_out = self.attention_conv(xa)
# N, T, C, V > N, C, T, V
attn_out = attn_out.reshape(N, T, -1, V).permute(0, 2, 1, 3)
y = attn_out
y = self.bn(self.relu(y))
return y
class UnitConv2D(nn.Module):
'''
This class is used in GaitTR[TCN_ST] block.
'''
def __init__(self, D_in, D_out, kernel_size=9, stride=1, dropout=0.1, bias=True):
super(UnitConv2D,self).__init__()
pad = int((kernel_size-1)/2)
self.conv = nn.Conv2d(D_in,D_out,kernel_size=(kernel_size,1)
,padding=(pad,0),stride=(stride,1),bias=bias)
self.bn = nn.BatchNorm2d(D_out)
self.relu = Mish()
self.dropout = nn.Dropout(dropout, inplace=False)
#initalize
self.conv_init(self.conv)
def forward(self,x):
x = self.dropout(x)
x = self.bn(self.relu(self.conv(x)))
return x
def conv_init(self,module):
n = module.out_channels
for k in module.kernel_size:
n = n*k
module.weight.data.normal_(0, math.sqrt(2. / n))
class TCN_ST(nn.Module):
"""
Block of GaitTR: https://arxiv.org/pdf/2204.03873.pdf
TCN: Temporal Convolution Network
ST: Sptail Temporal Graph Convolution Network
"""
def __init__(self,in_channel,out_channel,A,num_point):
super(TCN_ST, self).__init__()
#params
self.in_channel = in_channel
self.out_channel = out_channel
self.A = A
self.num_point = num_point
#network
self.tcn = UnitConv2D(D_in=self.in_channel,D_out=self.in_channel,kernel_size=9)
self.st = STModule(in_channels=self.in_channel,out_channels=self.out_channel,incidence=self.A,num_point=self.num_point)
self.residual = lambda x: x
if (in_channel != out_channel):
self.residual_s = nn.Sequential(
nn.Conv2d(in_channel, out_channel, 1),
nn.BatchNorm2d(out_channel),
)
self.down = UnitConv2D(D_in=self.in_channel,D_out=out_channel,kernel_size=1,dropout=0)
else:
self.residual_s = lambda x: x
self.down = None
def forward(self,x):
x0 = self.tcn(x) + self.residual(x)
y = self.st(x0) + self.residual_s(x0)
# skip residual
y = y + (x if(self.down is None) else self.down(x))
return y
class GaitTR(BaseModel):
"""
GaitTR: Spatial Transformer Network on Skeleton-based Gait Recognition
Arxiv : https://arxiv.org/abs/2204.03873.pdf
"""
def build_network(self, model_cfg):
in_c = model_cfg['in_channels']
self.num_class = model_cfg['num_class']
self.joint_format = model_cfg['joint_format']
self.graph = Graph(joint_format=self.joint_format,max_hop=3)
#### Network Define ####
# ajaceny matrix
self.A = torch.from_numpy(self.graph.A.astype(np.float32))
#data normalization
num_point = self.A.shape[-1]
self.data_bn = nn.BatchNorm1d(in_c[0] * num_point)
#backbone
backbone = []
for i in range(len(in_c)-1):
backbone.append(TCN_ST(in_channel= in_c[i],out_channel= in_c[i+1],A=self.A,num_point=num_point))
self.backbone = nn.ModuleList(backbone)
self.fcn = nn.Conv1d(in_c[-1], self.num_class, kernel_size=1)
def forward(self, inputs):
ipts, labs, _, _, seqL = inputs
x= ipts[0]
pose = x
# x = N, T, C, V, M -> N, C, T, V, M
x = x.permute(0, 2, 1, 3, 4)
N, C, T, V, M = x.size()
if len(x.size()) == 4:
x = x.unsqueeze(1)
del ipts
x = x.permute(0, 4, 3, 1, 2).contiguous().view(N, M * V * C, T)
x = self.data_bn(x)
x = x.view(N, M, V, C, T).permute(0, 1, 3, 4, 2).contiguous().view(
N * M, C, T, V)
#backbone
for _,m in enumerate(self.backbone):
x = m(x)
# V pooling
x = F.avg_pool2d(x, kernel_size=(1,V))
# M pooling
c = x.size(1)
t = x.size(2)
x = x.view(N, M, c, t).mean(dim=1).view(N, c, t)#[n,c,t]
# T pooling
x = F.avg_pool1d(x, kernel_size=x.size()[2]) #[n,c]
# C fcn
x = self.fcn(x) #[n,c']
x = F.avg_pool1d(x, x.size()[2:]) # [n,c']
x = x.view(N, self.num_class) # n,c
embed = x.unsqueeze(-1) # n,c,1
retval = {
'training_feat': {
'triplet': {'embeddings': embed, 'labels': labs}
},
'visual_summary': {
'image/pose': pose.view(N*T, M, V, C)
},
'inference_feat': {
'embeddings': embed
}
}
return retval
+484
View File
@@ -0,0 +1,484 @@
import torch
import copy
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from ..base_model import BaseModel
class MultiScaleGaitGraph(BaseModel):
"""
Learning Rich Features for Gait Recognition by Integrating Skeletons and Silhouettes
Github: https://github.com/YunjiePeng/BimodalFusion
"""
def build_network(self, model_cfg):
in_c = model_cfg['in_channels']
out_c = model_cfg['out_channels']
num_id = model_cfg['num_id']
temporal_kernel_size = model_cfg['temporal_kernel_size']
# load spatial graph
self.graph = SpatialGraph(**model_cfg['graph_cfg'])
A_lowSemantic = torch.tensor(self.graph.get_adjacency(semantic_level=0), dtype=torch.float32, requires_grad=False)
A_mediumSemantic = torch.tensor(self.graph.get_adjacency(semantic_level=1), dtype=torch.float32, requires_grad=False)
A_highSemantic = torch.tensor(self.graph.get_adjacency(semantic_level=2), dtype=torch.float32, requires_grad=False)
self.register_buffer('A_lowSemantic', A_lowSemantic)
self.register_buffer('A_mediumSemantic', A_mediumSemantic)
self.register_buffer('A_highSemantic', A_highSemantic)
# build networks
spatial_kernel_size = self.graph.num_A
temporal_kernel_size = temporal_kernel_size
kernel_size = (temporal_kernel_size, spatial_kernel_size)
self.st_gcn_networks_lowSemantic = nn.ModuleList()
self.st_gcn_networks_mediumSemantic = nn.ModuleList()
self.st_gcn_networks_highSemantic = nn.ModuleList()
for i in range(len(in_c)-1):
if i == 0:
self.st_gcn_networks_lowSemantic.append(st_gcn_block(in_c[i], in_c[i+1], kernel_size, 1, residual=False))
self.st_gcn_networks_mediumSemantic.append(st_gcn_block(in_c[i], in_c[i+1], kernel_size, 1, residual=False))
self.st_gcn_networks_highSemantic.append(st_gcn_block(in_c[i], in_c[i+1], kernel_size, 1, residual=False))
else:
self.st_gcn_networks_lowSemantic.append(st_gcn_block(in_c[i], in_c[i+1], kernel_size, 1))
self.st_gcn_networks_mediumSemantic.append(st_gcn_block(in_c[i], in_c[i+1], kernel_size, 1))
self.st_gcn_networks_highSemantic.append(st_gcn_block(in_c[i], in_c[i+1], kernel_size, 1))
self.st_gcn_networks_lowSemantic.append(st_gcn_block(in_c[i+1], in_c[i+1], kernel_size, 1))
self.st_gcn_networks_mediumSemantic.append(st_gcn_block(in_c[i+1], in_c[i+1], kernel_size, 1))
self.st_gcn_networks_highSemantic.append(st_gcn_block(in_c[i+1], in_c[i+1], kernel_size, 1))
self.edge_importance_lowSemantic = nn.ParameterList([
nn.Parameter(torch.ones(self.A_lowSemantic.size()))
for i in self.st_gcn_networks_lowSemantic])
self.edge_importance_mediumSemantic = nn.ParameterList([
nn.Parameter(torch.ones(self.A_mediumSemantic.size()))
for i in self.st_gcn_networks_mediumSemantic])
self.edge_importance_highSemantic = nn.ParameterList([
nn.Parameter(torch.ones(self.A_highSemantic.size()))
for i in self.st_gcn_networks_highSemantic])
self.fc = nn.Linear(in_c[-1], out_c)
self.bn_neck = nn.BatchNorm1d(out_c)
self.encoder_cls = nn.Linear(out_c, num_id, bias=False)
def semantic_pooling(self, x):
cur_node_num = x.size()[-1]
half_x_1, half_x_2 = torch.split(x, int(cur_node_num / 2), dim=-1)
x_sp = torch.add(half_x_1, half_x_2) / 2
return x_sp
def forward(self, inputs):
ipts, labs, _, _, seqL = inputs
x = ipts[0] # [N, T, V, C]
del ipts
"""
N - the number of videos.
T - the number of frames in one video.
V - the number of keypoints.
C - the number of features for one keypoint.
"""
N, T, V, C = x.size()
x = x.permute(0, 3, 1, 2).contiguous()
x = x.view(N, C, T, V)
y = self.semantic_pooling(x)
z = self.semantic_pooling(y)
for gcn_lowSemantic, importance_lowSemantic, gcn_mediumSemantic, importance_mediumSemantic, gcn_highSemantic, importance_highSemantic in zip(self.st_gcn_networks_lowSemantic, self.edge_importance_lowSemantic, self.st_gcn_networks_mediumSemantic, self.edge_importance_mediumSemantic, self.st_gcn_networks_highSemantic, self.edge_importance_highSemantic):
x, _ = gcn_lowSemantic(x, self.A_lowSemantic * importance_lowSemantic)
y, _ = gcn_mediumSemantic(y, self.A_mediumSemantic * importance_mediumSemantic)
z, _ = gcn_highSemantic(z, self.A_highSemantic * importance_highSemantic)
# Cross-scale Message Passing
x_sp = self.semantic_pooling(x)
y = torch.add(y, x_sp)
y_sp = self.semantic_pooling(y)
z = torch.add(z, y_sp)
# global pooling for each layer
x_sp = F.avg_pool2d(x, x.size()[2:])
N, C, T, V = x_sp.size()
x_sp = x_sp.view(N, C, T*V).contiguous()
y_sp = F.avg_pool2d(y, y.size()[2:])
N, C, T, V = y_sp.size()
y_sp = y_sp.view(N, C, T*V).contiguous()
z = F.avg_pool2d(z, z.size()[2:])
N, C, T, V = z.size()
z = z.permute(0, 2, 3, 1).contiguous()
z = z.view(N, T*V, C)
z_fc = self.fc(z.view(N, -1))
bn_z_fc = self.bn_neck(z_fc)
z_cls_score = self.encoder_cls(bn_z_fc)
z_fc = z_fc.unsqueeze(-1).contiguous() # [n, c, p]
z_cls_score = z_cls_score.unsqueeze(-1).contiguous() # [n, c, p]
retval = {
'training_feat': {
'triplet_joints': {'embeddings': x_sp, 'labels': labs},
'triplet_limbs': {'embeddings': y_sp, 'labels': labs},
'triplet_bodyparts': {'embeddings': z_fc, 'labels': labs},
'softmax': {'logits': z_cls_score, 'labels': labs}
},
'visual_summary': {},
'inference_feat': {
'embeddings': z_fc
}
}
return retval
class st_gcn_block(nn.Module):
r"""Applies a spatial temporal graph convolution over an input graph sequence.
Args:
in_channels (int): Number of channels in the input sequence data
out_channels (int): Number of channels produced by the convolution
kernel_size (tuple): Size of the temporal convolving kernel and graph convolving kernel
stride (int, optional): Stride of the temporal convolution. Default: 1
dropout (int, optional): Dropout rate of the final output. Default: 0
residual (bool, optional): If ``True``, applies a residual mechanism. Default: ``True``
Shape:
- Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format
- Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format
- Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format
- Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format
where
:math:`N` is a batch size, i.e. the number of videos.
:math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`.
:math:`T_{in}/T_{out}` is a length of input/output sequence, i.e. the number of frames in a video.
:math:`V` is the number of graph nodes.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
dropout=0,
residual=True):
super().__init__()
assert len(kernel_size) == 2
assert kernel_size[0] % 2 == 1
padding = ((kernel_size[0] - 1) // 2, 0)
self.gcn = SCN(in_channels, out_channels, kernel_size[1])
self.tcn = nn.Sequential(
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(
out_channels,
out_channels,
(kernel_size[0], 1),
(stride, 1),
padding,
),
nn.BatchNorm2d(out_channels),
nn.Dropout(dropout, inplace=True),
)
if not residual:
self.residual = lambda x: 0
elif (in_channels == out_channels) and (stride == 1):
self.residual = lambda x: x
else:
self.residual = nn.Sequential(
nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=(stride, 1)),
nn.BatchNorm2d(out_channels),
)
self.relu = nn.ReLU(inplace=True)
def forward(self, x, A):
res = self.residual(x)
x, A = self.gcn(x, A)
x = self.tcn(x) + res
return self.relu(x), A
class SCN(nn.Module):
r"""The basic module for applying a graph convolution.
Args:
in_channels (int): Number of channels in the input sequence data
out_channels (int): Number of channels produced by the convolution
kernel_size (int): Size of the graph convolving kernel
t_kernel_size (int): Size of the temporal convolving kernel
t_stride (int, optional): Stride of the temporal convolution. Default: 1
t_padding (int, optional): Temporal zero-padding added to both sides of
the input. Default: 0
t_dilation (int, optional): Spacing between temporal kernel elements.
Default: 1
bias (bool, optional): If ``True``, adds a learnable bias to the output.
Default: ``True``
Shape:
- Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format
- Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format
- Output[0]: Output graph sequence in :math:`(N, out_channels, T_{out}, V)` format
- Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format
where
:math:`N` is a batch size,
:math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`,
:math:`T_{in}/T_{out}` is a length of input/output sequence,
:math:`V` is the number of graph nodes.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
t_kernel_size=1,
t_stride=1,
t_padding=0,
t_dilation=1,
bias=True):
super().__init__()
# The defined module SCN are responsible only for the Spacial Graph (i.e. the graph in on frame),
# and the parameter t_kernel_size in this situation is always set to 1.
self.kernel_size = kernel_size
self.conv = nn.Conv2d(in_channels,
out_channels * kernel_size,
kernel_size=(t_kernel_size, 1),
padding=(t_padding, 0),
stride=(t_stride, 1),
dilation=(t_dilation, 1),
bias=bias)
"""
The 1x1 conv operation here stands for the weight metrix W.
The kernel_size here stands for the number of different adjacency matrix,
which are defined according to the partitioning strategy.
Because for neighbor nodes in the same subset (in one adjacency matrix), the weights are shared.
It is reasonable to apply 1x1 conv as the implementation of weight function.
"""
def forward(self, x, A):
assert A.size(0) == self.kernel_size
x = self.conv(x)
n, kc, t, v = x.size()
x = x.view(n, self.kernel_size, kc // self.kernel_size, t, v)
x = torch.einsum('nkctv,kvw->nctw', (x, A))
return x.contiguous(), A
class SpatialGraph():
""" Use skeleton sequences extracted by Openpose/HRNet to construct Spatial-Temporal Graph
Args:
strategy (string): must be one of the follow candidates
- uniform: Uniform Labeling
- distance: Distance Partitioning
- spatial: Spatial Configuration Partitioning
- gait_temporal: Gait Temporal Configuration Partitioning
For more information, please refer to the section 'Partition Strategies' in PGG.
layout (string): must be one of the follow candidates
- body_12: Is consists of 12 joints.
(right shoulder, right elbow, right knee, right hip, left elbow, left knee,
left shoulder, right wrist, right ankle, left hip, left wrist, left ankle).
For more information, please refer to the section 'Data Processing' in PGG.
max_hop (int): the maximal distance between two connected nodes # 1-neighbor
dilation (int): controls the spacing between the kernel points
"""
def __init__(self,
layout='body_12', # Openpose here represents for body_12
strategy='spatial',
semantic_level=0,
max_hop=1,
dilation=1):
self.layout = layout
self.strategy = strategy
self.max_hop = max_hop
self.dilation = dilation
self.num_node, self.neighbor_link_dic = self.get_layout_info(layout)
self.num_A = self.get_A_num(strategy)
def __str__(self):
return self.A
def get_A_num(self, strategy):
if self.strategy == 'uniform':
return 1
elif self.strategy == 'distance':
return 2
elif (self.strategy == 'spatial') or (self.strategy == 'gait_temporal'):
return 3
else:
raise ValueError("Do Not Exist This Strategy")
def get_layout_info(self, layout):
if layout == 'body_12':
num_node = 12
neighbor_link_dic = {
0: [(7, 1), (1, 0), (10, 4), (4, 6),
(8, 2), (2, 3), (11, 5), (5, 9),
(9, 3), (3, 0), (9, 6), (6, 0)],
1: [(1, 0), (4, 0), (0, 3), (2, 3), (5, 3)],
2: [(1, 0), (2, 0)]
}
return num_node, neighbor_link_dic
else:
raise ValueError("Do Not Exist This Layout.")
def get_edge(self, semantic_level):
# edge is a list of [child, parent] pairs, regarding the center node as root node
self_link = [(i, i) for i in range(int(self.num_node / (2 ** semantic_level)))]
neighbor_link = self.neighbor_link_dic[semantic_level]
edge = self_link + neighbor_link
center = []
if self.layout == 'body_12':
if semantic_level == 0:
center = [0, 3, 6, 9]
elif semantic_level == 1:
center = [0, 3]
elif semantic_level == 2:
center = [0]
return edge, center
def get_gait_temporal_partitioning(self, semantic_level):
if semantic_level == 0:
if self.layout == 'body_12':
positive_node = {1, 2, 4, 5, 7, 8, 10, 11}
negative_node = {0, 3, 6, 9}
elif semantic_level == 1:
if self.layout == 'body_12':
positive_node = {1, 2, 4, 5}
negative_node = {0, 3}
elif semantic_level == 2:
if self.layout == 'body_12':
positive_node = {1, 2}
negative_node = {0}
return positive_node, negative_node
def get_adjacency(self, semantic_level):
edge, center = self.get_edge(semantic_level)
num_node = int(self.num_node / (2 ** semantic_level))
hop_dis = get_hop_distance(num_node, edge, max_hop=self.max_hop)
valid_hop = range(0, self.max_hop + 1, self.dilation)
adjacency = np.zeros((num_node, num_node))
for hop in valid_hop:
adjacency[hop_dis == hop] = 1
normalize_adjacency = normalize_digraph(adjacency)
# normalize_adjacency = adjacency # withoutNodeNorm
# normalize_adjacency[a][b] = x
# when x = 0, node b has no connection with node a within valid hop.
# when x ≠ 0, the normalized adjacency from node b to node a is x.
# the value of x is normalized by the number of adjacent neighbor nodes around the node b.
if self.strategy == 'uniform':
A = np.zeros((1, num_node, num_node))
A[0] = normalize_adjacency
return A
elif self.strategy == 'distance':
A = np.zeros((len(valid_hop), num_node, num_node))
for i, hop in enumerate(valid_hop):
A[i][hop_dis == hop] = normalize_adjacency[hop_dis == hop]
return A
elif self.strategy == 'spatial':
A = []
for hop in valid_hop:
a_root = np.zeros((num_node, num_node))
a_close = np.zeros((num_node, num_node))
a_further = np.zeros((num_node, num_node))
for i in range(num_node):
for j in range(num_node):
if hop_dis[j, i] == hop:
j_hop_dis = min([hop_dis[j, _center] for _center in center])
i_hop_dis = min([hop_dis[i, _center] for _center in center])
if j_hop_dis == i_hop_dis:
a_root[j, i] = normalize_adjacency[j, i]
elif j_hop_dis > i_hop_dis:
a_close[j, i] = normalize_adjacency[j, i]
else:
a_further[j, i] = normalize_adjacency[j, i]
if hop == 0:
A.append(a_root)
else:
A.append(a_root + a_close)
A.append(a_further)
A = np.stack(A)
self.A = A
return A
elif self.strategy == 'gait_temporal':
A = []
positive_node, negative_node = self.get_gait_temporal_partitioning(semantic_level)
for hop in valid_hop:
a_root = np.zeros((num_node, num_node))
a_positive = np.zeros((num_node, num_node))
a_negative = np.zeros((num_node, num_node))
for i in range(num_node):
for j in range(num_node):
if hop_dis[j, i] == hop:
if i == j:
a_root[j, i] = normalize_adjacency[j, i]
elif j in positive_node:
a_positive[j, i] = normalize_adjacency[j, i]
else:
a_negative[j, i] = normalize_adjacency[j, i]
if hop == 0:
A.append(a_root)
else:
A.append(a_negative)
A.append(a_positive)
A = np.stack(A)
return A
else:
raise ValueError("Do Not Exist This Strategy")
def get_hop_distance(num_node, edge, max_hop=1):
# Calculate the shortest path between nodes
# i.e. The minimum number of steps needed to walk from one node to another
A = np.zeros((num_node, num_node)) # Ajacent Matrix
for i, j in edge:
A[j, i] = 1
A[i, j] = 1
# compute hop steps
hop_dis = np.zeros((num_node, num_node)) + np.inf
transfer_mat = [np.linalg.matrix_power(A, d) for d in range(max_hop + 1)]
arrive_mat = (np.stack(transfer_mat) > 0)
for d in range(max_hop, -1, -1):
hop_dis[arrive_mat[d]] = d
return hop_dis
def normalize_digraph(A):
Dl = np.sum(A, 0)
num_node = A.shape[0]
Dn = np.zeros((num_node, num_node))
for i in range(num_node):
if Dl[i] > 0:
Dn[i, i] = Dl[i]**(-1)
AD = np.dot(A, Dn)
return AD
def normalize_undigraph(A):
Dl = np.sum(A, 0)
num_node = A.shape[0]
Dn = np.zeros((num_node, num_node))
for i in range(num_node):
if Dl[i] > 0:
Dn[i, i] = Dl[i]**(-0.5)
DAD = np.dot(np.dot(Dn, A), Dn)
return DAD
+440
View File
@@ -253,3 +253,443 @@ def RmBN2dAffine(model):
if isinstance(m, nn.BatchNorm2d):
m.weight.requires_grad = False
m.bias.requires_grad = False
'''
Modifed from https://github.com/BNU-IVC/FastPoseGait/blob/main/fastposegait/modeling/components/units
'''
class Graph():
"""
# Thanks to YAN Sijie for the released code on Github (https://github.com/yysijie/st-gcn)
"""
def __init__(self, joint_format='coco', max_hop=2, dilation=1):
self.joint_format = joint_format
self.max_hop = max_hop
self.dilation = dilation
# get edges
self.num_node, self.edge, self.connect_joint, self.parts = self._get_edge()
# get adjacency matrix
self.A = self._get_adjacency()
def __str__(self):
return self.A
def _get_edge(self):
if self.joint_format == 'coco':
# keypoints = {
# 0: "nose",
# 1: "left_eye",
# 2: "right_eye",
# 3: "left_ear",
# 4: "right_ear",
# 5: "left_shoulder",
# 6: "right_shoulder",
# 7: "left_elbow",
# 8: "right_elbow",
# 9: "left_wrist",
# 10: "right_wrist",
# 11: "left_hip",
# 12: "right_hip",
# 13: "left_knee",
# 14: "right_knee",
# 15: "left_ankle",
# 16: "right_ankle"
# }
num_node = 17
self_link = [(i, i) for i in range(num_node)]
neighbor_link = [(0, 1), (0, 2), (1, 3), (2, 4), (3, 5), (4, 6), (5, 6),
(5, 7), (7, 9), (6, 8), (8, 10), (5, 11), (6, 12), (11, 12),
(11, 13), (13, 15), (12, 14), (14, 16)]
self.edge = self_link + neighbor_link
self.center = 0
self.flip_idx = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
connect_joint = np.array([5,0,0,1,2,0,0,5,6,7,8,5,6,11,12,13,14])
parts = [
np.array([5, 7, 9]), # left_arm
np.array([6, 8, 10]), # right_arm
np.array([11, 13, 15]), # left_leg
np.array([12, 14, 16]), # right_leg
np.array([0, 1, 2, 3, 4]), # head
]
elif self.joint_format == 'coco-no-head':
num_node = 12
self_link = [(i, i) for i in range(num_node)]
neighbor_link = [(0, 1),
(0, 2), (2, 4), (1, 3), (3, 5), (0, 6), (1, 7), (6, 7),
(6, 8), (8, 10), (7, 9), (9, 11)]
self.edge = self_link + neighbor_link
self.center = 0
connect_joint = np.array([3,1,0,2,4,0,6,8,10,7,9,11])
parts =[
np.array([0, 2, 4]), # left_arm
np.array([1, 3, 5]), # right_arm
np.array([6, 8, 10]), # left_leg
np.array([7, 9, 11]) # right_leg
]
elif self.joint_format =='alphapose' or self.joint_format =='openpose':
num_node = 18
self_link = [(i, i) for i in range(num_node)]
neighbor_link = [(0, 1), (0, 14), (0, 15), (14, 16), (15, 17),
(1, 2), (2, 3), (3, 4), (1, 5), (5, 6), (6, 7),
(1, 8), (8, 9), (9, 10), (1, 11), (11, 12), (12, 13)]
self.edge = self_link + neighbor_link
self.center = 1
self.flip_idx = [0, 1, 5, 6, 7, 2, 3, 4, 11, 12, 13, 8, 9, 10, 15, 14, 17, 16]
connect_joint = np.array([1,1,1,2,3,1,5,6,2,8,9,5,11,12,0,0,14,15])
parts = [
np.array([5, 6, 7]), # left_arm
np.array([2, 3, 4]), # right_arm
np.array([11, 12, 13]), # left_leg
np.array([8, 9, 10]), # right_leg
np.array([0, 1, 14, 15, 16, 17]), # head
]
else:
num_node, neighbor_link, connect_joint, parts = 0, [], [], []
logging.info('')
logging.error('Error: Do NOT exist this dataset: {}!'.format(self.dataset))
raise ValueError()
self_link = [(i, i) for i in range(num_node)]
edge = self_link + neighbor_link
return num_node, edge, connect_joint, parts
def _get_hop_distance(self):
A = np.zeros((self.num_node, self.num_node))
for i, j in self.edge:
A[j, i] = 1
A[i, j] = 1
hop_dis = np.zeros((self.num_node, self.num_node)) + np.inf
transfer_mat = [np.linalg.matrix_power(A, d) for d in range(self.max_hop + 1)]
arrive_mat = (np.stack(transfer_mat) > 0)
for d in range(self.max_hop, -1, -1):
hop_dis[arrive_mat[d]] = d
return hop_dis
def _get_adjacency(self):
hop_dis = self._get_hop_distance()
valid_hop = range(0, self.max_hop + 1, self.dilation)
adjacency = np.zeros((self.num_node, self.num_node))
for hop in valid_hop:
adjacency[hop_dis == hop] = 1
normalize_adjacency = self._normalize_digraph(adjacency)
A = np.zeros((len(valid_hop), self.num_node, self.num_node))
for i, hop in enumerate(valid_hop):
A[i][hop_dis == hop] = normalize_adjacency[hop_dis == hop]
return A
def _normalize_digraph(self, A):
Dl = np.sum(A, 0)
num_node = A.shape[0]
Dn = np.zeros((num_node, num_node))
for i in range(num_node):
if Dl[i] > 0:
Dn[i, i] = Dl[i]**(-1)
AD = np.dot(A, Dn)
return AD
class TemporalBasicBlock(nn.Module):
"""
TemporalConv_Res_Block
Arxiv: https://arxiv.org/abs/2010.09978
Github: https://github.com/Thomas-yx/ResGCNv1
"""
def __init__(self, channels, temporal_window_size, stride=1, residual=False,reduction=0,get_res=False,tcn_stride=False):
super(TemporalBasicBlock, self).__init__()
padding = ((temporal_window_size - 1) // 2, 0)
if not residual:
self.residual = lambda x: 0
elif stride == 1:
self.residual = lambda x: x
else:
self.residual = nn.Sequential(
nn.Conv2d(channels, channels, 1, (stride,1)),
nn.BatchNorm2d(channels),
)
self.conv = nn.Conv2d(channels, channels, (temporal_window_size,1), (stride,1), padding)
self.bn = nn.BatchNorm2d(channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x, res_module):
res_block = self.residual(x)
x = self.conv(x)
x = self.bn(x)
x = self.relu(x + res_block + res_module)
return x
class TemporalBottleneckBlock(nn.Module):
"""
TemporalConv_Res_Bottleneck
Arxiv: https://arxiv.org/abs/2010.09978
Github: https://github.com/Thomas-yx/ResGCNv1
"""
def __init__(self, channels, temporal_window_size, stride=1, residual=False, reduction=4,get_res=False, tcn_stride=False):
super(TemporalBottleneckBlock, self).__init__()
tcn_stride =False
padding = ((temporal_window_size - 1) // 2, 0)
inter_channels = channels // reduction
if get_res:
if tcn_stride:
stride =2
self.residual = nn.Sequential(
nn.Conv2d(channels, channels, 1, (2,1)),
nn.BatchNorm2d(channels),
)
tcn_stride= True
else:
if not residual:
self.residual = lambda x: 0
elif stride == 1:
self.residual = lambda x: x
else:
self.residual = nn.Sequential(
nn.Conv2d(channels, channels, 1, (2,1)),
nn.BatchNorm2d(channels),
)
tcn_stride= True
self.conv_down = nn.Conv2d(channels, inter_channels, 1)
self.bn_down = nn.BatchNorm2d(inter_channels)
if tcn_stride:
stride=2
self.conv = nn.Conv2d(inter_channels, inter_channels, (temporal_window_size,1), (stride,1), padding)
self.bn = nn.BatchNorm2d(inter_channels)
self.conv_up = nn.Conv2d(inter_channels, channels, 1)
self.bn_up = nn.BatchNorm2d(channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x, res_module):
res_block = self.residual(x)
x = self.conv_down(x)
x = self.bn_down(x)
x = self.relu(x)
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
x = self.conv_up(x)
x = self.bn_up(x)
x = self.relu(x + res_block + res_module)
return x
class SpatialGraphConv(nn.Module):
"""
SpatialGraphConv_Basic_Block
Arxiv: https://arxiv.org/abs/1801.07455
Github: https://github.com/yysijie/st-gcn
"""
def __init__(self, in_channels, out_channels, max_graph_distance):
super(SpatialGraphConv, self).__init__()
# spatial class number (distance = 0 for class 0, distance = 1 for class 1, ...)
self.s_kernel_size = max_graph_distance + 1
# weights of different spatial classes
self.gcn = nn.Conv2d(in_channels, out_channels*self.s_kernel_size, 1)
def forward(self, x, A):
# numbers in same class have same weight
x = self.gcn(x)
# divide nodes into different classes
n, kc, t, v = x.size()
x = x.view(n, self.s_kernel_size, kc//self.s_kernel_size, t, v).contiguous()
# spatial graph convolution
x = torch.einsum('nkctv,kvw->nctw', (x, A[:self.s_kernel_size])).contiguous()
return x
class SpatialBasicBlock(nn.Module):
"""
SpatialGraphConv_Res_Block
Arxiv: https://arxiv.org/abs/2010.09978
Github: https://github.com/Thomas-yx/ResGCNv1
"""
def __init__(self, in_channels, out_channels, max_graph_distance, residual=False,reduction=0):
super(SpatialBasicBlock, self).__init__()
if not residual:
self.residual = lambda x: 0
elif in_channels == out_channels:
self.residual = lambda x: x
else:
self.residual = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1),
nn.BatchNorm2d(out_channels),
)
self.conv = SpatialGraphConv(in_channels, out_channels, max_graph_distance)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x, A):
res_block = self.residual(x)
x = self.conv(x, A)
x = self.bn(x)
x = self.relu(x + res_block)
return x
class SpatialBottleneckBlock(nn.Module):
"""
SpatialGraphConv_Res_Bottleneck
Arxiv: https://arxiv.org/abs/2010.09978
Github: https://github.com/Thomas-yx/ResGCNv1
"""
def __init__(self, in_channels, out_channels, max_graph_distance, residual=False, reduction=4):
super(SpatialBottleneckBlock, self).__init__()
inter_channels = out_channels // reduction
if not residual:
self.residual = lambda x: 0
elif in_channels == out_channels:
self.residual = lambda x: x
else:
self.residual = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1),
nn.BatchNorm2d(out_channels),
)
self.conv_down = nn.Conv2d(in_channels, inter_channels, 1)
self.bn_down = nn.BatchNorm2d(inter_channels)
self.conv = SpatialGraphConv(inter_channels, inter_channels, max_graph_distance)
self.bn = nn.BatchNorm2d(inter_channels)
self.conv_up = nn.Conv2d(inter_channels, out_channels, 1)
self.bn_up = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x, A):
res_block = self.residual(x)
x = self.conv_down(x)
x = self.bn_down(x)
x = self.relu(x)
x = self.conv(x, A)
x = self.bn(x)
x = self.relu(x)
x = self.conv_up(x)
x = self.bn_up(x)
x = self.relu(x + res_block)
return x
class SpatialAttention(nn.Module):
"""
This class implements Spatial Transformer.
Function adapted from: https://github.com/leaderj1001/Attention-Augmented-Conv2d
"""
def __init__(self, in_channels, out_channel, A, num_point, dk_factor=0.25, kernel_size=1, Nh=8, num=4, stride=1):
super(SpatialAttention, self).__init__()
self.in_channels = in_channels
self.kernel_size = kernel_size
self.dk = int(dk_factor * out_channel)
self.dv = int(out_channel)
self.num = num
self.Nh = Nh
self.num_point=num_point
self.A = A[0] + A[1] + A[2]
self.stride = stride
self.padding = (self.kernel_size - 1) // 2
assert self.Nh != 0, "integer division or modulo by zero, Nh >= 1"
assert self.dk % self.Nh == 0, "dk should be divided by Nh. (example: out_channels: 20, dk: 40, Nh: 4)"
assert self.dv % self.Nh == 0, "dv should be divided by Nh. (example: out_channels: 20, dv: 4, Nh: 4)"
assert stride in [1, 2], str(stride) + " Up to 2 strides are allowed."
self.qkv_conv = nn.Conv2d(self.in_channels, 2 * self.dk + self.dv, kernel_size=self.kernel_size,
stride=stride,
padding=self.padding)
self.attn_out = nn.Conv2d(self.dv, self.dv, kernel_size=1, stride=1)
def forward(self, x):
# Input x
# (batch_size, channels, 1, joints)
B, _, T, V = x.size()
# flat_q, flat_k, flat_v
# (batch_size, Nh, dvh or dkh, joints)
# dvh = dv / Nh, dkh = dk / Nh
# q, k, v obtained by doing 2D convolution on the input (q=XWq, k=XWk, v=XWv)
flat_q, flat_k, flat_v, q, k, v = self.compute_flat_qkv(x, self.dk, self.dv, self.Nh)
# Calculate the scores, obtained by doing q*k
# (batch_size, Nh, joints, dkh)*(batch_size, Nh, dkh, joints) = (batch_size, Nh, joints,joints)
# The multiplication can also be divided (multi_matmul) in case of space problems
logits = torch.matmul(flat_q.transpose(2, 3), flat_k)
weights = F.softmax(logits, dim=-1)
# attn_out
# (batch, Nh, joints, dvh)
# weights*V
# (batch, Nh, joints, joints)*(batch, Nh, joints, dvh)=(batch, Nh, joints, dvh)
attn_out = torch.matmul(weights, flat_v.transpose(2, 3))
attn_out = torch.reshape(attn_out, (B, self.Nh, T, V, self.dv // self.Nh))
attn_out = attn_out.permute(0, 1, 4, 2, 3)
# combine_heads_2d, combine heads only after having calculated each Z separately
# (batch, Nh*dv, 1, joints)
attn_out = self.combine_heads_2d(attn_out)
# Multiply for W0 (batch, out_channels, 1, joints) with out_channels=dv
attn_out = self.attn_out(attn_out)
return attn_out
def compute_flat_qkv(self, x, dk, dv, Nh):
qkv = self.qkv_conv(x)
# T=1 in this case, because we are considering each frame separately
N, _, T, V = qkv.size()
q, k, v = torch.split(qkv, [dk, dk, dv], dim=1)
q = self.split_heads_2d(q, Nh)
k = self.split_heads_2d(k, Nh)
v = self.split_heads_2d(v, Nh)
dkh = dk // Nh
q = q*(dkh ** -0.5)
flat_q = torch.reshape(q, (N, Nh, dkh, T * V))
flat_k = torch.reshape(k, (N, Nh, dkh, T * V))
flat_v = torch.reshape(v, (N, Nh, dv // self.Nh, T * V))
return flat_q, flat_k, flat_v, q, k, v
def split_heads_2d(self, x, Nh):
B, channels, T, V = x.size()
ret_shape = (B, Nh, channels // Nh, T, V)
split = torch.reshape(x, ret_shape)
return split
def combine_heads_2d(self, x):
batch, Nh, dv, T, V = x.size()
ret_shape = (batch, Nh * dv, T, V)
return torch.reshape(x, ret_shape)