GaitSSB@Pretrain release

This commit is contained in:
jdyjjj
2023-11-20 20:28:09 +08:00
parent 476c4adbe3
commit b24e797486
6 changed files with 506 additions and 4 deletions
+142
View File
@@ -0,0 +1,142 @@
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from ..base_model import BaseModel
from ..modules import PackSequenceWrapper, HorizontalPoolingPyramid, SetBlockWrapper, ParallelBN1d, SeparateFCs
from utils import np2var, list2var, get_valid_args, ddp_all_gather
from data.transform import get_transform
from einops import rearrange
# Modified from https://github.com/PatrickHua/SimSiam/blob/main/models/simsiam.py
class GaitSSB_Pretrain(BaseModel):
def __init__(self, cfgs, training=True):
super(GaitSSB_Pretrain, self).__init__(cfgs, training=training)
def build_network(self, model_cfg):
self.p = model_cfg['parts_num']
self.Backbone = self.get_backbone(model_cfg['backbone_cfg'])
self.Backbone = SetBlockWrapper(self.Backbone)
self.TP = PackSequenceWrapper(torch.max)
self.HPP = HorizontalPoolingPyramid([16, 8, 4, 2, 1])
out_channels = model_cfg['backbone_cfg']['channels'][-1]
hidden_dim = out_channels
self.projector = nn.Sequential(SeparateFCs(self.p, out_channels, hidden_dim),
ParallelBN1d(self.p, hidden_dim),
nn.ReLU(inplace=True),
SeparateFCs(self.p, hidden_dim, out_channels),
ParallelBN1d(self.p, out_channels))
self.predictor = nn.Sequential(SeparateFCs(self.p, out_channels, hidden_dim),
ParallelBN1d(self.p, hidden_dim),
nn.ReLU(inplace=True),
SeparateFCs(self.p, hidden_dim, out_channels))
def inputs_pretreament(self, inputs):
if self.training:
seqs_batch, labs_batch, typs_batch, vies_batch, seqL_batch = inputs
trf_cfgs = self.engine_cfg['transform']
seq_trfs = get_transform(trf_cfgs)
requires_grad = True if self.training else False
batch_size = int(len(seqs_batch[0]) / 2)
img_q = [np2var(np.asarray([trf(fra) for fra in seq[:batch_size]]), requires_grad=requires_grad).float() for trf, seq in zip(seq_trfs, seqs_batch)]
img_k = [np2var(np.asarray([trf(fra) for fra in seq[batch_size:]]), requires_grad=requires_grad).float() for trf, seq in zip(seq_trfs, seqs_batch)]
seqs = [img_q, img_k]
typs = typs_batch
vies = vies_batch
if self.training:
labs = list2var(labs_batch).long()
else:
labs = None
if seqL_batch is not None:
seqL_batch = np2var(seqL_batch).int()
seqL = seqL_batch
ipts = seqs
del seqs
return ipts, labs, typs, vies, (seqL, seqL)
else:
return super().inputs_pretreament(inputs)
def encoder(self, inputs):
sils, seqL = inputs
assert sils.size(-1) in [44, 88]
outs = self.Backbone(sils) # [n, c, s, h, w]
outs = self.TP(outs, seqL, options={"dim": 2})[0] # [n, c, h, w]
feat = self.HPP(outs) # [n, c, p], Horizontal Pooling, HP
return feat
def forward(self, inputs):
'''
Input:
sils_q: a batch of query images, [n, s, h, w]
sils_k: a batch of key images, [n, s, h, w]
Output:
logits, targets
'''
if self.training:
(sils_q, sils_k), labs, typs, vies, (seqL_q, seqL_k) = inputs
sils_q, sils_k = sils_q[0].unsqueeze(1), sils_k[0].unsqueeze(1)
q_input = (sils_q, seqL_q)
q_feat = self.encoder(q_input) # [n, c, p]
z1 = self.projector(q_feat)
p1 = self.predictor(z1)
k_input = (sils_k, seqL_k)
k_feat = self.encoder(k_input) # [n, c, p]
z2 = self.projector(k_feat)
p2 = self.predictor(z2)
logits1, labels1 = self.D(p1, z2)
logits2, labels2 = self.D(p2, z1)
retval = {
'training_feat': {'softmax1': {'logits': logits1, 'labels': labels1},
'softmax2': {'logits': logits2, 'labels': labels2}
},
'visual_summary': {'image/encoder_q': rearrange(sils_q, 'n c s h w -> (n s) c h w'),
'image/encoder_k': rearrange(sils_k, 'n c s h w -> (n s) c h w'),
},
'inference_feat': None
}
return retval
else:
sils, labs, typs, vies, seqL = inputs
sils = sils[0].unsqueeze(1)
feat = self.encoder((sils, seqL)) # [n, c, p]
feat = self.projector(feat) # [n, c, p]
feat = self.predictor(feat) # [n, c, p]
retval = {
'training_feat': None,
'visual_summary': None,
'inference_feat': {'embeddings': F.normalize(feat, dim=1)}
}
return retval
def D(self, p, z): # negative cosine similarity
"""
p: [n, c, p]
z: [n, c, p]
"""
z = z.detach() # stop gradient
n = p.size(0)
p = F.normalize(p, dim=1) # l2-normalize, [n, c, p]
z = F.normalize(z, dim=1) # l2-normalize, [n, c, p]
z = ddp_all_gather(z, dim=0, requires_grad=False) # [m, c, p], m = n * the number of GPUs
logits = torch.einsum('ncp, mcp->nmp', [p, z]) # [n, m, p]
rank = torch.distributed.get_rank()
labels = torch.arange(rank*n, (rank+1)*n, dtype=torch.long).cuda()
return logits, labels