Files
OpenGait/opengait/modeling/models/parsinggait.py
T
Zzier 609aa0e9aa Update ParsingGait (#160)
* Update ParsingGait

* Clear up the confusion

Clear up the confusion about gait3d and gait3d-parsing.

* Update 0.get_started.md

* Add BaseParsingCuttingTransform

* Update gcn.py

* Create gaitbase_gait3d_parsing_btz32x2_fixed.yaml

* Add gait3d_parsing config file

* Update 1.model_zoo.md

Update Gait3D-Parsing checkpoints

* Update 1.model_zoo.md

add configuration

* Update 1.model_zoo.md

center text

---------

Co-authored-by: Junhao Liang <43094337+darkliang@users.noreply.github.com>
2023-10-29 21:53:02 +08:00

269 lines
10 KiB
Python

import torch
from ..base_model import BaseModel
from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper, SeparateFCs, SeparateBNNecks
from torch.nn import functional as F
import numpy as np
from ..backbones.gcn import GCN
def L_Matrix(adj_npy, adj_size):
D =np.zeros((adj_size, adj_size))
for i in range(adj_size):
tmp = adj_npy[i,:]
count = np.sum(tmp==1)
if count>0:
number = count ** (-1/2)
D[i,i] = number
x = np.matmul(D,adj_npy)
L = np.matmul(x,D)
return L
def get_fine_adj_npy():
fine_adj_list = [
# 1 2 3 4 5 6 7 8 9 10 11
[ 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1], #1
[ 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0], #2
[ 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1], #3
[ 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1], #4
[ 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0], #5
[ 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0], #6
[ 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1], #7
[ 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1], #8
[ 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0], #9
[ 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0], #10
[ 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1] #11
]
fine_adj_npy = np.array(fine_adj_list)
fine_adj_npy = L_Matrix(fine_adj_npy, len(fine_adj_npy)) # len返回的是行数
return fine_adj_npy
def get_coarse_adj_npy():
coarse_adj_list = [
# 1 2 3 4 5
[ 1, 1, 1, 1, 1], #1
[ 1, 1, 0, 0, 0], #2
[ 1, 0, 1, 0, 0], #3
[ 1, 0, 0, 1, 0], #4
[ 1, 0, 0, 0, 1] #5
]
coarse_adj_npy = np.array(coarse_adj_list)
coarse_adj_npy = L_Matrix(coarse_adj_npy, len(coarse_adj_npy)) # len返回的是行数
return coarse_adj_npy
class ParsingGait(BaseModel):
def build_network(self, model_cfg):
self.Backbone = self.get_backbone(model_cfg['backbone_cfg'])
self.Backbone = SetBlockWrapper(self.Backbone)
self.FCs = SeparateFCs(**model_cfg['SeparateFCs'])
self.BNNecks = SeparateBNNecks(**model_cfg['SeparateBNNecks'])
self.TP = PackSequenceWrapper(torch.max)
self.HPP = HorizontalPoolingPyramid(bin_num=model_cfg['bin_num'])
nfeat = model_cfg['SeparateFCs']['in_channels']
gcn_cfg = model_cfg['gcn_cfg']
self.fine_parts = gcn_cfg['fine_parts']
coarse_parts = gcn_cfg['coarse_parts']
self.only_fine_graph = gcn_cfg['only_fine_graph']
self.only_coarse_graph = gcn_cfg['only_coarse_graph']
self.combine_fine_coarse_graph = gcn_cfg['combine_fine_coarse_graph']
if self.only_fine_graph:
fine_adj_npy = get_fine_adj_npy()
self.fine_adj_npy = torch.from_numpy(fine_adj_npy).float()
self.gcn_fine = GCN(self.fine_parts, nfeat, nfeat, isMeanPooling=True)
self.gammas_fine = torch.nn.Parameter(torch.ones(self.fine_parts) * 0.75)
elif self.only_coarse_graph:
coarse_adj_npy = get_coarse_adj_npy()
self.coarse_adj_npy = torch.from_numpy(coarse_adj_npy).float()
self.gcn_coarse = GCN(coarse_parts, nfeat, nfeat, isMeanPooling=True)
self.gammas_coarse = torch.nn.Parameter(torch.ones(coarse_parts) * 0.75)
elif self.combine_fine_coarse_graph:
fine_adj_npy = get_fine_adj_npy()
self.fine_adj_npy = torch.from_numpy(fine_adj_npy).float()
self.gcn_fine = GCN(self.fine_parts, nfeat, nfeat, isMeanPooling=True)
self.gammas_fine = torch.nn.Parameter(torch.ones(self.fine_parts) * 0.75)
coarse_adj_npy = get_coarse_adj_npy()
self.coarse_adj_npy = torch.from_numpy(coarse_adj_npy).float()
self.gcn_coarse = GCN(coarse_parts, nfeat, nfeat, isMeanPooling=True)
self.gammas_coarse = torch.nn.Parameter(torch.ones(coarse_parts) * 0.75)
else:
raise ValueError("You should choose fine/coarse graph, or combine both of them.")
def PPforGCN(self, x):
"""
Part Pooling for GCN
x : [n, p, c, h, w]
ret : [n, p, c]
"""
n, p, c, h, w = x.size()
z = x.view(n, p, c, -1) # [n, p, c, h*w]
z = z.mean(-1) + z.max(-1)[0] # [n, p, c]
return z
def ParsPartforFineGraph(self, mask_resize, z):
"""
x: [n, c, s, h, w]
paes: [n, 1, s, H, W]
return [n*s, 11, c, h, w]
***Fine Parts:
# 0: Background,
1: Head,
2: Torso,
3: Left-arm,
4: Right-arm,
5: Left-hand,
6: Right-hand,
7: Left-leg,
8: Right-leg,
9: Left-foot,
10: Right-foot,
11: Dress
"""
fine_mask_list = list()
for i in range(1, self.fine_parts + 1):
fine_mask_list.append((mask_resize.long() == i)) # split mask of each class
fine_z_list = list()
for i in range(len(fine_mask_list)):
mask = fine_mask_list[i].unsqueeze(1)
fine_z_list.append((mask.float() * z * self.gammas_fine[i] + (~mask).float() * z * (1.0 - self.gammas_fine[i])).unsqueeze(1)) # split feature map by mask of each class
fine_z_feat = torch.cat(fine_z_list, dim=1) # [n*s, 11, c, h, w] or [n*s, 5, c, h, w]
return fine_z_feat
def ParsPartforCoarseGraph(self, mask_resize, z):
"""
x: [n, c, s, h, w]
paes: [n, 1, s, H, W]
return [n*s, 5, c, h, w]
***Coarse Parts:
1: [1, 2, 11] Head, Torso, Dress
2: [3, 5] Left-arm, Left-hand
3: [4, 6] Right-arm, Right-hand
4: [7, 9] Left-leg, Left-foot
5: [8, 10] Right-leg, Right-foot
"""
coarse_mask_list = list()
coarse_parts = [[1,2,11], [3,5], [4,6], [7,9], [8,10]]
for coarse_part in coarse_parts:
part = mask_resize.long() == -1
for i in coarse_part:
part += (mask_resize.long() == i)
coarse_mask_list.append(part)
coarse_z_list = list()
for i in range(len(coarse_mask_list)):
mask = coarse_mask_list[i].unsqueeze(1)
coarse_z_list.append((mask.float() * z * self.gammas_coarse[i] + (~mask).float() * z * (1.0 - self.gammas_coarse[i])).unsqueeze(1)) # split feature map by mask of each class
coarse_z_feat = torch.cat(coarse_z_list, dim=1) # [n*s, 11, c, h, w] or [n*s, 5, c, h, w]
return coarse_z_feat
def ParsPartforGCN(self, x, pars):
"""
x: [n, c, s, h, w]
paes: [n, 1, s, H, W]
return [n*s, 11, c, h, w] or [n*s, 5, c, h, w]
"""
n, c, s, h, w = x.size()
# mask_resize: [n, s, h, w]
mask_resize = F.interpolate(input=pars.squeeze(1), size=(h, w), mode='nearest')
mask_resize = mask_resize.view(n*s, h, w)
z = x.transpose(1, 2).reshape(n*s, c, h, w)
if self.only_fine_graph:
fine_z_feat = self.ParsPartforFineGraph(mask_resize, z)
return fine_z_feat, None
elif self.only_coarse_graph:
coarse_z_feat = self.ParsPartforCoarseGraph(mask_resize, z)
return None, coarse_z_feat
elif self.combine_fine_coarse_graph:
fine_z_feat = self.ParsPartforFineGraph(mask_resize, z)
coarse_z_feat = self.ParsPartforCoarseGraph(mask_resize, z)
return fine_z_feat, coarse_z_feat
else:
raise ValueError("You should choose fine/coarse graph, or combine both of them.")
def get_gcn_feat(self, n, input, adj_np, is_cuda, seqL):
input_ps = self.PPforGCN(input) # [n*s, 11, c]
n_s, p, c = input_ps.size()
if is_cuda:
adj = adj_np.cuda()
adj = adj.repeat(n_s, 1, 1)
if p == 11:
output_ps = self.gcn_fine(input_ps, adj) # [n*s, 11, c]
elif p == 5:
output_ps = self.gcn_coarse(input_ps, adj) # [n*s, 5, c]
else:
raise ValueError(f"The parsing parts should be 11 or 5, but got {p}")
output_ps = output_ps.view(n, n_s//n, p, c) # [n, s, ps, c]
output_ps = self.TP(output_ps, seqL, dim=1, options={"dim": 1})[0] # [n, ps, c]
return output_ps
def forward(self, inputs):
ipts, labs, _, _, seqL = inputs
pars = ipts[0]
if len(pars.size()) == 4:
pars = pars.unsqueeze(1)
del ipts
outs = self.Backbone(pars) # [n, c, s, h, w]
outs_n, outs_c, outs_s, outs_h, outs_w = outs.size()
# split features by parsing classes
# outs_ps_fine: [n*s, 11, c, h, w]
# outs_ps_coarse: [n*s, 5, c, h, w]
outs_ps_fine, outs_ps_coarse = self.ParsPartforGCN(outs, pars)
is_cuda = pars.is_cuda
if self.only_fine_graph:
outs_ps = self.get_gcn_feat(outs_n, outs_ps_fine, self.fine_adj_npy, is_cuda, seqL) # [n, 11, c]
elif self.only_coarse_graph:
outs_ps = self.get_gcn_feat(outs_n, outs_ps_coarse, self.coarse_adj_npy, is_cuda, seqL) # [n, 5, c]
elif self.combine_fine_coarse_graph:
outs_fine = self.get_gcn_feat(outs_n, outs_ps_fine, self.fine_adj_npy, is_cuda, seqL) # [n, 11, c]
outs_coarse = self.get_gcn_feat(outs_n, outs_ps_coarse, self.coarse_adj_npy, is_cuda, seqL) # [n, 5, c]
outs_ps = torch.cat([outs_fine, outs_coarse], 1) # [n, 16, c]
else:
raise ValueError("You should choose fine/coarse graph, or combine both of them.")
outs_ps = outs_ps.transpose(1, 2).contiguous() # [n, c, ps]
# Temporal Pooling, TP
outs = self.TP(outs, seqL, options={"dim": 2})[0] # [n, c, h, w]
# Horizontal Pooling Matching, HPM
feat = self.HPP(outs) # [n, c, p]
feat = torch.cat([feat, outs_ps], dim=-1) # [n, c, p+ps]
embed_1 = self.FCs(feat) # [n, c, p+ps]
embed_2, logits = self.BNNecks(embed_1) # [n, c, p+ps]
embed = embed_1
n, _, s, h, w = pars.size()
retval = {
'training_feat': {
'triplet': {'embeddings': embed_1, 'labels': labs},
'softmax': {'logits': logits, 'labels': labs}
},
'visual_summary': {
'image/pars': pars.view(n*s, 1, h, w)
},
'inference_feat': {
'embeddings': embed
}
}
return retval