import torch import torch.nn as nn from ..modules import TemporalBasicBlock, TemporalBottleneckBlock, SpatialBasicBlock, SpatialBottleneckBlock class ResGCNModule(nn.Module): """ ResGCNModule Arxiv: https://arxiv.org/abs/2010.09978 Github: https://github.com/Thomas-yx/ResGCNv1 https://github.com/BNU-IVC/FastPoseGait """ def __init__(self, in_channels, out_channels, block, A, stride=1, kernel_size=[9,2],reduction=4, get_res=False,is_main=False): super(ResGCNModule, self).__init__() if not len(kernel_size) == 2: logging.info('') logging.error('Error: Please check whether len(kernel_size) == 2') raise ValueError() if not kernel_size[0] % 2 == 1: logging.info('') logging.error('Error: Please check whether kernel_size[0] % 2 == 1') raise ValueError() temporal_window_size, max_graph_distance = kernel_size if block == 'initial': module_res, block_res = False, False elif block == 'Basic': module_res, block_res = True, False else: module_res, block_res = False, True if not module_res: self.residual = lambda x: 0 elif stride == 1 and in_channels == out_channels: self.residual = lambda x: x else: # stride =2 self.residual = nn.Sequential( nn.Conv2d(in_channels, out_channels, 1, (stride,1)), nn.BatchNorm2d(out_channels), ) if block in ['Basic','initial']: spatial_block = SpatialBasicBlock temporal_block = TemporalBasicBlock if block == 'Bottleneck': spatial_block = SpatialBottleneckBlock temporal_block = TemporalBottleneckBlock self.scn = spatial_block(in_channels, out_channels, max_graph_distance, block_res,reduction) if in_channels == out_channels and is_main: tcn_stride =True else: tcn_stride = False self.tcn = temporal_block(out_channels, temporal_window_size, stride, block_res,reduction,get_res=get_res,tcn_stride=tcn_stride) self.edge = nn.Parameter(torch.ones_like(A)) def forward(self, x, A): A = A.cuda(x.get_device()) return self.tcn(self.scn(x, A*self.edge), self.residual(x)) class ResGCNInputBranch(nn.Module): """ ResGCNInputBranch_Module Arxiv: https://arxiv.org/abs/2010.09978 Github: https://github.com/Thomas-yx/ResGCNv1 """ def __init__(self, input_branch, block, A, input_num , reduction = 4): super(ResGCNInputBranch, self).__init__() self.register_buffer('A', A) module_list = [] for i in range(len(input_branch)-1): if i==0: module_list.append(ResGCNModule(input_branch[i],input_branch[i+1],'initial',A, reduction=reduction)) else: module_list.append(ResGCNModule(input_branch[i],input_branch[i+1],block,A,reduction=reduction)) self.bn = nn.BatchNorm2d(input_branch[0]) self.layers = nn.ModuleList(module_list) def forward(self, x): x = self.bn(x) for layer in self.layers: x = layer(x, self.A) return x class ResGCN(nn.Module): """ ResGCN Arxiv: https://arxiv.org/abs/2010.09978 """ def __init__(self, input_num, input_branch, main_stream,num_class, reduction, block, graph): super(ResGCN, self).__init__() self.graph = graph self.head= nn.ModuleList( ResGCNInputBranch(input_branch, block, graph, input_num ,reduction) for _ in range(input_num) ) main_stream_list = [] for i in range(len(main_stream)-1): if main_stream[i]==main_stream[i+1]: stride = 1 else: stride = 2 if i ==0: main_stream_list.append(ResGCNModule(main_stream[i]*input_num,main_stream[i+1],block,graph,stride=1,reduction = reduction,get_res=True,is_main=True)) else: main_stream_list.append(ResGCNModule(main_stream[i],main_stream[i+1],block,graph,stride = stride, reduction = reduction,is_main=True)) self.backbone = nn.ModuleList(main_stream_list) self.global_pooling = nn.AdaptiveAvgPool2d(1) self.fcn = nn.Linear(256, num_class) def forward(self, x): # input branch x_cat = [] for i, branch in enumerate(self.head): x_cat.append(branch(x[:, i])) x = torch.cat(x_cat, dim=1) # main stream for layer in self.backbone: x = layer(x, self.graph) # output x = self.global_pooling(x) x = x.squeeze(-1) x = self.fcn(x.squeeze((-1))) return x