Support skeleton (#155)
* pose * pose * pose * pose * 你的提交消息 * pose * pose * Delete train1.sh * pretreatment * configs * pose * reference * Update gaittr.py * naming * naming * Update transform.py * update for datasets * update README * update name and README * update * Update transform.py
This commit is contained in:
@@ -0,0 +1,135 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ..modules import TemporalBasicBlock, TemporalBottleneckBlock, SpatialBasicBlock, SpatialBottleneckBlock
|
||||
|
||||
class ResGCNModule(nn.Module):
|
||||
"""
|
||||
ResGCNModule
|
||||
Arxiv: https://arxiv.org/abs/2010.09978
|
||||
Github: https://github.com/Thomas-yx/ResGCNv1
|
||||
https://github.com/BNU-IVC/FastPoseGait
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, block, A, stride=1, kernel_size=[9,2],reduction=4, get_res=False,is_main=False):
|
||||
super(ResGCNModule, self).__init__()
|
||||
|
||||
if not len(kernel_size) == 2:
|
||||
logging.info('')
|
||||
logging.error('Error: Please check whether len(kernel_size) == 2')
|
||||
raise ValueError()
|
||||
if not kernel_size[0] % 2 == 1:
|
||||
logging.info('')
|
||||
logging.error('Error: Please check whether kernel_size[0] % 2 == 1')
|
||||
raise ValueError()
|
||||
temporal_window_size, max_graph_distance = kernel_size
|
||||
|
||||
if block == 'initial':
|
||||
module_res, block_res = False, False
|
||||
elif block == 'Basic':
|
||||
module_res, block_res = True, False
|
||||
else:
|
||||
module_res, block_res = False, True
|
||||
|
||||
if not module_res:
|
||||
self.residual = lambda x: 0
|
||||
elif stride == 1 and in_channels == out_channels:
|
||||
self.residual = lambda x: x
|
||||
else:
|
||||
# stride =2
|
||||
self.residual = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, 1, (stride,1)),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
)
|
||||
|
||||
if block in ['Basic','initial']:
|
||||
spatial_block = SpatialBasicBlock
|
||||
temporal_block = TemporalBasicBlock
|
||||
if block == 'Bottleneck':
|
||||
spatial_block = SpatialBottleneckBlock
|
||||
temporal_block = TemporalBottleneckBlock
|
||||
self.scn = spatial_block(in_channels, out_channels, max_graph_distance, block_res,reduction)
|
||||
if in_channels == out_channels and is_main:
|
||||
tcn_stride =True
|
||||
else:
|
||||
tcn_stride = False
|
||||
self.tcn = temporal_block(out_channels, temporal_window_size, stride, block_res,reduction,get_res=get_res,tcn_stride=tcn_stride)
|
||||
self.edge = nn.Parameter(torch.ones_like(A))
|
||||
|
||||
def forward(self, x, A):
|
||||
A = A.cuda(x.get_device())
|
||||
return self.tcn(self.scn(x, A*self.edge), self.residual(x))
|
||||
|
||||
class ResGCNInputBranch(nn.Module):
|
||||
"""
|
||||
ResGCNInputBranch_Module
|
||||
Arxiv: https://arxiv.org/abs/2010.09978
|
||||
Github: https://github.com/Thomas-yx/ResGCNv1
|
||||
"""
|
||||
def __init__(self, input_branch, block, A, input_num , reduction = 4):
|
||||
super(ResGCNInputBranch, self).__init__()
|
||||
|
||||
self.register_buffer('A', A)
|
||||
|
||||
module_list = []
|
||||
for i in range(len(input_branch)-1):
|
||||
if i==0:
|
||||
module_list.append(ResGCNModule(input_branch[i],input_branch[i+1],'initial',A, reduction=reduction))
|
||||
else:
|
||||
module_list.append(ResGCNModule(input_branch[i],input_branch[i+1],block,A,reduction=reduction))
|
||||
|
||||
|
||||
self.bn = nn.BatchNorm2d(input_branch[0])
|
||||
self.layers = nn.ModuleList(module_list)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = self.bn(x)
|
||||
for layer in self.layers:
|
||||
x = layer(x, self.A)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ResGCN(nn.Module):
|
||||
"""
|
||||
ResGCN
|
||||
Arxiv: https://arxiv.org/abs/2010.09978
|
||||
"""
|
||||
def __init__(self, input_num, input_branch, main_stream,num_class, reduction, block, graph):
|
||||
super(ResGCN, self).__init__()
|
||||
self.graph = graph
|
||||
self.head= nn.ModuleList(
|
||||
ResGCNInputBranch(input_branch, block, graph, input_num ,reduction)
|
||||
for _ in range(input_num)
|
||||
)
|
||||
|
||||
main_stream_list = []
|
||||
for i in range(len(main_stream)-1):
|
||||
if main_stream[i]==main_stream[i+1]:
|
||||
stride = 1
|
||||
else:
|
||||
stride = 2
|
||||
if i ==0:
|
||||
main_stream_list.append(ResGCNModule(main_stream[i]*input_num,main_stream[i+1],block,graph,stride=1,reduction = reduction,get_res=True,is_main=True))
|
||||
else:
|
||||
main_stream_list.append(ResGCNModule(main_stream[i],main_stream[i+1],block,graph,stride = stride, reduction = reduction,is_main=True))
|
||||
self.backbone = nn.ModuleList(main_stream_list)
|
||||
self.global_pooling = nn.AdaptiveAvgPool2d(1)
|
||||
self.fcn = nn.Linear(256, num_class)
|
||||
|
||||
def forward(self, x):
|
||||
# input branch
|
||||
x_cat = []
|
||||
for i, branch in enumerate(self.head):
|
||||
x_cat.append(branch(x[:, i]))
|
||||
x = torch.cat(x_cat, dim=1)
|
||||
|
||||
# main stream
|
||||
for layer in self.backbone:
|
||||
x = layer(x, self.graph)
|
||||
|
||||
# output
|
||||
x = self.global_pooling(x)
|
||||
x = x.squeeze(-1)
|
||||
x = self.fcn(x.squeeze((-1)))
|
||||
|
||||
return x
|
||||
@@ -144,6 +144,7 @@ class BaseModel(MetaModel, nn.Module):
|
||||
|
||||
self.build_network(cfgs['model_cfg'])
|
||||
self.init_parameters()
|
||||
self.seq_trfs = get_transform(self.engine_cfg['transform'])
|
||||
|
||||
self.msg_mgr.log_info(cfgs['data_cfg'])
|
||||
if training:
|
||||
@@ -299,8 +300,7 @@ class BaseModel(MetaModel, nn.Module):
|
||||
tuple: training data including inputs, labels, and some meta data.
|
||||
"""
|
||||
seqs_batch, labs_batch, typs_batch, vies_batch, seqL_batch = inputs
|
||||
trf_cfgs = self.engine_cfg['transform']
|
||||
seq_trfs = get_transform(trf_cfgs)
|
||||
seq_trfs = self.seq_trfs
|
||||
if len(seqs_batch) != len(seq_trfs):
|
||||
raise ValueError(
|
||||
"The number of types of input data and transform should be same. But got {} and {}".format(len(seqs_batch), len(seq_trfs)))
|
||||
|
||||
@@ -0,0 +1,107 @@
|
||||
'''
|
||||
Modifed fromhttps://github.com/BNU-IVC/FastPoseGait/blob/main/fastposegait/modeling/losses/supconloss.py
|
||||
'''
|
||||
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
from .base import BaseLoss, gather_and_scale_wrapper
|
||||
|
||||
class SupConLoss_Re(BaseLoss):
|
||||
def __init__(self, temperature=0.01):
|
||||
super(SupConLoss_Re, self).__init__()
|
||||
self.train_loss = SupConLoss(temperature=temperature)
|
||||
@gather_and_scale_wrapper
|
||||
def forward(self, features, labels=None, mask=None):
|
||||
loss = self.train_loss(features,labels)
|
||||
self.info.update({
|
||||
'loss': loss.detach().clone()})
|
||||
return loss, self.info
|
||||
|
||||
|
||||
class SupConLoss(nn.Module):
|
||||
"""Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
|
||||
It also supports the unsupervised contrastive loss in SimCLR"""
|
||||
def __init__(self, temperature=0.01, contrast_mode='all',
|
||||
base_temperature=0.07):
|
||||
super(SupConLoss, self).__init__()
|
||||
self.temperature = temperature
|
||||
self.contrast_mode = contrast_mode
|
||||
self.base_temperature = base_temperature
|
||||
|
||||
def forward(self, features, labels=None, mask=None):
|
||||
"""Compute loss for model. If both `labels` and `mask` are None,
|
||||
it degenerates to SimCLR unsupervised loss:
|
||||
https://arxiv.org/pdf/2002.05709.pdf
|
||||
Args:
|
||||
features: hidden vector of shape [bsz, n_views, ...].
|
||||
labels: ground truth of shape [bsz].
|
||||
mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
|
||||
has the same class as sample i. Can be asymmetric.
|
||||
Returns:
|
||||
A loss scalar.
|
||||
"""
|
||||
device = (torch.device('cuda')
|
||||
if features.is_cuda
|
||||
else torch.device('cpu'))
|
||||
|
||||
if len(features.shape) < 3:
|
||||
raise ValueError('`features` needs to be [bsz, n_views, ...],'
|
||||
'at least 3 dimensions are required')
|
||||
if len(features.shape) > 3:
|
||||
features = features.view(features.shape[0], features.shape[1], -1)
|
||||
|
||||
batch_size = features.shape[0]
|
||||
if labels is not None and mask is not None:
|
||||
raise ValueError('Cannot define both `labels` and `mask`')
|
||||
elif labels is None and mask is None:
|
||||
mask = torch.eye(batch_size, dtype=torch.float32).to(device)
|
||||
elif labels is not None:
|
||||
labels = labels.contiguous().view(-1, 1)
|
||||
if labels.shape[0] != batch_size:
|
||||
raise ValueError('Num of labels does not match num of features')
|
||||
mask = torch.eq(labels, labels.T).float().to(device)
|
||||
else:
|
||||
mask = mask.float().to(device)
|
||||
|
||||
contrast_count = features.shape[1]
|
||||
contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
|
||||
if self.contrast_mode == 'one':
|
||||
anchor_feature = features[:, 0]
|
||||
anchor_count = 1
|
||||
elif self.contrast_mode == 'all':
|
||||
anchor_feature = contrast_feature
|
||||
anchor_count = contrast_count
|
||||
else:
|
||||
raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
|
||||
|
||||
# compute logits
|
||||
anchor_dot_contrast = torch.div(
|
||||
torch.matmul(anchor_feature, contrast_feature.T),
|
||||
self.temperature)
|
||||
# for numerical stability
|
||||
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
|
||||
logits = anchor_dot_contrast - logits_max.detach()
|
||||
|
||||
# tile mask
|
||||
mask = mask.repeat(anchor_count, contrast_count)
|
||||
# mask-out self-contrast cases
|
||||
logits_mask = torch.scatter(
|
||||
torch.ones_like(mask),
|
||||
1,
|
||||
torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
|
||||
0
|
||||
)
|
||||
mask = mask * logits_mask
|
||||
|
||||
# compute log_prob
|
||||
exp_logits = torch.exp(logits) * logits_mask
|
||||
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
|
||||
|
||||
# compute mean of log-likelihood over positive
|
||||
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
|
||||
|
||||
# loss
|
||||
loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
|
||||
loss = loss.view(anchor_count, batch_size).mean()
|
||||
|
||||
return loss
|
||||
@@ -0,0 +1,19 @@
|
||||
'''
|
||||
Modifed fromhttps://github.com/BNU-IVC/FastPoseGait/blob/main/fastposegait/modeling/losses/supconloss_Lp.py
|
||||
'''
|
||||
|
||||
from .base import BaseLoss, gather_and_scale_wrapper
|
||||
from pytorch_metric_learning import losses, distances
|
||||
|
||||
class SupConLoss_Lp(BaseLoss):
|
||||
def __init__(self, temperature=0.01):
|
||||
super(SupConLoss_Lp, self).__init__()
|
||||
self.distance = distances.LpDistance()
|
||||
self.train_loss = losses.SupConLoss(temperature=temperature, distance=self.distance)
|
||||
@gather_and_scale_wrapper
|
||||
def forward(self, features, labels=None, mask=None):
|
||||
loss = self.train_loss(features,labels)
|
||||
self.info.update({
|
||||
'loss': loss.detach().clone()})
|
||||
return loss, self.info
|
||||
|
||||
@@ -0,0 +1,75 @@
|
||||
import torch
|
||||
from ..base_model import BaseModel
|
||||
from ..backbones.resgcn import ResGCN
|
||||
from ..modules import Graph
|
||||
import torch.nn.functional as F
|
||||
|
||||
class GaitGraph1(BaseModel):
|
||||
"""
|
||||
GaitGraph1: Gaitgraph: Graph Convolutional Network for Skeleton-Based Gait Recognition
|
||||
Paper: https://ieeexplore.ieee.org/document/9506717
|
||||
Github: https://github.com/tteepe/GaitGraph
|
||||
"""
|
||||
def build_network(self, model_cfg):
|
||||
|
||||
self.joint_format = model_cfg['joint_format']
|
||||
self.input_num = model_cfg['input_num']
|
||||
self.block = model_cfg['block']
|
||||
self.input_branch = model_cfg['input_branch']
|
||||
self.main_stream = model_cfg['main_stream']
|
||||
self.num_class = model_cfg['num_class']
|
||||
self.reduction = model_cfg['reduction']
|
||||
self.tta = model_cfg['tta']
|
||||
|
||||
## Graph Init ##
|
||||
self.graph = Graph(joint_format=self.joint_format,max_hop=3)
|
||||
self.A = torch.tensor(self.graph.A, dtype=torch.float32, requires_grad=False)
|
||||
## Network ##
|
||||
self.ResGCN = ResGCN(input_num=self.input_num, input_branch=self.input_branch,
|
||||
main_stream=self.main_stream, num_class=self.num_class,
|
||||
reduction=self.reduction, block=self.block,graph=self.A)
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
ipts, labs, type_, view_, seqL = inputs
|
||||
x_input = ipts[0] # N T C V I
|
||||
# x = N, T, C, V, M -> N, C, T, V, M
|
||||
x_input = x_input.permute(0, 2, 3, 4, 1).contiguous()
|
||||
N, T, V, I, C = x_input.size()
|
||||
|
||||
pose = x_input
|
||||
if self.training:
|
||||
x_input = torch.cat([x_input[:,:int(T/2),...],x_input[:,int(T/2):,...]],dim=0) #[8, 60, 17, 1, 3]
|
||||
elif self.tta:
|
||||
data_flipped = torch.flip(x_input,dims=[1])
|
||||
x_input = torch.cat([x_input,data_flipped], dim=0)
|
||||
|
||||
x = x_input.permute(0, 3, 4, 1, 2).contiguous()
|
||||
|
||||
# resgcn
|
||||
x = self.ResGCN(x)
|
||||
x = F.normalize(x, dim=1, p=2) # norm #only for GaitGraph1 # Remove from GaitGraph2
|
||||
|
||||
if self.training:
|
||||
f1, f2 = torch.split(x, [N, N], dim=0)
|
||||
embed = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1) #[4, 2, 128]
|
||||
|
||||
elif self.tta:
|
||||
f1, f2 = torch.split(x, [N, N], dim=0)
|
||||
embed = torch.mean(torch.stack([f1, f2]), dim=0)
|
||||
embed = embed.unsqueeze(-1)
|
||||
else:
|
||||
embed = embed.unsqueeze(-1)
|
||||
|
||||
retval = {
|
||||
'training_feat': {
|
||||
'SupConLoss': {'features': embed , 'labels': labs}, # loss
|
||||
},
|
||||
'visual_summary': {
|
||||
'image/pose': pose.view(N*T, 1, I*V, C).contiguous() # visualization
|
||||
},
|
||||
'inference_feat': {
|
||||
'embeddings': embed # for metric
|
||||
}
|
||||
}
|
||||
return retval
|
||||
@@ -0,0 +1,110 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ..base_model import BaseModel
|
||||
from ..backbones.resgcn import ResGCN
|
||||
from ..modules import Graph
|
||||
import numpy as np
|
||||
|
||||
|
||||
class GaitGraph2(BaseModel):
|
||||
"""
|
||||
GaitGraph2: Towards a Deeper Understanding of Skeleton-based Gait Recognition
|
||||
Paper: https://openaccess.thecvf.com/content/CVPR2022W/Biometrics/papers/Teepe_Towards_a_Deeper_Understanding_of_Skeleton-Based_Gait_Recognition_CVPRW_2022_paper
|
||||
Github: https://github.com/tteepe/GaitGraph2
|
||||
"""
|
||||
def build_network(self, model_cfg):
|
||||
|
||||
self.joint_format = model_cfg['joint_format']
|
||||
self.input_num = model_cfg['input_num']
|
||||
self.block = model_cfg['block']
|
||||
self.input_branch = model_cfg['input_branch']
|
||||
self.main_stream = model_cfg['main_stream']
|
||||
self.num_class = model_cfg['num_class']
|
||||
self.reduction = model_cfg['reduction']
|
||||
self.tta = model_cfg['tta']
|
||||
## Graph Init ##
|
||||
self.graph = Graph(joint_format=self.joint_format,max_hop=3)
|
||||
self.A = torch.tensor(self.graph.A, dtype=torch.float32, requires_grad=False)
|
||||
## Network ##
|
||||
self.ResGCN = ResGCN(input_num=self.input_num, input_branch=self.input_branch,
|
||||
main_stream=self.main_stream, num_class=self.num_class,
|
||||
reduction=self.reduction, block=self.block,graph=self.A)
|
||||
|
||||
def forward(self, inputs):
|
||||
|
||||
ipts, labs, type_, view_, seqL = inputs
|
||||
x_input = ipts[0]
|
||||
N, T, V, I, C = x_input.size()
|
||||
pose = x_input
|
||||
flip_idx = self.graph.flip_idx
|
||||
|
||||
if not self.training and self.tta:
|
||||
multi_input = MultiInput(self.graph.connect_joint, self.graph.center)
|
||||
x1 = []
|
||||
x2 = []
|
||||
for i in range(N):
|
||||
x1.append(multi_input(x_input[i,:,:,0,:3].flip(0)))
|
||||
x2.append(multi_input(x_input[i,:,flip_idx,0,:3]))
|
||||
x_input = torch.cat([x_input, torch.stack(x1,0), torch.stack(x2,0)], dim=0)
|
||||
|
||||
x = x_input.permute(0, 3, 4, 1, 2).contiguous()
|
||||
|
||||
# resgcn
|
||||
x = self.ResGCN(x)
|
||||
|
||||
if not self.training and self.tta:
|
||||
f1, f2, f3 = torch.split(x, [N, N, N], dim=0)
|
||||
x = torch.cat((f1, f2, f3), dim=1)
|
||||
|
||||
embed = torch.unsqueeze(x,-1)
|
||||
|
||||
retval = {
|
||||
'training_feat': {
|
||||
'SupConLoss': {'features': x , 'labels': labs}, # loss
|
||||
},
|
||||
'visual_summary': {
|
||||
'image/pose': pose.view(N*T, 1, I*V, C).contiguous() # visualization
|
||||
},
|
||||
'inference_feat': {
|
||||
'embeddings': embed # for metric
|
||||
}
|
||||
}
|
||||
return retval
|
||||
|
||||
class MultiInput:
|
||||
def __init__(self, connect_joint, center):
|
||||
self.connect_joint = connect_joint
|
||||
self.center = center
|
||||
|
||||
def __call__(self, data):
|
||||
|
||||
# T, V, C -> T, V, I=3, C + 2
|
||||
T, V, C = data.shape
|
||||
x_new = torch.zeros((T, V, 3, C + 2), device=data.device)
|
||||
|
||||
# Joints
|
||||
x = data
|
||||
x_new[:, :, 0, :C] = x
|
||||
for i in range(V):
|
||||
x_new[:, i, 0, C:] = x[:, i, :2] - x[:, self.center, :2]
|
||||
|
||||
# Velocity
|
||||
for i in range(T - 2):
|
||||
x_new[i, :, 1, :2] = x[i + 1, :, :2] - x[i, :, :2]
|
||||
x_new[i, :, 1, 3:] = x[i + 2, :, :2] - x[i, :, :2]
|
||||
x_new[:, :, 1, 3] = x[:, :, 2]
|
||||
|
||||
# Bones
|
||||
for i in range(V):
|
||||
x_new[:, i, 2, :2] = x[:, i, :2] - x[:, self.connect_joint[i], :2]
|
||||
bone_length = 0
|
||||
for i in range(C - 1):
|
||||
bone_length += torch.pow(x_new[:, :, 2, i], 2)
|
||||
bone_length = torch.sqrt(bone_length) + 0.0001
|
||||
for i in range(C - 1):
|
||||
x_new[:, :, 2, C+i] = torch.acos(x_new[:, :, 2, i] / bone_length)
|
||||
x_new[:, :, 2, 3] = x[:, :, 2]
|
||||
|
||||
data = x_new
|
||||
return data
|
||||
|
||||
@@ -0,0 +1,186 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from ..base_model import BaseModel
|
||||
from ..modules import Graph, SpatialAttention
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
|
||||
class Mish(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def forward(self,x):
|
||||
return x * (torch.tanh(F.softplus(x)))
|
||||
|
||||
class STModule(nn.Module):
|
||||
def __init__(self,in_channels, out_channels, incidence, num_point):
|
||||
super(STModule, self).__init__()
|
||||
"""
|
||||
This class implements augmented graph spatial convolution in case of Spatial Transformer
|
||||
Fucntion adapated from: https://github.com/Chiaraplizz/ST-TR/blob/master/code/st_gcn/net/gcn_attention.py
|
||||
"""
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.incidence = incidence
|
||||
self.num_point = num_point
|
||||
self.relu = Mish()
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
self.data_bn = nn.BatchNorm1d(self.in_channels * self.num_point)
|
||||
self.attention_conv = SpatialAttention(in_channels=in_channels,out_channel=out_channels,A=self.incidence,num_point=self.num_point)
|
||||
def forward(self,x):
|
||||
N, C, T, V = x.size()
|
||||
# data normlization
|
||||
x = x.permute(0, 1, 3, 2).reshape(N, C * V, T)
|
||||
x = self.data_bn(x)
|
||||
x = x.reshape(N, C, V, T).permute(0, 1, 3, 2)
|
||||
# adjacency matrix
|
||||
self.incidence = self.incidence.cuda(x.get_device())
|
||||
# N, T, C, V > NT, C, 1, V
|
||||
xa = x.permute(0, 2, 1, 3).reshape(-1, C, 1, V)
|
||||
# spatial attention
|
||||
attn_out = self.attention_conv(xa)
|
||||
# N, T, C, V > N, C, T, V
|
||||
attn_out = attn_out.reshape(N, T, -1, V).permute(0, 2, 1, 3)
|
||||
y = attn_out
|
||||
y = self.bn(self.relu(y))
|
||||
return y
|
||||
|
||||
class UnitConv2D(nn.Module):
|
||||
'''
|
||||
This class is used in GaitTR[TCN_ST] block.
|
||||
'''
|
||||
|
||||
def __init__(self, D_in, D_out, kernel_size=9, stride=1, dropout=0.1, bias=True):
|
||||
super(UnitConv2D,self).__init__()
|
||||
pad = int((kernel_size-1)/2)
|
||||
self.conv = nn.Conv2d(D_in,D_out,kernel_size=(kernel_size,1)
|
||||
,padding=(pad,0),stride=(stride,1),bias=bias)
|
||||
self.bn = nn.BatchNorm2d(D_out)
|
||||
self.relu = Mish()
|
||||
self.dropout = nn.Dropout(dropout, inplace=False)
|
||||
#initalize
|
||||
self.conv_init(self.conv)
|
||||
|
||||
def forward(self,x):
|
||||
x = self.dropout(x)
|
||||
x = self.bn(self.relu(self.conv(x)))
|
||||
return x
|
||||
|
||||
def conv_init(self,module):
|
||||
n = module.out_channels
|
||||
for k in module.kernel_size:
|
||||
n = n*k
|
||||
module.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
|
||||
class TCN_ST(nn.Module):
|
||||
"""
|
||||
Block of GaitTR: https://arxiv.org/pdf/2204.03873.pdf
|
||||
TCN: Temporal Convolution Network
|
||||
ST: Sptail Temporal Graph Convolution Network
|
||||
"""
|
||||
def __init__(self,in_channel,out_channel,A,num_point):
|
||||
super(TCN_ST, self).__init__()
|
||||
#params
|
||||
self.in_channel = in_channel
|
||||
self.out_channel = out_channel
|
||||
self.A = A
|
||||
self.num_point = num_point
|
||||
#network
|
||||
self.tcn = UnitConv2D(D_in=self.in_channel,D_out=self.in_channel,kernel_size=9)
|
||||
self.st = STModule(in_channels=self.in_channel,out_channels=self.out_channel,incidence=self.A,num_point=self.num_point)
|
||||
self.residual = lambda x: x
|
||||
if (in_channel != out_channel):
|
||||
self.residual_s = nn.Sequential(
|
||||
nn.Conv2d(in_channel, out_channel, 1),
|
||||
nn.BatchNorm2d(out_channel),
|
||||
)
|
||||
self.down = UnitConv2D(D_in=self.in_channel,D_out=out_channel,kernel_size=1,dropout=0)
|
||||
else:
|
||||
self.residual_s = lambda x: x
|
||||
self.down = None
|
||||
|
||||
def forward(self,x):
|
||||
x0 = self.tcn(x) + self.residual(x)
|
||||
y = self.st(x0) + self.residual_s(x0)
|
||||
# skip residual
|
||||
y = y + (x if(self.down is None) else self.down(x))
|
||||
return y
|
||||
|
||||
|
||||
|
||||
class GaitTR(BaseModel):
|
||||
"""
|
||||
GaitTR: Spatial Transformer Network on Skeleton-based Gait Recognition
|
||||
Arxiv : https://arxiv.org/abs/2204.03873.pdf
|
||||
"""
|
||||
def build_network(self, model_cfg):
|
||||
|
||||
in_c = model_cfg['in_channels']
|
||||
self.num_class = model_cfg['num_class']
|
||||
self.joint_format = model_cfg['joint_format']
|
||||
self.graph = Graph(joint_format=self.joint_format,max_hop=3)
|
||||
|
||||
#### Network Define ####
|
||||
|
||||
# ajaceny matrix
|
||||
self.A = torch.from_numpy(self.graph.A.astype(np.float32))
|
||||
|
||||
#data normalization
|
||||
num_point = self.A.shape[-1]
|
||||
self.data_bn = nn.BatchNorm1d(in_c[0] * num_point)
|
||||
|
||||
#backbone
|
||||
backbone = []
|
||||
for i in range(len(in_c)-1):
|
||||
backbone.append(TCN_ST(in_channel= in_c[i],out_channel= in_c[i+1],A=self.A,num_point=num_point))
|
||||
self.backbone = nn.ModuleList(backbone)
|
||||
|
||||
self.fcn = nn.Conv1d(in_c[-1], self.num_class, kernel_size=1)
|
||||
|
||||
def forward(self, inputs):
|
||||
ipts, labs, _, _, seqL = inputs
|
||||
|
||||
x= ipts[0]
|
||||
pose = x
|
||||
# x = N, T, C, V, M -> N, C, T, V, M
|
||||
x = x.permute(0, 2, 1, 3, 4)
|
||||
N, C, T, V, M = x.size()
|
||||
if len(x.size()) == 4:
|
||||
x = x.unsqueeze(1)
|
||||
del ipts
|
||||
|
||||
x = x.permute(0, 4, 3, 1, 2).contiguous().view(N, M * V * C, T)
|
||||
|
||||
x = self.data_bn(x)
|
||||
x = x.view(N, M, V, C, T).permute(0, 1, 3, 4, 2).contiguous().view(
|
||||
N * M, C, T, V)
|
||||
#backbone
|
||||
for _,m in enumerate(self.backbone):
|
||||
x = m(x)
|
||||
# V pooling
|
||||
x = F.avg_pool2d(x, kernel_size=(1,V))
|
||||
# M pooling
|
||||
c = x.size(1)
|
||||
t = x.size(2)
|
||||
x = x.view(N, M, c, t).mean(dim=1).view(N, c, t)#[n,c,t]
|
||||
# T pooling
|
||||
x = F.avg_pool1d(x, kernel_size=x.size()[2]) #[n,c]
|
||||
# C fcn
|
||||
x = self.fcn(x) #[n,c']
|
||||
x = F.avg_pool1d(x, x.size()[2:]) # [n,c']
|
||||
x = x.view(N, self.num_class) # n,c
|
||||
embed = x.unsqueeze(-1) # n,c,1
|
||||
|
||||
retval = {
|
||||
'training_feat': {
|
||||
'triplet': {'embeddings': embed, 'labels': labs}
|
||||
},
|
||||
'visual_summary': {
|
||||
'image/pose': pose.view(N*T, M, V, C)
|
||||
},
|
||||
'inference_feat': {
|
||||
'embeddings': embed
|
||||
}
|
||||
}
|
||||
return retval
|
||||
@@ -0,0 +1,484 @@
|
||||
import torch
|
||||
import copy
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
from ..base_model import BaseModel
|
||||
|
||||
class MultiScaleGaitGraph(BaseModel):
|
||||
"""
|
||||
Learning Rich Features for Gait Recognition by Integrating Skeletons and Silhouettes
|
||||
Github: https://github.com/YunjiePeng/BimodalFusion
|
||||
"""
|
||||
|
||||
def build_network(self, model_cfg):
|
||||
in_c = model_cfg['in_channels']
|
||||
out_c = model_cfg['out_channels']
|
||||
num_id = model_cfg['num_id']
|
||||
|
||||
temporal_kernel_size = model_cfg['temporal_kernel_size']
|
||||
|
||||
# load spatial graph
|
||||
self.graph = SpatialGraph(**model_cfg['graph_cfg'])
|
||||
A_lowSemantic = torch.tensor(self.graph.get_adjacency(semantic_level=0), dtype=torch.float32, requires_grad=False)
|
||||
A_mediumSemantic = torch.tensor(self.graph.get_adjacency(semantic_level=1), dtype=torch.float32, requires_grad=False)
|
||||
A_highSemantic = torch.tensor(self.graph.get_adjacency(semantic_level=2), dtype=torch.float32, requires_grad=False)
|
||||
|
||||
self.register_buffer('A_lowSemantic', A_lowSemantic)
|
||||
self.register_buffer('A_mediumSemantic', A_mediumSemantic)
|
||||
self.register_buffer('A_highSemantic', A_highSemantic)
|
||||
|
||||
# build networks
|
||||
spatial_kernel_size = self.graph.num_A
|
||||
temporal_kernel_size = temporal_kernel_size
|
||||
kernel_size = (temporal_kernel_size, spatial_kernel_size)
|
||||
|
||||
self.st_gcn_networks_lowSemantic = nn.ModuleList()
|
||||
self.st_gcn_networks_mediumSemantic = nn.ModuleList()
|
||||
self.st_gcn_networks_highSemantic = nn.ModuleList()
|
||||
for i in range(len(in_c)-1):
|
||||
if i == 0:
|
||||
self.st_gcn_networks_lowSemantic.append(st_gcn_block(in_c[i], in_c[i+1], kernel_size, 1, residual=False))
|
||||
self.st_gcn_networks_mediumSemantic.append(st_gcn_block(in_c[i], in_c[i+1], kernel_size, 1, residual=False))
|
||||
self.st_gcn_networks_highSemantic.append(st_gcn_block(in_c[i], in_c[i+1], kernel_size, 1, residual=False))
|
||||
else:
|
||||
self.st_gcn_networks_lowSemantic.append(st_gcn_block(in_c[i], in_c[i+1], kernel_size, 1))
|
||||
self.st_gcn_networks_mediumSemantic.append(st_gcn_block(in_c[i], in_c[i+1], kernel_size, 1))
|
||||
self.st_gcn_networks_highSemantic.append(st_gcn_block(in_c[i], in_c[i+1], kernel_size, 1))
|
||||
|
||||
self.st_gcn_networks_lowSemantic.append(st_gcn_block(in_c[i+1], in_c[i+1], kernel_size, 1))
|
||||
self.st_gcn_networks_mediumSemantic.append(st_gcn_block(in_c[i+1], in_c[i+1], kernel_size, 1))
|
||||
self.st_gcn_networks_highSemantic.append(st_gcn_block(in_c[i+1], in_c[i+1], kernel_size, 1))
|
||||
|
||||
self.edge_importance_lowSemantic = nn.ParameterList([
|
||||
nn.Parameter(torch.ones(self.A_lowSemantic.size()))
|
||||
for i in self.st_gcn_networks_lowSemantic])
|
||||
|
||||
self.edge_importance_mediumSemantic = nn.ParameterList([
|
||||
nn.Parameter(torch.ones(self.A_mediumSemantic.size()))
|
||||
for i in self.st_gcn_networks_mediumSemantic])
|
||||
|
||||
self.edge_importance_highSemantic = nn.ParameterList([
|
||||
nn.Parameter(torch.ones(self.A_highSemantic.size()))
|
||||
for i in self.st_gcn_networks_highSemantic])
|
||||
|
||||
self.fc = nn.Linear(in_c[-1], out_c)
|
||||
self.bn_neck = nn.BatchNorm1d(out_c)
|
||||
self.encoder_cls = nn.Linear(out_c, num_id, bias=False)
|
||||
|
||||
def semantic_pooling(self, x):
|
||||
cur_node_num = x.size()[-1]
|
||||
half_x_1, half_x_2 = torch.split(x, int(cur_node_num / 2), dim=-1)
|
||||
x_sp = torch.add(half_x_1, half_x_2) / 2
|
||||
return x_sp
|
||||
|
||||
def forward(self, inputs):
|
||||
ipts, labs, _, _, seqL = inputs
|
||||
|
||||
x = ipts[0] # [N, T, V, C]
|
||||
del ipts
|
||||
"""
|
||||
N - the number of videos.
|
||||
T - the number of frames in one video.
|
||||
V - the number of keypoints.
|
||||
C - the number of features for one keypoint.
|
||||
"""
|
||||
N, T, V, C = x.size()
|
||||
x = x.permute(0, 3, 1, 2).contiguous()
|
||||
x = x.view(N, C, T, V)
|
||||
|
||||
y = self.semantic_pooling(x)
|
||||
z = self.semantic_pooling(y)
|
||||
for gcn_lowSemantic, importance_lowSemantic, gcn_mediumSemantic, importance_mediumSemantic, gcn_highSemantic, importance_highSemantic in zip(self.st_gcn_networks_lowSemantic, self.edge_importance_lowSemantic, self.st_gcn_networks_mediumSemantic, self.edge_importance_mediumSemantic, self.st_gcn_networks_highSemantic, self.edge_importance_highSemantic):
|
||||
x, _ = gcn_lowSemantic(x, self.A_lowSemantic * importance_lowSemantic)
|
||||
y, _ = gcn_mediumSemantic(y, self.A_mediumSemantic * importance_mediumSemantic)
|
||||
z, _ = gcn_highSemantic(z, self.A_highSemantic * importance_highSemantic)
|
||||
|
||||
# Cross-scale Message Passing
|
||||
x_sp = self.semantic_pooling(x)
|
||||
y = torch.add(y, x_sp)
|
||||
y_sp = self.semantic_pooling(y)
|
||||
z = torch.add(z, y_sp)
|
||||
|
||||
# global pooling for each layer
|
||||
x_sp = F.avg_pool2d(x, x.size()[2:])
|
||||
N, C, T, V = x_sp.size()
|
||||
x_sp = x_sp.view(N, C, T*V).contiguous()
|
||||
|
||||
y_sp = F.avg_pool2d(y, y.size()[2:])
|
||||
N, C, T, V = y_sp.size()
|
||||
y_sp = y_sp.view(N, C, T*V).contiguous()
|
||||
|
||||
z = F.avg_pool2d(z, z.size()[2:])
|
||||
N, C, T, V = z.size()
|
||||
z = z.permute(0, 2, 3, 1).contiguous()
|
||||
z = z.view(N, T*V, C)
|
||||
|
||||
z_fc = self.fc(z.view(N, -1))
|
||||
bn_z_fc = self.bn_neck(z_fc)
|
||||
z_cls_score = self.encoder_cls(bn_z_fc)
|
||||
|
||||
z_fc = z_fc.unsqueeze(-1).contiguous() # [n, c, p]
|
||||
z_cls_score = z_cls_score.unsqueeze(-1).contiguous() # [n, c, p]
|
||||
|
||||
retval = {
|
||||
'training_feat': {
|
||||
'triplet_joints': {'embeddings': x_sp, 'labels': labs},
|
||||
'triplet_limbs': {'embeddings': y_sp, 'labels': labs},
|
||||
'triplet_bodyparts': {'embeddings': z_fc, 'labels': labs},
|
||||
'softmax': {'logits': z_cls_score, 'labels': labs}
|
||||
},
|
||||
'visual_summary': {},
|
||||
'inference_feat': {
|
||||
'embeddings': z_fc
|
||||
}
|
||||
}
|
||||
return retval
|
||||
|
||||
class st_gcn_block(nn.Module):
|
||||
r"""Applies a spatial temporal graph convolution over an input graph sequence.
|
||||
Args:
|
||||
in_channels (int): Number of channels in the input sequence data
|
||||
out_channels (int): Number of channels produced by the convolution
|
||||
kernel_size (tuple): Size of the temporal convolving kernel and graph convolving kernel
|
||||
stride (int, optional): Stride of the temporal convolution. Default: 1
|
||||
dropout (int, optional): Dropout rate of the final output. Default: 0
|
||||
residual (bool, optional): If ``True``, applies a residual mechanism. Default: ``True``
|
||||
Shape:
|
||||
- Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format
|
||||
- Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format
|
||||
- Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format
|
||||
- Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format
|
||||
where
|
||||
:math:`N` is a batch size, i.e. the number of videos.
|
||||
:math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`.
|
||||
:math:`T_{in}/T_{out}` is a length of input/output sequence, i.e. the number of frames in a video.
|
||||
:math:`V` is the number of graph nodes.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
dropout=0,
|
||||
residual=True):
|
||||
super().__init__()
|
||||
|
||||
assert len(kernel_size) == 2
|
||||
assert kernel_size[0] % 2 == 1
|
||||
padding = ((kernel_size[0] - 1) // 2, 0)
|
||||
|
||||
self.gcn = SCN(in_channels, out_channels, kernel_size[1])
|
||||
|
||||
self.tcn = nn.Sequential(
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(
|
||||
out_channels,
|
||||
out_channels,
|
||||
(kernel_size[0], 1),
|
||||
(stride, 1),
|
||||
padding,
|
||||
),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.Dropout(dropout, inplace=True),
|
||||
)
|
||||
|
||||
if not residual:
|
||||
self.residual = lambda x: 0
|
||||
|
||||
elif (in_channels == out_channels) and (stride == 1):
|
||||
self.residual = lambda x: x
|
||||
|
||||
else:
|
||||
self.residual = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=(stride, 1)),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
)
|
||||
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x, A):
|
||||
res = self.residual(x)
|
||||
x, A = self.gcn(x, A)
|
||||
x = self.tcn(x) + res
|
||||
|
||||
return self.relu(x), A
|
||||
|
||||
class SCN(nn.Module):
|
||||
r"""The basic module for applying a graph convolution.
|
||||
Args:
|
||||
in_channels (int): Number of channels in the input sequence data
|
||||
out_channels (int): Number of channels produced by the convolution
|
||||
kernel_size (int): Size of the graph convolving kernel
|
||||
t_kernel_size (int): Size of the temporal convolving kernel
|
||||
t_stride (int, optional): Stride of the temporal convolution. Default: 1
|
||||
t_padding (int, optional): Temporal zero-padding added to both sides of
|
||||
the input. Default: 0
|
||||
t_dilation (int, optional): Spacing between temporal kernel elements.
|
||||
Default: 1
|
||||
bias (bool, optional): If ``True``, adds a learnable bias to the output.
|
||||
Default: ``True``
|
||||
Shape:
|
||||
- Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format
|
||||
- Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format
|
||||
- Output[0]: Output graph sequence in :math:`(N, out_channels, T_{out}, V)` format
|
||||
- Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format
|
||||
where
|
||||
:math:`N` is a batch size,
|
||||
:math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`,
|
||||
:math:`T_{in}/T_{out}` is a length of input/output sequence,
|
||||
:math:`V` is the number of graph nodes.
|
||||
"""
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
t_kernel_size=1,
|
||||
t_stride=1,
|
||||
t_padding=0,
|
||||
t_dilation=1,
|
||||
bias=True):
|
||||
super().__init__()
|
||||
# The defined module SCN are responsible only for the Spacial Graph (i.e. the graph in on frame),
|
||||
# and the parameter t_kernel_size in this situation is always set to 1.
|
||||
|
||||
self.kernel_size = kernel_size
|
||||
self.conv = nn.Conv2d(in_channels,
|
||||
out_channels * kernel_size,
|
||||
kernel_size=(t_kernel_size, 1),
|
||||
padding=(t_padding, 0),
|
||||
stride=(t_stride, 1),
|
||||
dilation=(t_dilation, 1),
|
||||
bias=bias)
|
||||
"""
|
||||
The 1x1 conv operation here stands for the weight metrix W.
|
||||
The kernel_size here stands for the number of different adjacency matrix,
|
||||
which are defined according to the partitioning strategy.
|
||||
Because for neighbor nodes in the same subset (in one adjacency matrix), the weights are shared.
|
||||
It is reasonable to apply 1x1 conv as the implementation of weight function.
|
||||
"""
|
||||
|
||||
|
||||
def forward(self, x, A):
|
||||
assert A.size(0) == self.kernel_size
|
||||
|
||||
x = self.conv(x)
|
||||
|
||||
n, kc, t, v = x.size()
|
||||
x = x.view(n, self.kernel_size, kc // self.kernel_size, t, v)
|
||||
x = torch.einsum('nkctv,kvw->nctw', (x, A))
|
||||
|
||||
return x.contiguous(), A
|
||||
|
||||
class SpatialGraph():
|
||||
""" Use skeleton sequences extracted by Openpose/HRNet to construct Spatial-Temporal Graph
|
||||
|
||||
Args:
|
||||
strategy (string): must be one of the follow candidates
|
||||
- uniform: Uniform Labeling
|
||||
- distance: Distance Partitioning
|
||||
- spatial: Spatial Configuration Partitioning
|
||||
- gait_temporal: Gait Temporal Configuration Partitioning
|
||||
For more information, please refer to the section 'Partition Strategies' in PGG.
|
||||
layout (string): must be one of the follow candidates
|
||||
- body_12: Is consists of 12 joints.
|
||||
(right shoulder, right elbow, right knee, right hip, left elbow, left knee,
|
||||
left shoulder, right wrist, right ankle, left hip, left wrist, left ankle).
|
||||
For more information, please refer to the section 'Data Processing' in PGG.
|
||||
max_hop (int): the maximal distance between two connected nodes # 1-neighbor
|
||||
dilation (int): controls the spacing between the kernel points
|
||||
"""
|
||||
def __init__(self,
|
||||
layout='body_12', # Openpose here represents for body_12
|
||||
strategy='spatial',
|
||||
semantic_level=0,
|
||||
max_hop=1,
|
||||
dilation=1):
|
||||
self.layout = layout
|
||||
self.strategy = strategy
|
||||
self.max_hop = max_hop
|
||||
self.dilation = dilation
|
||||
self.num_node, self.neighbor_link_dic = self.get_layout_info(layout)
|
||||
self.num_A = self.get_A_num(strategy)
|
||||
|
||||
def __str__(self):
|
||||
return self.A
|
||||
|
||||
def get_A_num(self, strategy):
|
||||
if self.strategy == 'uniform':
|
||||
return 1
|
||||
elif self.strategy == 'distance':
|
||||
return 2
|
||||
elif (self.strategy == 'spatial') or (self.strategy == 'gait_temporal'):
|
||||
return 3
|
||||
else:
|
||||
raise ValueError("Do Not Exist This Strategy")
|
||||
|
||||
def get_layout_info(self, layout):
|
||||
if layout == 'body_12':
|
||||
num_node = 12
|
||||
neighbor_link_dic = {
|
||||
0: [(7, 1), (1, 0), (10, 4), (4, 6),
|
||||
(8, 2), (2, 3), (11, 5), (5, 9),
|
||||
(9, 3), (3, 0), (9, 6), (6, 0)],
|
||||
1: [(1, 0), (4, 0), (0, 3), (2, 3), (5, 3)],
|
||||
2: [(1, 0), (2, 0)]
|
||||
}
|
||||
return num_node, neighbor_link_dic
|
||||
else:
|
||||
raise ValueError("Do Not Exist This Layout.")
|
||||
|
||||
def get_edge(self, semantic_level):
|
||||
# edge is a list of [child, parent] pairs, regarding the center node as root node
|
||||
self_link = [(i, i) for i in range(int(self.num_node / (2 ** semantic_level)))]
|
||||
neighbor_link = self.neighbor_link_dic[semantic_level]
|
||||
edge = self_link + neighbor_link
|
||||
center = []
|
||||
if self.layout == 'body_12':
|
||||
if semantic_level == 0:
|
||||
center = [0, 3, 6, 9]
|
||||
elif semantic_level == 1:
|
||||
center = [0, 3]
|
||||
elif semantic_level == 2:
|
||||
center = [0]
|
||||
return edge, center
|
||||
|
||||
def get_gait_temporal_partitioning(self, semantic_level):
|
||||
if semantic_level == 0:
|
||||
if self.layout == 'body_12':
|
||||
positive_node = {1, 2, 4, 5, 7, 8, 10, 11}
|
||||
negative_node = {0, 3, 6, 9}
|
||||
elif semantic_level == 1:
|
||||
if self.layout == 'body_12':
|
||||
positive_node = {1, 2, 4, 5}
|
||||
negative_node = {0, 3}
|
||||
elif semantic_level == 2:
|
||||
if self.layout == 'body_12':
|
||||
positive_node = {1, 2}
|
||||
negative_node = {0}
|
||||
return positive_node, negative_node
|
||||
|
||||
def get_adjacency(self, semantic_level):
|
||||
edge, center = self.get_edge(semantic_level)
|
||||
num_node = int(self.num_node / (2 ** semantic_level))
|
||||
hop_dis = get_hop_distance(num_node, edge, max_hop=self.max_hop)
|
||||
|
||||
valid_hop = range(0, self.max_hop + 1, self.dilation)
|
||||
adjacency = np.zeros((num_node, num_node))
|
||||
for hop in valid_hop:
|
||||
adjacency[hop_dis == hop] = 1
|
||||
|
||||
normalize_adjacency = normalize_digraph(adjacency)
|
||||
# normalize_adjacency = adjacency # withoutNodeNorm
|
||||
|
||||
# normalize_adjacency[a][b] = x
|
||||
# when x = 0, node b has no connection with node a within valid hop.
|
||||
# when x ≠ 0, the normalized adjacency from node b to node a is x.
|
||||
# the value of x is normalized by the number of adjacent neighbor nodes around the node b.
|
||||
|
||||
if self.strategy == 'uniform':
|
||||
A = np.zeros((1, num_node, num_node))
|
||||
A[0] = normalize_adjacency
|
||||
return A
|
||||
elif self.strategy == 'distance':
|
||||
A = np.zeros((len(valid_hop), num_node, num_node))
|
||||
for i, hop in enumerate(valid_hop):
|
||||
A[i][hop_dis == hop] = normalize_adjacency[hop_dis == hop]
|
||||
return A
|
||||
elif self.strategy == 'spatial':
|
||||
A = []
|
||||
for hop in valid_hop:
|
||||
a_root = np.zeros((num_node, num_node))
|
||||
a_close = np.zeros((num_node, num_node))
|
||||
a_further = np.zeros((num_node, num_node))
|
||||
for i in range(num_node):
|
||||
for j in range(num_node):
|
||||
if hop_dis[j, i] == hop:
|
||||
j_hop_dis = min([hop_dis[j, _center] for _center in center])
|
||||
i_hop_dis = min([hop_dis[i, _center] for _center in center])
|
||||
if j_hop_dis == i_hop_dis:
|
||||
a_root[j, i] = normalize_adjacency[j, i]
|
||||
elif j_hop_dis > i_hop_dis:
|
||||
a_close[j, i] = normalize_adjacency[j, i]
|
||||
else:
|
||||
a_further[j, i] = normalize_adjacency[j, i]
|
||||
if hop == 0:
|
||||
A.append(a_root)
|
||||
else:
|
||||
A.append(a_root + a_close)
|
||||
A.append(a_further)
|
||||
A = np.stack(A)
|
||||
self.A = A
|
||||
return A
|
||||
elif self.strategy == 'gait_temporal':
|
||||
A = []
|
||||
positive_node, negative_node = self.get_gait_temporal_partitioning(semantic_level)
|
||||
for hop in valid_hop:
|
||||
a_root = np.zeros((num_node, num_node))
|
||||
a_positive = np.zeros((num_node, num_node))
|
||||
a_negative = np.zeros((num_node, num_node))
|
||||
for i in range(num_node):
|
||||
for j in range(num_node):
|
||||
if hop_dis[j, i] == hop:
|
||||
if i == j:
|
||||
a_root[j, i] = normalize_adjacency[j, i]
|
||||
elif j in positive_node:
|
||||
a_positive[j, i] = normalize_adjacency[j, i]
|
||||
else:
|
||||
a_negative[j, i] = normalize_adjacency[j, i]
|
||||
|
||||
if hop == 0:
|
||||
A.append(a_root)
|
||||
else:
|
||||
A.append(a_negative)
|
||||
A.append(a_positive)
|
||||
A = np.stack(A)
|
||||
return A
|
||||
else:
|
||||
raise ValueError("Do Not Exist This Strategy")
|
||||
|
||||
|
||||
def get_hop_distance(num_node, edge, max_hop=1):
|
||||
# Calculate the shortest path between nodes
|
||||
# i.e. The minimum number of steps needed to walk from one node to another
|
||||
A = np.zeros((num_node, num_node)) # Ajacent Matrix
|
||||
for i, j in edge:
|
||||
A[j, i] = 1
|
||||
A[i, j] = 1
|
||||
|
||||
# compute hop steps
|
||||
hop_dis = np.zeros((num_node, num_node)) + np.inf
|
||||
transfer_mat = [np.linalg.matrix_power(A, d) for d in range(max_hop + 1)]
|
||||
arrive_mat = (np.stack(transfer_mat) > 0)
|
||||
for d in range(max_hop, -1, -1):
|
||||
hop_dis[arrive_mat[d]] = d
|
||||
return hop_dis
|
||||
|
||||
|
||||
def normalize_digraph(A):
|
||||
Dl = np.sum(A, 0)
|
||||
num_node = A.shape[0]
|
||||
Dn = np.zeros((num_node, num_node))
|
||||
for i in range(num_node):
|
||||
if Dl[i] > 0:
|
||||
Dn[i, i] = Dl[i]**(-1)
|
||||
AD = np.dot(A, Dn)
|
||||
return AD
|
||||
|
||||
|
||||
def normalize_undigraph(A):
|
||||
Dl = np.sum(A, 0)
|
||||
num_node = A.shape[0]
|
||||
Dn = np.zeros((num_node, num_node))
|
||||
for i in range(num_node):
|
||||
if Dl[i] > 0:
|
||||
Dn[i, i] = Dl[i]**(-0.5)
|
||||
DAD = np.dot(np.dot(Dn, A), Dn)
|
||||
return DAD
|
||||
@@ -253,3 +253,443 @@ def RmBN2dAffine(model):
|
||||
if isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.requires_grad = False
|
||||
m.bias.requires_grad = False
|
||||
|
||||
|
||||
'''
|
||||
Modifed from https://github.com/BNU-IVC/FastPoseGait/blob/main/fastposegait/modeling/components/units
|
||||
'''
|
||||
|
||||
class Graph():
|
||||
"""
|
||||
# Thanks to YAN Sijie for the released code on Github (https://github.com/yysijie/st-gcn)
|
||||
"""
|
||||
def __init__(self, joint_format='coco', max_hop=2, dilation=1):
|
||||
self.joint_format = joint_format
|
||||
self.max_hop = max_hop
|
||||
self.dilation = dilation
|
||||
|
||||
# get edges
|
||||
self.num_node, self.edge, self.connect_joint, self.parts = self._get_edge()
|
||||
|
||||
# get adjacency matrix
|
||||
self.A = self._get_adjacency()
|
||||
|
||||
def __str__(self):
|
||||
return self.A
|
||||
|
||||
def _get_edge(self):
|
||||
if self.joint_format == 'coco':
|
||||
# keypoints = {
|
||||
# 0: "nose",
|
||||
# 1: "left_eye",
|
||||
# 2: "right_eye",
|
||||
# 3: "left_ear",
|
||||
# 4: "right_ear",
|
||||
# 5: "left_shoulder",
|
||||
# 6: "right_shoulder",
|
||||
# 7: "left_elbow",
|
||||
# 8: "right_elbow",
|
||||
# 9: "left_wrist",
|
||||
# 10: "right_wrist",
|
||||
# 11: "left_hip",
|
||||
# 12: "right_hip",
|
||||
# 13: "left_knee",
|
||||
# 14: "right_knee",
|
||||
# 15: "left_ankle",
|
||||
# 16: "right_ankle"
|
||||
# }
|
||||
num_node = 17
|
||||
self_link = [(i, i) for i in range(num_node)]
|
||||
neighbor_link = [(0, 1), (0, 2), (1, 3), (2, 4), (3, 5), (4, 6), (5, 6),
|
||||
(5, 7), (7, 9), (6, 8), (8, 10), (5, 11), (6, 12), (11, 12),
|
||||
(11, 13), (13, 15), (12, 14), (14, 16)]
|
||||
self.edge = self_link + neighbor_link
|
||||
self.center = 0
|
||||
self.flip_idx = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
|
||||
connect_joint = np.array([5,0,0,1,2,0,0,5,6,7,8,5,6,11,12,13,14])
|
||||
parts = [
|
||||
np.array([5, 7, 9]), # left_arm
|
||||
np.array([6, 8, 10]), # right_arm
|
||||
np.array([11, 13, 15]), # left_leg
|
||||
np.array([12, 14, 16]), # right_leg
|
||||
np.array([0, 1, 2, 3, 4]), # head
|
||||
]
|
||||
|
||||
elif self.joint_format == 'coco-no-head':
|
||||
num_node = 12
|
||||
self_link = [(i, i) for i in range(num_node)]
|
||||
neighbor_link = [(0, 1),
|
||||
(0, 2), (2, 4), (1, 3), (3, 5), (0, 6), (1, 7), (6, 7),
|
||||
(6, 8), (8, 10), (7, 9), (9, 11)]
|
||||
self.edge = self_link + neighbor_link
|
||||
self.center = 0
|
||||
connect_joint = np.array([3,1,0,2,4,0,6,8,10,7,9,11])
|
||||
parts =[
|
||||
np.array([0, 2, 4]), # left_arm
|
||||
np.array([1, 3, 5]), # right_arm
|
||||
np.array([6, 8, 10]), # left_leg
|
||||
np.array([7, 9, 11]) # right_leg
|
||||
]
|
||||
|
||||
elif self.joint_format =='alphapose' or self.joint_format =='openpose':
|
||||
num_node = 18
|
||||
self_link = [(i, i) for i in range(num_node)]
|
||||
neighbor_link = [(0, 1), (0, 14), (0, 15), (14, 16), (15, 17),
|
||||
(1, 2), (2, 3), (3, 4), (1, 5), (5, 6), (6, 7),
|
||||
(1, 8), (8, 9), (9, 10), (1, 11), (11, 12), (12, 13)]
|
||||
self.edge = self_link + neighbor_link
|
||||
self.center = 1
|
||||
self.flip_idx = [0, 1, 5, 6, 7, 2, 3, 4, 11, 12, 13, 8, 9, 10, 15, 14, 17, 16]
|
||||
connect_joint = np.array([1,1,1,2,3,1,5,6,2,8,9,5,11,12,0,0,14,15])
|
||||
parts = [
|
||||
np.array([5, 6, 7]), # left_arm
|
||||
np.array([2, 3, 4]), # right_arm
|
||||
np.array([11, 12, 13]), # left_leg
|
||||
np.array([8, 9, 10]), # right_leg
|
||||
np.array([0, 1, 14, 15, 16, 17]), # head
|
||||
]
|
||||
|
||||
else:
|
||||
num_node, neighbor_link, connect_joint, parts = 0, [], [], []
|
||||
logging.info('')
|
||||
logging.error('Error: Do NOT exist this dataset: {}!'.format(self.dataset))
|
||||
raise ValueError()
|
||||
self_link = [(i, i) for i in range(num_node)]
|
||||
edge = self_link + neighbor_link
|
||||
return num_node, edge, connect_joint, parts
|
||||
|
||||
def _get_hop_distance(self):
|
||||
A = np.zeros((self.num_node, self.num_node))
|
||||
for i, j in self.edge:
|
||||
A[j, i] = 1
|
||||
A[i, j] = 1
|
||||
hop_dis = np.zeros((self.num_node, self.num_node)) + np.inf
|
||||
transfer_mat = [np.linalg.matrix_power(A, d) for d in range(self.max_hop + 1)]
|
||||
arrive_mat = (np.stack(transfer_mat) > 0)
|
||||
for d in range(self.max_hop, -1, -1):
|
||||
hop_dis[arrive_mat[d]] = d
|
||||
return hop_dis
|
||||
|
||||
def _get_adjacency(self):
|
||||
hop_dis = self._get_hop_distance()
|
||||
valid_hop = range(0, self.max_hop + 1, self.dilation)
|
||||
adjacency = np.zeros((self.num_node, self.num_node))
|
||||
for hop in valid_hop:
|
||||
adjacency[hop_dis == hop] = 1
|
||||
normalize_adjacency = self._normalize_digraph(adjacency)
|
||||
A = np.zeros((len(valid_hop), self.num_node, self.num_node))
|
||||
for i, hop in enumerate(valid_hop):
|
||||
A[i][hop_dis == hop] = normalize_adjacency[hop_dis == hop]
|
||||
return A
|
||||
|
||||
def _normalize_digraph(self, A):
|
||||
Dl = np.sum(A, 0)
|
||||
num_node = A.shape[0]
|
||||
Dn = np.zeros((num_node, num_node))
|
||||
for i in range(num_node):
|
||||
if Dl[i] > 0:
|
||||
Dn[i, i] = Dl[i]**(-1)
|
||||
AD = np.dot(A, Dn)
|
||||
return AD
|
||||
|
||||
|
||||
class TemporalBasicBlock(nn.Module):
|
||||
"""
|
||||
TemporalConv_Res_Block
|
||||
Arxiv: https://arxiv.org/abs/2010.09978
|
||||
Github: https://github.com/Thomas-yx/ResGCNv1
|
||||
"""
|
||||
def __init__(self, channels, temporal_window_size, stride=1, residual=False,reduction=0,get_res=False,tcn_stride=False):
|
||||
super(TemporalBasicBlock, self).__init__()
|
||||
|
||||
padding = ((temporal_window_size - 1) // 2, 0)
|
||||
|
||||
if not residual:
|
||||
self.residual = lambda x: 0
|
||||
elif stride == 1:
|
||||
self.residual = lambda x: x
|
||||
else:
|
||||
self.residual = nn.Sequential(
|
||||
nn.Conv2d(channels, channels, 1, (stride,1)),
|
||||
nn.BatchNorm2d(channels),
|
||||
)
|
||||
|
||||
self.conv = nn.Conv2d(channels, channels, (temporal_window_size,1), (stride,1), padding)
|
||||
self.bn = nn.BatchNorm2d(channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x, res_module):
|
||||
|
||||
res_block = self.residual(x)
|
||||
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x + res_block + res_module)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class TemporalBottleneckBlock(nn.Module):
|
||||
"""
|
||||
TemporalConv_Res_Bottleneck
|
||||
Arxiv: https://arxiv.org/abs/2010.09978
|
||||
Github: https://github.com/Thomas-yx/ResGCNv1
|
||||
"""
|
||||
def __init__(self, channels, temporal_window_size, stride=1, residual=False, reduction=4,get_res=False, tcn_stride=False):
|
||||
super(TemporalBottleneckBlock, self).__init__()
|
||||
tcn_stride =False
|
||||
padding = ((temporal_window_size - 1) // 2, 0)
|
||||
inter_channels = channels // reduction
|
||||
if get_res:
|
||||
if tcn_stride:
|
||||
stride =2
|
||||
self.residual = nn.Sequential(
|
||||
nn.Conv2d(channels, channels, 1, (2,1)),
|
||||
nn.BatchNorm2d(channels),
|
||||
)
|
||||
tcn_stride= True
|
||||
else:
|
||||
if not residual:
|
||||
self.residual = lambda x: 0
|
||||
elif stride == 1:
|
||||
self.residual = lambda x: x
|
||||
else:
|
||||
self.residual = nn.Sequential(
|
||||
nn.Conv2d(channels, channels, 1, (2,1)),
|
||||
nn.BatchNorm2d(channels),
|
||||
)
|
||||
tcn_stride= True
|
||||
|
||||
self.conv_down = nn.Conv2d(channels, inter_channels, 1)
|
||||
self.bn_down = nn.BatchNorm2d(inter_channels)
|
||||
if tcn_stride:
|
||||
stride=2
|
||||
self.conv = nn.Conv2d(inter_channels, inter_channels, (temporal_window_size,1), (stride,1), padding)
|
||||
self.bn = nn.BatchNorm2d(inter_channels)
|
||||
self.conv_up = nn.Conv2d(inter_channels, channels, 1)
|
||||
self.bn_up = nn.BatchNorm2d(channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x, res_module):
|
||||
|
||||
res_block = self.residual(x)
|
||||
|
||||
x = self.conv_down(x)
|
||||
x = self.bn_down(x)
|
||||
x = self.relu(x)
|
||||
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
|
||||
x = self.conv_up(x)
|
||||
x = self.bn_up(x)
|
||||
x = self.relu(x + res_block + res_module)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class SpatialGraphConv(nn.Module):
|
||||
"""
|
||||
SpatialGraphConv_Basic_Block
|
||||
Arxiv: https://arxiv.org/abs/1801.07455
|
||||
Github: https://github.com/yysijie/st-gcn
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, max_graph_distance):
|
||||
super(SpatialGraphConv, self).__init__()
|
||||
|
||||
# spatial class number (distance = 0 for class 0, distance = 1 for class 1, ...)
|
||||
self.s_kernel_size = max_graph_distance + 1
|
||||
|
||||
# weights of different spatial classes
|
||||
self.gcn = nn.Conv2d(in_channels, out_channels*self.s_kernel_size, 1)
|
||||
|
||||
def forward(self, x, A):
|
||||
|
||||
# numbers in same class have same weight
|
||||
x = self.gcn(x)
|
||||
|
||||
# divide nodes into different classes
|
||||
n, kc, t, v = x.size()
|
||||
x = x.view(n, self.s_kernel_size, kc//self.s_kernel_size, t, v).contiguous()
|
||||
|
||||
# spatial graph convolution
|
||||
x = torch.einsum('nkctv,kvw->nctw', (x, A[:self.s_kernel_size])).contiguous()
|
||||
|
||||
return x
|
||||
|
||||
class SpatialBasicBlock(nn.Module):
|
||||
"""
|
||||
SpatialGraphConv_Res_Block
|
||||
Arxiv: https://arxiv.org/abs/2010.09978
|
||||
Github: https://github.com/Thomas-yx/ResGCNv1
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, max_graph_distance, residual=False,reduction=0):
|
||||
super(SpatialBasicBlock, self).__init__()
|
||||
|
||||
if not residual:
|
||||
self.residual = lambda x: 0
|
||||
elif in_channels == out_channels:
|
||||
self.residual = lambda x: x
|
||||
else:
|
||||
self.residual = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, 1),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
)
|
||||
|
||||
self.conv = SpatialGraphConv(in_channels, out_channels, max_graph_distance)
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x, A):
|
||||
|
||||
res_block = self.residual(x)
|
||||
|
||||
x = self.conv(x, A)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x + res_block)
|
||||
|
||||
return x
|
||||
|
||||
class SpatialBottleneckBlock(nn.Module):
|
||||
"""
|
||||
SpatialGraphConv_Res_Bottleneck
|
||||
Arxiv: https://arxiv.org/abs/2010.09978
|
||||
Github: https://github.com/Thomas-yx/ResGCNv1
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, max_graph_distance, residual=False, reduction=4):
|
||||
super(SpatialBottleneckBlock, self).__init__()
|
||||
|
||||
inter_channels = out_channels // reduction
|
||||
|
||||
if not residual:
|
||||
self.residual = lambda x: 0
|
||||
elif in_channels == out_channels:
|
||||
self.residual = lambda x: x
|
||||
else:
|
||||
self.residual = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, 1),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
)
|
||||
|
||||
self.conv_down = nn.Conv2d(in_channels, inter_channels, 1)
|
||||
self.bn_down = nn.BatchNorm2d(inter_channels)
|
||||
self.conv = SpatialGraphConv(inter_channels, inter_channels, max_graph_distance)
|
||||
self.bn = nn.BatchNorm2d(inter_channels)
|
||||
self.conv_up = nn.Conv2d(inter_channels, out_channels, 1)
|
||||
self.bn_up = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x, A):
|
||||
|
||||
res_block = self.residual(x)
|
||||
|
||||
x = self.conv_down(x)
|
||||
x = self.bn_down(x)
|
||||
x = self.relu(x)
|
||||
|
||||
x = self.conv(x, A)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
|
||||
x = self.conv_up(x)
|
||||
x = self.bn_up(x)
|
||||
x = self.relu(x + res_block)
|
||||
|
||||
return x
|
||||
|
||||
class SpatialAttention(nn.Module):
|
||||
"""
|
||||
This class implements Spatial Transformer.
|
||||
Function adapted from: https://github.com/leaderj1001/Attention-Augmented-Conv2d
|
||||
"""
|
||||
def __init__(self, in_channels, out_channel, A, num_point, dk_factor=0.25, kernel_size=1, Nh=8, num=4, stride=1):
|
||||
super(SpatialAttention, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dk = int(dk_factor * out_channel)
|
||||
self.dv = int(out_channel)
|
||||
self.num = num
|
||||
self.Nh = Nh
|
||||
self.num_point=num_point
|
||||
self.A = A[0] + A[1] + A[2]
|
||||
self.stride = stride
|
||||
self.padding = (self.kernel_size - 1) // 2
|
||||
|
||||
assert self.Nh != 0, "integer division or modulo by zero, Nh >= 1"
|
||||
assert self.dk % self.Nh == 0, "dk should be divided by Nh. (example: out_channels: 20, dk: 40, Nh: 4)"
|
||||
assert self.dv % self.Nh == 0, "dv should be divided by Nh. (example: out_channels: 20, dv: 4, Nh: 4)"
|
||||
assert stride in [1, 2], str(stride) + " Up to 2 strides are allowed."
|
||||
|
||||
self.qkv_conv = nn.Conv2d(self.in_channels, 2 * self.dk + self.dv, kernel_size=self.kernel_size,
|
||||
stride=stride,
|
||||
padding=self.padding)
|
||||
|
||||
self.attn_out = nn.Conv2d(self.dv, self.dv, kernel_size=1, stride=1)
|
||||
|
||||
def forward(self, x):
|
||||
# Input x
|
||||
# (batch_size, channels, 1, joints)
|
||||
B, _, T, V = x.size()
|
||||
|
||||
# flat_q, flat_k, flat_v
|
||||
# (batch_size, Nh, dvh or dkh, joints)
|
||||
# dvh = dv / Nh, dkh = dk / Nh
|
||||
# q, k, v obtained by doing 2D convolution on the input (q=XWq, k=XWk, v=XWv)
|
||||
flat_q, flat_k, flat_v, q, k, v = self.compute_flat_qkv(x, self.dk, self.dv, self.Nh)
|
||||
|
||||
# Calculate the scores, obtained by doing q*k
|
||||
# (batch_size, Nh, joints, dkh)*(batch_size, Nh, dkh, joints) = (batch_size, Nh, joints,joints)
|
||||
# The multiplication can also be divided (multi_matmul) in case of space problems
|
||||
|
||||
logits = torch.matmul(flat_q.transpose(2, 3), flat_k)
|
||||
|
||||
weights = F.softmax(logits, dim=-1)
|
||||
|
||||
# attn_out
|
||||
# (batch, Nh, joints, dvh)
|
||||
# weights*V
|
||||
# (batch, Nh, joints, joints)*(batch, Nh, joints, dvh)=(batch, Nh, joints, dvh)
|
||||
attn_out = torch.matmul(weights, flat_v.transpose(2, 3))
|
||||
|
||||
attn_out = torch.reshape(attn_out, (B, self.Nh, T, V, self.dv // self.Nh))
|
||||
|
||||
attn_out = attn_out.permute(0, 1, 4, 2, 3)
|
||||
|
||||
# combine_heads_2d, combine heads only after having calculated each Z separately
|
||||
# (batch, Nh*dv, 1, joints)
|
||||
attn_out = self.combine_heads_2d(attn_out)
|
||||
|
||||
# Multiply for W0 (batch, out_channels, 1, joints) with out_channels=dv
|
||||
attn_out = self.attn_out(attn_out)
|
||||
return attn_out
|
||||
|
||||
def compute_flat_qkv(self, x, dk, dv, Nh):
|
||||
qkv = self.qkv_conv(x)
|
||||
# T=1 in this case, because we are considering each frame separately
|
||||
N, _, T, V = qkv.size()
|
||||
|
||||
q, k, v = torch.split(qkv, [dk, dk, dv], dim=1)
|
||||
q = self.split_heads_2d(q, Nh)
|
||||
k = self.split_heads_2d(k, Nh)
|
||||
v = self.split_heads_2d(v, Nh)
|
||||
|
||||
dkh = dk // Nh
|
||||
q = q*(dkh ** -0.5)
|
||||
flat_q = torch.reshape(q, (N, Nh, dkh, T * V))
|
||||
flat_k = torch.reshape(k, (N, Nh, dkh, T * V))
|
||||
flat_v = torch.reshape(v, (N, Nh, dv // self.Nh, T * V))
|
||||
return flat_q, flat_k, flat_v, q, k, v
|
||||
|
||||
def split_heads_2d(self, x, Nh):
|
||||
B, channels, T, V = x.size()
|
||||
ret_shape = (B, Nh, channels // Nh, T, V)
|
||||
split = torch.reshape(x, ret_shape)
|
||||
return split
|
||||
|
||||
def combine_heads_2d(self, x):
|
||||
batch, Nh, dv, T, V = x.size()
|
||||
ret_shape = (batch, Nh * dv, T, V)
|
||||
return torch.reshape(x, ret_shape)
|
||||
|
||||
Reference in New Issue
Block a user