import torch import torch.nn as nn import torch.utils.checkpoint from einops import rearrange from ...modules import SetBlockWrapper, SeparateFCs, SeparateBNNecks, PackSequenceWrapper, HorizontalPoolingPyramid from torch.nn import functional as F # ######################################## GaitBase ########################################### def conv1x1(in_planes, out_planes, stride=1): """1x1 convolution""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation) class AttentionFusion(nn.Module): def __init__(self, in_channels, squeeze_ratio, feat_len): super(AttentionFusion, self).__init__() hidden_dim = int(in_channels / squeeze_ratio) self.feat_len = feat_len self.conv = SetBlockWrapper( nn.Sequential( conv1x1(in_channels * feat_len, hidden_dim), nn.BatchNorm2d(hidden_dim), nn.ReLU(inplace=True), conv3x3(hidden_dim, hidden_dim), nn.BatchNorm2d(hidden_dim), nn.ReLU(inplace=True), conv1x1(hidden_dim, in_channels * feat_len), ) ) def forward(self, feat_list): ''' sil_feat: [n, c, s, h, w] map_feat: [n, c, s, h, w] ... ''' feats = torch.cat(feat_list, dim=1) score = self.conv(feats) # [n, 2 * c, s, h, w] score = rearrange(score, 'n (c d) s h w -> n c d s h w', d=self.feat_len) score = F.softmax(score, dim=2) retun = feat_list[0]*score[:,:,0] for i in range(1, self.feat_len): retun += feat_list[i]*score[:,:,i] return retun from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet from ...modules import BasicConv2d block_map = {'BasicBlock': BasicBlock, 'Bottleneck': Bottleneck} class Pre_ResNet9(ResNet): def __init__(self, type, block, channels=[32, 64, 128, 256], in_channel=1, layers=[1, 2, 2, 1], strides=[1, 2, 2, 1], maxpool=True): if block in block_map.keys(): block = block_map[block] else: raise ValueError( "Error type for -block-Cfg-, supported: 'BasicBlock' or 'Bottleneck'.") self.maxpool_flag = maxpool super(Pre_ResNet9, self).__init__(block, layers) # Not used # self.fc = None self.layer2 = None self.layer3 = None self.layer4 = None ############ self.inplanes = channels[0] self.bn1 = nn.BatchNorm2d(self.inplanes) self.conv1 = BasicConv2d(in_channel, self.inplanes, 3, 1, 1) self.layer1 = self._make_layer( block, channels[0], layers[0], stride=strides[0], dilate=False) def _make_layer(self, block, planes, blocks, stride=1, dilate=False): if blocks >= 1: layer = super()._make_layer(block, planes, blocks, stride=stride, dilate=dilate) else: def layer(x): return x return layer def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) if self.maxpool_flag: x = self.maxpool(x) x = self.layer1(x) return x class Post_ResNet9(ResNet): def __init__(self, type, block, channels=[32, 64, 128, 256], in_channel=1, layers=[1, 2, 2, 1], strides=[1, 2, 2, 1], maxpool=True): if block in block_map.keys(): block = block_map[block] else: raise ValueError( "Error type for -block-Cfg-, supported: 'BasicBlock' or 'Bottleneck'.") super(Post_ResNet9, self).__init__(block, layers) # Not used # self.fc = None self.conv1 = None self.bn1 = None self.relu = None self.layer1 = None ############ self.inplanes = channels[0] self.layer2 = self._make_layer( block, channels[1], layers[1], stride=strides[1], dilate=False) self.layer3 = self._make_layer( block, channels[2], layers[2], stride=strides[2], dilate=False) self.layer4 = self._make_layer( block, channels[3], layers[3], stride=strides[3], dilate=False) def _make_layer(self, block, planes, blocks, stride=1, dilate=False): if blocks >= 1: layer = super()._make_layer(block, planes, blocks, stride=stride, dilate=dilate) else: def layer(x): return x return layer def forward(self, x): x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) return x from opengait.utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from from ... import backbones class Baseline(nn.Module): def __init__(self, model_cfg): super(Baseline, self).__init__() model_cfg['backbone_cfg']['in_channel'] = model_cfg['Denoising_Branch']['target_dim'] self.pre_part = SetBlockWrapper(Pre_ResNet9(**model_cfg['backbone_cfg'])) model_cfg['backbone_cfg']['in_channel'] = model_cfg['Appearance_Branch']['target_dim'] self.pre_rgb = SetBlockWrapper(Pre_ResNet9(**model_cfg['backbone_cfg'])) self.post_backbone = SetBlockWrapper(Post_ResNet9(**model_cfg['backbone_cfg'])) self.FCs = SeparateFCs(**model_cfg['SeparateFCs']) self.BNNecks = SeparateBNNecks(**model_cfg['SeparateBNNecks']) self.TP = PackSequenceWrapper(torch.max) self.HPP = HorizontalPoolingPyramid(bin_num=model_cfg['bin_num']) self.fusion = AttentionFusion(**model_cfg['AttentionFusion']) def get_backbone(self, backbone_cfg): """Get the backbone of the model.""" if is_dict(backbone_cfg): Backbone = get_attr_from([backbones], backbone_cfg['type']) valid_args = get_valid_args(Backbone, backbone_cfg, ['type']) return Backbone(**valid_args) if is_list(backbone_cfg): Backbone = nn.ModuleList([self.get_backbone(cfg) for cfg in backbone_cfg]) return Backbone raise ValueError( "Error type for -Backbone-Cfg-, supported: (A list of) dict.") def vis_forward(self, denosing, appearance, seqL): denosing = self.pre_part(denosing) # [n, c, s, h, w] appearance = self.pre_rgb(appearance) # [n, c, s, h, w] outs = self.fusion([denosing, appearance]) return denosing, appearance, outs def forward(self, denosing, appearance, seqL): denosing = self.pre_part(denosing) # [n, c, s, h, w] appearance = self.pre_rgb(appearance) # [n, c, s, h, w] outs = self.fusion([denosing, appearance]) # heat_mapt = rearrange(outs, 'n c s h w -> n s h w c') del denosing, appearance outs = self.post_backbone(outs) # Temporal Pooling, TP outs = self.TP(outs, seqL, options={"dim": 2})[0] # [n, c, h, w] # Horizontal Pooling Matching, HPM outs = self.HPP(outs) # [n, c, p] embed_1 = self.FCs(outs) # [n, c, p] _, logits = self.BNNecks(embed_1) # [n, c, p] # return embed_1, logits, heat_mapt return embed_1, logits class Baseline_Single(nn.Module): def __init__(self, model_cfg): super(Baseline_Single, self).__init__() self.pre_rgb = SetBlockWrapper(Pre_ResNet9(**model_cfg['backbone_cfg'])) self.post_backbone = SetBlockWrapper(Post_ResNet9(**model_cfg['backbone_cfg'])) self.FCs = SeparateFCs(**model_cfg['SeparateFCs']) self.BNNecks = SeparateBNNecks(**model_cfg['SeparateBNNecks']) self.TP = PackSequenceWrapper(torch.max) self.HPP = HorizontalPoolingPyramid(bin_num=model_cfg['bin_num']) def get_backbone(self, backbone_cfg): """Get the backbone of the model.""" if is_dict(backbone_cfg): Backbone = get_attr_from([backbones], backbone_cfg['type']) valid_args = get_valid_args(Backbone, backbone_cfg, ['type']) return Backbone(**valid_args) if is_list(backbone_cfg): Backbone = nn.ModuleList([self.get_backbone(cfg) for cfg in backbone_cfg]) return Backbone raise ValueError( "Error type for -Backbone-Cfg-, supported: (A list of) dict.") def pre_forward(self, appearance, *args, **kwargs): outs = self.pre_rgb(appearance, *args, **kwargs) # [n, c, s, h, w] outs = self.post_backbone(outs, *args, **kwargs) return outs def forward(self, appearance, seqL, *args, **kwargs): outs = self.pre_rgb(appearance, *args, **kwargs) # [n, c, s, h, w] outs = self.post_backbone(outs, *args, **kwargs) # Temporal Pooling, TP outs = self.TP(outs, seqL, options={"dim": 2})[0] # [n, c, h, w] # Horizontal Pooling Matching, HPM outs = self.HPP(outs) # [n, c, p] embed_1 = self.FCs(outs) # [n, c, p] _, logits = self.BNNecks(embed_1) # [n, c, p] return embed_1, logits def test_1(self, appearance, *args, **kwargs): outs = self.pre_rgb(appearance, *args, **kwargs) # [n, c, s, h, w] outs = self.post_backbone(outs, *args, **kwargs) return outs def test_2(self, outs, seqL): outs = self.TP(outs, seqL, options={"dim": 2})[0] # [n, c, h, w] outs = self.HPP(outs) # [n, c, p] embed_1 = self.FCs(outs) # [n, c, p] _, logits = self.BNNecks(embed_1) # [n, c, p] return embed_1, logits class Baseline_Share(nn.Module): def __init__(self, model_cfg): super(Baseline_Share, self).__init__() self.head_num = model_cfg['head_num'] self.num_FPN = model_cfg['total_layer_num'] // model_cfg['group_layer_num'] self.real_gait = nn.ModuleList([ Baseline_Single(model_cfg) for _ in range(self.head_num) ]) self.Gait_List = nn.ModuleList([ self.real_gait[_ // (self.num_FPN // self.head_num)] for _ in range(self.num_FPN) ]) def forward(self, x, seqL): x = self.test_1(x) embed_list, log_list = self.test_2(x, seqL) return embed_list, log_list def test_1(self, x, *args, **kwargs): # x: [n, c, s, h, w] n,c,s,h,w = x.shape x_list = list(torch.chunk(x, self.num_FPN, dim=1)) for i in range(self.num_FPN): x_list[i] = self.Gait_List[i].test_1(x_list[i], *args, **kwargs) x = torch.concat(x_list, dim=1) return x def test_2(self, x, seqL): # x: [n, c, s, h, w] # embed_1: [n, c, p] x_list = torch.chunk(x, self.num_FPN, dim=1) embed_list = [] log_list = [] for i in range(self.num_FPN): embed_1, logits = self.Gait_List[i].test_2(x_list[i], seqL) embed_list.append(embed_1) log_list.append(logits) return embed_list, log_list