Files
OpenGait/opengait/modeling/modules.py
T
2022-04-12 11:28:09 +08:00

194 lines
6.5 KiB
Python

import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from utils import clones, is_list_or_tuple
class HorizontalPoolingPyramid():
"""
Horizontal Pyramid Matching for Person Re-identification
Arxiv: https://arxiv.org/abs/1804.05275
Github: https://github.com/SHI-Labs/Horizontal-Pyramid-Matching
"""
def __init__(self, bin_num=None):
if bin_num is None:
bin_num = [16, 8, 4, 2, 1]
self.bin_num = bin_num
def __call__(self, x):
"""
x : [n, c, h, w]
ret: [n, c, p]
"""
n, c = x.size()[:2]
features = []
for b in self.bin_num:
z = x.view(n, c, b, -1)
z = z.mean(-1) + z.max(-1)[0]
features.append(z)
return torch.cat(features, -1)
class SetBlockWrapper(nn.Module):
def __init__(self, forward_block):
super(SetBlockWrapper, self).__init__()
self.forward_block = forward_block
def forward(self, x, *args, **kwargs):
"""
In x: [n, s, c, h, w]
Out x: [n, s, ...]
"""
n, s, c, h, w = x.size()
x = self.forward_block(x.view(-1, c, h, w), *args, **kwargs)
input_size = x.size()
output_size = [n, s] + [*input_size[1:]]
return x.view(*output_size)
class PackSequenceWrapper(nn.Module):
def __init__(self, pooling_func):
super(PackSequenceWrapper, self).__init__()
self.pooling_func = pooling_func
def forward(self, seqs, seqL, seq_dim=1, **kwargs):
"""
In seqs: [n, s, ...]
Out rets: [n, ...]
"""
if seqL is None:
return self.pooling_func(seqs, **kwargs)
seqL = seqL[0].data.cpu().numpy().tolist()
start = [0] + np.cumsum(seqL).tolist()[:-1]
rets = []
for curr_start, curr_seqL in zip(start, seqL):
narrowed_seq = seqs.narrow(seq_dim, curr_start, curr_seqL)
# save the memory
# splited_narrowed_seq = torch.split(narrowed_seq, 256, dim=1)
# ret = []
# for seq_to_pooling in splited_narrowed_seq:
# ret.append(self.pooling_func(seq_to_pooling, keepdim=True, **kwargs)
# [0] if self.is_tuple_result else self.pooling_func(seq_to_pooling, **kwargs))
rets.append(self.pooling_func(narrowed_seq, **kwargs))
if len(rets) > 0 and is_list_or_tuple(rets[0]):
return [torch.cat([ret[j] for ret in rets])
for j in range(len(rets[0]))]
return torch.cat(rets)
class BasicConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, **kwargs):
super(BasicConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size,
stride=stride, padding=padding, bias=False, **kwargs)
def forward(self, x):
x = self.conv(x)
return x
class SeparateFCs(nn.Module):
def __init__(self, parts_num, in_channels, out_channels, norm=False):
super(SeparateFCs, self).__init__()
self.p = parts_num
self.fc_bin = nn.Parameter(
nn.init.xavier_uniform_(
torch.zeros(parts_num, in_channels, out_channels)))
self.norm = norm
def forward(self, x):
"""
x: [p, n, c]
"""
if self.norm:
out = x.matmul(F.normalize(self.fc_bin, dim=1))
else:
out = x.matmul(self.fc_bin)
return out
class SeparateBNNecks(nn.Module):
"""
GaitSet: Bag of Tricks and a Strong Baseline for Deep Person Re-Identification
CVPR Workshop: https://openaccess.thecvf.com/content_CVPRW_2019/papers/TRMTMCT/Luo_Bag_of_Tricks_and_a_Strong_Baseline_for_Deep_Person_CVPRW_2019_paper.pdf
Github: https://github.com/michuanhaohao/reid-strong-baseline
"""
def __init__(self, parts_num, in_channels, class_num, norm=True, parallel_BN1d=True):
super(SeparateBNNecks, self).__init__()
self.p = parts_num
self.class_num = class_num
self.norm = norm
self.fc_bin = nn.Parameter(
nn.init.xavier_uniform_(
torch.zeros(parts_num, in_channels, class_num)))
if parallel_BN1d:
self.bn1d = nn.BatchNorm1d(in_channels * parts_num)
else:
self.bn1d = clones(nn.BatchNorm1d(in_channels), parts_num)
self.parallel_BN1d = parallel_BN1d
def forward(self, x):
"""
x: [p, n, c]
"""
if self.parallel_BN1d:
p, n, c = x.size()
x = x.transpose(0, 1).contiguous().view(n, -1) # [n, p*c]
x = self.bn1d(x)
x = x.view(n, p, c).permute(1, 0, 2).contiguous()
else:
x = torch.cat([bn(_.squeeze(0)).unsqueeze(0)
for _, bn in zip(x.split(1, 0), self.bn1d)], 0) # [p, n, c]
if self.norm:
feature = F.normalize(x, dim=-1) # [p, n, c]
logits = feature.matmul(F.normalize(
self.fc_bin, dim=1)) # [p, n, c]
else:
feature = x
logits = feature.matmul(self.fc_bin)
return feature, logits
class FocalConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, halving, **kwargs):
super(FocalConv2d, self).__init__()
self.halving = halving
self.conv = nn.Conv2d(in_channels, out_channels,
kernel_size, bias=False, **kwargs)
def forward(self, x):
if self.halving == 0:
z = self.conv(x)
else:
h = x.size(2)
split_size = int(h // 2**self.halving)
z = x.split(split_size, 2)
z = torch.cat([self.conv(_) for _ in z], 2)
return z
class BasicConv3d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False, **kwargs):
super(BasicConv3d, self).__init__()
self.conv3d = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size,
stride=stride, padding=padding, bias=bias, **kwargs)
def forward(self, ipts):
'''
ipts: [n, c, s, h, w]
outs: [n, c, s, h, w]
'''
outs = self.conv3d(ipts)
return outs
def RmBN2dAffine(model):
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
m.weight.requires_grad = False
m.bias.requires_grad = False