GaitSSB@Pretrain release
This commit is contained in:
@@ -0,0 +1,93 @@
|
|||||||
|
data_cfg:
|
||||||
|
dataset_name: GaitLU-1M
|
||||||
|
dataset_root: /your/path/to/GaitLU-1M
|
||||||
|
dataset_partition: ./datasets/GaitLU-1M/GaitLU-1M.json
|
||||||
|
num_workers: 1
|
||||||
|
remove_no_gallery: false # Remove probe if no gallery for it
|
||||||
|
|
||||||
|
evaluator_cfg:
|
||||||
|
enable_float16: true
|
||||||
|
restore_ckpt_strict: true
|
||||||
|
restore_hint: 150000
|
||||||
|
save_name: GaitSSB_Pretrain
|
||||||
|
sampler:
|
||||||
|
batch_shuffle: false
|
||||||
|
batch_size: 16
|
||||||
|
sample_type: all_ordered # all indicates whole sequence used to test, while ordered means input sequence by its natural order; Other options: fixed_unordered
|
||||||
|
frames_all_limit: 720 # limit the number of sampled frames to prevent out of memory
|
||||||
|
metric: euc # cos
|
||||||
|
transform:
|
||||||
|
- type: BaseSilTransform
|
||||||
|
|
||||||
|
loss_cfg:
|
||||||
|
- loss_term_weight: 1.0
|
||||||
|
scale: 16
|
||||||
|
type: CrossEntropyLoss
|
||||||
|
log_prefix: softmax1
|
||||||
|
log_accuracy: true
|
||||||
|
- loss_term_weight: 1.0
|
||||||
|
scale: 16
|
||||||
|
type: CrossEntropyLoss
|
||||||
|
log_prefix: softmax2
|
||||||
|
log_accuracy: true
|
||||||
|
|
||||||
|
model_cfg:
|
||||||
|
model: GaitSSB_Pretrain
|
||||||
|
backbone_cfg:
|
||||||
|
type: ResNet9
|
||||||
|
block: BasicBlock
|
||||||
|
channels: # Layers configuration for automatically model construction
|
||||||
|
- 64
|
||||||
|
- 128
|
||||||
|
- 256
|
||||||
|
- 512
|
||||||
|
layers:
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
strides:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 2
|
||||||
|
- 1
|
||||||
|
maxpool: false
|
||||||
|
parts_num: 31
|
||||||
|
|
||||||
|
optimizer_cfg:
|
||||||
|
lr: 0.05
|
||||||
|
momentum: 0.9
|
||||||
|
solver: SGD
|
||||||
|
weight_decay: 0.0005
|
||||||
|
# weight_decay: 0.
|
||||||
|
|
||||||
|
scheduler_cfg:
|
||||||
|
gamma: 0.1
|
||||||
|
milestones: # Learning Rate Reduction at each milestones
|
||||||
|
- 80000
|
||||||
|
- 120000
|
||||||
|
scheduler: MultiStepLR
|
||||||
|
|
||||||
|
trainer_cfg:
|
||||||
|
enable_float16: true # half_percesion float for memory reduction and speedup
|
||||||
|
fix_BN: false
|
||||||
|
with_test: false
|
||||||
|
log_iter: 100
|
||||||
|
restore_ckpt_strict: true
|
||||||
|
restore_hint: 0
|
||||||
|
save_iter: 10000
|
||||||
|
save_name: GaitSSB_Pretrain
|
||||||
|
sync_BN: true
|
||||||
|
total_iter: 150000
|
||||||
|
sampler:
|
||||||
|
batch_shuffle: true
|
||||||
|
batch_size:
|
||||||
|
- 8 # TripletSampler, batch_size[0] indicates Number of Identity
|
||||||
|
- 64 # batch_size[1] indicates Samples sequqnce for each Identity
|
||||||
|
frames_num_fixed: 16 # fixed frames number for training
|
||||||
|
sample_type: fixed_ordered # fixed control input frames number, unordered for controlling order of input tensor; Other options: unfixed_ordered or all_ordered
|
||||||
|
frames_skip_num: 4
|
||||||
|
type: BilateralSampler
|
||||||
|
transform:
|
||||||
|
- type: DA4GaitSSB
|
||||||
|
cutting: 10
|
||||||
@@ -0,0 +1,105 @@
|
|||||||
|
{
|
||||||
|
"TRAIN_SET": [
|
||||||
|
"000",
|
||||||
|
"001",
|
||||||
|
"002",
|
||||||
|
"003",
|
||||||
|
"004",
|
||||||
|
"005",
|
||||||
|
"006",
|
||||||
|
"007",
|
||||||
|
"008",
|
||||||
|
"009",
|
||||||
|
"010",
|
||||||
|
"011",
|
||||||
|
"012",
|
||||||
|
"013",
|
||||||
|
"014",
|
||||||
|
"015",
|
||||||
|
"016",
|
||||||
|
"017",
|
||||||
|
"018",
|
||||||
|
"019",
|
||||||
|
"020",
|
||||||
|
"021",
|
||||||
|
"022",
|
||||||
|
"023",
|
||||||
|
"024",
|
||||||
|
"025",
|
||||||
|
"026",
|
||||||
|
"027",
|
||||||
|
"028",
|
||||||
|
"029",
|
||||||
|
"030",
|
||||||
|
"031",
|
||||||
|
"032",
|
||||||
|
"033",
|
||||||
|
"034",
|
||||||
|
"035",
|
||||||
|
"036",
|
||||||
|
"037",
|
||||||
|
"038",
|
||||||
|
"039",
|
||||||
|
"040",
|
||||||
|
"041",
|
||||||
|
"042",
|
||||||
|
"043",
|
||||||
|
"044",
|
||||||
|
"045",
|
||||||
|
"046",
|
||||||
|
"047",
|
||||||
|
"048",
|
||||||
|
"049",
|
||||||
|
"050",
|
||||||
|
"051",
|
||||||
|
"052",
|
||||||
|
"053",
|
||||||
|
"054",
|
||||||
|
"055",
|
||||||
|
"056",
|
||||||
|
"057",
|
||||||
|
"058",
|
||||||
|
"059",
|
||||||
|
"060",
|
||||||
|
"061",
|
||||||
|
"062",
|
||||||
|
"063",
|
||||||
|
"064",
|
||||||
|
"065",
|
||||||
|
"066",
|
||||||
|
"067",
|
||||||
|
"068",
|
||||||
|
"069",
|
||||||
|
"070",
|
||||||
|
"071",
|
||||||
|
"072",
|
||||||
|
"073",
|
||||||
|
"074",
|
||||||
|
"075",
|
||||||
|
"076",
|
||||||
|
"077",
|
||||||
|
"078",
|
||||||
|
"079",
|
||||||
|
"080",
|
||||||
|
"081",
|
||||||
|
"082",
|
||||||
|
"083",
|
||||||
|
"084",
|
||||||
|
"085",
|
||||||
|
"086",
|
||||||
|
"087",
|
||||||
|
"088",
|
||||||
|
"089",
|
||||||
|
"090",
|
||||||
|
"091",
|
||||||
|
"092",
|
||||||
|
"093",
|
||||||
|
"094",
|
||||||
|
"095",
|
||||||
|
"096",
|
||||||
|
"097",
|
||||||
|
"098",
|
||||||
|
"099"
|
||||||
|
],
|
||||||
|
"TEST_SET": []
|
||||||
|
}
|
||||||
@@ -134,3 +134,41 @@ class CommonSampler(tordata.sampler.Sampler):
|
|||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.dataset)
|
return len(self.dataset)
|
||||||
|
|
||||||
|
# **************** For GaitSSB ****************
|
||||||
|
# Fan, et al: Learning Gait Representation from Massive Unlabelled Walking Videos: A Benchmark, T-PAMI2023
|
||||||
|
import random
|
||||||
|
class BilateralSampler(tordata.sampler.Sampler):
|
||||||
|
def __init__(self, dataset, batch_size, batch_shuffle=False):
|
||||||
|
self.dataset = dataset
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.batch_shuffle = batch_shuffle
|
||||||
|
|
||||||
|
self.world_size = dist.get_world_size()
|
||||||
|
self.rank = dist.get_rank()
|
||||||
|
|
||||||
|
self.dataset_length = len(self.dataset)
|
||||||
|
self.total_indices = list(range(self.dataset_length))
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
random.shuffle(self.total_indices)
|
||||||
|
count = 0
|
||||||
|
batch_size = self.batch_size[0] * self.batch_size[1]
|
||||||
|
while True:
|
||||||
|
if (count + 1) * batch_size >= self.dataset_length:
|
||||||
|
count = 0
|
||||||
|
random.shuffle(self.total_indices)
|
||||||
|
|
||||||
|
sampled_indices = self.total_indices[count*batch_size:(count+1)*batch_size]
|
||||||
|
sampled_indices = sync_random_sample_list(sampled_indices, len(sampled_indices))
|
||||||
|
|
||||||
|
total_size = int(math.ceil(batch_size / self.world_size)) * self.world_size
|
||||||
|
sampled_indices += sampled_indices[:(batch_size - len(sampled_indices))]
|
||||||
|
|
||||||
|
sampled_indices = sampled_indices[self.rank:total_size:self.world_size]
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
yield sampled_indices * 2
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.dataset)
|
||||||
+112
-3
@@ -35,7 +35,8 @@ class BaseParsingCuttingTransform():
|
|||||||
cutting = self.cutting
|
cutting = self.cutting
|
||||||
else:
|
else:
|
||||||
cutting = int(x.shape[-1] // 64) * 10
|
cutting = int(x.shape[-1] // 64) * 10
|
||||||
x = x[..., cutting:-cutting]
|
if cutting != 0:
|
||||||
|
x = x[..., cutting:-cutting]
|
||||||
if x.max() == 255 or x.max() == 255.:
|
if x.max() == 255 or x.max() == 255.:
|
||||||
return x / self.divsor
|
return x / self.divsor
|
||||||
else:
|
else:
|
||||||
@@ -52,7 +53,8 @@ class BaseSilCuttingTransform():
|
|||||||
cutting = self.cutting
|
cutting = self.cutting
|
||||||
else:
|
else:
|
||||||
cutting = int(x.shape[-1] // 64) * 10
|
cutting = int(x.shape[-1] // 64) * 10
|
||||||
x = x[..., cutting:-cutting]
|
if cutting != 0:
|
||||||
|
x = x[..., cutting:-cutting]
|
||||||
return x / self.divsor
|
return x / self.divsor
|
||||||
|
|
||||||
|
|
||||||
@@ -214,8 +216,115 @@ def get_transform(trf_cfg=None):
|
|||||||
return transform
|
return transform
|
||||||
raise "Error type for -Transform-Cfg-"
|
raise "Error type for -Transform-Cfg-"
|
||||||
|
|
||||||
|
# **************** For GaitSSB ****************
|
||||||
|
# Fan, et al: Learning Gait Representation from Massive Unlabelled Walking Videos: A Benchmark, T-PAMI2023
|
||||||
|
|
||||||
# **************** For pose ****************
|
class RandomPartDilate():
|
||||||
|
def __init__(self, prob=0.5, top_range=(12, 16), bot_range=(36, 40)):
|
||||||
|
self.prob = prob
|
||||||
|
self.top_range = top_range
|
||||||
|
self.bot_range = bot_range
|
||||||
|
self.modes_and_kernels = {
|
||||||
|
'RECT': [[5, 3], [5, 5], [3, 5]],
|
||||||
|
'CROSS': [[3, 3], [3, 5], [5, 3]],
|
||||||
|
'ELLIPSE': [[3, 3], [3, 5], [5, 3]]}
|
||||||
|
self.modes = list(self.modes_and_kernels.keys())
|
||||||
|
|
||||||
|
def __call__(self, seq):
|
||||||
|
'''
|
||||||
|
Using the image dialte and affine transformation to simulate the clorhing change cases.
|
||||||
|
Input:
|
||||||
|
seq: a sequence of silhouette frames, [s, h, w]
|
||||||
|
Output:
|
||||||
|
seq: a sequence of agumented frames, [s, h, w]
|
||||||
|
'''
|
||||||
|
if random.uniform(0, 1) >= self.prob:
|
||||||
|
return seq
|
||||||
|
else:
|
||||||
|
mode = random.choice(self.modes)
|
||||||
|
kernel_size = random.choice(self.modes_and_kernels[mode])
|
||||||
|
top = random.randint(self.top_range[0], self.top_range[1])
|
||||||
|
bot = random.randint(self.bot_range[0], self.bot_range[1])
|
||||||
|
|
||||||
|
seq = seq.transpose(1, 2, 0) # [s, h, w] -> [h, w, s]
|
||||||
|
_seq_ = seq.copy()
|
||||||
|
_seq_ = _seq_[top:bot, ...]
|
||||||
|
_seq_ = self.dilate(_seq_, kernel_size=kernel_size, mode=mode)
|
||||||
|
seq[top:bot, ...] = _seq_
|
||||||
|
seq = seq.transpose(2, 0, 1) # [h, w, s] -> [s, h, w]
|
||||||
|
return seq
|
||||||
|
|
||||||
|
def dilate(self, img, kernel_size=[3, 3], mode='RECT'):
|
||||||
|
'''
|
||||||
|
MORPH_RECT, MORPH_CROSS, ELLIPSE
|
||||||
|
Input:
|
||||||
|
img: [h, w]
|
||||||
|
Output:
|
||||||
|
img: [h, w]
|
||||||
|
'''
|
||||||
|
assert mode in ['RECT', 'CROSS', 'ELLIPSE']
|
||||||
|
kernel = cv2.getStructuringElement(getattr(cv2, 'MORPH_'+mode), kernel_size)
|
||||||
|
dst = cv2.dilate(img, kernel)
|
||||||
|
return dst
|
||||||
|
|
||||||
|
class RandomPartBlur():
|
||||||
|
def __init__(self, prob=0.5, top_range=(9, 20), bot_range=(29, 40), per_frame=False):
|
||||||
|
self.prob = prob
|
||||||
|
self.top_range = top_range
|
||||||
|
self.bot_range = bot_range
|
||||||
|
self.per_frame = per_frame
|
||||||
|
|
||||||
|
def __call__(self, seq):
|
||||||
|
'''
|
||||||
|
Input:
|
||||||
|
seq: a sequence of silhouette frames, [s, h, w]
|
||||||
|
Output:
|
||||||
|
seq: a sequence of agumented frames, [s, h, w]
|
||||||
|
'''
|
||||||
|
if not self.per_frame:
|
||||||
|
if random.uniform(0, 1) >= self.prob:
|
||||||
|
return seq
|
||||||
|
else:
|
||||||
|
top = random.randint(self.top_range[0], self.top_range[1])
|
||||||
|
bot = random.randint(self.bot_range[0], self.bot_range[1])
|
||||||
|
|
||||||
|
seq = seq.transpose(1, 2, 0) # [s, h, w] -> [h, w, s]
|
||||||
|
_seq_ = seq.copy()
|
||||||
|
_seq_ = _seq_[top:bot, ...]
|
||||||
|
_seq_ = cv2.GaussianBlur(_seq_, ksize=(3, 3), sigmaX=0)
|
||||||
|
_seq_ = (_seq_ > 0.2).astype(np.float)
|
||||||
|
seq[top:bot, ...] = _seq_
|
||||||
|
seq = seq.transpose(2, 0, 1) # [h, w, s] -> [s, h, w]
|
||||||
|
|
||||||
|
return seq
|
||||||
|
else:
|
||||||
|
self.per_frame = False
|
||||||
|
frame_num = seq.shape[0]
|
||||||
|
ret = [self.__call__(seq[k][np.newaxis, ...]) for k in range(frame_num)]
|
||||||
|
self.per_frame = True
|
||||||
|
return np.concatenate(ret, 0)
|
||||||
|
|
||||||
|
def DA4GaitSSB(
|
||||||
|
cutting = None,
|
||||||
|
ra_prob = 0.2,
|
||||||
|
rp_prob = 0.2,
|
||||||
|
rhf_prob = 0.5,
|
||||||
|
rpd_prob = 0.2,
|
||||||
|
rpb_prob = 0.2,
|
||||||
|
top_range = (9, 20),
|
||||||
|
bot_range = (39, 50),
|
||||||
|
):
|
||||||
|
transform = T.Compose([
|
||||||
|
RandomAffine(prob=ra_prob),
|
||||||
|
RandomPerspective(prob=rp_prob),
|
||||||
|
BaseSilCuttingTransform(cutting=cutting),
|
||||||
|
RandomHorizontalFlip(prob=rhf_prob),
|
||||||
|
RandomPartDilate(prob=rpd_prob, top_range=top_range, bot_range=bot_range),
|
||||||
|
RandomPartBlur(prob=rpb_prob, top_range=top_range, bot_range=bot_range),
|
||||||
|
])
|
||||||
|
return transform
|
||||||
|
|
||||||
|
# **************** For pose-based methods ****************
|
||||||
class RandomSelectSequence(object):
|
class RandomSelectSequence(object):
|
||||||
"""
|
"""
|
||||||
Randomly select different subsequences
|
Randomly select different subsequences
|
||||||
|
|||||||
@@ -0,0 +1,142 @@
|
|||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from ..base_model import BaseModel
|
||||||
|
from ..modules import PackSequenceWrapper, HorizontalPoolingPyramid, SetBlockWrapper, ParallelBN1d, SeparateFCs
|
||||||
|
|
||||||
|
from utils import np2var, list2var, get_valid_args, ddp_all_gather
|
||||||
|
from data.transform import get_transform
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
# Modified from https://github.com/PatrickHua/SimSiam/blob/main/models/simsiam.py
|
||||||
|
class GaitSSB_Pretrain(BaseModel):
|
||||||
|
def __init__(self, cfgs, training=True):
|
||||||
|
super(GaitSSB_Pretrain, self).__init__(cfgs, training=training)
|
||||||
|
|
||||||
|
def build_network(self, model_cfg):
|
||||||
|
self.p = model_cfg['parts_num']
|
||||||
|
self.Backbone = self.get_backbone(model_cfg['backbone_cfg'])
|
||||||
|
self.Backbone = SetBlockWrapper(self.Backbone)
|
||||||
|
|
||||||
|
self.TP = PackSequenceWrapper(torch.max)
|
||||||
|
self.HPP = HorizontalPoolingPyramid([16, 8, 4, 2, 1])
|
||||||
|
|
||||||
|
out_channels = model_cfg['backbone_cfg']['channels'][-1]
|
||||||
|
hidden_dim = out_channels
|
||||||
|
self.projector = nn.Sequential(SeparateFCs(self.p, out_channels, hidden_dim),
|
||||||
|
ParallelBN1d(self.p, hidden_dim),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
SeparateFCs(self.p, hidden_dim, out_channels),
|
||||||
|
ParallelBN1d(self.p, out_channels))
|
||||||
|
self.predictor = nn.Sequential(SeparateFCs(self.p, out_channels, hidden_dim),
|
||||||
|
ParallelBN1d(self.p, hidden_dim),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
SeparateFCs(self.p, hidden_dim, out_channels))
|
||||||
|
|
||||||
|
def inputs_pretreament(self, inputs):
|
||||||
|
if self.training:
|
||||||
|
seqs_batch, labs_batch, typs_batch, vies_batch, seqL_batch = inputs
|
||||||
|
trf_cfgs = self.engine_cfg['transform']
|
||||||
|
seq_trfs = get_transform(trf_cfgs)
|
||||||
|
|
||||||
|
requires_grad = True if self.training else False
|
||||||
|
batch_size = int(len(seqs_batch[0]) / 2)
|
||||||
|
img_q = [np2var(np.asarray([trf(fra) for fra in seq[:batch_size]]), requires_grad=requires_grad).float() for trf, seq in zip(seq_trfs, seqs_batch)]
|
||||||
|
img_k = [np2var(np.asarray([trf(fra) for fra in seq[batch_size:]]), requires_grad=requires_grad).float() for trf, seq in zip(seq_trfs, seqs_batch)]
|
||||||
|
seqs = [img_q, img_k]
|
||||||
|
|
||||||
|
typs = typs_batch
|
||||||
|
vies = vies_batch
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
labs = list2var(labs_batch).long()
|
||||||
|
else:
|
||||||
|
labs = None
|
||||||
|
|
||||||
|
if seqL_batch is not None:
|
||||||
|
seqL_batch = np2var(seqL_batch).int()
|
||||||
|
seqL = seqL_batch
|
||||||
|
|
||||||
|
ipts = seqs
|
||||||
|
del seqs
|
||||||
|
|
||||||
|
return ipts, labs, typs, vies, (seqL, seqL)
|
||||||
|
else:
|
||||||
|
return super().inputs_pretreament(inputs)
|
||||||
|
|
||||||
|
def encoder(self, inputs):
|
||||||
|
sils, seqL = inputs
|
||||||
|
assert sils.size(-1) in [44, 88]
|
||||||
|
outs = self.Backbone(sils) # [n, c, s, h, w]
|
||||||
|
outs = self.TP(outs, seqL, options={"dim": 2})[0] # [n, c, h, w]
|
||||||
|
feat = self.HPP(outs) # [n, c, p], Horizontal Pooling, HP
|
||||||
|
return feat
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
'''
|
||||||
|
Input:
|
||||||
|
sils_q: a batch of query images, [n, s, h, w]
|
||||||
|
sils_k: a batch of key images, [n, s, h, w]
|
||||||
|
Output:
|
||||||
|
logits, targets
|
||||||
|
'''
|
||||||
|
if self.training:
|
||||||
|
(sils_q, sils_k), labs, typs, vies, (seqL_q, seqL_k) = inputs
|
||||||
|
|
||||||
|
sils_q, sils_k = sils_q[0].unsqueeze(1), sils_k[0].unsqueeze(1)
|
||||||
|
|
||||||
|
q_input = (sils_q, seqL_q)
|
||||||
|
q_feat = self.encoder(q_input) # [n, c, p]
|
||||||
|
z1 = self.projector(q_feat)
|
||||||
|
p1 = self.predictor(z1)
|
||||||
|
|
||||||
|
k_input = (sils_k, seqL_k)
|
||||||
|
k_feat = self.encoder(k_input) # [n, c, p]
|
||||||
|
z2 = self.projector(k_feat)
|
||||||
|
p2 = self.predictor(z2)
|
||||||
|
|
||||||
|
logits1, labels1 = self.D(p1, z2)
|
||||||
|
logits2, labels2 = self.D(p2, z1)
|
||||||
|
|
||||||
|
retval = {
|
||||||
|
'training_feat': {'softmax1': {'logits': logits1, 'labels': labels1},
|
||||||
|
'softmax2': {'logits': logits2, 'labels': labels2}
|
||||||
|
},
|
||||||
|
'visual_summary': {'image/encoder_q': rearrange(sils_q, 'n c s h w -> (n s) c h w'),
|
||||||
|
'image/encoder_k': rearrange(sils_k, 'n c s h w -> (n s) c h w'),
|
||||||
|
},
|
||||||
|
'inference_feat': None
|
||||||
|
}
|
||||||
|
return retval
|
||||||
|
else:
|
||||||
|
sils, labs, typs, vies, seqL = inputs
|
||||||
|
sils = sils[0].unsqueeze(1)
|
||||||
|
feat = self.encoder((sils, seqL)) # [n, c, p]
|
||||||
|
feat = self.projector(feat) # [n, c, p]
|
||||||
|
feat = self.predictor(feat) # [n, c, p]
|
||||||
|
retval = {
|
||||||
|
'training_feat': None,
|
||||||
|
'visual_summary': None,
|
||||||
|
'inference_feat': {'embeddings': F.normalize(feat, dim=1)}
|
||||||
|
}
|
||||||
|
return retval
|
||||||
|
|
||||||
|
def D(self, p, z): # negative cosine similarity
|
||||||
|
"""
|
||||||
|
p: [n, c, p]
|
||||||
|
z: [n, c, p]
|
||||||
|
"""
|
||||||
|
z = z.detach() # stop gradient
|
||||||
|
n = p.size(0)
|
||||||
|
|
||||||
|
p = F.normalize(p, dim=1) # l2-normalize, [n, c, p]
|
||||||
|
z = F.normalize(z, dim=1) # l2-normalize, [n, c, p]
|
||||||
|
z = ddp_all_gather(z, dim=0, requires_grad=False) # [m, c, p], m = n * the number of GPUs
|
||||||
|
|
||||||
|
logits = torch.einsum('ncp, mcp->nmp', [p, z]) # [n, m, p]
|
||||||
|
rank = torch.distributed.get_rank()
|
||||||
|
labels = torch.arange(rank*n, (rank+1)*n, dtype=torch.long).cuda()
|
||||||
|
return logits, labels
|
||||||
@@ -691,3 +691,18 @@ class SpatialAttention(nn.Module):
|
|||||||
ret_shape = (batch, Nh * dv, T, V)
|
ret_shape = (batch, Nh * dv, T, V)
|
||||||
return torch.reshape(x, ret_shape)
|
return torch.reshape(x, ret_shape)
|
||||||
|
|
||||||
|
from einops import rearrange
|
||||||
|
class ParallelBN1d(nn.Module):
|
||||||
|
def __init__(self, parts_num, in_channels, **kwargs):
|
||||||
|
super(ParallelBN1d, self).__init__()
|
||||||
|
self.parts_num = parts_num
|
||||||
|
self.bn1d = nn.BatchNorm1d(in_channels * parts_num, **kwargs)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
'''
|
||||||
|
x: [n, c, p]
|
||||||
|
'''
|
||||||
|
x = rearrange(x, 'n c p -> n (c p)')
|
||||||
|
x = self.bn1d(x)
|
||||||
|
x = rearrange(x, 'n (c p) -> n c p', p=self.parts_num)
|
||||||
|
return x
|
||||||
Reference in New Issue
Block a user